Skip to content

Commit

Permalink
Add support for resharding width-sharded tensors to/from DRAM
Browse files Browse the repository at this point in the history
This commit introduces a new width-shard to width-shard reshard kernel,
modeled after the height-sharded tensor reshard special case implemented
in `reshard_multi_core_same_width`.

Supported operations include:

- L1 to L1
- L1 to DRAM
- DRAM to L1

Currently, only row-major tensors are supported. For unsupported cases,
we fall back to the generalized reshard implementation.

Unit tests have been added to validate the new kernel functionality.
  • Loading branch information
esmalTT committed Nov 28, 2024
1 parent 4584bc3 commit 8c0efe8
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 10 deletions.
157 changes: 156 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def run_reshard_test(
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (3, 0)]],
(4, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 1)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (1, 0)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 7)]],
(4, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
],
)
@pytest.mark.parametrize("tt_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -434,6 +458,62 @@ def test_reshard_with_program_cache(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
(
[1, 1, 1, 96],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(2, 0))}),
(1, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(1, 48),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 32, 512],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(32, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(32, 256),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 2, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(2, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(2, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard(
Expand Down Expand Up @@ -466,4 +546,79 @@ def test_dram_reshard(

passing, output_log = comp_equal(input, output)

assert passing, output_log

@skip_for_blackhole("GH Issue #15234")
@pytest.mark.parametrize(
"input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, input_buffer_type, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme, output_buffer_type",
[
( # tests reshard_multi_core_same_width
[1, 1, 768, 64],
ttnn.TILE_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}),
(96, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}),
(32, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
( # test reshard_multi_core_same_height
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard_with_program_cache(
use_program_cache,
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
):
dtype = ttnn.bfloat8_b
for _ in range(4):
dummy_tensor = (
ttnn.Tensor(torch.rand([1, 1, 128, 512]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)
test_dram_reshard(
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
)
dummy_tensor = (
ttnn.Tensor(torch.rand([2, 2, 128, 64]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)

assert device.num_program_cache_entries() == 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_read_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_write_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t read_size = args[args_idx++];

uint32_t write_offset = args[args_idx++];
uint32_t l1_write_addr = base_write_addr + write_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t read_offset = base_read_addr + args[args_idx++];
uint64_t noc_read_addr = get_noc_addr(x_coord, y_coord, read_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_read(noc_read_addr, l1_write_addr, read_size);
l1_write_addr += local_stride_bytes;
noc_read_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_write_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_l1_read_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t write_size = args[args_idx++];

uint32_t read_offset = args[args_idx++];
uint32_t l1_read_addr = base_l1_read_addr + read_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t write_offset = base_write_addr + args[args_idx++];
uint64_t noc_write_addr = get_noc_addr(x_coord, y_coord, write_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_write(l1_read_addr, noc_write_addr, write_size);
l1_read_addr += local_stride_bytes;
noc_write_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include "reshard_op.hpp"

#include <magic_enum.hpp>
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"

#include "reshard_program_factory.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/work_split.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::constants;
using namespace tt::tt_metal;
Expand All @@ -21,6 +22,7 @@ void ReshardDeviceOperation::validate_with_output_tensors(
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to shard need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to shard need to be allocated in buffers on device!");
TT_FATAL(input_tensor.is_sharded(), "input must be sharded");

bool has_output_tensor = output_tensors.size() == 1 && output_tensors[0].has_value();
if (has_output_tensor) {
const auto& output_tensor = output_tensors[0].value();
Expand All @@ -31,19 +33,33 @@ void ReshardDeviceOperation::validate_with_output_tensors(
const auto& out_mem_config =
has_output_tensor ? output_tensors[0].value().memory_config() : this->output_mem_config;
TT_FATAL(out_mem_config.is_sharded(), "output must be sharded");

if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding height shard to height shard must have at least one buffer in L1");
} else if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding width shard to width shard must have at least one buffer in L1");
} else {
TT_FATAL(out_mem_config.buffer_type == BufferType::L1, "Resharding requires output buffer to be in L1");
}

if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_row_size, "row major must have shard_spec[1] be the same on both input and output");
if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[0] == out_mem_config.shard_spec.value().shape[0];
TT_FATAL(same_row_size, "row major must have shard_spec[0] be the same on both input and output");
} else {
bool same_height_size =
input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_height_size, "row major must have shard_spec[1] be the same on both input and output");
}
}
}

Expand Down
Loading

0 comments on commit 8c0efe8

Please sign in to comment.