From cb08e0e2345a942a63949429ad4eeab1fd0f8146 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Wed, 23 Oct 2024 15:17:19 +0000 Subject: [PATCH] #12330: changed flash decode op to use simple shape and added optimizations --- .../kernels/compute/sdpa_flash_decode.cpp | 5 +- .../kernels/dataflow/reader_decode_all.cpp | 72 ++++++++----------- .../kernels/dataflow/writer_decode_all.cpp | 5 +- .../sdpa_decode/device/sdpa_decode_op.cpp | 22 +++--- .../sdpa_decode/device/sdpa_decode_op.hpp | 2 +- .../device/sdpa_decode_program_factory.cpp | 8 +-- .../transformer/sdpa_decode/sdpa_decode.cpp | 6 +- 7 files changed, 55 insertions(+), 65 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index 78045f01fce6..d21dc41a84fd 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -425,8 +425,9 @@ void MAIN { } // Get cur_pos - uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position - if (is_causal) { + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if constexpr(is_causal) { // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list if (cur_pos_arg != UINT32_MAX){ cur_pos = cur_pos_arg; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index 32a8f28cf757..5e82ebcfa62d 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -27,6 +27,31 @@ uint32_t virtual_seq_tile_id_to_physical_tile_id(uint32_t seq_tile_idx, uint32_t return physical_block * block_stride + head_offset + block_offset; } +template +uint32_t read_mask_chunk(uint32_t PSt, uint32_t mask_start_tile_id, const InterleavedAddrGenFast mask_reader) { + // Read mask chunk + cb_reserve_back(cb_mask_in, mask_chunk_tiles); + uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); + uint32_t barrier_count = 0; + for (uint32_t row = 0; row < PNHt; ++row) { + uint32_t mask_tile_id = mask_start_tile_id + row * PSt; + for (uint32_t col = 0; col < Sk_chunk_t; ++col) { + noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); + mask_tile_id++; + mask_write_ptr += mask_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_mask_in, mask_chunk_tiles); + mask_start_tile_id += mask_chunk_tiles; + return mask_start_tile_id; +} + void kernel_main() { /* In DRAM, Q is (B, PNHt, DHt), K is (B, St, DHt), V is (B, St, DHt), mask is (B, PNHt, PSt) @@ -74,8 +99,9 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position - if (is_causal) { + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if constexpr(is_causal) { // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list if (cur_pos_arg != UINT32_MAX){ cur_pos = cur_pos_arg; @@ -247,26 +273,7 @@ void kernel_main() { cb_push_back(cb_k_in, k_chunk_tiles); if constexpr(use_attention_mask){ - // Read mask chunk - cb_reserve_back(cb_mask_in, mask_chunk_tiles); - uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); - barrier_count = 0; - for (uint32_t row = 0; row < PNHt; ++row) { - uint32_t mask_tile_id = mask_start_tile_id + row * PSt; - for (uint32_t col = 0; col < Sk_chunk_t; ++col) { - noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); - mask_tile_id++; - mask_write_ptr += mask_tile_bytes; - - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } - } - } - noc_async_read_barrier(); - cb_push_back(cb_mask_in, mask_chunk_tiles); - mask_start_tile_id += mask_chunk_tiles; + mask_start_tile_id = read_mask_chunk(PSt, mask_start_tile_id, mask_reader); } // Read V chunk in row major order, write in row-major order @@ -330,26 +337,7 @@ void kernel_main() { k_start_tile_id += k_chunk_tiles; if constexpr(use_attention_mask){ - // Read mask chunk - cb_reserve_back(cb_mask_in, mask_chunk_tiles); - uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); - barrier_count = 0; - for (uint32_t row = 0; row < PNHt; ++row) { - uint32_t mask_tile_id = mask_start_tile_id + row * PSt; - for (uint32_t col = 0; col < Sk_chunk_t; ++col) { - noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); - mask_tile_id++; - mask_write_ptr += mask_tile_bytes; - - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } - } - } - noc_async_read_barrier(); - cb_push_back(cb_mask_in, mask_chunk_tiles); - mask_start_tile_id += mask_chunk_tiles; + mask_start_tile_id = read_mask_chunk(PSt, mask_start_tile_id, mask_reader); } // Read V chunk diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index fdc083b391c5..4059ad847360 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -263,9 +263,10 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + constexpr uint32_t cur_pos_base = St*32-1; + uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (is_causal) { + if constexpr(is_causal) { if (cur_pos_arg != UINT32_MAX){ cur_pos = cur_pos_arg; } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index 8baed071307b..b192917fe09d 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -23,10 +23,10 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ input_tensor.get_dtype()); } - const auto q_shape = input_tensors.at(0).get_legacy_shape(); - const auto q_shape_unpadded = input_tensors.at(0).get_shape(); - const auto k_shape = input_tensors.at(1).get_legacy_shape(); - const auto v_shape = input_tensors.at(2).get_legacy_shape(); + const auto q_shape = input_tensors.at(0).get_padded_shape(); + const auto q_shape_unpadded = input_tensors.at(0).get_logical_shape(); + const auto k_shape = input_tensors.at(1).get_padded_shape(); + const auto v_shape = input_tensors.at(2).get_padded_shape(); // Input 0 must be sharded by height or DRAM interleaved. All other inputs must be in DRAM. const auto Q_memcfg = input_tensors.at(0).memory_config(); @@ -48,8 +48,8 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ if (optional_input_tensors.at(2).has_value()){ // Causal attention verification const auto& mask_tensor = optional_input_tensors.at(2).value(); - const auto mask_shape = mask_tensor.get_legacy_shape(); - const auto mask_shape_unpadded = mask_tensor.get_shape(); + const auto mask_shape = mask_tensor.get_padded_shape(); + const auto mask_shape_unpadded = mask_tensor.get_logical_shape(); TT_FATAL(mask_shape[2] == q_shape[2], "Expect same number of padded heads in mask as in Q, got {} and {}", mask_shape[2], q_shape[2]); TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[3], q_shape_unpadded[2]); @@ -78,7 +78,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Expect cur_pos to be INT32, got {}", cur_pos_tensor.get_dtype()); TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Expect cur_pos to be ROW_MAJOR, got {}", cur_pos_tensor.get_layout()); - const auto cur_pos_shape = cur_pos_tensor.get_legacy_shape(); + const auto cur_pos_shape = cur_pos_tensor.get_padded_shape(); TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q, got {} and {}", cur_pos_shape[0], B); } @@ -88,7 +88,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ TT_FATAL(page_table_tensor.get_dtype() == DataType::INT32, "Error"); TT_FATAL(page_table_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); - const auto page_table_shape = page_table_tensor.get_legacy_shape(); + const auto page_table_shape = page_table_tensor.get_padded_shape(); TT_FATAL(page_table_shape[0] == B, "page_table must have hidden size equal to Q"); @@ -149,9 +149,9 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ } } -std::vector ScaledDotProductAttentionDecode::compute_output_shapes( +std::vector ScaledDotProductAttentionDecode::compute_output_shapes( const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; + return {input_tensors.at(0).get_padded_shape()}; } std::vector ScaledDotProductAttentionDecode::create_output_tensors( @@ -174,7 +174,7 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( auto scale = this->scale; if (not scale.has_value()) { - scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_legacy_shape()[-1])); + scale = 1.0f / std::sqrt(static_cast(input_tensor_q.get_padded_shape()[-1])); } return detail::sdpa_decode_multi_core( diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index 7993a1a96b2d..c055bbb77b9a 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -27,7 +27,7 @@ struct ScaledDotProductAttentionDecode { void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_shapes(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 92d88c002f40..3d5ba1ae74fc 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -48,9 +48,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const bool is_paged_attention = page_table_tensor.has_value(); - const auto q_shape = input_tensor_q.get_legacy_shape(); - const auto q_shape_unpadded = input_tensor_q.get_shape(); - const auto k_shape = input_tensor_k.get_legacy_shape(); + const auto q_shape = input_tensor_q.get_padded_shape(); + const auto q_shape_unpadded = input_tensor_q.get_logical_shape(); + const auto k_shape = input_tensor_k.get_padded_shape(); // Use k_shape for S and DH since Q might be different for decode uint32_t B = q_shape[1], PNH = q_shape[2], S = k_shape[2], DH = k_shape[3]; @@ -62,7 +62,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t block_size = k_shape[2]; page_block_size_t = block_size / TILE_HEIGHT; // get real S using the page_table_tensor - S = page_table_tensor.value().get_legacy_shape()[-1]*S; + S = page_table_tensor.value().get_padded_shape()[-1]*S; } uint32_t Bkv = k_shape[0]; uint32_t St = S/TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp index 6035b3876815..d1da3bc4c990 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp @@ -9,7 +9,7 @@ #include "ttnn/run_operation.hpp" namespace { -uint32_t get_chunk_size(uint32_t s) { +inline uint32_t get_chunk_size(uint32_t s) { /* # find maximum power of 2 divisor of s for i in range(1, s): @@ -43,7 +43,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - uint32_t s = input_tensor_k.get_shape()[-2]; + uint32_t s = input_tensor_k.get_logical_shape()[-2]; uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size; @@ -118,7 +118,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - uint32_t s = input_tensor_k.get_shape()[-2]; + uint32_t s = input_tensor_k.get_logical_shape()[-2]; uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size;