From 0fb2138d06f6ec09f4bac70b938d51531c674dc8 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Sat, 16 Mar 2024 16:41:44 -0500 Subject: [PATCH] #6462: Upsample kernel opt Introduced a reader which just trivially splits the work, reader does first half, writer does second half. Allocate the full output buffer up front, barrier once at end of kernel. --- .../unit_tests/operations/test_upsample.py | 14 ++-- .../writer_upsample_multi_core_sharded.cpp | 69 +++++++++++-------- .../multi_core/upsample_op_multi_core.cpp | 16 ++++- 3 files changed, 60 insertions(+), 39 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 4cf02d42297..1b4cdf5eeab 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -11,14 +11,14 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import skip_for_wormhole_b0 TILE_WIDTH = 32 -def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet: - max_grid_size = (9, 12) ## (y, x) +def get_shard_grid_from_num_cores(device, ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet: + device_grid = device.compute_with_storage_grid_size() + max_grid_size = (device_grid.y, device_grid.x) if isinstance(ncores, int): if ncores % max_grid_size[1] == 0: core_grid = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1]) @@ -62,7 +62,6 @@ def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.e raise ValueError("Invalid ncores") -@skip_for_wormhole_b0() @pytest.mark.parametrize( "input_shapes", [ @@ -105,7 +104,6 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): assert isequal -@skip_for_wormhole_b0() @pytest.mark.parametrize( "input_shape", [ @@ -114,6 +112,7 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [2, 1280, 8, 8], # 512x512 [2, 1280, 16, 16], [1, 64, 132, 10], + [1, 32, 8, 8], ], ) @pytest.mark.parametrize("scale_h", [2]) @@ -137,7 +136,8 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate ## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape ncores = None - max_grid_size = (9, 12) ## (y, x) + device_grid = device.compute_with_storage_grid_size() + max_grid_size = (device_grid.y, device_grid.x) if shard_strategy == ttnn.ShardStrategy.HEIGHT: ## nsticks per shard should be divisible by in_w max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1]) @@ -180,7 +180,7 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate # use_height_and_width_as_shard_shape=False ##shard_strategy == ttnn.ShardStrategy.HEIGHT, # ) - shard_grid = get_shard_grid_from_num_cores(ncores) + shard_grid = get_shard_grid_from_num_cores(device, ncores) shard_orientation = ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR if shard_strategy == ttnn.ShardStrategy.BLOCK: diff --git a/tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 044aa27b4f9..03530ea7433 100644 --- a/tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -8,50 +8,59 @@ void kernel_main() { uint32_t stick_nbytes = get_arg_val(0); - uint32_t in_nsticks_local = get_arg_val(1); + uint32_t in_image_rows_per_core = get_arg_val(1); uint32_t scale_h = get_arg_val(2); uint32_t scale_w = get_arg_val(3); uint32_t in_w = get_arg_val(4); uint32_t out_w = get_arg_val(5); - uint32_t start_in_stick_id = get_arg_val(6); constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = get_compile_time_arg_val(1); + constexpr uint32_t is_reader = get_compile_time_arg_val(2); - uint32_t l1_read_addr = get_read_ptr(in_cb_id); - uint32_t l1_write_addr = get_write_ptr(out_cb_id); + uint32_t in_image_row_nbytes = in_w * stick_nbytes; + uint32_t out_image_row_nbytes = out_w * stick_nbytes; + uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; + uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; + uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; + uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; + uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes; + uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes; - // cb_wait_front(in_cb_id, in_nsticks_local); + cb_reserve_back(out_cb_id, out_w); - uint32_t in_stick_row_id = start_in_stick_id / in_w; // assuming shard begins with a new row. TODO: generalize? - uint32_t l1_write_addr_stick = l1_write_addr; - // for each input stick - for (uint32_t i = start_in_stick_id; i < start_in_stick_id + in_nsticks_local; ++ i) { - cb_reserve_back(out_cb_id, scale_h * scale_w); - - uint32_t l1_write_addr_local = l1_write_addr_stick; - for (uint32_t j = 0; j < scale_h; ++j) { - l1_write_addr_local = l1_write_addr_stick + j * out_w * stick_nbytes; - // replicate stick scale_h times. - for (size_t k = 0; k < scale_w; ++k) { + // assuming shard begins with a new row. TODO: generalize? + for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) { + uint32_t l1_write_addr_image_row_start = l1_write_addr; + for (uint32_t i = 0; i < in_w; ++i) { + // replicate stick scale_w times. + for (uint32_t sw = 0; sw < scale_w; ++sw) { // replicate stick scale_w times. - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr_local); - noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes); - l1_write_addr_local += stick_nbytes; + if constexpr (is_reader) { + uint64_t src_noc_addr = get_noc_addr(l1_read_addr); + noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); + } else { + uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); + noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes); + } + l1_write_addr += stick_nbytes; } + l1_read_addr += stick_nbytes; } - // move to the next input stick - l1_read_addr += stick_nbytes; - // move to the next output stick - l1_write_addr_stick += stick_nbytes * scale_w; - - // if this is the end of a row, skip the next (scale_h - 1) rows - l1_write_addr_stick += (i == (in_w * (in_stick_row_id + 1) - 1)) * out_w * stick_nbytes * (scale_h - 1); - in_stick_row_id += (i == (in_w * (in_stick_row_id + 1) - 1)); - noc_async_write_barrier(); - cb_push_back(out_cb_id, scale_h * scale_w); + // Duplicate the whole image row in one shot + if constexpr (is_reader) { + uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start); + noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes); + } else { + uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); + noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes); + } + l1_write_addr += out_image_row_nbytes; } - // cb_pop_front(in_cb_id, in_nsticks_local); + cb_push_back(out_cb_id, out_w); + + noc_async_write_barrier(); + noc_async_read_barrier(); } diff --git a/tt_eager/tt_dnn/op_library/upsample/multi_core/upsample_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/upsample/multi_core/upsample_op_multi_core.cpp index 0c356d916c3..190cee30d82 100644 --- a/tt_eager/tt_dnn/op_library/upsample/multi_core/upsample_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/upsample/multi_core/upsample_op_multi_core.cpp @@ -69,6 +69,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // TODO: Support non-multiple case TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core); + TT_FATAL(input_nsticks_per_core % in_w == 0); // CBs @@ -109,12 +110,21 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& std::vector writer_compile_time_args = { in_cb_id, out_cb_id, + false, }; auto writer_kernel_fname = std::string("tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, WriterDataMovementConfig(writer_compile_time_args)); - // no reader kernel + std::vector reader_compile_time_args = { + in_cb_id, + out_cb_id, + true, + }; + auto reader_kernel_fname = std::string("tt_eager/tt_dnn/op_library/upsample/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); + auto reader_kernel = + CreateKernel(program, reader_kernel_fname, all_cores, ReaderDataMovementConfig(reader_compile_time_args)); + // no compute kernel // runtime args @@ -122,7 +132,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t writer_nargs = 7; vector writer_rt_args(writer_nargs); writer_rt_args[0] = input_stick_nbytes; - writer_rt_args[1] = input_nsticks_per_core; + writer_rt_args[1] = input_nsticks_per_core / in_w; writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; writer_rt_args[4] = in_w; @@ -136,6 +146,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& CoreCoord core_coord(core_x, core); // logical writer_rt_args[6] = start_input_stick_id; SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args); + SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args); } start_input_stick_id += input_nsticks_per_core; } @@ -144,6 +155,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical writer_rt_args[6] = start_input_stick_id; SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args); + SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args); start_input_stick_id += input_nsticks_per_core; } } else {