From 8c0efe86b12e63405ca5a11ae45dcee5b5820ca2 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 19 Nov 2024 15:58:28 +0000 Subject: [PATCH] Add support for resharding width-sharded tensors to/from DRAM This commit introduces a new width-shard to width-shard reshard kernel, modeled after the height-sharded tensor reshard special case implemented in `reshard_multi_core_same_width`. Supported operations include: - L1 to L1 - L1 to DRAM - DRAM to L1 Currently, only row-major tensors are supported. For unsupported cases, we fall back to the generalized reshard implementation. Unit tests have been added to validate the new kernel functionality. --- .../unit_testing/misc/test_reshard.py | 157 +++++++++++- .../dataflow/reshard_same_height_reader.cpp | 40 +++ .../dataflow/reshard_same_height_writer.cpp | 40 +++ .../sharded/reshard/device/reshard_op.cpp | 26 +- .../device/reshard_program_factory.cpp | 230 +++++++++++++++++- 5 files changed, 483 insertions(+), 10 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py index f976ecb1d11..79c8408fa96 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py @@ -170,6 +170,30 @@ def run_reshard_test( ttnn.ShardOrientation.ROW_MAJOR, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ), + ( + [1, 1, 4, 256], + ttnn.ROW_MAJOR_LAYOUT, + [[(0, 0), (3, 0)]], + (4, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + [[(0, 0), (0, 1)]], + (4, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), + ( + [1, 1, 4, 256], + ttnn.ROW_MAJOR_LAYOUT, + [[(0, 0), (1, 0)]], + (4, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + [[(0, 0), (0, 7)]], + (4, 32), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ), ], ) @pytest.mark.parametrize("tt_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -434,6 +458,62 @@ def test_reshard_with_program_cache( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, ), + ( + [1, 1, 1, 96], + ttnn.ROW_MAJOR_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(2, 0))}), + (1, 32), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), + (1, 48), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ), + ( + [1, 1, 32, 512], + ttnn.ROW_MAJOR_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}), + (32, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), + (32, 256), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ), + ( + [1, 1, 2, 256], + ttnn.ROW_MAJOR_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), + (2, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}), + (2, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ), + ( + [1, 1, 16, 256], + ttnn.ROW_MAJOR_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), + (16, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}), + (16, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ), ], ) def test_dram_reshard( @@ -466,4 +546,79 @@ def test_dram_reshard( passing, output_log = comp_equal(input, output) - assert passing, output_log + +@skip_for_blackhole("GH Issue #15234") +@pytest.mark.parametrize( + "input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, input_buffer_type, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme, output_buffer_type", + [ + ( # tests reshard_multi_core_same_width + [1, 1, 768, 64], + ttnn.TILE_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}), + (96, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.DRAM, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}), + (32, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + ), + ( # test reshard_multi_core_same_height + [1, 1, 16, 256], + ttnn.ROW_MAJOR_LAYOUT, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), + (16, 128), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}), + (16, 64), + ttnn.ShardOrientation.ROW_MAJOR, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ), + ], +) +def test_dram_reshard_with_program_cache( + use_program_cache, + device, + input_shape, + input_layout, + input_shard_grid, + input_shard_shape, + input_shard_orientation, + input_sharding_scheme, + input_buffer_type, + output_shard_grid, + output_shard_shape, + output_shard_orientation, + output_sharding_scheme, + output_buffer_type, +): + dtype = ttnn.bfloat8_b + for _ in range(4): + dummy_tensor = ( + ttnn.Tensor(torch.rand([1, 1, 128, 512]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG) + ) + test_dram_reshard( + device, + input_shape, + input_layout, + input_shard_grid, + input_shard_shape, + input_shard_orientation, + input_sharding_scheme, + input_buffer_type, + output_shard_grid, + output_shard_shape, + output_shard_orientation, + output_sharding_scheme, + output_buffer_type, + ) + dummy_tensor = ( + ttnn.Tensor(torch.rand([2, 2, 128, 64]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG) + ) + + assert device.num_program_cache_entries() == 1 diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp new file mode 100644 index 00000000000..174f71e22b7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); + + const uint32_t total_num_sticks = get_arg_val(0); + const uint32_t local_stride_bytes = get_arg_val(1); + const uint32_t remote_stride_bytes = get_arg_val(2); + const uint32_t base_read_addr = get_arg_val(3); + const uint32_t num_segments = get_arg_val(4); + + uint32_t args_idx = 0; + tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5)); + + uint32_t base_write_addr = get_read_ptr(shard_cb_id); + + for (uint32_t i = 0; i < num_segments; ++i) { + uint32_t read_size = args[args_idx++]; + + uint32_t write_offset = args[args_idx++]; + uint32_t l1_write_addr = base_write_addr + write_offset; + + uint32_t x_coord = args[args_idx++]; + uint32_t y_coord = args[args_idx++]; + uint32_t read_offset = base_read_addr + args[args_idx++]; + uint64_t noc_read_addr = get_noc_addr(x_coord, y_coord, read_offset); + + for (uint32_t j = 0; j < total_num_sticks; ++j) { + noc_async_read(noc_read_addr, l1_write_addr, read_size); + l1_write_addr += local_stride_bytes; + noc_read_addr += remote_stride_bytes; + } + } + noc_async_write_barrier(); +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp new file mode 100644 index 00000000000..a52753903a5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); + + const uint32_t total_num_sticks = get_arg_val(0); + const uint32_t local_stride_bytes = get_arg_val(1); + const uint32_t remote_stride_bytes = get_arg_val(2); + const uint32_t base_write_addr = get_arg_val(3); + const uint32_t num_segments = get_arg_val(4); + + uint32_t args_idx = 0; + tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5)); + + uint32_t base_l1_read_addr = get_read_ptr(shard_cb_id); + + for (uint32_t i = 0; i < num_segments; ++i) { + uint32_t write_size = args[args_idx++]; + + uint32_t read_offset = args[args_idx++]; + uint32_t l1_read_addr = base_l1_read_addr + read_offset; + + uint32_t x_coord = args[args_idx++]; + uint32_t y_coord = args[args_idx++]; + uint32_t write_offset = base_write_addr + args[args_idx++]; + uint64_t noc_write_addr = get_noc_addr(x_coord, y_coord, write_offset); + + for (uint32_t j = 0; j < total_num_sticks; ++j) { + noc_async_write(l1_read_addr, noc_write_addr, write_size); + l1_read_addr += local_stride_bytes; + noc_write_addr += remote_stride_bytes; + } + } + noc_async_write_barrier(); +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp index 754fd60b784..6ec78ed579b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp @@ -5,10 +5,11 @@ #include "reshard_op.hpp" #include -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" + #include "reshard_program_factory.hpp" +#include "tt_metal/common/constants.hpp" #include "tt_metal/common/work_split.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; using namespace tt::tt_metal; @@ -21,6 +22,7 @@ void ReshardDeviceOperation::validate_with_output_tensors( TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to shard need to be on device!"); TT_FATAL(input_tensor.buffer() != nullptr, "Operands to shard need to be allocated in buffers on device!"); TT_FATAL(input_tensor.is_sharded(), "input must be sharded"); + bool has_output_tensor = output_tensors.size() == 1 && output_tensors[0].has_value(); if (has_output_tensor) { const auto& output_tensor = output_tensors[0].value(); @@ -31,19 +33,33 @@ void ReshardDeviceOperation::validate_with_output_tensors( const auto& out_mem_config = has_output_tensor ? output_tensors[0].value().memory_config() : this->output_mem_config; TT_FATAL(out_mem_config.is_sharded(), "output must be sharded"); + if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { TT_FATAL( (input_tensor.memory_config().buffer_type == BufferType::L1 || out_mem_config.buffer_type == BufferType::L1), "Resharding height shard to height shard must have at least one buffer in L1"); + } else if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + out_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) { + TT_FATAL( + (input_tensor.memory_config().buffer_type == BufferType::L1 || + out_mem_config.buffer_type == BufferType::L1), + "Resharding width shard to width shard must have at least one buffer in L1"); } else { TT_FATAL(out_mem_config.buffer_type == BufferType::L1, "Resharding requires output buffer to be in L1"); } + if (input_tensor.get_layout() == Layout::ROW_MAJOR) { - bool same_row_size = - input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1]; - TT_FATAL(same_row_size, "row major must have shard_spec[1] be the same on both input and output"); + if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + bool same_row_size = + input_tensor.memory_config().shard_spec.value().shape[0] == out_mem_config.shard_spec.value().shape[0]; + TT_FATAL(same_row_size, "row major must have shard_spec[0] be the same on both input and output"); + } else { + bool same_height_size = + input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1]; + TT_FATAL(same_height_size, "row major must have shard_spec[1] be the same on both input and output"); + } } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp index 847321782cc..713e8aebb4e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp @@ -2,15 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "reshard_program_factory.hpp" + #include -#include "ttnn/operations/math.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" -#include "reshard_program_factory.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.hpp" +#include "ttnn/operations/math.hpp" using namespace tt::constants; using namespace tt::tt_metal; @@ -536,6 +537,218 @@ operation::ProgramWithCallbacks reshard_multi_core_generic(const Tensor& input, return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } +struct WidthShardedRuntimeArgs { + uint32_t write_size; + uint32_t read_offset; + uint32_t x_coord; + uint32_t y_coord; + uint32_t write_offset; +}; + +std::tuple>, uint32_t, uint32_t, uint32_t> +compute_width_sharded_reshard_runtime_args( + const std::array& local_shard_shape, + const std::array& remote_shard_shape, + const std::vector& local_cores, + const std::vector& remote_cores, + const BufferType& remote_buffer_type, + const CoreType& remote_core_type, + Device* device, + uint32_t element_size) { + const uint32_t num_local_shards = local_cores.size(); + const uint32_t num_remote_shards = remote_cores.size(); + + const uint32_t local_shard_height = local_shard_shape[0]; + const uint32_t local_shard_width = local_shard_shape[1]; + const uint32_t remote_shard_height = remote_shard_shape[0]; + const uint32_t remote_shard_width = remote_shard_shape[1]; + + using WidthShardedRuntimeArgsForSingleCore = std::vector; + + TT_FATAL(local_shard_height == remote_shard_height, "Unexpected mismatch in shard heights"); + TT_FATAL( + local_shard_width * num_local_shards == remote_shard_width * num_remote_shards, + "Unexpected mismatch in tensor widths"); + + const uint32_t total_num_sticks = local_shard_height; + const uint32_t local_stride_bytes = element_size * local_shard_width; + const uint32_t remote_stride_bytes = element_size * remote_shard_width; + const uint32_t total_elements_bytes = local_shard_width * num_local_shards; + + std::vector runtime_args_for_each_core; + + uint32_t local_shard_offset = 0; + uint32_t remote_shard_offset = 0; + uint32_t current_remote_core_idx = 0; + uint32_t total_bytes_to_transfer = 0; + for (const auto& core : local_cores) { + WidthShardedRuntimeArgsForSingleCore core_args; + while (local_shard_offset < local_shard_width) { + const uint32_t remaining_input = local_shard_width - local_shard_offset; + const uint32_t remaining_output = remote_shard_width - remote_shard_offset; + const uint32_t transfer_size = std::min(remaining_input, remaining_output); + + auto bank_id = + device->bank_ids_from_logical_core(remote_buffer_type, remote_cores[current_remote_core_idx])[0]; + auto bank_offset = device->bank_offset(remote_buffer_type, bank_id); + const auto& remote_core = + device->physical_core_from_logical_core(remote_cores[current_remote_core_idx], remote_core_type); + + core_args.emplace_back( + element_size * transfer_size, + element_size * local_shard_offset, + remote_core.x, + remote_core.y, + element_size * remote_shard_offset + bank_offset); + + local_shard_offset += transfer_size; + remote_shard_offset += transfer_size; + total_bytes_to_transfer += transfer_size; + + // If the current output shard is full, move to the next one + if (remote_shard_offset == remote_shard_width) { + ++current_remote_core_idx; + remote_shard_offset = 0; + } + } + local_shard_offset = 0; + runtime_args_for_each_core.push_back(core_args); + } + + TT_FATAL( + runtime_args_for_each_core.size() == num_local_shards, + "Expect to have one set of runtime args per local core"); // sanity check + TT_FATAL( + total_bytes_to_transfer == total_elements_bytes, + "Expect to transfer all elements from input to output"); // sanity check + + return {runtime_args_for_each_core, total_num_sticks, local_stride_bytes, remote_stride_bytes}; +} + +template +operation::ProgramWithCallbacks reshard_multi_core_same_height(const Tensor& input, Tensor& output) { + auto device = input.device(); + + tt::tt_metal::Program program{}; + + const auto& local_tensor = is_reader ? output : input; + const auto& remote_tensor = is_reader ? input : output; + + const auto local_shard_spec = local_tensor.shard_spec().value(); + const auto remote_shard_spec = remote_tensor.shard_spec().value(); + const auto& all_cores = local_shard_spec.grid; + + const auto local_core_type = local_tensor.buffer()->core_type(); + const auto remote_core_type = remote_tensor.buffer()->core_type(); + + const auto local_cores = corerange_to_cores( + local_shard_spec.grid, std::nullopt, local_shard_spec.orientation == ShardOrientation::ROW_MAJOR); + const auto remote_cores = corerange_to_cores( + remote_shard_spec.grid, std::nullopt, remote_shard_spec.orientation == ShardOrientation::ROW_MAJOR); + + const auto data_format = tt::tt_metal::datatype_to_dataformat_converter(local_tensor.get_dtype()); + const uint32_t element_size = tt::datum_size(data_format); + + TT_FATAL(local_tensor.get_layout() == Layout::ROW_MAJOR, "Expected row major tensor"); + const uint32_t unit_size = local_shard_spec.shape[1] * local_tensor.element_size(); // width * element size + const uint32_t local_units_per_shard = local_shard_spec.shape[0]; // height + const uint32_t remote_units_per_shard = remote_shard_spec.shape[0]; // height + const uint32_t total_size = remote_units_per_shard * unit_size; + + constexpr uint32_t cb_index = tt::CBIndex::c_0; + tt::tt_metal::CircularBufferConfig cb_config = + tt::tt_metal::CircularBufferConfig(total_size, {{cb_index, data_format}}) + .set_page_size(cb_index, unit_size) + .set_globally_allocated_address(*local_tensor.buffer()); + auto cb_0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_config); + + const std::string kernel_name = + is_reader + ? "ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_reader.cpp" + : "ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_height_writer.cpp"; + + tt::tt_metal::KernelHandle kernel_id_0 = + tt::tt_metal::CreateKernel(program, kernel_name, all_cores, tt::tt_metal::ReaderDataMovementConfig({cb_index})); + + tt::tt_metal::KernelHandle kernel_id_1 = + tt::tt_metal::CreateKernel(program, kernel_name, all_cores, tt::tt_metal::WriterDataMovementConfig({cb_index})); + + uint32_t remote_address = remote_tensor.buffer()->address(); + auto remote_buffer_type = remote_tensor.buffer()->buffer_type(); + + // Generate all read/write offsets for each core + auto [runtime_args_for_each_core, total_num_sticks, local_stride_bytes, remote_stride_bytes] = + compute_width_sharded_reshard_runtime_args( + local_shard_spec.shape, + remote_shard_spec.shape, + local_cores, + remote_cores, + remote_buffer_type, + remote_core_type, + device, + element_size); // local_core_idx -> runtime args[] + + // Split work across each kernel along tensor height since this is the best way to split work evenly + const uint32_t total_num_sticks_kernel_0 = total_num_sticks / 2; + const uint32_t total_num_sticks_kernel_1 = total_num_sticks - total_num_sticks_kernel_0; + + // Here all we do is convert pre-computed offsets into vectors so they can be passed as runtime arguments + for (uint32_t core_idx = 0; core_idx < local_cores.size(); core_idx++) { + const auto& args_for_all_segments = runtime_args_for_each_core[core_idx]; + std::vector runtime_args_0 = { + total_num_sticks_kernel_0, + local_stride_bytes, + remote_stride_bytes, + remote_address, + args_for_all_segments.size()}; + std::vector runtime_args_1 = { + total_num_sticks_kernel_1, + local_stride_bytes, + remote_stride_bytes, + remote_address, + args_for_all_segments.size()}; + for (const auto& args : args_for_all_segments) { + const std::vector segment_kernel_0 = { + args.write_size, args.read_offset, args.x_coord, args.y_coord, args.write_offset}; + runtime_args_0.insert(runtime_args_0.end(), segment_kernel_0.begin(), segment_kernel_0.end()); + + // Adjust read and write offsets to the correct stick address because we are splitting work across 2 kernels + const uint32_t adjusted_read_offset = args.read_offset + total_num_sticks_kernel_0 * local_stride_bytes; + const uint32_t adjusted_write_offset = args.write_offset + total_num_sticks_kernel_0 * remote_stride_bytes; + + const std::vector segment_kernel_1 = { + args.write_size, adjusted_read_offset, args.x_coord, args.y_coord, adjusted_write_offset}; + runtime_args_1.insert(runtime_args_1.end(), segment_kernel_1.begin(), segment_kernel_1.end()); + } + SetRuntimeArgs(program, kernel_id_0, local_cores[core_idx], runtime_args_0); + SetRuntimeArgs(program, kernel_id_1, local_cores[core_idx], runtime_args_1); + } + + auto override_runtime_arguments_callback = [kernel_id_0, kernel_id_1, cb_0, local_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { + const auto& input = input_tensors.at(0); + const auto& output = output_tensors.at(0); + const auto& local_tensor = is_reader ? output : input; + const auto& remote_tensor = is_reader ? input : output; + uint32_t remote_address = remote_tensor.buffer()->address(); + auto& runtime_args_0_by_core = GetRuntimeArgs(program, kernel_id_0); + auto& runtime_args_1_by_core = GetRuntimeArgs(program, kernel_id_1); + for (auto core : local_cores) { + auto& runtime_args_0 = runtime_args_0_by_core[core.x][core.y]; + auto& runtime_args_1 = runtime_args_1_by_core[core.x][core.y]; + runtime_args_0[3] = remote_address; + runtime_args_1[3] = remote_address; + } + UpdateDynamicCircularBufferAddress(program, cb_0, *local_tensor.buffer()); + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + operation::ProgramWithCallbacks reshard_multi_core(const Tensor& input, Tensor& output) { if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && output.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { @@ -544,6 +757,15 @@ operation::ProgramWithCallbacks reshard_multi_core(const Tensor& input, Tensor& } else { return reshard_multi_core_same_width(input, output); } + } else if ( + input.layout() == Layout::ROW_MAJOR && + input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + output.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + if (output.memory_config().buffer_type == BufferType::L1) { + return reshard_multi_core_same_height(input, output); + } else { + return reshard_multi_core_same_height(input, output); + } } else { return reshard_multi_core_generic(input, output); }