Skip to content

Commit

Permalink
#13365: added program caching for page tensor for flash decode
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Oct 2, 2024
1 parent f33dcd9 commit 3109518
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,8 @@ def test_sdpa_decode_paged_attention(
sharded_out=False,
)

assert device.num_program_cache_entries() == 3


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -896,6 +898,7 @@ def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_
sharded_in=False,
sharded_out=False,
start_indices=start_indices,
cur_pos_tensor=True,
)
run_test_sdpa_decode_single_iter(
device,
Expand All @@ -910,6 +913,7 @@ def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_
sharded_in=True,
sharded_out=False,
start_indices=start_indices,
cur_pos_tensor=False,
)
run_test_sdpa_decode_single_iter(
device,
Expand All @@ -924,6 +928,7 @@ def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_
sharded_in=True,
sharded_out=True,
start_indices=start_indices,
cur_pos_tensor=False,
)
run_test_sdpa_decode_single_iter(
device,
Expand All @@ -938,6 +943,7 @@ def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_
sharded_in=False,
sharded_out=True,
start_indices=start_indices,
cur_pos_tensor=True,
)

assert device.num_program_cache_entries() == 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,22 @@ void kernel_main() {
constexpr bool is_q_sharded = get_compile_time_arg_val(6);
constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(7);
constexpr uint32_t k_chunk_size = get_compile_time_arg_val(8);
constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(9);
constexpr uint32_t index_stick_size_B = get_compile_time_arg_val(10);
constexpr bool is_paged_attention = get_compile_time_arg_val(11) == 1;
constexpr uint32_t num_kv_heads = get_compile_time_arg_val(12);
constexpr uint32_t block_size_t = get_compile_time_arg_val(13);
constexpr uint32_t log2_page_table_page_size = get_compile_time_arg_val(14);
constexpr uint32_t page_table_page_size = get_compile_time_arg_val(15);
constexpr uint32_t Bkv = get_compile_time_arg_val(16);
constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(17);
constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(18);
constexpr uint32_t num_output_cores = get_compile_time_arg_val(19);
constexpr uint32_t index_stick_size_B = get_compile_time_arg_val(9);
constexpr bool is_paged_attention = get_compile_time_arg_val(10) == 1;
constexpr uint32_t num_kv_heads = get_compile_time_arg_val(11);
constexpr uint32_t block_size_t = get_compile_time_arg_val(12);
constexpr uint32_t Bkv = get_compile_time_arg_val(13);
constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14);
constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15);
constexpr uint32_t num_output_cores = get_compile_time_arg_val(16);

uint32_t arg_idx = 0;
const uint32_t q_addr = get_arg_val<uint32_t>(arg_idx++);
const uint32_t k_addr = get_arg_val<uint32_t>(arg_idx++);
const uint32_t v_addr = get_arg_val<uint32_t>(arg_idx++);
const uint32_t pos_addr = get_arg_val<uint32_t>(arg_idx++);
const uint32_t page_table_addr = get_arg_val<uint32_t>(arg_idx++);
const uint32_t page_table_page_size = get_arg_val<uint32_t>(arg_idx++);
const bool is_worker = get_arg_val<uint32_t>(arg_idx++) == 0;
const bool is_output_core = get_arg_val<uint32_t>(arg_idx++) == 1;
const uint32_t cur_head_group = get_arg_val<uint32_t>(arg_idx++);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program(
scale = 1.0f / std::sqrt(static_cast<float>(input_tensor_q.get_legacy_shape()[-1]));
}

// TODO: get this from program_config
std::size_t q_chunk_size = program_config.has_value() ? program_config->q_chunk_size : 32;
std::size_t k_chunk_size = program_config.has_value() ? program_config->k_chunk_size : 32;

return detail::sdpa_decode_multi_core(
input_tensor_q,
input_tensor_k,
Expand All @@ -183,8 +179,7 @@ operation::Hash ScaledDotProductAttentionDecode::compute_program_hash(const std:
this->compute_kernel_config,
this->k_chunk_size,
this->paged_attention,
input_tensors,
optional_input_tensors);
input_tensors);
}

} // namespace ttnn::operations::transformer
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
uint32_t page_block_size_t = 0;

if (is_paged_attention) {
const auto page_table_shape = page_table_tensor.value().get_legacy_shape();
uint32_t max_blocks_per_seq = page_table_shape[1];
uint32_t block_size = k_shape[2];
S = max_blocks_per_seq * block_size;
page_block_size_t = block_size / TILE_HEIGHT;
}
uint32_t Bkv = k_shape[0];
Expand Down Expand Up @@ -507,9 +504,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(

std::vector<uint32_t> reader_compile_time_args_common = {
B, PNHt, St, DHt, Sk_chunk_t, num_active_cores, is_q_sharded,
num_cores_per_batch, k_chunk_size, log2_page_size, index_stick_size,
num_cores_per_batch, k_chunk_size, index_stick_size,
(uint32_t)is_paged_attention, num_kv_heads, page_block_size_t,
log2_page_table_page_size, page_table_stick_size, Bkv, num_cores_per_head, num_heads_per_core, num_output_cores
Bkv, num_cores_per_head, num_heads_per_core, num_output_cores
};

std::vector<uint32_t> writer_compile_time_args_common = {
Expand Down Expand Up @@ -613,7 +610,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
log_debug("cur_pos: {}", cur_pos);

// reader runtime args
std::vector<uint32_t> reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos};
std::vector<uint32_t> reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, page_table_stick_size, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos};
reader_rt_args.insert(reader_rt_args.end(), output_core_physical_xs.begin(), output_core_physical_xs.end());
reader_rt_args.insert(reader_rt_args.end(), output_core_physical_ys.begin(), output_core_physical_ys.end());

Expand Down Expand Up @@ -684,6 +681,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
uint32_t v_addr = v_buffer->address();
uint32_t pos_addr = use_cur_pos_tensor ? optional_input_tensors.at(0).value().buffer()->address() : 0;
uint32_t page_table_addr = is_paged_attention ? optional_input_tensors.at(1).value().buffer()->address() : 0;
auto page_table_buffer = is_paged_attention ? optional_input_tensors.at(1).value().buffer() : nullptr;
uint32_t page_table_stick_size = is_paged_attention ? page_table_buffer->aligned_page_size() : 0;
uint32_t out_addr = out0_buffer->address();

auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id);
Expand Down Expand Up @@ -714,6 +713,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
reader_args[arg_idx++] = v_addr;
reader_args[arg_idx++] = pos_addr;
reader_args[arg_idx++] = page_table_addr;
reader_args[arg_idx++] = page_table_stick_size;
reader_args[arg_idx++] = do_reduce;
reader_args[arg_idx++] = do_output;
reader_args[arg_idx++] = cur_head;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke(
: ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch();
//uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end());
uint32_t k_chunk_size = 512; //get_chunk_size(max_cur_pos + 1);
if (program_config.has_value() && program_config.value().k_chunk_size > 0) {
k_chunk_size = program_config.value().k_chunk_size;
// assert chunk size must be power of 2 and multiple of 32
TT_FATAL((k_chunk_size & (k_chunk_size - 1)) == 0, "User provided k_chunk_size must be power of 2, got: {}", k_chunk_size);
TT_FATAL(k_chunk_size % 32 == 0, "User provided k_chunk_size must be multiple of 32, got: {}", k_chunk_size);
}

// get chunk size and then pass to sdpa decode as an attribute for prgm cache
auto kernel_config_val = init_device_compute_kernel_config(
Expand Down

0 comments on commit 3109518

Please sign in to comment.