Skip to content

Commit

Permalink
Add support for resharding width-sharded tensors to/from DRAM
Browse files Browse the repository at this point in the history
This commit introduces a new width-shard to width-shard reshard kernel,
modeled after the height-sharded tensor reshard special case implemented
in `reshard_multi_core_same_width`.

Supported operations include:

- L1 to L1
- L1 to DRAM
- DRAM to L1

Currently, only row-major tensors are supported. For unsupported cases,
we fall back to the generalized reshard implementation.

Unit tests have been added to validate the new kernel functionality.
  • Loading branch information
esmalTT committed Nov 28, 2024
1 parent f3108ee commit 2440b50
Show file tree
Hide file tree
Showing 5 changed files with 511 additions and 40 deletions.
157 changes: 156 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def run_reshard_test(
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (3, 0)]],
(4, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 1)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (1, 0)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 7)]],
(4, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
],
)
@pytest.mark.parametrize("tt_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -434,6 +458,62 @@ def test_reshard_with_program_cache(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
(
[1, 1, 1, 96],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(2, 0))}),
(1, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(1, 48),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 32, 512],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(32, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(32, 256),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 2, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(2, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(2, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard(
Expand Down Expand Up @@ -466,4 +546,79 @@ def test_dram_reshard(

passing, output_log = comp_equal(input, output)

assert passing, output_log

@skip_for_blackhole("GH Issue #15234")
@pytest.mark.parametrize(
"input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, input_buffer_type, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme, output_buffer_type",
[
( # tests reshard_multi_core_same_width
[1, 1, 768, 64],
ttnn.TILE_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}),
(96, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}),
(32, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
( # test reshard_multi_core_same_height
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard_with_program_cache(
use_program_cache,
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
):
dtype = ttnn.bfloat8_b
for _ in range(4):
dummy_tensor = (
ttnn.Tensor(torch.rand([1, 1, 128, 512]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)
test_dram_reshard(
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
)
dummy_tensor = (
ttnn.Tensor(torch.rand([2, 2, 128, 64]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)

assert device.num_program_cache_entries() == 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_read_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_write_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t read_size = args[args_idx++];

uint32_t write_offset = args[args_idx++];
uint32_t l1_write_addr = base_write_addr + write_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t read_offset = base_read_addr + args[args_idx++];
uint64_t noc_read_addr = get_noc_addr(x_coord, y_coord, read_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_read(noc_read_addr, l1_write_addr, read_size);
l1_write_addr += local_stride_bytes;
noc_read_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_write_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_l1_read_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t write_size = args[args_idx++];

uint32_t read_offset = args[args_idx++];
uint32_t l1_read_addr = base_l1_read_addr + read_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t write_offset = base_write_addr + args[args_idx++];
uint64_t noc_write_addr = get_noc_addr(x_coord, y_coord, write_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_write(l1_read_addr, noc_write_addr, write_size);
l1_read_addr += local_stride_bytes;
noc_write_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
#include "reshard_op.hpp"

#include <magic_enum.hpp>
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"

#include "reshard_program_factory.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/work_split.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::constants;
using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {

void ReshardDeviceOperation::validate_with_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
void ReshardDeviceOperation::validate_with_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
const auto& input_tensor = input_tensors.at(0);
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to shard need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to shard need to be allocated in buffers on device!");
Expand All @@ -27,35 +29,53 @@ void ReshardDeviceOperation::validate_with_output_tensors(const std::vector<Tens
TT_FATAL(input_tensor.get_dtype() == output_tensor.get_dtype(), "Error");
TT_FATAL(input_tensor.get_layout() == output_tensor.get_layout(), "Error");
}
const auto& out_mem_config = has_output_tensor ? output_tensors[0].value().memory_config() : this->output_mem_config;
const auto& out_mem_config =
has_output_tensor ? output_tensors[0].value().memory_config() : this->output_mem_config;
TT_FATAL(out_mem_config.is_sharded(), "output must be sharded");
if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
TT_FATAL((input_tensor.memory_config().buffer_type == BufferType::L1 || out_mem_config.buffer_type == BufferType::L1), "Resharding height shard to height shard must have at least one buffer in L1");
out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding height shard to height shard must have at least one buffer in L1");
} else if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding width shard to width shard must have at least one buffer in L1");
} else {
TT_FATAL(out_mem_config.buffer_type == BufferType::L1, "Resharding requires output buffer to be in L1");
}
if(input_tensor.get_layout() == Layout::ROW_MAJOR) {
bool same_row_size = input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_row_size, "row major must have shard_spec[1] be the same on both input and output");
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[0] == out_mem_config.shard_spec.value().shape[0];
TT_FATAL(same_row_size, "row major must have shard_spec[0] be the same on both input and output");
} else {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_row_size, "row major must have shard_spec[1] be the same on both input and output");
}
}
}

std::vector<tt::tt_metal::LegacyShape> ReshardDeviceOperation::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<tt::tt_metal::LegacyShape> ReshardDeviceOperation::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor.get_legacy_shape()};
}

operation::ProgramWithCallbacks ReshardDeviceOperation::create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const {

const auto& input_tensor = input_tensors.at(0);
auto& output_tensor = output_tensors.at(0);
//each tensor has its respective shard_spec within its memory_config
// each tensor has its respective shard_spec within its memory_config
return detail::reshard_multi_core(input_tensor, output_tensor);
}

std::vector<Tensor> ReshardDeviceOperation::create_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
std::vector<Tensor> ReshardDeviceOperation::create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (output_tensors.size() == 1 && output_tensors[0].has_value()) {
return {output_tensors[0].value()};
Expand All @@ -67,10 +87,8 @@ std::vector<Tensor> ReshardDeviceOperation::create_output_tensors(const std::vec
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
mem_config
)};
mem_config)};
}
}


} // namespace ttnn::operations::data_movement
Loading

0 comments on commit 2440b50

Please sign in to comment.