Skip to content

Commit

Permalink
#14517: Porting from defines to compile time arguments in kernel code
Browse files Browse the repository at this point in the history
  • Loading branch information
jvegaTT committed Nov 8, 2024
1 parent 3231aa2 commit dc21b52
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ void kernel_main() {
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(4);
constexpr uint32_t ring_size = get_compile_time_arg_val(5);
constexpr bool fuse_op = get_compile_time_arg_val(6);
constexpr uint32_t output_tile_size = get_compile_time_arg_val(7);


ASSERT(half_cb_n_pages > rem_num_pages);

#ifdef SHARDED_MEM_LAYOUT
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(7));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0;
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(8));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0;
#endif

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
Expand Down Expand Up @@ -94,7 +96,7 @@ void kernel_main() {
#ifdef INTERLEAVED_MEM_LAYOUT
const DataFormat in0_df = get_dataformat(cb_id_in0);

InterleavedAddrGenFast<dst_is_dram,OUTPUT_TILE_SIZE> d = {
InterleavedAddrGenFast<dst_is_dram,output_tile_size> d = {
.bank_base_address = dst_addr,
.page_size = output_page_size,
.data_format = in0_df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,27 @@ void kernel_main() {
uint32_t sem_addr = get_semaphore(get_compile_time_arg_val(5));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(6);
constexpr uint32_t ring_size = get_compile_time_arg_val(7);
constexpr uint32_t input_tile_size = get_compile_time_arg_val(8);
constexpr uint32_t output_tile_size = get_compile_time_arg_val(9);
#ifdef SHARDED_MEM_LAYOUT

constexpr tt::tt_metal::TensorMemoryLayout input_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(8));
constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(9);
constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(10);
constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11);
constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12);
constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13);
constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14);
constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0;

constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(16));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(17);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(18);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(19);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(20);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(21);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(22);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(23) != 0;
constexpr tt::tt_metal::TensorMemoryLayout input_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(10));
constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(11);
constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(12);
constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(13);
constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(14);
constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(15);
constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(16);
constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(17) != 0;

constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(18));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(19);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(20);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(21);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(22);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(23);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(24);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(25) != 0;
#endif

ASSERT(half_cb_n_pages > rem_num_pages);
Expand Down Expand Up @@ -121,13 +123,13 @@ void kernel_main() {
#ifdef INTERLEAVED_MEM_LAYOUT
const DataFormat in0_df = get_dataformat(cb_id_in0);

const InterleavedAddrGenFast<src_is_dram, INPUT_TILE_SIZE> s = {
const InterleavedAddrGenFast<src_is_dram, input_tile_size> s = {
.bank_base_address = src_addr,
.page_size = page_size,
.data_format = in0_df
};

InterleavedAddrGenFast<dst_is_dram, OUTPUT_TILE_SIZE> d = {
InterleavedAddrGenFast<dst_is_dram, output_tile_size> d = {
.bank_base_address = dst_addr,
.page_size = output_page_size,
.data_format = in0_df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,19 @@ void kernel_main() {
volatile uint32_t *const writer_send_sem_ptr = reinterpret_cast<volatile uint32_t *const >(get_semaphore(get_compile_time_arg_val(4)));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(5);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(6);
constexpr uint32_t output_tile_size = get_compile_time_arg_val(7);

ASSERT(half_cb_n_pages > rem_num_pages);

#ifdef SHARDED_MEM_LAYOUT
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(7));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0;
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(8));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0;
#endif

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
Expand Down Expand Up @@ -95,7 +96,7 @@ void kernel_main() {
#ifdef INTERLEAVED_MEM_LAYOUT
const DataFormat in0_df = get_dataformat(cb_id_in0);

const InterleavedAddrGenFast<dst_is_dram, OUTPUT_TILE_SIZE> d = {
const InterleavedAddrGenFast<dst_is_dram, output_tile_size> d = {
.bank_base_address = dst_addr,
.page_size = output_page_size,
.data_format = in0_df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
tt::DataFormat df = datatype_to_dataformat_converter(input_tensor.get_dtype());

std::map<string, string> worker_defines;
worker_defines["INPUT_TILE_SIZE"] = std::to_string(input_tensor_config->get_tile_size());
worker_defines["OUTPUT_TILE_SIZE"] = std::to_string(output_tensor_config->get_tile_size());
if (rm) {
worker_defines["ROW_MAJOR_LAYOUT"] = "1";
} else {
Expand Down Expand Up @@ -371,6 +369,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
log_trace(tt::LogOp, "input_page_size: {}", input_page_size);
uint32_t src0_cb_index = tt::CB::c_in0;
const uint32_t cb_n_packets = 2;
const uint32_t cb_size_in_pages = cb_n_packets * max_pages_per_chunk;
const uint32_t CB_buffer_size = cb_n_packets * max_buffer_per_chunk;
log_trace(tt::LogOp, "max_pages_per_chunk: {}", max_pages_per_chunk);
CircularBufferConfig cb_src0_config = CircularBufferConfig(CB_buffer_size, {{src0_cb_index, df}})
Expand Down Expand Up @@ -398,7 +397,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(ring_index),
static_cast<uint32_t>(sender_worker_reader_semaphore_id),
static_cast<uint32_t>(max_pages_per_chunk),
static_cast<uint32_t>(ring_size)
static_cast<uint32_t>(ring_size),
static_cast<uint32_t>(input_tensor_config->get_tile_size()),
static_cast<uint32_t>(output_tensor_config->get_tile_size())
};
if (is_sharded) {
emit_sharded_tensor_kernel_ct_args(device, input_tensor, worker_reader_sender_ct_args, input_pages_per_shard_y, input_pages_per_shard_x);
Expand Down Expand Up @@ -442,6 +443,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(sender_worker_writer_semaphore_id),
static_cast<uint32_t>(max_pages_per_chunk),
static_cast<uint32_t>(num_edm_buffers_per_channel),
static_cast<uint32_t>(output_tensor_config->get_tile_size())
};

if (is_sharded) {
Expand Down Expand Up @@ -503,7 +505,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(sender_worker_reader_semaphore_id),
static_cast<uint32_t>(max_pages_per_chunk),
static_cast<uint32_t>(ring_size),
static_cast<bool>(fuse_op)
static_cast<bool>(fuse_op),
static_cast<uint32_t>(output_tensor_config->get_tile_size())
};

if (is_sharded) {
Expand Down Expand Up @@ -690,7 +693,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
}
if (rem_pages != 0) {
rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) = rem_pages;
TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= max_pages_per_chunk * 2);
TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= cb_size_in_pages);
}
{ // Logging
log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (clockwise):");
Expand Down

0 comments on commit dc21b52

Please sign in to comment.