Skip to content

Commit

Permalink
Fix s2i op when shard grid is larger than actual used grid (#15113)
Browse files Browse the repository at this point in the history
### Ticket

#15004 (comment)
#14902

### Problem description
Bug occured when running a matmul + s2i ops test, s2i op does not take
account that some cores are unused in the passed in sharding grid, and
writes extra data to output tensor, causing L1 corruption.

Fix is to let s2i use the correct num_cores in shard spec, and let
mamtul produce the exact shard grid.


### Checklist
- [x] Post commit CI
https://github.com/tenstorrent/tt-metal/actions/runs/11894930806
- [x] Blackhole Post commit
https://github.com/tenstorrent/tt-metal/actions/runs/11862361147
- [x] nightly
https://github.com/tenstorrent/tt-metal/actions/runs/11894947971
- [x] model perf
https://github.com/tenstorrent/tt-metal/actions/runs/11894959587
- [x] t3k freq
https://github.com/tenstorrent/tt-metal/actions/runs/11894967497
  • Loading branch information
yugaoTT authored Nov 21, 2024
1 parent 919d46a commit 3332f11
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 76 deletions.
146 changes: 143 additions & 3 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
comp_equal,
comp_pcc,
)
from models.utility_functions import is_wormhole_b0, is_wormhole_b0, is_blackhole, skip_for_blackhole
from models.utility_functions import (
is_wormhole_b0,
is_wormhole_b0,
is_blackhole,
skip_for_blackhole,
skip_for_grayskull,
run_for_wormhole_b0,
)
from loguru import logger
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero, roundup32

Expand Down Expand Up @@ -682,8 +689,7 @@ def test_bcast_hw(device, num_cores, in0_height_sharded, out_height_sharded, in_
out_mem_config = ttnn.DRAM_MEMORY_CONFIG

if in0_height_sharded:
compute_with_storage_grid_size = device.compute_with_storage_grid_size()
device_grid_size = ttnn.CoreGrid(y=compute_with_storage_grid_size.y, x=compute_with_storage_grid_size.x)
device_grid_size = ttnn.CoreGrid(y=8, x=8) if num_cores == 64 else ttnn.CoreGrid(y=1, x=1)

tt_in0_height_sharded = ttnn.to_memory_config(
tt_in0_dram,
Expand Down Expand Up @@ -2418,3 +2424,137 @@ def test_interleaved_2_sharded_DRAM(device, dtype, y):
)

yt = ttnn.interleaved_to_sharded(xt, shard_grid, (y // 8, 18 * 32), shard_scheme, ttnn.ShardOrientation.ROW_MAJOR)


@run_for_wormhole_b0()
@pytest.mark.parametrize(
"seq_len",
(32,),
)
def test_llama_mlp_width_sharded_to_interleaved_pcc_err(device, seq_len, use_program_cache):
dim_in = 4096
dim_hidden = int(3.5 * dim_in / 4) # 3584
dim_out = dim_in
# Create random input tensor
input_tensor = torch.randn(1, 1, int(seq_len), dim_in)
# Create random weight matrices
w1 = torch.randn(dim_hidden, dim_in)
w2 = torch.randn(dim_out, dim_hidden)
# Pytorch reference implementation
## First linear layer
hidden = torch.matmul(input_tensor, w1.t())
## Second linear layer
output_w2 = torch.matmul(hidden, w2.t())
## Add residual connection
reference_output = output_w2 + input_tensor
# TTNN implementation
input_mem_config = ttnn.create_sharded_memory_config(
(
32,
128,
), # Shard shape: [32, 128] -> 1 shard per core
ttnn.CoreGrid(x=8, y=4),
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
w1_out_reshard_mem_config = ttnn.create_sharded_memory_config(
(
32,
128,
), # Shard shape: [32, 128] -> 1 shard per core
ttnn.CoreGrid(x=7, y=4),
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
dram_core_range_set = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(11, 0),
),
}
)
w1_w3_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.ShardSpec(dram_core_range_set, (4096, 320), ttnn.ShardOrientation.ROW_MAJOR, False),
)
w2_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
ttnn.ShardSpec(dram_core_range_set, (3584, 352), ttnn.ShardOrientation.ROW_MAJOR, False),
)
pc_1 = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=4,
per_core_M=1,
per_core_N=4,
fused_activation=None,
)
pc_2 = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=4,
per_core_M=1,
per_core_N=5,
fused_activation=None,
)
## convert input tensor and weights to TTNN tensors
tt_input = ttnn.from_torch(
input_tensor,
device=device,
dtype=ttnn.bfloat8_b,
memory_config=input_mem_config,
layout=ttnn.TILE_LAYOUT,
)
as_sharded_tensor = lambda w, type, dim, mem_config: ttnn.as_tensor(
w, # Grab only the wX part of the name
dtype=type,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=mem_config,
)
# Sharded weights
tt_w1 = as_sharded_tensor(w1.t(), ttnn.bfloat8_b, dim=-1, mem_config=w1_w3_mem_config)
tt_w2 = as_sharded_tensor(w2.t(), ttnn.bfloat8_b, dim=-2, mem_config=w2_mem_config)
## MLP takes replicated inputs and produces fractured outputs
logger.info(f"tt_input shape: {tt_input.shape}")
logger.info(f"tt_input memory config: {tt_input.memory_config()}")
w1_out = ttnn.linear(
tt_input,
tt_w1,
core_grid=None,
dtype=ttnn.bfloat16,
program_config=pc_1,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
)
logger.info(f"w1_out shape: {w1_out.shape}")
logger.info(f"w1_out memory config: {w1_out.memory_config()}")
w1_out = ttnn.reshard(w1_out, w1_out_reshard_mem_config)
w2_out = ttnn.linear(
w1_out,
tt_w2,
core_grid=None,
dtype=ttnn.bfloat16,
program_config=pc_2,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
)
logger.info(f"w2_out shape: {w2_out.shape}")
logger.info(f"w2_out memory config: {w2_out.memory_config()}")
w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.L1_MEMORY_CONFIG)
tt_input = ttnn.sharded_to_interleaved(tt_input, ttnn.L1_MEMORY_CONFIG)

# ## Add residual connection
tt_input_torch = ttnn.to_torch(tt_input)
tt_w2_out_torch = ttnn.to_torch(w2_out)
tt_output = ttnn.add(tt_input, w2_out)
tt_output_torch = ttnn.to_torch(tt_output)
pcc_required = 0.99
passing_w2_out, pcc_message_w2_out = comp_pcc(output_w2, tt_w2_out_torch)
passing_input, pcc_message_input = comp_pcc(input_tensor, tt_input_torch)
passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required)
logger.info(f"w2_out PCC: {pcc_message_w2_out}")
logger.info(f"input PCC: {pcc_message_input}")
logger.info(f"residual PCC: {pcc_message}")
assert passing_w2_out
assert passing_input
assert passing
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
tt_metal::Program program{};

uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width,
num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_per_shard_height_last,
num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_height, num_units_per_shard_height_last,
num_units_per_shard_width_last;

tt_metal::Device* device = input.device();
Expand All @@ -30,7 +30,12 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
auto shard_strategy = input.memory_config().memory_layout;

bool rm_orientation = shard_spec.orientation == ShardOrientation::ROW_MAJOR;
CoreCoord end_core = (*shard_spec.grid.ranges().rbegin()).end_coord;
auto& all_cores = shard_spec.grid;
uint32_t num_cores = all_cores.num_cores();
uint32_t num_cores_unpadded = num_cores;
const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation);

CoreCoord end_core = cores[num_cores - 1];
if (output.get_layout() == Layout::TILE) {
num_units = input.volume() / TILE_HW;
input_unit_size = tt_metal::detail::TileSize(input_cb_data_format);
Expand All @@ -40,7 +45,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width;
num_units_per_row = output.get_legacy_shape()[-1] / TILE_WIDTH;
num_units_offset = num_units_per_row;
uint32_t num_units_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT / num_slices;
num_units_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT / num_slices;
num_units_per_shard_height_last =
num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height);
num_units_per_shard_width_last =
Expand All @@ -55,17 +60,26 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width;
num_units_per_row = output.get_legacy_shape()[-1] * output.element_size();
num_units_offset = 1;
uint32_t num_units_height = input.volume() / input.get_legacy_shape()[-1];
num_units_height = input.volume() / input.get_legacy_shape()[-1];
num_units_per_shard_height_last =
num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height);
num_units_per_shard_width_last =
output_unit_size - (round_up(num_units_per_row, output_unit_size) - num_units_per_row);
}

bool convert_df = input_cb_data_format != output_cb_data_format;
// re-calculate end_core in the case shard grid is larger than used grid
if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) {
num_cores_unpadded = div_up(num_units_height, num_units_per_shard_height);
} else if (shard_strategy == TensorMemoryLayout::WIDTH_SHARDED) {
if (output.get_layout() == Layout::TILE) {
num_cores_unpadded = div_up(num_units_per_row, num_units_per_shard_width);
} else {
num_cores_unpadded = div_up(num_units_per_row, output_unit_size);
}
}
TT_ASSERT(num_cores_unpadded == num_cores, "number of cores {} in shard spec not equal to the unpadded number of cores {}", num_cores_unpadded, num_cores);

auto& all_cores = shard_spec.grid;
uint32_t num_cores = all_cores.num_cores();
bool convert_df = input_cb_data_format != output_cb_data_format;

uint32_t src0_cb_index = CB::c_in0;
uint32_t out_cb_index = src0_cb_index;
Expand Down Expand Up @@ -141,13 +155,12 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
uint32_t curr_idx_h = 0;
uint32_t curr_idx_w = 0;

const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation);
uint32_t padded_offset_bytes;

for (const auto& core : cores) {
uint32_t shard_height = num_units_per_shard_height;
uint32_t shard_width = input.get_layout() == Layout::TILE ? num_units_per_shard_width : output_unit_size;
if (input.get_layout() == Layout::TILE) {
uint32_t shard_height = num_units_per_shard_height;
uint32_t shard_width = num_units_per_shard_width;
if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) {
if (core.x == end_core.x && core.y == end_core.y) {
shard_height = num_units_per_shard_height_last;
Expand Down Expand Up @@ -192,8 +205,6 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
curr_idx_h += num_units_per_row * num_units_per_shard_height;
}
} else {
uint32_t shard_height = num_units_per_shard_height;
uint32_t shard_width = output_unit_size;
if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) {
if (core.x == end_core.x && core.y == end_core.y) {
shard_height = num_units_per_shard_height_last;
Expand Down
12 changes: 10 additions & 2 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ void Matmul::validate(

// No padding
TT_FATAL(M == per_core_M, "Error");
TT_FATAL(M == 1, "currently only support in0 tensor height of tile height");
TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error");
TT_FATAL(K % program_config.in0_block_w == 0, "Error");
TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error");
Expand Down Expand Up @@ -1406,7 +1407,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
} else if constexpr (std::is_same_v<
ProgramConfigType,
MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) {
uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1];
uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0];
uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1];
auto input_tensor_b_shape = input_tensor_b.get_legacy_shape();

Expand All @@ -1415,7 +1416,14 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp

TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width");

CoreRangeSet all_cores = input_tensor_a.shard_spec().value().grid;
uint32_t num_blocks_y = (M - 1) / per_core_M + 1;
uint32_t num_blocks_x = (N - 1) / per_core_N + 1;
uint32_t num_blocks_total = num_blocks_y * num_blocks_x;
uint32_t num_cores = num_blocks_x * num_blocks_y;
auto end_core = input_tensor_a.shard_spec()->grid.bounding_box().end_coord;
auto grid_size = CoreCoord{end_core.x + 1, end_core.y + 1};
CoreRangeSet all_cores =
num_cores_to_corerangeset(num_cores, grid_size, true);
ShardSpec shard_spec = ShardSpec{
all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, ShardOrientation::ROW_MAJOR};
auto mem_config = this->output_mem_config;
Expand Down
Loading

0 comments on commit 3332f11

Please sign in to comment.