Skip to content

Commit

Permalink
#13875: fix tilize and attn matmul on BH (#15459)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #13875

### Problem description
- there were issues with attn matmul ops on BH

### What's changed
- updated llk commit
- reverted z dim in tilize uninit
- moved untilize init earlier

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12050639818
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12050643335
- [x] Model regression CI testing passes (if applicable) fails in the
same way as main
https://github.com/tenstorrent/tt-metal/actions/runs/12050659167 vs
https://github.com/tenstorrent/tt-metal/actions/runs/12045311806
- [x] Device performance regression CI testing passes (if applicable)
Unrelated flakiness
https://github.com/tenstorrent/tt-metal/actions/runs/12050647494 waiting
for re-run. Re-run succeeded
- [x] New/Existing tests provide coverage for changes

Based on
main...nvelickovic/fix_attn_matmul
  • Loading branch information
bbradelTT authored Nov 27, 2024
1 parent 5785a3d commit 62e8824
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ttnn
from models.utility_functions import comp_pcc
from models.utility_functions import is_grayskull, skip_for_blackhole
from models.utility_functions import is_grayskull
import ttnn


Expand All @@ -30,7 +30,6 @@ def generate_input_shapes():
yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len]


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.parametrize("in0_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize("out_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -71,7 +70,6 @@ def test_attn_matmul(num_loops, enable_async, in0_dtype, in1_dtype, out_dtype, d
device.enable_async(False)


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.parametrize("in_dtype", [ttnn.float32, ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -117,7 +115,6 @@ def test_attn_matmul_fp32(num_loops, enable_async, in_dtype, device):
device.enable_async(False)


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.parametrize("in0_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize("out_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -156,7 +153,6 @@ def test_attn_matmul_with_program_cache(
device.enable_async(False)


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.parametrize(
"shard_orientation",
(ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR),
Expand Down Expand Up @@ -278,7 +274,6 @@ def test_group_attn_matmul(
device.enable_async(False)


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.parametrize("sharded", [False, True])
@pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -373,7 +368,6 @@ def test_group_attn_matmul_with_program_cache(
device.enable_async(False)


@skip_for_blackhole("Hanging on BH, see #12349")
@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.parametrize("in_dtype", [ttnn.float32, ttnn.bfloat16])
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ inline void llk_unpack_tilize_init(const std::uint32_t operand, const std::uint3
}

inline void llk_unpack_tilize_uninit(const std::uint32_t operand, const std::uint32_t face_r_dim = FACE_R_DIM) {
// Revert X dim value to default.
TT_SETADCXX(p_setadc::UNP_A, face_r_dim * FACE_C_DIM - 1, 0x0);
TT_SETADCXX(p_setadc::UNP_B, face_r_dim * FACE_C_DIM - 1, 0x0);

// Revert Z dim value back to default.
const uint Tile_z_dim = get_operand_num_faces(operand);
cfg_reg_rmw_tensix<THCON_SEC0_REG0_TileDescriptor_ADDR32+1, 16, 0xffff0000>(Tile_z_dim);

std::uint32_t operand_id = get_operand_id(operand);
unpack_config_u config = {0};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ void MAIN {
cb_wait_front(cb_in1, in1_block_num_tiles);
cb_pop_front(cb_in1, num_kv_heads_skip);

// This init changes DEST mapping, hence needs to be called before MATH does any processing, so that it has correct DEST mapping.
pack_untilize_dst_init_short<intermediate_num_tiles>(cb_intermed0);

for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { // TODO: Must be 1; need to review inner dim blocking and the untilizing
uint32_t in1_index_subblock_offset = 0;

Expand Down Expand Up @@ -133,14 +136,12 @@ void MAIN {
in1_index_subblock_offset += out_subblock_w;
} // in1_num_subblocks loop
cb_pop_front(cb_in1, num_kv_heads_remaining);

// TODO: Review inner dim blocking, untilizing, and in1_num_subblocks > 1 (with pack_untilize, can only untilize up to dst num tiles)
// This should normally be inside subblock loop and we pack out out_subblock_num_tiles
pack_untilize_dst_init_short<intermediate_num_tiles>(cb_intermed0);
cb_reserve_back(cb_intermed0, intermediate_num_tiles);
tile_regs_wait();
pack_untilize_dst<intermediate_num_tiles>(cb_intermed0);
pack_untilize_uninit();
pack_untilize_uninit(cb_intermed0);

tile_regs_release();
cb_push_back(cb_intermed0, intermediate_num_tiles);
Expand Down

0 comments on commit 62e8824

Please sign in to comment.