From 7f6ab6993bf7173142ed5ccfb4a8630480c05095 Mon Sep 17 00:00:00 2001 From: Michael Chiou <156848643+ttmchiou@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:00:15 -0800 Subject: [PATCH] Revert "#12184: Alignment fix for BH in I2S and S2I" This reverts commit cf1c75e4a0106ef0081fb390f867f29a2616a21a. --- .../unit_testing/misc/test_sharded.py | 14 +-- tests/ttnn/unit_tests/operations/test_core.py | 94 ------------------- .../unit_tests/operations/test_maxpool2d.py | 3 - ttnn/cpp/ttnn/operations/core/core.cpp | 37 +------- .../data_movement/common/kernels/debug.hpp | 20 ---- ...ut_sharded_blocks_interleaved_start_id.cpp | 18 +--- ...ut_sharded_blocks_interleaved_start_id.cpp | 12 --- .../device/interleaved_to_sharded_op.cpp | 15 +-- ...interleaved_to_sharded_program_factory.cpp | 43 +++------ .../device/sharded_to_interleaved_op.cpp | 3 +- ...sharded_to_interleaved_program_factory.cpp | 13 +-- 11 files changed, 29 insertions(+), 243 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py index 5df2d752340..b3e41058c67 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py @@ -101,7 +101,6 @@ def test_sharded_tile( # TODO (7735): Switch to new interleaved_to_sharded with sharded_mem_config input and re-enable BLOCK sharded tests -@skip_for_blackhole("WIP") @pytest.mark.parametrize( "input_shape, shard_scheme, shard_size, num_cores", [ @@ -181,7 +180,7 @@ def test_sharded_rm( assert passing -@skip_for_blackhole("BH LLK issue with untilize, #14594") +@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("H, num_cores", [[100352, 98], [25088, 98]]) @pytest.mark.parametrize("in_sharded", [True, False]) @pytest.mark.parametrize("out_sharded", [True, False]) @@ -257,7 +256,7 @@ def test_sharded_untilize(H, num_cores, in_sharded, out_sharded, dtype, device, assert passing -@skip_for_blackhole("Mismatching on BH, see #14609") +@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("H, num_cores", [[25088, 98]]) @pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_sharded_tilize(H, num_cores, output_dtype, device, function_level_defaults): @@ -896,7 +895,6 @@ def test_partial_sharded_op_binary( assert passing -@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"]) @pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"]) @@ -1337,7 +1335,6 @@ def test_sharded_matmul_2d_transposed( assert passing -@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") def test_resharded_binary_to_matmul(device, function_level_defaults): grid_size_binary = device.compute_with_storage_grid_size() num_cores_binary = 98 @@ -1429,7 +1426,6 @@ def test_resharded_binary_to_matmul(device, function_level_defaults): assert passing -@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"]) @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1505,7 +1501,6 @@ def test_sharded_untilize_padded_shard(in_sharded, out_sharded, dtype, device, f assert passing -@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("out_sharded", [False], ids=["out_unsharded"]) @pytest.mark.parametrize("activations_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1696,7 +1691,6 @@ def test_block_sharded_untilize_with_unpadding(in_sharded, out_sharded, dtype, d "unbatched_16_shape_out_interleaved", ], ) -@skip_for_blackhole("BH Issue with untilize LLK, see #14594") @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_width_sharded_untilize_with_unpadding( shape, output_H, in_sharded, out_sharded, dtype, device, function_level_defaults @@ -1767,7 +1761,7 @@ def test_width_sharded_untilize_with_unpadding( assert passing -@skip_for_blackhole("BH LLK Issue with tilize, #14609") +@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("input_shape", [[8, 1, 49, 2048], [1, 1, 8, 2048], [16, 1, 49, 2048], [1, 1, 16, 2048]]) @pytest.mark.parametrize("sharding_config", [(True, True), (False, False)], ids=["both_sharded", "both_interleaved"]) @pytest.mark.parametrize("output_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @@ -1839,6 +1833,7 @@ def test_sharded_tilize_with_val_padding(input_shape, sharding_config, output_dt assert passing +@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("N", [8, 16]) @pytest.mark.parametrize("in_sharded", [True], ids=["in0_sharded"]) @pytest.mark.parametrize("out_sharded", [True], ids=["out_sharded"]) @@ -2069,7 +2064,6 @@ def test_sharded_matmul_1d_in1_wormhole(device, function_level_defaults): assert passing -@pytest.mark.skipif(is_blackhole(), reason="BH ND hang, see issue #14745") @pytest.mark.parametrize("in0_sharded", [True, False], ids=["in0_sharded", "in0_unsharded"]) @pytest.mark.parametrize("in1_sharded", [True, False], ids=["in1_sharded", "in1_unsharded"]) @pytest.mark.parametrize("out_sharded", [True, False], ids=["out_sharded", "out_unsharded"]) diff --git a/tests/ttnn/unit_tests/operations/test_core.py b/tests/ttnn/unit_tests/operations/test_core.py index c39154379df..23b9d1f8459 100644 --- a/tests/ttnn/unit_tests/operations/test_core.py +++ b/tests/ttnn/unit_tests/operations/test_core.py @@ -439,97 +439,3 @@ def test_create_sharded_memory_config(device, shape, strategy, orientation, core passing = torch.equal(input_data, output_data) assert passing - - -@pytest.mark.parametrize( - "shape, shard_shape, strategy, orientation, core_grid", - [ - ([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=1, x=1)), - ([1, 1, 2, 16], None, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), - ([1, 1, 32, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), - ([1, 1, 64, 16], None, ttnn.ShardStrategy.HEIGHT, ttnn.ShardOrientation.ROW_MAJOR, ttnn.CoreGrid(y=2, x=1)), - ( - [1, 1, 2, 16], - [2, 16], - ttnn.ShardStrategy.HEIGHT, - ttnn.ShardOrientation.ROW_MAJOR, - ttnn.CoreRangeSet( - { - ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), - } - ), - ), - ( - [1, 1, 5280, 16], - [5280, 16], - ttnn.ShardStrategy.HEIGHT, - ttnn.ShardOrientation.ROW_MAJOR, - ttnn.CoreRangeSet( - { - ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), - } - ), - ), - # TODO: Add this test back by checking for core grid size and skipping if we can't do it - # ( - # [1, 1, 675840, 16], - # [5280, 16], - # ttnn.ShardStrategy.HEIGHT, - # ttnn.ShardOrientation.ROW_MAJOR, - # ttnn.CoreRangeSet( - # { - # ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(11, 9)), # 120 - # ttnn.CoreRange(ttnn.CoreCoord(12, 0), ttnn.CoreCoord(12, 7)), # 8 - # } - # ), - # ), - ], -) -@pytest.mark.parametrize( - "input_buffer_type", - [ - ttnn.L1_MEMORY_CONFIG, - ttnn.DRAM_MEMORY_CONFIG, - ], -) -@pytest.mark.parametrize( - "output_buffer_type", - [ - ttnn.L1_MEMORY_CONFIG, - ttnn.DRAM_MEMORY_CONFIG, - ], -) -def test_bh_alignment_i2s( - device, shape, shard_shape, strategy, orientation, core_grid, input_buffer_type, output_buffer_type -): - torch.manual_seed(0) - input_data = torch.randn(shape, dtype=torch.bfloat16) - if shard_shape == None: - shard_config = ttnn.create_sharded_memory_config( - shape=shape, - core_grid=core_grid, - strategy=strategy, - orientation=orientation, - use_height_and_width_as_shard_shape=False, - ) - else: - shard_config = ttnn.create_sharded_memory_config( - shape=shard_shape, - core_grid=core_grid, - strategy=strategy, - orientation=orientation, - use_height_and_width_as_shard_shape=True, - ) - x_t = ttnn.from_torch( - input_data, - device=device, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=input_buffer_type, - dtype=ttnn.bfloat16, - ) - x_t_sharded = ttnn.to_memory_config(x_t, shard_config) - x_t = ttnn.to_memory_config(x_t_sharded, output_buffer_type) - output_data = ttnn.from_device(x_t) - output_data = ttnn.to_torch(output_data) - passing = torch.equal(input_data, output_data) - assert passing diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 04903485f40..43fa209acb0 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -183,9 +183,6 @@ def run_max_pool( output_host = output.cpu() output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host)) output_pytorch = output_pytorch_padded[:, :, :, :in_c] - torch.set_printoptions(profile="full") - print("output_pytorch" + str(output_pytorch)) - torch.set_printoptions(profile="default") # reset ## reference golden_pytorch = torch.nn.MaxPool2d( diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index b61567ab540..dba2edf328b 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -11,8 +11,6 @@ #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" #include "ttnn/distributed/types.hpp" -#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" -#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" namespace ttnn::operations::core { @@ -56,29 +54,12 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) { } ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config) { - auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); - if(mem_config.is_sharded () and (device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG); - return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); - } - else { - return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); - } + return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } ttnn::Tensor to_device( const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config) { - - auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); - // Currently no direct sharded write support in BLACKHOLE due to alignment issue - if(mem_config.is_sharded () and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG); - return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); - } - else { - return tensor.to(mesh_device, mem_config); - } - + return tensor.to(mesh_device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } ttnn::Tensor allocate_tensor_on_device( @@ -105,19 +86,7 @@ void copy_host_to_device_tensor(ttnn::Tensor host_tensor, ttnn::Tensor device_te tt::tt_metal::write_tensor(host_tensor, device_tensor, cq_id); } -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { - - // Currently no direct sharded read support in BLACKHOLE due to alignment issue - if(tensor.is_sharded () and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt); - return interleaved_tensor.cpu(blocking, cq_id); - } - else { - return tensor.cpu(blocking, cq_id); - - } - -} +ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { return tensor.cpu(blocking, cq_id); } void deallocate(Tensor& tensor, bool force) { tensor.deallocate(force); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp deleted file mode 100644 index 25c95ab1888..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -// This file contains common kernel functions used for debugging -#pragma once -#include "debug/dprint.h" -namespace tt::data_movement::common { - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++ page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++ j, ++ ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } -} -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp index 16b8820e61a..c132e643ad5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp @@ -5,12 +5,6 @@ #include #include "dataflow_api.h" -//#define DEBUG - -#ifdef DEBUG -#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp" -#endif - void kernel_main() { const uint32_t src_addr = get_arg_val(0); @@ -44,20 +38,15 @@ void kernel_main() { uint32_t stick_id = start_id; cb_reserve_back(cb_id_in0, block_height); uint32_t l1_write_addr = get_write_ptr(cb_id_in0); - uint32_t l1_write_addr_base = l1_write_addr; if (aligned) { for (uint32_t h = 0; h < block_height; ++h) { uint64_t src_noc_addr = get_noc_addr(stick_id, s0); noc_async_read(src_noc_addr, l1_write_addr, block_width_bytes); stick_id++; -#ifdef DEBUG - noc_async_read_barrier(); - tt::data_movement::common::print_pages(l1_write_addr, block_width_bytes >> 1, 1); -#endif l1_write_addr += padded_block_width_bytes; } } else { - cb_reserve_back(cb_id_in1, 4); + cb_reserve_back(cb_id_in1, 1); uint32_t scratch_l1_write_addr = get_write_ptr(cb_id_in1); uint64_t scratch_l1_noc_read_addr = get_noc_addr(scratch_l1_write_addr + aligned_offset); for (uint32_t h = 0; h < block_height; ++h) { @@ -65,15 +54,10 @@ void kernel_main() { noc_async_read(src_noc_addr, scratch_l1_write_addr, aligned_block_width_bytes); noc_async_read_barrier(); noc_async_read(scratch_l1_noc_read_addr, l1_write_addr, block_width_bytes); -#ifdef DEBUG - noc_async_read_barrier(); - tt::data_movement::common::print_pages(l1_write_addr, block_width_bytes >> 1, 1); -#endif stick_id++; l1_write_addr += padded_block_width_bytes; } } - noc_async_read_barrier(); cb_push_back(cb_id_in0, block_height); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp index 03820991b77..aed1d42e19f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp @@ -5,12 +5,6 @@ #include #include "dataflow_api.h" -//#define DEBUG - -#ifdef DEBUG -#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp" -#endif - void kernel_main() { const uint32_t dst_addr = get_arg_val(0); @@ -40,15 +34,9 @@ void kernel_main() { uint32_t stick_id = start_id; cb_wait_front(cb_id_out0, block_height); uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - - for (uint32_t h = 0; h < block_height; ++h) { uint64_t dst_noc_addr = get_noc_addr(stick_id, s0); noc_async_write(l1_read_addr, dst_noc_addr, block_width_bytes); -#ifdef DEBUG - noc_async_read_barrier(); - tt::data_movement::common::print_pages(l1_read_addr, block_width_bytes >> 1, 1); -#endif stick_id++; l1_read_addr += padded_block_width_bytes; noc_async_write_barrier(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp index 2bbcb4f4574..b899760c02a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp @@ -36,19 +36,8 @@ std::vector InterleavedToShardedDeviceOperation::comp std::vector InterleavedToShardedDeviceOperation::create_output_tensors(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); - //return operation::generic_create_output_tensors( - // *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config); - - - auto mem_config = this->output_mem_config; - - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - mem_config - )}; + return operation::generic_create_output_tensors( + *this, input_tensors, this->output_dtype, input_tensor.get_layout(), this->output_mem_config); } operation::ProgramWithCallbacks InterleavedToShardedDeviceOperation::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp index e51e67fc92a..d41cadcf1d1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp @@ -32,14 +32,6 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( bool rm_orientation = shard_spec.orientation == ShardOrientation::ROW_MAJOR; CoreCoord end_core = (*shard_spec.grid.ranges().rbegin()).end_coord; - - bool convert_df = input_cb_data_format != output_cb_data_format; - auto src_buffer = input.buffer(); - auto dst_buffer = output.buffer(); - bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE); - bool is_blackhole_and_dram = (input.device()->arch() == tt::ARCH::BLACKHOLE) and src_is_dram; - if (input.get_layout() == Layout::TILE) { num_units = input.volume() / TILE_HW; input_unit_size = tt::tt_metal::detail::TileSize(input_cb_data_format); @@ -74,6 +66,13 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( padded_offset_bytes = align(input_unit_size, input.buffer()->alignment()); } + bool convert_df = input_cb_data_format != output_cb_data_format; + + auto src_buffer = input.buffer(); + + auto dst_buffer = output.buffer(); + + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; auto all_cores = shard_spec.grid; uint32_t input_cb_index = tt::CB::c_in0; @@ -95,17 +94,10 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( .set_globally_allocated_address(*output.buffer()); auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_out_config); uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - if (src_is_dram && input_unit_size % dram_alignment != 0 or is_blackhole_and_dram) { - uint32_t scratch_cb_page_size; - //scratchpad going to be used to align DRAM (64B) to L1 (16B) - if (is_blackhole_and_dram) { - scratch_cb_page_size = align(input_unit_size, hal.get_alignment(HalMemType::L1)); - } - else { - scratch_cb_page_size = align(input_unit_size, dram_alignment); - } + if (src_is_dram && input_unit_size % dram_alignment != 0) { + uint32_t scratch_cb_page_size = align(input_unit_size, dram_alignment); tt::tt_metal::CircularBufferConfig scratch_cb_out_config = - tt::tt_metal::CircularBufferConfig(4 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}}) + tt::tt_metal::CircularBufferConfig(1 * scratch_cb_page_size, {{scratch_cb_index, input_cb_data_format}}) .set_page_size(scratch_cb_index, scratch_cb_page_size); auto cb_scratch = tt::tt_metal::CreateCircularBuffer(program, all_cores, scratch_cb_out_config); } @@ -244,17 +236,10 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( } uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - bool aligned = (src_is_dram ? curr_idx_w % dram_alignment == 0 : true); - aligned = aligned and !(is_blackhole_and_dram); + bool aligned = src_is_dram ? curr_idx_w % dram_alignment == 0 : true; uint32_t aligned_width_offset, aligned_shard_width, aligned_offset; if (!aligned) { - if(src_is_dram) { - aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment); - } - else { - aligned_width_offset = tt::round_down(curr_idx_w, l1_alignment); - } + aligned_width_offset = tt::round_down(curr_idx_w, dram_alignment); aligned_offset = curr_idx_w - aligned_width_offset; aligned_shard_width = aligned_offset + shard_width; } else { @@ -271,7 +256,7 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( num_units_per_row, shard_height, shard_width, - (is_blackhole) ? shard_width : padded_offset_bytes, + padded_offset_bytes, static_cast(aligned), aligned_width_offset, aligned_shard_width, @@ -320,4 +305,6 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } + + } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp index f736258f7d6..55b32e3c00a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp @@ -20,8 +20,9 @@ void ShardedToInterleavedDeviceOperation::validate(const std::vector& in TT_FATAL(input_tensor.memory_config().buffer_type == BufferType::L1, "Input tensor must be in L1"); TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Output memory config must be Interleaved"); if (input_tensor.get_layout() == Layout::ROW_MAJOR) { + uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor", l1_alignment); + TT_FATAL((*input_tensor.memory_config().shard_spec).shape[1] * input_tensor.element_size() % (this->output_mem_config.buffer_type == BufferType::DRAM ? dram_alignment : l1_alignment) == 0, "Shard page size must be aligned to {}B for L1 Tensor, or {}B for DRAM tensor", l1_alignment, dram_alignment); } if (input_tensor.get_dtype() != this->output_dtype) { TT_FATAL(input_tensor.get_layout() == Layout::TILE, "If diff output type, tensor must be TILED"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp index 2cb58883bf1..6d585e65a13 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp @@ -98,7 +98,6 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool is_blackhole = (input.device()->arch() == tt::ARCH::BLACKHOLE); tt_metal::KernelHandle unary_writer_kernel_id; if (input.get_layout() == Layout::TILE) { @@ -142,8 +141,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( uint32_t curr_idx_w = 0; const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation); - uint32_t padded_offset_bytes; - + uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment()); for (const auto& core : cores) { if (input.get_layout() == Layout::TILE) { uint32_t shard_height = num_units_per_shard_height; @@ -219,13 +217,6 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( } } } - uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment()); - if(is_blackhole) { - if(!dst_is_dram) - padded_shard_width = align(output_unit_size, l1_alignment); - } tt_metal::SetRuntimeArgs( program, unary_writer_kernel_id, @@ -234,7 +225,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( num_units_per_row, shard_height, shard_width, - (is_blackhole) ? shard_width : padded_shard_width, + padded_shard_width, curr_idx_w, curr_idx_h}); curr_idx_w += output_unit_size;