From 9bcea52f13abbe4ff14eaa942ef5a650299e8c40 Mon Sep 17 00:00:00 2001 From: Borys Bradel Date: Tue, 26 Nov 2024 14:44:10 +0000 Subject: [PATCH] #13875: fix tilize and attn matmul on BH --- .../unit_testing/misc/test_attn_matmul.py | 8 +------- .../blackhole/metal/llk_api/llk_unpack_tilize_api.h | 6 ++++++ .../kernels/compute/transformer_group_attn_matmul.cpp | 7 ++++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_attn_matmul.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_attn_matmul.py index ffb55fe0d72..68bdc6501ad 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_attn_matmul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_attn_matmul.py @@ -8,7 +8,7 @@ import ttnn from models.utility_functions import comp_pcc -from models.utility_functions import is_grayskull, skip_for_blackhole +from models.utility_functions import is_grayskull import ttnn @@ -30,7 +30,6 @@ def generate_input_shapes(): yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len] -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.parametrize("in0_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("out_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -71,7 +70,6 @@ def test_attn_matmul(num_loops, enable_async, in0_dtype, in1_dtype, out_dtype, d device.enable_async(False) -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") @pytest.mark.parametrize("in_dtype", [ttnn.float32, ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize( @@ -117,7 +115,6 @@ def test_attn_matmul_fp32(num_loops, enable_async, in_dtype, device): device.enable_async(False) -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.parametrize("in0_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("out_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -156,7 +153,6 @@ def test_attn_matmul_with_program_cache( device.enable_async(False) -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.parametrize( "shard_orientation", (ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR), @@ -278,7 +274,6 @@ def test_group_attn_matmul( device.enable_async(False) -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.parametrize("sharded", [False, True]) @pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("in1_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -373,7 +368,6 @@ def test_group_attn_matmul_with_program_cache( device.enable_async(False) -@skip_for_blackhole("Hanging on BH, see #12349") @pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32") @pytest.mark.parametrize("in_dtype", [ttnn.float32, ttnn.bfloat16]) @pytest.mark.parametrize( diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_unpack_tilize_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_unpack_tilize_api.h index a0aed4572a8..42b3965f246 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_unpack_tilize_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_unpack_tilize_api.h @@ -52,8 +52,14 @@ inline void llk_unpack_tilize_init(const std::uint32_t operand, const std::uint3 } inline void llk_unpack_tilize_uninit(const std::uint32_t operand, const std::uint32_t face_r_dim = FACE_R_DIM) { + // Revert X dim value to default. TT_SETADCXX(p_setadc::UNP_A, face_r_dim * FACE_C_DIM - 1, 0x0); TT_SETADCXX(p_setadc::UNP_B, face_r_dim * FACE_C_DIM - 1, 0x0); + + // Revert Z dim value back to default. + const uint Tile_z_dim = get_operand_num_faces(operand); + cfg_reg_rmw_tensix(Tile_z_dim); + std::uint32_t operand_id = get_operand_id(operand); unpack_config_u config = {0}; diff --git a/ttnn/cpp/ttnn/operations/experimental/matmul/group_attn_matmul/device/kernels/compute/transformer_group_attn_matmul.cpp b/ttnn/cpp/ttnn/operations/experimental/matmul/group_attn_matmul/device/kernels/compute/transformer_group_attn_matmul.cpp index 7fb014c9f08..4065bb9afdf 100644 --- a/ttnn/cpp/ttnn/operations/experimental/matmul/group_attn_matmul/device/kernels/compute/transformer_group_attn_matmul.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/matmul/group_attn_matmul/device/kernels/compute/transformer_group_attn_matmul.cpp @@ -75,6 +75,9 @@ void MAIN { cb_wait_front(cb_in1, in1_block_num_tiles); cb_pop_front(cb_in1, num_kv_heads_skip); + // This init changes DEST mapping, hence needs to be called before MATH does any processing, so that it has correct DEST mapping. + pack_untilize_dst_init_short(cb_intermed0); + for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { // TODO: Must be 1; need to review inner dim blocking and the untilizing uint32_t in1_index_subblock_offset = 0; @@ -133,14 +136,12 @@ void MAIN { in1_index_subblock_offset += out_subblock_w; } // in1_num_subblocks loop cb_pop_front(cb_in1, num_kv_heads_remaining); - // TODO: Review inner dim blocking, untilizing, and in1_num_subblocks > 1 (with pack_untilize, can only untilize up to dst num tiles) // This should normally be inside subblock loop and we pack out out_subblock_num_tiles - pack_untilize_dst_init_short(cb_intermed0); cb_reserve_back(cb_intermed0, intermediate_num_tiles); tile_regs_wait(); pack_untilize_dst(cb_intermed0); - pack_untilize_uninit(); + pack_untilize_uninit(cb_intermed0); tile_regs_release(); cb_push_back(cb_intermed0, intermediate_num_tiles);