Skip to content

Commit

Permalink
#6462: Upsample kernel opt
Browse files Browse the repository at this point in the history
Introduced a reader which just trivially splits the work, reader does
first half, writer does second half.

Allocate the full output buffer up front, barrier once at end of kernel.
  • Loading branch information
nsmithtt committed Mar 20, 2024
1 parent c79f2b8 commit 0fb2138
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 39 deletions.
14 changes: 7 additions & 7 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0


TILE_WIDTH = 32


def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet:
max_grid_size = (9, 12) ## (y, x)
def get_shard_grid_from_num_cores(device, ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet:
device_grid = device.compute_with_storage_grid_size()
max_grid_size = (device_grid.y, device_grid.x)
if isinstance(ncores, int):
if ncores % max_grid_size[1] == 0:
core_grid = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1])
Expand Down Expand Up @@ -62,7 +62,6 @@ def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.e
raise ValueError("Invalid ncores")


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"input_shapes",
[
Expand Down Expand Up @@ -105,7 +104,6 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
assert isequal


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"input_shape",
[
Expand All @@ -114,6 +112,7 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
[2, 1280, 8, 8], # 512x512
[2, 1280, 16, 16],
[1, 64, 132, 10],
[1, 32, 8, 8],
],
)
@pytest.mark.parametrize("scale_h", [2])
Expand All @@ -137,7 +136,8 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate

## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape
ncores = None
max_grid_size = (9, 12) ## (y, x)
device_grid = device.compute_with_storage_grid_size()
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1])
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
# use_height_and_width_as_shard_shape=False ##shard_strategy == ttnn.ShardStrategy.HEIGHT,
# )

shard_grid = get_shard_grid_from_num_cores(ncores)
shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,59 @@
void kernel_main() {

uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_nsticks_local = get_arg_val<uint32_t>(1);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
uint32_t out_w = get_arg_val<uint32_t>(5);
uint32_t start_in_stick_id = get_arg_val<uint32_t>(6);

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = get_compile_time_arg_val(1);
constexpr uint32_t is_reader = get_compile_time_arg_val(2);

uint32_t l1_read_addr = get_read_ptr(in_cb_id);
uint32_t l1_write_addr = get_write_ptr(out_cb_id);
uint32_t in_image_row_nbytes = in_w * stick_nbytes;
uint32_t out_image_row_nbytes = out_w * stick_nbytes;
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2;
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core;
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes;
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes;

// cb_wait_front(in_cb_id, in_nsticks_local);
cb_reserve_back(out_cb_id, out_w);

uint32_t in_stick_row_id = start_in_stick_id / in_w; // assuming shard begins with a new row. TODO: generalize?
uint32_t l1_write_addr_stick = l1_write_addr;
// for each input stick
for (uint32_t i = start_in_stick_id; i < start_in_stick_id + in_nsticks_local; ++ i) {
cb_reserve_back(out_cb_id, scale_h * scale_w);

uint32_t l1_write_addr_local = l1_write_addr_stick;
for (uint32_t j = 0; j < scale_h; ++j) {
l1_write_addr_local = l1_write_addr_stick + j * out_w * stick_nbytes;
// replicate stick scale_h times.
for (size_t k = 0; k < scale_w; ++k) {
// assuming shard begins with a new row. TODO: generalize?
for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) {
uint32_t l1_write_addr_image_row_start = l1_write_addr;
for (uint32_t i = 0; i < in_w; ++i) {
// replicate stick scale_w times.
for (uint32_t sw = 0; sw < scale_w; ++sw) {
// replicate stick scale_w times.
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr_local);
noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes);
l1_write_addr_local += stick_nbytes;
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_read_addr);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes);
}
l1_write_addr += stick_nbytes;
}
l1_read_addr += stick_nbytes;
}
// move to the next input stick
l1_read_addr += stick_nbytes;
// move to the next output stick
l1_write_addr_stick += stick_nbytes * scale_w;

// if this is the end of a row, skip the next (scale_h - 1) rows
l1_write_addr_stick += (i == (in_w * (in_stick_row_id + 1) - 1)) * out_w * stick_nbytes * (scale_h - 1);
in_stick_row_id += (i == (in_w * (in_stick_row_id + 1) - 1));

noc_async_write_barrier();
cb_push_back(out_cb_id, scale_h * scale_w);
// Duplicate the whole image row in one shot
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start);
noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes);
}
l1_write_addr += out_image_row_nbytes;
}

// cb_pop_front(in_cb_id, in_nsticks_local);
cb_push_back(out_cb_id, out_w);

noc_async_write_barrier();
noc_async_read_barrier();
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
// TODO: Support non-multiple case
TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core);
TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core);
TT_FATAL(input_nsticks_per_core % in_w == 0);

// CBs

Expand Down Expand Up @@ -109,20 +110,29 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
std::vector<uint32_t> writer_compile_time_args = {
in_cb_id,
out_cb_id,
false,
};
auto writer_kernel_fname = std::string("tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp");
auto writer_kernel =
CreateKernel(program, writer_kernel_fname, all_cores, WriterDataMovementConfig(writer_compile_time_args));

// no reader kernel
std::vector<uint32_t> reader_compile_time_args = {
in_cb_id,
out_cb_id,
true,
};
auto reader_kernel_fname = std::string("tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp");
auto reader_kernel =
CreateKernel(program, reader_kernel_fname, all_cores, ReaderDataMovementConfig(reader_compile_time_args));

// no compute kernel

// runtime args

uint32_t writer_nargs = 7;
vector<uint32_t> writer_rt_args(writer_nargs);
writer_rt_args[0] = input_stick_nbytes;
writer_rt_args[1] = input_nsticks_per_core;
writer_rt_args[1] = input_nsticks_per_core / in_w;
writer_rt_args[2] = scale_factor_h;
writer_rt_args[3] = scale_factor_w;
writer_rt_args[4] = in_w;
Expand All @@ -136,6 +146,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
CoreCoord core_coord(core_x, core); // logical
writer_rt_args[6] = start_input_stick_id;
SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args);
SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args);
}
start_input_stick_id += input_nsticks_per_core;
}
Expand All @@ -144,6 +155,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor&
CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical
writer_rt_args[6] = start_input_stick_id;
SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args);
SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args);
start_input_stick_id += input_nsticks_per_core;
}
} else {
Expand Down

0 comments on commit 0fb2138

Please sign in to comment.