Skip to content

Commit

Permalink
#8865: Optimize bcast_h and bcast_w binary kernel override_runtime_ar…
Browse files Browse the repository at this point in the history
…guments
  • Loading branch information
nemanjagrujic committed Aug 15, 2024
1 parent 76f94a4 commit d45096f
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 102 deletions.
10 changes: 8 additions & 2 deletions tests/tt_eager/python_api_testing/sweep_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,17 @@ def _gen_tt_nn_rmsnorm_shapes(shape):

def _gen_tt_nn_bcast_shapes(shape):
shape_type = random.randint(0, 2)
second_shape = shape.copy()

if shape_type == 0:
second_shape = [1]
second_shape[-2] = 1
second_shape[-1] = 1
elif shape_type == 1:
second_shape = [shape[-1]]
second_shape[-2] = shape[-2]
second_shape[-1] = 1
elif shape_type == 2:
second_shape[-2] = 1
second_shape[-1] = shape[-1]
# elif shape_type == 2:
# second_shape = [shape[-2], shape[-1]]
# elif shape_type == 3:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,68 +258,75 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id);
auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id);
auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Ht_per_core;

auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y);
auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y);
auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y);

if (core_group_1.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(15, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
binary_reader_args[3] = 0;
binary_reader_args[7] = 0;
binary_reader_args[8] = 0;
binary_reader_args[9] = 0;
binary_reader_args[10] = 0;
binary_reader_args[11] = 0;
binary_reader_args[12] = 0;
binary_reader_args[13] = 0;
binary_reader_args[14] = 0;

bcast_kernel_args[0] = 0;
bcast_kernel_args[1] = 0;
bcast_kernel_args[2] = 0;

unary_writer_args[3] = 0;
unary_writer_args[4] = 0;
unary_writer_args[5] = 0;
unary_writer_args[7] = 0;
unary_writer_args[8] = 0;
continue;
}
uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt;

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
core,
{
src_dram_buffer_a->address(), // 0
0, // 1
0, // 2
num_tensor_tiles_per_core, // 3
src_dram_buffer_b->address(), // 4
0, // 5
0, // 6
num_btensor_tiles, // 7
num_tensor_tiles_per_core, // 8
NC, // 9
Ht_per_core, // 10
Wt, // 11
bnc1, // 12
num_Wtiles_read, // 13
Ht * Wt, // 14
});

tt_metal::SetRuntimeArgs(
program,
bcast_kernel_id,
core,
{
NC, // B
Ht_per_core, // Ht
Wt // Wt
});

tt_metal::SetRuntimeArgs(
program,
unary_writer_kernel_id,
core,
{
dst_dram_buffer->address(),
0,
0,
Ht_per_core,
Wt,
num_Wtiles_read,
0,
NC,
Ht * Wt,
});
binary_reader_args[0] = src_dram_buffer_a->address();
// binary_reader_args[1] = 0;
// binary_reader_args[2] = 0;
binary_reader_args[3] = num_tensor_tiles_per_core;
binary_reader_args[4] = src_dram_buffer_b->address();
// binary_reader_args[5] = 0;
// binary_reader_args[6] = 0;
binary_reader_args[7] = num_btensor_tiles;
binary_reader_args[8] = num_tensor_tiles_per_core;
binary_reader_args[9] = NC;
binary_reader_args[10] = Ht_per_core;
binary_reader_args[11] = Wt;
binary_reader_args[12] = bnc1;
binary_reader_args[13] = num_Wtiles_read;
binary_reader_args[14] = Ht * Wt;

bcast_kernel_args[0] = NC;
bcast_kernel_args[1] = Ht_per_core;
bcast_kernel_args[2] = Wt;

unary_writer_args[0] = dst_dram_buffer->address();
// unary_writer_args[1] = 0;
// unary_writer_args[2] = 0;
unary_writer_args[3] = Ht_per_core;
unary_writer_args[4] = Wt;
unary_writer_args[5] = num_Wtiles_read;
// unary_writer_args[6] = 0;
unary_writer_args[7] = NC;
unary_writer_args[8] = Ht * Wt;

num_Wtiles_read += Ht_per_core * Wt;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,70 +258,79 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments(

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id);
auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id);
auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Wt_per_core;

auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y);
auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y);
auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y);

if (core_group_1.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(16, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
binary_reader_args[3] = 0;
binary_reader_args[7] = 0;
binary_reader_args[8] = 0;
binary_reader_args[9] = 0;
binary_reader_args[10] = 0;
binary_reader_args[11] = 0;
binary_reader_args[12] = 0;
binary_reader_args[13] = 0;
binary_reader_args[14] = 0;
binary_reader_args[15] = 0;

bcast_kernel_args[0] = 0;
bcast_kernel_args[1] = 0;
bcast_kernel_args[2] = 0;

unary_writer_args[3] = 0;
unary_writer_args[4] = 0;
unary_writer_args[5] = 0;
unary_writer_args[7] = 0;
unary_writer_args[8] = 0;
continue;
}
uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core;
uint32_t Wt_skip = Wt - Wt_per_core;

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
core,
{
src_dram_buffer_a->address(), // 0
0, // 1
0, // 2
num_tensor_tiles_per_core, // 3
src_dram_buffer_b->address(), // 4
0, // 5
0, // 6
num_btensor_tiles, // 7
num_tensor_tiles_per_core, // 8
NC, // 9
Ht, // 10
Wt_per_core, // 11
bnc1, // 12
num_Wtiles_read, // 13
Ht * Wt, // 14
Wt_skip, // 15
});

tt_metal::SetRuntimeArgs(
program,
bcast_kernel_id,
core,
{
NC, // B
Ht, // Ht
Wt_per_core // Wt
});
binary_reader_args[0] = src_dram_buffer_a->address();
// binary_reader_args[1] = 0;
// binary_reader_args[2] = 0;
binary_reader_args[3] = num_tensor_tiles_per_core;
binary_reader_args[4] = src_dram_buffer_b->address();
// binary_reader_args[5] = 0;
// binary_reader_args[6] = 0;
binary_reader_args[7] = num_btensor_tiles;
binary_reader_args[8] = num_tensor_tiles_per_core;
binary_reader_args[9] = NC;
binary_reader_args[10] = Ht;
binary_reader_args[11] = Wt_per_core;
binary_reader_args[12] = bnc1;
binary_reader_args[13] = num_Wtiles_read;
binary_reader_args[14] = Ht * Wt;
binary_reader_args[15] = Wt_skip;

bcast_kernel_args[0] = NC;
bcast_kernel_args[1] = Ht;
bcast_kernel_args[2] = Wt_per_core;

unary_writer_args[0] = dst_dram_buffer->address();
// unary_writer_args[1] = 0;
// unary_writer_args[2] = 0;
unary_writer_args[3] = Ht;
unary_writer_args[4] = Wt_per_core;
unary_writer_args[5] = num_Wtiles_read;
unary_writer_args[6] = Wt_skip;
unary_writer_args[7] = NC;
unary_writer_args[8] = Ht * Wt;

tt_metal::SetRuntimeArgs(
program,
unary_writer_kernel_id,
core,
{
dst_dram_buffer->address(),
0,
0,
Ht,
Wt_per_core,
num_Wtiles_read,
Wt_skip,
NC,
Ht * Wt,
});
num_Wtiles_read += Wt_per_core;
}
}
Expand Down

0 comments on commit d45096f

Please sign in to comment.