Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for resharding width-sharded tensors to/from DRAM #15526

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def run_reshard_test(
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (3, 0)]],
(4, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 1)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
(
[1, 1, 4, 256],
ttnn.ROW_MAJOR_LAYOUT,
[[(0, 0), (1, 0)]],
(4, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 7)]],
(4, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
],
)
@pytest.mark.parametrize("tt_dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
Expand Down Expand Up @@ -434,6 +458,62 @@ def test_reshard_with_program_cache(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
(
[1, 1, 1, 96],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(2, 0))}),
(1, 32),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(1, 48),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 32, 512],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(32, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(32, 256),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 2, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(2, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(2, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
),
(
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard(
Expand Down Expand Up @@ -466,4 +546,79 @@ def test_dram_reshard(

passing, output_log = comp_equal(input, output)

assert passing, output_log

@skip_for_blackhole("GH Issue #15234")
@pytest.mark.parametrize(
"input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, input_buffer_type, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme, output_buffer_type",
[
( # tests reshard_multi_core_same_width
[1, 1, 768, 64],
ttnn.TILE_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}),
(96, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}),
(32, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
),
( # test reshard_multi_core_same_height
[1, 1, 16, 256],
ttnn.ROW_MAJOR_LAYOUT,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}),
(16, 128),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 0))}),
(16, 64),
ttnn.ShardOrientation.ROW_MAJOR,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
),
],
)
def test_dram_reshard_with_program_cache(
use_program_cache,
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
):
dtype = ttnn.bfloat8_b
for _ in range(4):
dummy_tensor = (
ttnn.Tensor(torch.rand([1, 1, 128, 512]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)
test_dram_reshard(
device,
input_shape,
input_layout,
input_shard_grid,
input_shard_shape,
input_shard_orientation,
input_sharding_scheme,
input_buffer_type,
output_shard_grid,
output_shard_shape,
output_shard_orientation,
output_sharding_scheme,
output_buffer_type,
)
dummy_tensor = (
ttnn.Tensor(torch.rand([2, 2, 128, 64]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)

assert device.num_program_cache_entries() == 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_read_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_write_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t read_size = args[args_idx++];

uint32_t write_offset = args[args_idx++];
uint32_t l1_write_addr = base_write_addr + write_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t read_offset = base_read_addr + args[args_idx++];
uint64_t noc_read_addr = get_noc_addr(x_coord, y_coord, read_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_read(noc_read_addr, l1_write_addr, read_size);
l1_write_addr += local_stride_bytes;
noc_read_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0);

const uint32_t total_num_sticks = get_arg_val<uint32_t>(0);
const uint32_t local_stride_bytes = get_arg_val<uint32_t>(1);
const uint32_t remote_stride_bytes = get_arg_val<uint32_t>(2);
const uint32_t base_write_addr = get_arg_val<uint32_t>(3);
const uint32_t num_segments = get_arg_val<uint32_t>(4);

uint32_t args_idx = 0;
tt_l1_ptr uint32_t* args = (tt_l1_ptr uint32_t*)(get_arg_addr(5));

uint32_t base_l1_read_addr = get_read_ptr(shard_cb_id);

for (uint32_t i = 0; i < num_segments; ++i) {
uint32_t write_size = args[args_idx++];

uint32_t read_offset = args[args_idx++];
uint32_t l1_read_addr = base_l1_read_addr + read_offset;

uint32_t x_coord = args[args_idx++];
uint32_t y_coord = args[args_idx++];
uint32_t write_offset = base_write_addr + args[args_idx++];
uint64_t noc_write_addr = get_noc_addr(x_coord, y_coord, write_offset);

for (uint32_t j = 0; j < total_num_sticks; ++j) {
noc_async_write(l1_read_addr, noc_write_addr, write_size);
l1_read_addr += local_stride_bytes;
noc_write_addr += remote_stride_bytes;
}
}
noc_async_write_barrier();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include "reshard_op.hpp"

#include <magic_enum.hpp>
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"

#include "reshard_program_factory.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/work_split.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::constants;
using namespace tt::tt_metal;
Expand All @@ -21,6 +22,7 @@ void ReshardDeviceOperation::validate_with_output_tensors(
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to shard need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "Operands to shard need to be allocated in buffers on device!");
TT_FATAL(input_tensor.is_sharded(), "input must be sharded");

bool has_output_tensor = output_tensors.size() == 1 && output_tensors[0].has_value();
if (has_output_tensor) {
const auto& output_tensor = output_tensors[0].value();
Expand All @@ -31,19 +33,33 @@ void ReshardDeviceOperation::validate_with_output_tensors(
const auto& out_mem_config =
has_output_tensor ? output_tensors[0].value().memory_config() : this->output_mem_config;
TT_FATAL(out_mem_config.is_sharded(), "output must be sharded");

if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding height shard to height shard must have at least one buffer in L1");
} else if ((input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED &&
out_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)) {
TT_FATAL(
(input_tensor.memory_config().buffer_type == BufferType::L1 ||
out_mem_config.buffer_type == BufferType::L1),
"Resharding width shard to width shard must have at least one buffer in L1");
} else {
TT_FATAL(out_mem_config.buffer_type == BufferType::L1, "Resharding requires output buffer to be in L1");
}

if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_row_size, "row major must have shard_spec[1] be the same on both input and output");
if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
bool same_row_size =
input_tensor.memory_config().shard_spec.value().shape[0] == out_mem_config.shard_spec.value().shape[0];
TT_FATAL(same_row_size, "row major must have shard_spec[0] be the same on both input and output");
} else {
bool same_height_size =
input_tensor.memory_config().shard_spec.value().shape[1] == out_mem_config.shard_spec.value().shape[1];
TT_FATAL(same_height_size, "row major must have shard_spec[1] be the same on both input and output");
}
}
}

Expand Down
Loading
Loading