Skip to content

Commit

Permalink
#12330: changed flash decode op to use simple shape and added optimiz…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
caixunshiren committed Oct 23, 2024
1 parent 1938e0a commit cb08e0e
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,9 @@ void MAIN {
}

// Get cur_pos
uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
if (is_causal) {
constexpr uint32_t cur_pos_base = St*32-1;
uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
if constexpr(is_causal) {
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@ uint32_t virtual_seq_tile_id_to_physical_tile_id(uint32_t seq_tile_idx, uint32_t
return physical_block * block_stride + head_offset + block_offset;
}

template<uint32_t cb_mask_in, uint32_t mask_chunk_tiles, uint32_t mask_tile_bytes, uint32_t barrier_threshold, uint32_t PNHt, uint32_t Sk_chunk_t>
uint32_t read_mask_chunk(uint32_t PSt, uint32_t mask_start_tile_id, const InterleavedAddrGenFast<true> mask_reader) {
// Read mask chunk
cb_reserve_back(cb_mask_in, mask_chunk_tiles);
uint32_t mask_write_ptr = get_write_ptr(cb_mask_in);
uint32_t barrier_count = 0;
for (uint32_t row = 0; row < PNHt; ++row) {
uint32_t mask_tile_id = mask_start_tile_id + row * PSt;
for (uint32_t col = 0; col < Sk_chunk_t; ++col) {
noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr);
mask_tile_id++;
mask_write_ptr += mask_tile_bytes;

if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
}
noc_async_read_barrier();
cb_push_back(cb_mask_in, mask_chunk_tiles);
mask_start_tile_id += mask_chunk_tiles;
return mask_start_tile_id;
}

void kernel_main() {
/*
In DRAM, Q is (B, PNHt, DHt), K is (B, St, DHt), V is (B, St, DHt), mask is (B, PNHt, PSt)
Expand Down Expand Up @@ -74,8 +99,9 @@ void kernel_main() {
return;
}
// Get cur_pos
uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
if (is_causal) {
constexpr uint32_t cur_pos_base = St*32-1;
uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
if constexpr(is_causal) {
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
Expand Down Expand Up @@ -247,26 +273,7 @@ void kernel_main() {
cb_push_back(cb_k_in, k_chunk_tiles);

if constexpr(use_attention_mask){
// Read mask chunk
cb_reserve_back(cb_mask_in, mask_chunk_tiles);
uint32_t mask_write_ptr = get_write_ptr(cb_mask_in);
barrier_count = 0;
for (uint32_t row = 0; row < PNHt; ++row) {
uint32_t mask_tile_id = mask_start_tile_id + row * PSt;
for (uint32_t col = 0; col < Sk_chunk_t; ++col) {
noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr);
mask_tile_id++;
mask_write_ptr += mask_tile_bytes;

if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
}
noc_async_read_barrier();
cb_push_back(cb_mask_in, mask_chunk_tiles);
mask_start_tile_id += mask_chunk_tiles;
mask_start_tile_id = read_mask_chunk<cb_mask_in, mask_chunk_tiles, mask_tile_bytes, barrier_threshold, PNHt, Sk_chunk_t>(PSt, mask_start_tile_id, mask_reader);
}

// Read V chunk in row major order, write in row-major order
Expand Down Expand Up @@ -330,26 +337,7 @@ void kernel_main() {
k_start_tile_id += k_chunk_tiles;

if constexpr(use_attention_mask){
// Read mask chunk
cb_reserve_back(cb_mask_in, mask_chunk_tiles);
uint32_t mask_write_ptr = get_write_ptr(cb_mask_in);
barrier_count = 0;
for (uint32_t row = 0; row < PNHt; ++row) {
uint32_t mask_tile_id = mask_start_tile_id + row * PSt;
for (uint32_t col = 0; col < Sk_chunk_t; ++col) {
noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr);
mask_tile_id++;
mask_write_ptr += mask_tile_bytes;

if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
}
noc_async_read_barrier();
cb_push_back(cb_mask_in, mask_chunk_tiles);
mask_start_tile_id += mask_chunk_tiles;
mask_start_tile_id = read_mask_chunk<cb_mask_in, mask_chunk_tiles, mask_tile_bytes, barrier_threshold, PNHt, Sk_chunk_t>(PSt, mask_start_tile_id, mask_reader);
}

// Read V chunk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,10 @@ void kernel_main() {
return;
}
// Get cur_pos
uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
constexpr uint32_t cur_pos_base = St*32-1;
uint32_t cur_pos = cur_pos_base; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (is_causal) {
if constexpr(is_causal) {
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
input_tensor.get_dtype());
}

const auto q_shape = input_tensors.at(0).get_legacy_shape();
const auto q_shape_unpadded = input_tensors.at(0).get_shape();
const auto k_shape = input_tensors.at(1).get_legacy_shape();
const auto v_shape = input_tensors.at(2).get_legacy_shape();
const auto q_shape = input_tensors.at(0).get_padded_shape();
const auto q_shape_unpadded = input_tensors.at(0).get_logical_shape();
const auto k_shape = input_tensors.at(1).get_padded_shape();
const auto v_shape = input_tensors.at(2).get_padded_shape();

// Input 0 must be sharded by height or DRAM interleaved. All other inputs must be in DRAM.
const auto Q_memcfg = input_tensors.at(0).memory_config();
Expand All @@ -48,8 +48,8 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
if (optional_input_tensors.at(2).has_value()){
// Causal attention verification
const auto& mask_tensor = optional_input_tensors.at(2).value();
const auto mask_shape = mask_tensor.get_legacy_shape();
const auto mask_shape_unpadded = mask_tensor.get_shape();
const auto mask_shape = mask_tensor.get_padded_shape();
const auto mask_shape_unpadded = mask_tensor.get_logical_shape();

TT_FATAL(mask_shape[2] == q_shape[2], "Expect same number of padded heads in mask as in Q, got {} and {}", mask_shape[2], q_shape[2]);
TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[3], q_shape_unpadded[2]);
Expand Down Expand Up @@ -78,7 +78,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
const auto& cur_pos_tensor = optional_input_tensors.at(0).value();
TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Expect cur_pos to be INT32, got {}", cur_pos_tensor.get_dtype());
TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Expect cur_pos to be ROW_MAJOR, got {}", cur_pos_tensor.get_layout());
const auto cur_pos_shape = cur_pos_tensor.get_legacy_shape();
const auto cur_pos_shape = cur_pos_tensor.get_padded_shape();
TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q, got {} and {}", cur_pos_shape[0], B);
}

Expand All @@ -88,7 +88,7 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
TT_FATAL(page_table_tensor.get_dtype() == DataType::INT32, "Error");
TT_FATAL(page_table_tensor.get_layout() == Layout::ROW_MAJOR, "Error");

const auto page_table_shape = page_table_tensor.get_legacy_shape();
const auto page_table_shape = page_table_tensor.get_padded_shape();

TT_FATAL(page_table_shape[0] == B, "page_table must have hidden size equal to Q");

Expand Down Expand Up @@ -149,9 +149,9 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
}
}

std::vector<tt::tt_metal::LegacyShape> ScaledDotProductAttentionDecode::compute_output_shapes(
std::vector<ttnn::SimpleShape> ScaledDotProductAttentionDecode::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
return {input_tensors.at(0).get_legacy_shape()};
return {input_tensors.at(0).get_padded_shape()};
}

std::vector<Tensor> ScaledDotProductAttentionDecode::create_output_tensors(
Expand All @@ -174,7 +174,7 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program(

auto scale = this->scale;
if (not scale.has_value()) {
scale = 1.0f / std::sqrt(static_cast<float>(input_tensor_q.get_legacy_shape()[-1]));
scale = 1.0f / std::sqrt(static_cast<float>(input_tensor_q.get_padded_shape()[-1]));
}

return detail::sdpa_decode_multi_core(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct ScaledDotProductAttentionDecode {
void validate(const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;

std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;

std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(

const bool is_paged_attention = page_table_tensor.has_value();

const auto q_shape = input_tensor_q.get_legacy_shape();
const auto q_shape_unpadded = input_tensor_q.get_shape();
const auto k_shape = input_tensor_k.get_legacy_shape();
const auto q_shape = input_tensor_q.get_padded_shape();
const auto q_shape_unpadded = input_tensor_q.get_logical_shape();
const auto k_shape = input_tensor_k.get_padded_shape();
// Use k_shape for S and DH since Q might be different for decode
uint32_t B = q_shape[1], PNH = q_shape[2], S = k_shape[2], DH = k_shape[3];

Expand All @@ -62,7 +62,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
uint32_t block_size = k_shape[2];
page_block_size_t = block_size / TILE_HEIGHT;
// get real S using the page_table_tensor
S = page_table_tensor.value().get_legacy_shape()[-1]*S;
S = page_table_tensor.value().get_padded_shape()[-1]*S;
}
uint32_t Bkv = k_shape[0];
uint32_t St = S/TILE_HEIGHT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "ttnn/run_operation.hpp"

namespace {
uint32_t get_chunk_size(uint32_t s) {
inline uint32_t get_chunk_size(uint32_t s) {
/*
# find maximum power of 2 divisor of s
for i in range(1, s):
Expand Down Expand Up @@ -43,7 +43,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke(
std::optional<DeviceComputeKernelConfig> compute_kernel_config) {
auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch()
: ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch();
uint32_t s = input_tensor_k.get_shape()[-2];
uint32_t s = input_tensor_k.get_logical_shape()[-2];
uint32_t k_chunk_size = get_chunk_size(s);
if (program_config.has_value() && program_config.value().k_chunk_size > 0) {
k_chunk_size = program_config.value().k_chunk_size;
Expand Down Expand Up @@ -118,7 +118,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke(
std::optional<DeviceComputeKernelConfig> compute_kernel_config) {
auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch()
: ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch();
uint32_t s = input_tensor_k.get_shape()[-2];
uint32_t s = input_tensor_k.get_logical_shape()[-2];
uint32_t k_chunk_size = get_chunk_size(s);
if (program_config.has_value() && program_config.value().k_chunk_size > 0) {
k_chunk_size = program_config.value().k_chunk_size;
Expand Down

0 comments on commit cb08e0e

Please sign in to comment.