From 43ebbcb8d9574a4cbd104aca19ef8fbe93d2870b Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Thu, 21 Nov 2024 14:56:49 -0800 Subject: [PATCH] #0: Use TensorLayout in Tensor (#15028) ### Ticket ### Problem description We're migrating TTNN infrastructure to use new TensorLayout everywhere, including Tensor ### What's changed Introduced `TensorSpec`, which represents `logical_shape` + `tensor_layout` Used `TensorSpec` in Tensor instead of `(shape, dtype, layout, tile)` Refactored Tensor constructors Use `set_tensor_spec` instead of a bunch of different setters, removed the need of manually setting `metadata_populated = true` ### Checklist - [x] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/11960140767) - [ ] Blackhole Post commit (if applicable) - [x] [Model regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/11956334891) - [x] [Device performance regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/11956337771) - [ ] New/Existing tests provide coverage for changes --- .../tensors/test_async_tensor_apis.cpp | 13 +- .../tensor/common_tensor_test_utils.cpp | 2 +- .../gtests/tensor/test_create_tensor.cpp | 6 +- .../tensor/test_create_tensor_with_layout.cpp | 2 +- .../unit_tests/gtests/test_async_runtime.cpp | 13 +- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 12 +- .../gtests/test_multi_cq_multi_dev.cpp | 16 +- .../gtests/test_multiprod_queue.cpp | 12 +- .../unit_tests/operations/test_creation.py | 9 +- .../operations/test_paged_update_cache.py | 28 +- tests/ttnn/unit_tests/test_deallocate.py | 2 +- tt_metal/tt_stl/reflection.hpp | 3 +- ttnn/cpp/pybind11/pytensor.cpp | 45 ++- ttnn/cpp/ttnn/distributed/api.cpp | 14 +- ttnn/cpp/ttnn/operation.hpp | 7 +- .../ccl/all_gather/device/all_gather_op.cpp | 5 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp | 12 +- ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp | 10 +- .../ccl/ccl_host_datastructures.cpp | 2 +- .../ttnn/operations/conv/conv2d/conv2d.cpp | 9 +- .../core/to_layout/to_layout_op.cpp | 16 +- .../data_movement/concat/concat.cpp | 4 +- .../device/permute_device_operation.cpp | 4 +- .../data_movement/repeat/repeat.cpp | 18 +- .../reshape_on_device/device/reshape_op.cpp | 3 +- .../data_movement/slice/device/slice_op.cpp | 4 +- .../transpose/device/transpose_op.cpp | 2 +- .../device/transpose_program_factory.cpp | 7 +- .../binary/device/binary_device_operation.cpp | 4 +- .../device/example_device_operation.cpp | 6 +- ...ample_multiple_return_device_operation.cpp | 10 +- .../experimental/reduction/argmax/argmax.cpp | 4 +- .../device/full_like_device_operation.cpp | 4 +- .../device/index_fill_device_operation.cpp | 4 +- .../operations/matmul/device/matmul_op.cpp | 37 +-- ...ti_core_reuse_mcast_1d_program_factory.cpp | 14 +- ...ti_core_reuse_mcast_2d_program_factory.cpp | 12 +- ...use_mcast_dram_sharded_program_factory.cpp | 14 +- ...i_core_reuse_optimized_program_factory.cpp | 8 +- .../device/moreh_cumsum_device_operation.cpp | 2 +- .../device/moreh_dot_device_operation.cpp | 4 +- ttnn/cpp/ttnn/operations/numpy/functions.hpp | 14 +- ttnn/cpp/ttnn/run_operation.cpp | 7 +- ttnn/cpp/ttnn/run_operation_inl.hpp | 6 +- ttnn/cpp/ttnn/tensor/layout/page_config.cpp | 39 ++- ttnn/cpp/ttnn/tensor/layout/page_config.hpp | 37 ++- ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 9 +- ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp | 16 +- ttnn/cpp/ttnn/tensor/shape/shape_base.hpp | 2 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 282 ++++++++++-------- ttnn/cpp/ttnn/tensor/tensor.hpp | 102 +++---- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 104 +++---- ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 2 +- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 47 ++- ttnn/cpp/ttnn/tensor/tensor_spec.hpp | 81 +++++ ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 13 +- 56 files changed, 641 insertions(+), 523 deletions(-) create mode 100644 ttnn/cpp/ttnn/tensor/tensor_spec.hpp diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 46d7d79c269..0418df6b535 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -42,7 +42,7 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { // Ensure that tensor data is copied and owned as expected Device* device = this->devices_[0]; Tensor host_tensor = ttnn::numpy::arange(0, 32 * 32 * 4, 1); - Tensor readback_tensor({}, 1); + Tensor readback_tensor(1); auto func = [device, host_tensor, readback_tensor]() mutable { // Ensure that both the lambda and global scope have ownership to this tensor @@ -67,9 +67,7 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); readback_tensor.set_storage(thread_local_tensor.get_storage()); - readback_tensor.set_shape(thread_local_tensor.get_shape()); - readback_tensor.set_dtype(thread_local_tensor.get_dtype()); - readback_tensor.set_layout(thread_local_tensor.get_layout()); + readback_tensor.set_tensor_spec(thread_local_tensor.get_tensor_spec()); readback_tensor.tensor_attributes->metadata_populated = true; readback_tensor.tensor_attributes->num_workers_completed++; // Ensure that the readback buffer is owned inside and outside the lambda @@ -240,8 +238,7 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { uint32_t tensor_start = 0; uint32_t num_tiles = 128; uint32_t tensor_stop = TILE_HEIGHT * TILE_WIDTH * num_tiles; - Tensor readback_tensor({}, 1); - ; + Tensor readback_tensor(1); std::thread worker; { @@ -278,9 +275,7 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); log_info(LogTest, "Worker populating empty host readback_tensor"); readback_tensor.set_storage(thread_local_tensor.get_storage()); - readback_tensor.set_shape(thread_local_tensor.get_shape()); - readback_tensor.set_dtype(thread_local_tensor.get_dtype()); - readback_tensor.set_layout(thread_local_tensor.get_layout()); + readback_tensor.set_tensor_spec(thread_local_tensor.get_tensor_spec()); readback_tensor.tensor_attributes->metadata_populated = true; readback_tensor.tensor_attributes->num_workers_completed++; // Ensure that this buffer is currently owned by both the thread_local and read_back tensors diff --git a/tests/ttnn/unit_tests/gtests/tensor/common_tensor_test_utils.cpp b/tests/ttnn/unit_tests/gtests/tensor/common_tensor_test_utils.cpp index c78373747a8..a23647b0d04 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/common_tensor_test_utils.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/common_tensor_test_utils.cpp @@ -27,7 +27,7 @@ void test_tensor_on_device(const ttnn::SimpleShape& input_shape, const TensorLay host_data[i] = i % random_prime_number; } - auto tensor = tt::tt_metal::create_device_tensor(input_shape, layout, device); + auto tensor = tt::tt_metal::create_device_tensor(TensorSpec(input_shape, layout), device); ttnn::queue_synchronize(device->command_queue(io_cq)); ttnn::write_buffer(io_cq, tensor, {host_data}); diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp index c02e5bcd0b6..8ba7a619866 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp @@ -37,9 +37,9 @@ void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::SimpleShape inpu host_data[i] = 1; } - tt::tt_metal::TensorLayout tensor_layout(dtype, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(input_buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(input_shape)); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, input_shape, tensor_layout); + TensorSpec tensor_spec(input_shape, TensorLayout(dtype, PageConfig(Layout::TILE), mem_cfg)); + ASSERT_EQ(input_buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_with_layout.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_with_layout.cpp index c65d42f845e..92d64a79f68 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_with_layout.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_with_layout.cpp @@ -38,7 +38,7 @@ class CreateTensorWithLayoutTest : public ttnn::TTNNFixtureWithDevice, public :: TEST_P(CreateTensorWithLayoutTest, Tile) { CreateTensorParams params = GetParam(); - auto tensor = tt::tt_metal::create_device_tensor(params.inputs.shape, params.inputs.layout, device_); + auto tensor = tt::tt_metal::create_device_tensor(TensorSpec(params.inputs.shape, params.inputs.layout), device_); EXPECT_EQ(tensor.get_padded_shape(), params.expected.padded_shape); EXPECT_EQ(tensor.get_logical_shape(), params.inputs.shape); } diff --git a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp index d57f2950f0f..0d7607635b2 100644 --- a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp +++ b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp @@ -58,8 +58,8 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { tt_metal::TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); ASSERT_EQ(input_buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(input_shape.padded_shape())); ASSERT_EQ(output_buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(np_out.get_padded_shape())); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, input_shape.padded_shape(), tensor_layout); - auto output_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, np_out.get_padded_shape(), tensor_layout); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(input_shape.padded_shape(), tensor_layout)); + auto output_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(np_out.get_padded_shape(), tensor_layout)); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; auto output_storage = tt::tt_metal::DeviceStorage{output_buffer}; Tensor input_tensor = Tensor(input_storage, input_shape, DataType::BFLOAT16, Layout::TILE); @@ -124,7 +124,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) { auto workload_event = std::make_shared(); TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Write using cq 1 @@ -134,10 +134,10 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) { // Run operation on cq 0 Tensor output_tensor = ttnn::sqrt(workload_dispatch_cq, input_tensor); - auto dummy_buffer_0 = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + auto dummy_buffer_0 = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); output_tensor = ttnn::neg(workload_dispatch_cq, output_tensor); // Allocate this buffer to stress test async allocation across op execution and explicit allocation - auto dummy_buffer_1 = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + auto dummy_buffer_1 = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); // Record cq 0 prog execution ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event); // Wait until cq 0 prog execution is done @@ -169,9 +169,10 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeBufferDestructor) { // This will asynchronously allocate the buffer, wait for the allocation to complete (address to be assigned to the buffer), destroy the buffer (which will asynchronously // deallocate the buffer) in a loop TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); + TensorSpec tensor_spec(shape, tensor_layout); for (int loop = 0; loop < 100000; loop++) { { - auto input_buffer_dummy = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + auto input_buffer_dummy = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); device->synchronize(); } } diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 8224d307ae6..67c0e772c2a 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -156,9 +156,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { log_info(LogTest, "Running iteration {}", i); } for (auto& dev : devs) { - tt::tt_metal::TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(dev, shape, tensor_layout); + TensorSpec tensor_spec(shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg)); + ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(dev, tensor_spec); auto input_storage = DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); // Push inputs. @@ -253,10 +253,10 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { log_info(LogTest, "Running iteration {}", i); } - TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); + TensorSpec tensor_spec(shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg)); + ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); for (auto& dev : ring_devices) { - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(dev, shape, tensor_layout); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(dev, tensor_spec); auto input_storage = DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); // Push inputs. diff --git a/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp b/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp index 9e1b1cb4f37..8497d9dd890 100644 --- a/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp @@ -60,9 +60,9 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceProgramsOnCQ1) { for (int j = 0; j < buf_size_datums; j++) { host_data[j] = bfloat16(static_cast(i + dev_idx)); } - tt_metal::TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + TensorSpec tensor_spec(shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg)); + ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); @@ -102,8 +102,8 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceProgramsOnCQ0) { auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); auto readback_data = std::shared_ptr(new bfloat16[buf_size_datums]); - TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); + TensorSpec tensor_spec(shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg)); + ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); for (int outer_loop = 0; outer_loop < 5; outer_loop++) { log_info(LogTest, "Running outer loop {}", outer_loop); for (int i = 0; i < 30; i++) { @@ -115,7 +115,7 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceProgramsOnCQ0) { for (int j = 0; j < buf_size_datums; j++) { host_data[j] = bfloat16(static_cast(i + dev_idx)); } - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); @@ -167,8 +167,8 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceWithCQ1Only) { host_data[j] = bfloat16(static_cast(i + dev_idx)); } - TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); + TensorSpec tensor_spec(shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg)); + auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); diff --git a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp index e242ff39a94..98d43f94c1e 100644 --- a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp @@ -54,8 +54,9 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiProducerLockBasedQueue) { } // Allocate and write buffer tt_metal::TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(tensor_buf_size * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(tensor_shape)); - auto t0_input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_shape, tensor_layout); + tt_metal::TensorSpec tensor_spec(tensor_shape, tensor_layout); + ASSERT_EQ(tensor_buf_size * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); + auto t0_input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto t0_input_storage = tt::tt_metal::DeviceStorage{t0_input_buffer}; Tensor t0_input_tensor = Tensor(t0_input_storage, tensor_shape, DataType::BFLOAT16, Layout::TILE); ttnn::write_buffer(t0_io_cq, t0_input_tensor, {t0_host_data}); @@ -71,12 +72,13 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiProducerLockBasedQueue) { std::thread t1([&]() { TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - ASSERT_EQ(tensor_buf_size * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(tensor_shape)); + TensorSpec tensor_spec(tensor_shape, tensor_layout); + ASSERT_EQ(tensor_buf_size * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes()); for (int j = 0; j < 100; j++) { for (int i = 0; i < tensor_buf_size; i++) { t1_host_data[i] = bfloat16(static_cast(4 + j)); } - auto t1_input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_shape, tensor_layout); + auto t1_input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); auto t1_input_storage = tt::tt_metal::DeviceStorage{t1_input_buffer}; Tensor t1_input_tensor = Tensor(t1_input_storage, tensor_shape, DataType::BFLOAT16, Layout::TILE); @@ -124,7 +126,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiAppThreadSync) { ttnn::SimpleShape tensor_shape{1, 1, 1024, 1024}; auto host_data = std::shared_ptr(new bfloat16[tensor_buf_size]); TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); - auto allocated_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_shape, tensor_layout); + auto allocated_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(tensor_shape, tensor_layout)); auto allocated_storage = tt::tt_metal::DeviceStorage{allocated_buffer}; auto allocated_tensor = Tensor(allocated_storage, tensor_shape, DataType::BFLOAT16, Layout::TILE); auto readback_data = std::shared_ptr(new bfloat16[tensor_buf_size]); diff --git a/tests/ttnn/unit_tests/operations/test_creation.py b/tests/ttnn/unit_tests/operations/test_creation.py index c3ede6b5f13..07f13d5708f 100644 --- a/tests/ttnn/unit_tests/operations/test_creation.py +++ b/tests/ttnn/unit_tests/operations/test_creation.py @@ -259,16 +259,21 @@ def test_zeros(device, input_shape): [ [32, 32], [5, 96, 64], + [1, 50257], ], ) @pytest.mark.parametrize( "fill_value", [-5.25, 0, 1.0], ) -def test_full(device, input_shape, fill_value): +@pytest.mark.parametrize( + "layout", + [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], +) +def test_full(device, input_shape, fill_value, layout): torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value) - tensor = ttnn.full(input_shape, device=device, fill_value=fill_value) + tensor = ttnn.full(input_shape, device=device, fill_value=fill_value, layout=layout) assert ttnn.is_tensor_storage_on_device(tensor) tensor = ttnn.to_torch(tensor) diff --git a/tests/ttnn/unit_tests/operations/test_paged_update_cache.py b/tests/ttnn/unit_tests/operations/test_paged_update_cache.py index bfeccc29cfc..90f4f9f2798 100644 --- a/tests/ttnn/unit_tests/operations/test_paged_update_cache.py +++ b/tests/ttnn/unit_tests/operations/test_paged_update_cache.py @@ -41,12 +41,11 @@ def run_test_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -152,12 +151,11 @@ def test_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, - xt.shape[-1], + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec @@ -236,12 +234,11 @@ def test_update_cache_decode_program_cache( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, - xt.shape[-1], + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec @@ -279,12 +276,11 @@ def run_test_tensor_index_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, - xt.shape[-1], + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -418,12 +414,11 @@ def run_test_paged_update_cache_decode( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, - xt.shape[-1], + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec) xt = xt.to(device, input_mem_config) @@ -548,12 +543,11 @@ def test_paged_update_cache_decode_program_caching( input_shard_spec = ttnn.ShardSpec( shard_grid, [ - xt.logical_volume() // xt.shape[-1] // num_cores, - xt.shape[-1], + xt.volume() // xt.shape.with_tile_padding()[-1] // num_cores, + xt.shape.with_tile_padding()[-1], ], ttnn.ShardOrientation.ROW_MAJOR, False, - ttnn.ShardMode.LOGICAL, ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec diff --git a/tests/ttnn/unit_tests/test_deallocate.py b/tests/ttnn/unit_tests/test_deallocate.py index 315067ff7d2..bd98ca975c9 100644 --- a/tests/ttnn/unit_tests/test_deallocate.py +++ b/tests/ttnn/unit_tests/test_deallocate.py @@ -27,4 +27,4 @@ def test_deallocate(device, h, w): with pytest.raises(RuntimeError) as exception: output_tensor_reference + output_tensor_reference - assert "MemoryConfig can only be obtained if the buffer is not null" in str(exception.value) + assert "Cannot get the device from a tensor without an allocated buffer" in str(exception.value) diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index 45f9631d6ee..8cff985522c 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -1041,7 +1041,8 @@ inline hash_t hash_object(const std::variant& variant) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::variant: {}\n", variant); } - return std::visit([](const auto& value) { return hash_object(value); }, variant); + auto active_variant = variant.index(); + return std::visit([&](const auto& value) { return hash_objects(active_variant, value); }, variant); } template diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index c0964d46861..2e526cc0556 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -229,27 +229,29 @@ Tensor convert_torch_tensor_to_tt_tensor( auto data_ptr = reinterpret_cast(torch_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); + auto tile = optional_tile.value_or(Tile()); auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(torch_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); + auto tile = optional_tile.value_or(Tile()); auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); + std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -388,27 +390,27 @@ Tensor convert_numpy_tensor_to_tt_tensor( auto data_ptr = reinterpret_cast(np_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); + auto tile = optional_tile.value_or(Tile()); auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); + return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(np_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); + auto tile = optional_tile.value_or(Tile()); auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tensor.get_tile()); + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tensor.get_tile()); + return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -465,12 +467,7 @@ Tensor convert_python_tensors_to_tt_tensors( auto distributed_tensor_config = get_distributed_tensor_config(strategy); auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes}; - auto output = Tensor( - std::move(storage), - tt_shards.at(0).get_legacy_shape(), - tt_shards.at(0).get_dtype(), - tt_shards.at(0).get_layout(), - tt_shards.at(0).get_tile()); + auto output = Tensor(std::move(storage), tt_shards.at(0).get_tensor_spec()); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -530,7 +527,7 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor &tt_tensor) { }, tt_tensor.get_storage()); - const auto &tile = tt_tensor.get_tile(); + const auto tile = tt_tensor.get_tensor_spec().tile(); auto tt_dtype = tt_tensor.get_dtype(); if (tt_dtype == DataType::BFLOAT8_B) { TT_ASSERT( @@ -546,7 +543,7 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor &tt_tensor) { tt_tensor.get_shape(), DataType::FLOAT32, tt_tensor.get_layout(), - tt_tensor.get_tile()) + tile) .to(Layout::ROW_MAJOR); auto output_float_data = owned_buffer::get_as(float_tensor).get(); buffer = owned_buffer::create(std::move(output_float_data)); @@ -566,7 +563,7 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor &tt_tensor) { tt_tensor.get_shape(), DataType::FLOAT32, tt_tensor.get_layout(), - tt_tensor.get_tile()) + tile) .to(Layout::ROW_MAJOR); auto output_float_data = owned_buffer::get_as(float_tensor).get(); buffer = owned_buffer::create(std::move(output_float_data)); @@ -631,7 +628,7 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor &tt_tensor) { }, tt_tensor.get_storage()); - const auto &tile = tt_tensor.get_tile(); + const auto tile = tt_tensor.get_tensor_spec().tile(); auto tt_dtype = tt_tensor.get_dtype(); if (tt_dtype == DataType::BFLOAT8_B) { TT_ASSERT( @@ -647,7 +644,7 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor &tt_tensor) { tt_tensor.get_shape(), DataType::FLOAT32, tt_tensor.get_layout(), - tt_tensor.get_tile()) + tile) .to(Layout::ROW_MAJOR); auto output_float_data = owned_buffer::get_as(float_tensor).get(); buffer = owned_buffer::create(std::move(output_float_data)); @@ -667,7 +664,7 @@ py::object convert_tt_tensor_to_numpy_tensor(const Tensor &tt_tensor) { tt_tensor.get_shape(), DataType::FLOAT32, tt_tensor.get_layout(), - tt_tensor.get_tile()) + tile) .to(Layout::ROW_MAJOR); auto output_float_data = owned_buffer::get_as(float_tensor).get(); buffer = owned_buffer::create(std::move(output_float_data)); @@ -1053,7 +1050,7 @@ void pytensor_module(py::module &m_tensor) { .def_property_readonly("shape", [](const Tensor &self) { return self.get_shape(); }) .def_property_readonly("dtype", [](const Tensor &self) { return self.get_dtype(); }) .def_property_readonly("layout", [](const Tensor &self) { return self.get_layout(); }) - .def_property_readonly("tile", [](const Tensor &self) { return self.get_tile(); }) + .def_property_readonly("tile", [](const Tensor &self) { return self.get_tensor_spec().tile(); }) .def( "deallocate", [](Tensor &self, bool force) { return self.deallocate(force); }, @@ -1668,7 +1665,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( - "get_tile", [](const Tensor &self) { return self.get_tile(); }, R"doc( + "get_tile", [](const Tensor &self) { return self.get_tensor_spec().tile(); }, R"doc( Get tile dims of TT Tensor. .. code-block:: python diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 6b691de8bca..572c9171b85 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -25,7 +25,7 @@ std::vector get_device_tensors(const ttnn::Tensor& tensor) { if (std::holds_alternative(tensor.get_storage())) { std::vector tensors; auto& host_storage = std::get(tensor.get_storage()); - const Tile tile = tensor.get_tile(); + const Tile tile = tensor.get_tensor_spec().tile(); for (int i = 0; i < host_storage.num_buffers(); ++i) { tensors.push_back(Tensor{OwnedStorage{host_storage.get_buffer(i)}, host_storage.shapes[i], tensor.get_dtype(), tensor.get_layout(),tile}); @@ -58,16 +58,17 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) // Based whether the first tensor shard has OwnedBuffer or Device buffer, // we want to use MultiDeviceHostStorage or MultiDeviceStorage StorageType storage_type = tensor_shards.at(0).storage_type(); - Tile tile = tensor_shards.at(0).get_tile(); + Tile tile = tensor_shards.at(0).get_tensor_spec().tile(); if (storage_type == StorageType::OWNED) { std::vector shapes; std::vector host_owned_buffers; for (const auto &shard : tensor_shards) { host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); shapes.push_back(shard.get_shape()); - if (shard.get_tile() != tile) { + Tile shard_tile = shard.get_tensor_spec().tile(); + if (shard_tile != tile) { TT_THROW("Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}." - ,tensor_shards.at(0).device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), shard.get_tile().get_height(), shard.get_tile().get_width()); + ,tensor_shards.at(0).device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), shard_tile.get_height(), shard_tile.get_width()); } } auto storage = MultiDeviceHostStorage{AllGatherTensor(), std::move(host_owned_buffers), shapes}; @@ -82,9 +83,10 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) ordered_device_ids.push_back(device_id); device_buffers.insert({device->id(), std::get(shard.get_storage()).buffer}); shapes.insert({device->id(), shard.get_shape()}); - if (shard.get_tile() != tile) { + Tile shard_tile = shard.get_tensor_spec().tile(); + if (shard_tile != tile) { TT_THROW("Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}." - ,tensor_shards.at(0).device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), shard.get_tile().get_height(), shard.get_tile().get_width()); + ,tensor_shards.at(0).device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), shard_tile.get_height(), shard_tile.get_width()); } } auto storage = MultiDeviceStorage{AllGatherTensor(), ordered_device_ids, std::move(device_buffers), shapes}; diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index 267571da734..d4b57665801 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -413,11 +413,8 @@ auto default_create_output_tensors( const auto& device = input_tensors.at(0).device(); const auto& output_specs = operation.compute_output_specs(input_tensors); output_tensors.reserve(output_specs.size()); - for (const auto& [output_shape, output_layout] : output_specs) { - output_tensors.emplace_back(create_device_tensor( - output_shape, - output_layout, - device)); + for (const auto& output_spec : output_specs) { + output_tensors.emplace_back(create_device_tensor(output_spec, device)); } return output_tensors; } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 4957355cf7e..32fc7afb01a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -152,6 +152,7 @@ std::vector AllGather::compute_output_shapes(const std::vecto std::vector AllGather::create_output_tensors(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors[0]; + auto tile = input_tensor.get_tensor_spec().tile(); if(this->output_mem_config.is_sharded()) { return {create_device_tensor( this->compute_output_shapes(input_tensors).at(0), @@ -159,10 +160,10 @@ std::vector AllGather::create_output_tensors(const std::vector & input_tensor.get_layout(), input_tensor.device(), this->output_mem_config, - input_tensor.get_tile() + tile )}; } else { - return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, input_tensor.get_tile()); + return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, tile); } } diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 6c49072b809..85b94a0505f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -110,7 +110,7 @@ CclOpTensorConfig::CclOpTensorConfig(Tensor const& tensor) : buffer_start_address(tensor.buffer()->address()), df(tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) { if (tensor.get_layout() == Layout::TILE) { - this->tile = tensor.get_tile(); + this->tile = tensor.get_tensor_spec().tile(); this->page_size =this->tile.get_tile_size(this->df); this->tile_size = this->tile.get_tile_hw(); } else { @@ -327,8 +327,9 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - num_rows; } else { - const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / input_tensor.get_tile().get_width(); - uint32_t num_tiles_y = (input_tensor.get_legacy_shape()[-2] / input_tensor.get_tile().get_height()); + auto input_tile = input_tensor.get_tensor_spec().tile(); + const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / input_tile.get_width(); + uint32_t num_tiles_y = (input_tensor.get_legacy_shape()[-2] / input_tile.get_height()); for (std::size_t i = 0; input_tensor.get_legacy_shape().rank() > 2 && i < input_tensor.get_legacy_shape().rank() - 2; i++) { num_tiles_y *= input_tensor.get_legacy_shape()[i]; } @@ -369,11 +370,12 @@ RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensor input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * input_tensor.get_legacy_shape()[2]}; } else { + auto input_tile = input_tensor.get_tensor_spec().tile(); this->flattened_tensor_shape = tt_xy_pair{ - input_tensor.get_legacy_shape()[3] /input_tensor.get_tile().get_width(), + input_tensor.get_legacy_shape()[3] / input_tile.get_width(), (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * input_tensor.get_legacy_shape()[2]) / - input_tensor.get_tile().get_height()}; + input_tile.get_height()}; } this->worker_slice_offsets = DERIVED_SLICER_T::compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 3f71a810bb2..a5e6d49c184 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -415,17 +415,19 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) - num_rows; } else { - this->num_cols = input_tensor.get_legacy_shape()[-1] / input_tensor.get_tile().get_width(); auto input_shape = input_tensor.get_legacy_shape(); auto output_shape = output_tensor.get_legacy_shape(); - uint32_t num_output_cols = output_tensor.get_legacy_shape()[-1] / output_tensor.get_tile().get_width(); + auto input_tile = input_tensor.tensor_spec().tile(); + auto output_tile = output_tensor.tensor_spec().tile(); + this->num_cols = input_shape[-1] / input_tile.get_width(); + uint32_t num_output_cols = output_tensor.get_legacy_shape()[-1] / output_tile.get_width(); this->num_rows = std::accumulate( input_shape.begin() + slice_dim, input_shape.end() - 1, 1, std::multiplies()) / - input_tensor.get_tile().get_height(); + input_tile.get_height(); this->row_offset = (std::accumulate( - output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) / output_tensor.get_tile().get_height() - num_rows) * + output_shape.begin() + slice_dim, output_shape.end() - 1, 1, std::multiplies()) / output_tile.get_height() - num_rows) * num_output_cols; this->col_offset = num_output_cols - num_cols; this->num_tiles = num_rows * num_cols; diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp index a7d663625df..e20bc28435d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp @@ -63,7 +63,7 @@ CCLOpConfig::CCLOpConfig( topology(topology), is_row_major(input_tensors.at(0).get_layout() == Layout::ROW_MAJOR) { if(input_tensors.at(0).get_layout() == Layout::TILE) { - this->tile = input_tensors.at(0).get_tile(); + this->tile = input_tensors.at(0).tensor_spec().tile(); this->page_size = this->tile.get_tile_size(this->df); //this->page_size = input_tensors.at(0).buffer()->page_size(); } else { diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 4dca42fbd88..3b7a325c4ad 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -696,17 +696,14 @@ std::pair> prepare_conv_weights_biases int32_t weight_matrix_height_padding = weight_tensor_.shape()[2] - weight_matrix_height; TT_FATAL(weight_matrix_height_padding >= 0," Matrix Height Padding can't be negative"); - // convert_conv_weight_tensor adds the padding to the base shape. - // Reshape the weights to remove padding from the base shape. - weight_tensor_.set_shape( - ttnn::Shape(std::array{1, 1, weight_matrix_height, out_channels}, + auto target_shape = ttnn::Shape(std::array{1, 1, weight_matrix_height, out_channels}, std::array, 4>{ std::array{0, 0}, std::array{0, 0}, std::array{0, weight_matrix_height_padding}, std::array{0, out_channel_padding} - })); - + }); + weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape); weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt); if (bias_tensor.has_value()) { bias_tensor_ = bias_tensor.value(); diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index 811748c7f8b..c2dcae940bf 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -80,20 +80,20 @@ Tensor to_layout_impl( const auto padded_shape = shape.with_tile_padding(); if (layout == ttnn::ROW_MAJOR_LAYOUT and intended_shape != padded_shape) { return true; - } else if ( - auto tile = tensor.tile(); - layout == ttnn::TILE_LAYOUT and (padded_shape.rank() < 2 or padded_shape[-1] % tile.get_tile_shape()[1] != 0 or - padded_shape[-2] % tile.get_tile_shape()[0] != 0)) { - return true; - } else { - return false; } + if (layout == ttnn::TILE_LAYOUT) { + auto tile_shape = tensor.tensor_spec().tile().get_tile_shape(); + if (padded_shape.rank() < 2 or padded_shape[-1] % tile_shape[1] != 0 or padded_shape[-2] % tile_shape[0] != 0) { + return true; + } + } + return false; }; const auto intended_shape = tensor_arg.get_shape(); auto tensor = tensor_arg; - const auto& tile = tensor.tile(); + const auto tile = tensor.get_tensor_spec().tile(); SmallVector output_shape; if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp index 301b224c6ea..b4e66c61433 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -109,8 +109,8 @@ MassagedConcat build_untilize_rm_retilize_concat(uint8_t queue_id, const MemoryC TT_FATAL(input_tensor.get_layout() == ttnn::TILE_LAYOUT, "ttnn.concat: expected all input tensors to be in tile layout"); auto untilized_tensor = ttnn::untilize(input_tensor); // untilized, so now we have a padded rm tensor - untilized_tensor.set_shape(ttnn::Shape {input_tensor.get_logical_shape().view(), - untilized_tensor.get_padded_shape().view()}); + untilized_tensor = ttnn::reshape(untilized_tensor, + ttnn::Shape {input_tensor.get_logical_shape().view(), untilized_tensor.get_padded_shape().view()}); return untilized_tensor; } ); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index c99e6f62b3d..d3b41a7ff31 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -48,8 +48,8 @@ PermuteDeviceOperation::tensor_return_value_t PermuteDeviceOperation::create_out const auto& input_tensor = tensor_args.input_tensor; return create_device_tensor( output_shape, - input_tensor.tensor_attributes->dtype, - input_tensor.tensor_attributes->layout, + input_tensor.dtype(), + input_tensor.layout(), input_tensor.device()); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp index c35700720bc..d9c23666d66 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp @@ -2,17 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/common/constants.hpp" -#include "ttnn/run_operation.hpp" -#include "ttnn/decorators.hpp" #include "ttnn/operations/data_movement/repeat/repeat.hpp" + #include "device/repeat_op.hpp" -#include "ttnn/operations/data_movement/untilize/untilize.hpp" -#include "ttnn/operations/data_movement/tilize/tilize.hpp" -#include "ttnn/operations/data_movement/slice/slice.hpp" #include "tt_metal/common/math.hpp" +#include "ttnn/common/constants.hpp" +#include "ttnn/decorators.hpp" #include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" +#include "ttnn/operations/data_movement/tilize/tilize.hpp" +#include "ttnn/operations/data_movement/untilize/untilize.hpp" +#include "ttnn/run_operation.hpp" namespace ttnn::operations::data_movement { @@ -99,8 +100,7 @@ ttnn::Tensor RepeatOperation::invoke( auto padded_to_tiled_shape = ttnn::Shape(sliced_logical_shape.view(), tiled_output.get_padded_shape().view()); - tiled_output.set_shape(padded_to_tiled_shape); - return tiled_output; + return ttnn::reshape(tiled_output, padded_to_tiled_shape); } else { return ttnn::slice(output_tensors[0], zero_indices, end_indices, step, input_tensor.memory_config(), std::nullopt); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp index 20c7045b827..21ff349a32c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp @@ -45,7 +45,8 @@ std::vector ReshapeDeviceOperation::compute_output_shapes(con std::vector ReshapeDeviceOperation::create_output_tensors(const std::vector &input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); - return {create_device_tensor(output_shape, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), input_tensor_a.device(), this->output_mem_config, input_tensor_a.tile())}; + return {create_device_tensor(output_shape, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), input_tensor_a.device(), + this->output_mem_config, input_tensor_a.get_tensor_spec().tile())}; } operation::ProgramWithCallbacks ReshapeDeviceOperation::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index 8ded07491f0..6398986a22d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -82,11 +82,11 @@ void SliceDeviceOperation::validate_with_output_tensors( TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error"); } if(!output_tensors.empty() && output_tensors[0].has_value()){ - const auto output_shape_required = std::get<0>(this->compute_output_specs(input_tensors)[0]); + const auto output_shape_required = compute_output_specs(input_tensors)[0].logical_shape(); const auto& out_tensor = output_tensors[0].value(); TT_FATAL(out_tensor.get_padded_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_padded_shape()); } - auto output_tensor_shape = std::get<0>(this->compute_output_specs(input_tensors)[0]); + auto output_tensor_shape = this->compute_output_specs(input_tensors)[0].logical_shape(); if (has_step) { // if all ones modify before passing in to function TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Strided slice is only supported for row major layout"); TT_FATAL(!input_tensor_a.is_sharded(), "Strided slice is not supported for sharded tensor"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 6ac01f85546..49766f75387 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -100,7 +100,7 @@ std::vector Transpose::compute_output_specs(const std::vector< break; } else { uint32_t C = output_shape[1]; - uint32_t C_p = tt::round_up(C, input_tensor.get_tile().get_height()); + uint32_t C_p = tt::round_up(C, input_tensor.get_tensor_spec().tile().get_height()); uint32_t H = output_shape[2]; output_shape[1] = H; output_shape[2] = C; diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp index 69758fdc15e..9c43c604645 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp @@ -458,7 +458,7 @@ void override_runtime_args_mc_hc_tiled_interleaved( auto input_buffer = input_tensor.buffer(); auto output_buffer = output_tensor.buffer(); - auto tile_shape = input_tensor.get_tile().get_tile_shape(); + auto tile_shape = input_tensor.get_tensor_spec().tile().get_tile_shape(); auto tile_hw = (tile_shape[0] * tile_shape[1]); uint32_t num_tensor_tiles = input_tensor.volume() / tile_hw; uint32_t num_output_tiles = output_tensor.volume() / tile_hw; @@ -548,8 +548,9 @@ operation::ProgramWithCallbacks transpose_hc_multi_core_tiled_interleaved(const TT_ASSERT(a.buffer() != nullptr, "Operand to transpose_hc needs to be allocated in a buffer on device!"); tt::tt_metal::Program program = tt::tt_metal::Program(); - auto tile_shape = a.get_tile().get_tile_shape(); - auto face_shape = a.get_tile().get_face_shape(); + auto tile = a.get_tensor_spec().tile(); + auto tile_shape = tile.get_tile_shape(); + auto face_shape = tile.get_face_shape(); uint32_t num_tensor_tiles = a.volume() / (tile_shape[0] * tile_shape[1]); uint32_t num_output_tiles = output.volume() / (tile_shape[0] * tile_shape[1]); uint32_t W = a.get_logical_shape()[3], H = a.get_logical_shape()[2], C = a.get_logical_shape()[1], N = a.get_logical_shape()[0]; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index a0437deb247..bc89424c50a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -152,9 +152,9 @@ void BinaryDeviceOperation::validate_on_program_cache_hit( BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { - const auto input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; + const auto input_shape_a = tensor_args.input_tensor_a.shape(); const auto& tensor_b = tensor_args.input_tensor_b; - const auto input_shape_b = tensor_b.has_value() ? tensor_b->tensor_attributes->shape : ttnn::Shape{1, 1}; + const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1}; const int rank_a = input_shape_a.rank(); const int rank_b = input_shape_b.rank(); diff --git a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp index 02dd5f1e205..1a880b2b24b 100644 --- a/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp @@ -23,7 +23,7 @@ void ExampleDeviceOperation::validate_on_program_cache_hit( ExampleDeviceOperation::shape_return_value_t ExampleDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { - return tensor_args.input_tensor.tensor_attributes->shape; + return tensor_args.input_tensor.shape(); } ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_output_tensors( @@ -32,8 +32,8 @@ ExampleDeviceOperation::tensor_return_value_t ExampleDeviceOperation::create_out const auto& input_tensor = tensor_args.input_tensor; return create_device_tensor( output_shape, - input_tensor.tensor_attributes->dtype, - input_tensor.tensor_attributes->layout, + input_tensor.dtype(), + input_tensor.layout(), input_tensor.device()); } diff --git a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/device/example_multiple_return_device_operation.cpp b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/device/example_multiple_return_device_operation.cpp index b6b9ffeaac3..4c195f2962b 100644 --- a/ttnn/cpp/ttnn/operations/examples/example_multiple_return/device/example_multiple_return_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/examples/example_multiple_return/device/example_multiple_return_device_operation.cpp @@ -26,7 +26,7 @@ void ExampleMultipleReturnDeviceOperation::validate_on_program_cache_hit( ExampleMultipleReturnDeviceOperation::shape_return_value_t ExampleMultipleReturnDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { - return {tensor_args.input_tensor.tensor_attributes->shape, tensor_args.input_tensor.tensor_attributes->shape}; + return {tensor_args.input_tensor.shape(), tensor_args.input_tensor.shape()}; } ExampleMultipleReturnDeviceOperation::tensor_return_value_t ExampleMultipleReturnDeviceOperation::create_output_tensors( @@ -42,14 +42,14 @@ ExampleMultipleReturnDeviceOperation::tensor_return_value_t ExampleMultipleRetur const auto& input_tensor = tensor_args.input_tensor; auto output1 = create_device_tensor( output1_shape, - input_tensor.tensor_attributes->dtype, - input_tensor.tensor_attributes->layout, + input_tensor.dtype(), + input_tensor.layout(), input_tensor.device()); auto output2 = create_device_tensor( output2_shape, - input_tensor.tensor_attributes->dtype, - input_tensor.tensor_attributes->layout, + input_tensor.dtype(), + input_tensor.layout(), input_tensor.device()); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp index 3380cff9e69..004ef31c01b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp @@ -15,7 +15,7 @@ namespace ttnn::operations::experimental::reduction { Tensor create_mask(const Tensor& input_a, const std::optional& output_mem_config) { - auto& padded_shape = input_a.get_legacy_shape(); + auto padded_shape = input_a.get_legacy_shape(); auto& unpadded_shape = padded_shape.without_padding(); if (padded_shape == unpadded_shape) return input_a; @@ -35,7 +35,7 @@ Tensor ArgmaxOperation::invoke(const Tensor& input_t, int64_t _dim, bool all, co const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input = input_tensors.at(0); - auto& input_shape = input.get_legacy_shape(); + auto input_shape = input.get_legacy_shape(); TT_FATAL(input_shape.rank() == 4, "supported for rank-4 tensors at this time"); Tensor input_a = create_mask(input, output_memory_config); diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp index 7ae59a16637..0ae4b2f9256 100644 --- a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp @@ -67,8 +67,8 @@ std::tupledtype), - layout.value_or(input.tensor_attributes->layout), + dtype.value_or(input.dtype()), + layout.value_or(input.layout()), memory_config.value_or(input.memory_config())}, tensor_args_t{input}}; } diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp index 6b8ff0ba570..13666bfb9f1 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_device_operation.cpp @@ -50,8 +50,8 @@ IndexFillOperation::tensor_return_value_t IndexFillOperation::create_output_tens const auto& input = tensor_args.input; return create_device_tensor( output_shape, - input.tensor_attributes->dtype, - input.tensor_attributes->layout, + input.dtype(), + input.layout(), input.device(), operation_attributes.memory_config); } diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 86df11f5769..3603768f368 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -323,8 +323,8 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config( uint32_t K = input_tensor_a.get_legacy_shape()[-1]; uint32_t N = input_tensor_b.get_legacy_shape()[-1]; uint32_t per_core_M, per_core_N; - auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); - auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); if (mcast_in0) { per_core_M = M / in0_tile_shape[0]; per_core_N = div_up(div_up(N, grid_size.x * grid_size.y), in1_tile_shape[1]); @@ -367,8 +367,8 @@ inline MatmulProgramConfig create_simple_matmul_program_config( uint32_t batch_size_a = get_batch_size(ashape); uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N - auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); - auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); // Parameters for large matmul with reuse uint32_t B = batch_size_a; @@ -635,8 +635,8 @@ MatmulProgramConfig get_matmul_program_config( // generic sharded output tensor creation auto grid_size = input_tensor_a.shard_spec().value().grid.bounding_box().grid_size(); - auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); - auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); // MCAST matmuls only support input_b in INTERLEAVED if (matmul) { @@ -982,8 +982,8 @@ Matmul create_matmul_struct( bool broadcast_batch = parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config)); TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config"); - const auto& in0_tile = input_tensor_a.get_tile(); - const auto& in1_tile = input_tensor_b.get_tile(); + auto in0_tile = input_tensor_a.get_tensor_spec().tile(); + auto in1_tile = input_tensor_b.get_tensor_spec().tile(); tt::tt_metal::Tile output_tile = get_output_tile( parameters.output_mem_config, in0_tile, in1_tile, parameters.output_tile); @@ -1048,20 +1048,20 @@ void Matmul::validate( const auto& input_tensor_b = input_tensors.at(1); const auto& a_shape = input_tensor_a.get_shape(); const auto& b_shape = input_tensor_b.get_shape(); - auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); - auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); if (input_tensor_a.device()->arch() == tt::ARCH::GRAYSKULL) { TT_FATAL( - (input_tensor_a.get_tile().get_tile_shape()[1] == TILE_WIDTH && input_tensor_a.get_tile().get_tile_shape()[0] == TILE_HEIGHT), + (in0_tile_shape[1] == TILE_WIDTH && in0_tile_shape[0] == TILE_HEIGHT), "Grayskull does not support tiny tile"); TT_FATAL( - (input_tensor_b.get_tile().get_tile_shape()[1] == TILE_WIDTH && input_tensor_b.get_tile().get_tile_shape()[0] == TILE_HEIGHT), + (in1_tile_shape[1] == TILE_WIDTH && in1_tile_shape[0] == TILE_HEIGHT), "Grayskull does not support tiny tile"); } TT_FATAL( - (input_tensor_a.get_tile().get_tile_shape()[1] == TILE_WIDTH && in1_tile_shape[0] == TILE_WIDTH), + (in0_tile_shape[1] == TILE_WIDTH && in1_tile_shape[0] == TILE_WIDTH), "Input tile dims must have inner dim equal to 32 due to llk constraints"); TT_FATAL( @@ -1103,11 +1103,12 @@ void Matmul::validate( TT_FATAL(optional_input_tensors.size() == 1, "Error"); const auto& optional_bias = optional_input_tensors.at(0); if (optional_bias.has_value()) { + const auto& bias = optional_bias.value(); + auto bias_tile_shape = bias.tensor_spec().tile().get_tile_shape(); TT_FATAL( - (optional_bias->get_tile().get_tile_shape()[0] == input_tensor_a.get_tile().get_tile_shape()[0] && - optional_bias->get_tile().get_tile_shape()[1] == in1_tile_shape[1]), + (bias_tile_shape[0] == in0_tile_shape[0] && + bias_tile_shape[1] == in1_tile_shape[1]), "Input tile dims must have inner dim equal to 32 due to llk constraints"); - const auto& bias = optional_bias.value(); TT_FATAL(bias.get_layout() == Layout::TILE, "Unsupported input layout"); const auto& bias_shape = bias.get_shape(); uint32_t bias_batch_size = get_batch_size(bias_shape); @@ -1422,8 +1423,8 @@ std::vector Matmul::compute_output_shapes(const std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); - auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); - auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); auto output_tile = this->output_tile.value(); auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1]; auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index 516031c4e27..1ef14eaf848 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -1648,12 +1648,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( bool untilize_out, std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); - auto in0_tile = a.get_tile(); - auto in1_tile = b.get_tile(); + auto in0_tile = a.get_tensor_spec().tile(); + auto in1_tile = b.get_tensor_spec().tile(); // cannot use the output tensor tile directly as that might be changed by user override - auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); - auto in0_tile_shape = a.get_tile().get_tile_shape(); - auto in1_tile_shape = b.get_tile().get_tile_shape(); + auto in0_tile_shape = in0_tile.get_tile_shape(); + auto in1_tile_shape = in1_tile.get_tile_shape(); + auto output_tile = tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]}); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1768,7 +1768,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( out_buffer, in0_tile, in1_tile, - bias.has_value() ? bias->get_tile() : output_tile, + bias.has_value() ? bias->get_tensor_spec().tile() : output_tile, output_tile, in0_data_format, in1_data_format, @@ -1805,7 +1805,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( out_buffer, in0_tile, in1_tile, - bias.has_value() ? bias->get_tile() : output_tile, + bias.has_value() ? bias->get_tensor_spec().tile() : output_tile, output_tile, in0_data_format, in1_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp index 16ab7359316..70736290e0d 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp @@ -1313,12 +1313,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( bool untilize_out, std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); - auto in0_tile = a.get_tile(); - auto in1_tile = b.get_tile(); + auto in0_tile = a.get_tensor_spec().tile(); + auto in1_tile = b.get_tensor_spec().tile(); + auto in0_tile_shape = in0_tile.get_tile_shape(); + auto in1_tile_shape = in1_tile.get_tile_shape(); // cannot use the output tensor tile directly as that might be changed by user override - auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); - auto in0_tile_shape = a.get_tile().get_tile_shape(); - auto in1_tile_shape = b.get_tile().get_tile_shape(); + auto output_tile = tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]}); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1433,7 +1433,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( out_buffer, in0_tile, in1_tile, - bias.has_value() ? bias->get_tile() : output_tile, + bias.has_value() ? bias->get_tensor_spec().tile() : output_tile, output_tile, in0_data_format, in1_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp index 72203d67390..a4996e3b19c 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp @@ -1266,12 +1266,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( bool skip_in0_mcast, bool skip_write_back) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); - auto in0_tile = a.get_tile(); - auto in1_tile = b.get_tile(); + auto in0_tile = a.get_tensor_spec().tile(); + auto in1_tile = b.get_tensor_spec().tile(); + auto in0_tile_shape = in0_tile.get_tile_shape(); + auto in1_tile_shape = in1_tile.get_tile_shape(); // cannot use the output tensor tile directly as that might be changed by user override auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); - auto in0_tile_shape = a.get_tile().get_tile_shape(); - auto in1_tile_shape = b.get_tile().get_tile_shape(); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1296,8 +1296,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( TT_FATAL(a.shard_spec().has_value() && output.shard_spec().has_value(), "Error"); CoreRangeSet all_cores_storage = a.shard_spec().value().grid; - uint32_t in0_single_tile_size = a.get_tile().get_tile_size(in0_data_format); - uint32_t in1_single_tile_size = b.get_tile().get_tile_size(in1_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0, "Error"); @@ -1356,7 +1356,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( out_buffer, in0_tile, in1_tile, - bias.has_value() ? bias->get_tile() : output_tile, + bias.has_value() ? bias->get_tensor_spec().tile() : output_tile, output_tile, in0_data_format, in1_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp index b635b938e6e..614d1627844 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp @@ -61,8 +61,8 @@ operation::ProgramWithCallbacks create_program( ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); - auto in0_tile = in0.get_tile(); - auto in1_tile = in1.get_tile(); + auto in0_tile = in0.get_tensor_spec().tile(); + auto in1_tile = in1.get_tensor_spec().tile(); // currently only support transpose of the full tile bool in1_transpose_tile = in1_tile.get_transpose_of_faces() && in1_tile.get_transpose_within_face(); auto in1_tile_shape = in1_tile.get_tile_shape(); @@ -498,8 +498,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_( bool untilize_out) { const auto& ashape = a.get_legacy_shape(); const auto& bshape = b.get_legacy_shape(); - auto in0_tile_shape = a.get_tile().get_tile_shape(); - auto in1_tile_shape = b.get_tile().get_tile_shape(); + auto in0_tile_shape = a.get_tensor_spec().tile().get_tile_shape(); + auto in1_tile_shape = b.get_tensor_spec().tile().get_tile_shape(); TT_FATAL( (bcast_batch == false) or (ashape[0] == 1) or (ashape.rank() == 2), diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp index 08390f56752..03391bbec7b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp @@ -70,7 +70,7 @@ MorehCumsumDeviceOperation::tensor_return_value_t MorehCumsumDeviceOperation::cr auto output_shape = compute_output_shapes(operation_attributes, tensor_args); return create_device_tensor( - output_shape, input.tensor_attributes->dtype, input.tensor_attributes->layout, input.device()); + output_shape, input.dtype(), input.layout(), input.device()); } std::tuple diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp index d8c67f67898..21c2013f6d6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp @@ -67,8 +67,8 @@ MorehDotOperation::tensor_return_value_t MorehDotOperation::create_output_tensor const auto& input_tensor = tensor_args.input_a; return create_device_tensor( output_shape, - input_tensor.tensor_attributes->dtype, - input_tensor.tensor_attributes->layout, + input_tensor.dtype(), + input_tensor.layout(), input_tensor.device(), operation_attributes.memory_config); } diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index fcd92f10f6d..9b329555420 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -58,7 +58,8 @@ static Tensor full( .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, std::optional optional_output_tensor = std::nullopt) { constexpr DataType data_type = detail::get_data_type(); - auto owned_buffer = tt::tt_metal::owned_buffer::create(tt::tt_metal::compute_volume(shape)); + TensorSpec tensor_spec(shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); + auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); if (!optional_output_tensor.has_value()){ @@ -227,7 +228,8 @@ static Tensor arange( owned_buffer[index++] = static_cast(value); } } - auto output = Tensor(OwnedStorage{owned_buffer}, ttnn::SimpleShape{1, 1, 1, static_cast(size)}, data_type, layout); + auto output = Tensor(OwnedStorage{owned_buffer}, ttnn::SimpleShape{1, 1, 1, static_cast(size)}, data_type, Layout::ROW_MAJOR) + .to(layout); if (device != nullptr) { output = output.to(device, output_mem_config); } @@ -444,7 +446,7 @@ static Tensor fill_first_val_into_tensor( owned_buffer[i] = input_buffer[0]; } const tt::tt_metal::LegacyShape& s_a = input_tensor.get_legacy_shape(); - auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, Layout::ROW_MAJOR).to(layout); if (device != nullptr) { output = output.to(device, output_mem_config); } @@ -493,7 +495,7 @@ static Tensor prod_result_computation_GS( } owned_buffer[0] = result; // store the result at the first position of the tensor,and the rest of the values as // 0.0f - auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, Layout::ROW_MAJOR).to(layout); if (device != nullptr) { output = output.to(device, output_mem_config); } @@ -546,7 +548,7 @@ static Tensor prod_result_computation_WH_B0( } owned_buffer[0] = result; // store the result at the first position of the tensor,and the rest of the values as // 0.0f - auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, Layout::ROW_MAJOR).to(layout); if (device != nullptr) { output = output.to(device, output_mem_config); } @@ -710,7 +712,7 @@ static Tensor uniform(T low, T high, const tt::tt_metal::LegacyShape& shape, con } } - return Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + return Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR).to(layout); } static Tensor random( diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index c06cc37e3e5..dbacda85c7d 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -315,9 +315,10 @@ std::vector extract_legacy_shapes( std::vector legacy_shapes; legacy_shapes.reserve(tensor_specs.size()); for (size_t idx = 0; idx < tensor_specs.size(); idx++) { - const auto& [simple_shape, output_layout] = tensor_specs[idx]; - TensorLayout tensor_layout = use_tensor_layout_from_tensor_spec ? output_layout : layout_provider(idx); - legacy_shapes.emplace_back(simple_shape.view(), tensor_layout.compute_padded_shape(simple_shape).view()); + const auto& tensor_spec = tensor_specs[idx]; + TensorLayout tensor_layout = use_tensor_layout_from_tensor_spec ? tensor_spec.tensor_layout() : layout_provider(idx); + auto logical_shape = tensor_spec.logical_shape(); + legacy_shapes.emplace_back(logical_shape.view(), tensor_layout.compute_padded_shape(logical_shape).view()); } return legacy_shapes; } else { diff --git a/ttnn/cpp/ttnn/run_operation_inl.hpp b/ttnn/cpp/ttnn/run_operation_inl.hpp index 687a821f115..3c992c3e23c 100644 --- a/ttnn/cpp/ttnn/run_operation_inl.hpp +++ b/ttnn/cpp/ttnn/run_operation_inl.hpp @@ -232,11 +232,7 @@ void launch_op( insert_buffer_and_shape_for_device(target_device, *local_tensor, *output_tensor); int num_workers_completed = (output_tensor->tensor_attributes->num_workers_completed)++; if (not num_workers_completed) { - output_tensor->tensor_attributes->shape = local_tensor->tensor_attributes->shape; - output_tensor->tensor_attributes->dtype = local_tensor->tensor_attributes->dtype; - output_tensor->tensor_attributes->layout = local_tensor->tensor_attributes->layout; - output_tensor->tensor_attributes->tile = local_tensor->tensor_attributes->tile; - output_tensor->tensor_attributes->metadata_populated = true; + output_tensor->set_tensor_spec(local_tensor->tensor_spec()); } } } diff --git a/ttnn/cpp/ttnn/tensor/layout/page_config.cpp b/ttnn/cpp/ttnn/tensor/layout/page_config.cpp index e5bb55a1e28..3c9307f97de 100644 --- a/ttnn/cpp/ttnn/tensor/layout/page_config.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/page_config.cpp @@ -8,7 +8,7 @@ namespace tt::tt_metal { namespace { namespace CMAKE_UNIQUE_NAMESPACE { -size_t element_size_bytes(DataType dtype) { +size_t rm_element_size_bytes(DataType dtype) { switch (dtype) { case DataType::BFLOAT16: return sizeof(bfloat16); case DataType::FLOAT32: return sizeof(float); @@ -18,7 +18,8 @@ size_t element_size_bytes(DataType dtype) { case DataType::UINT8: return sizeof(uint8_t); case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: - TT_THROW("element_size_bytes() should not be used for BFLOAT8_B and BFLOAT4_B types becaues of how they are packed"); + // To store block floats in RowMajor layout, we use a fallback and store full floats instead + return sizeof(float); default: TT_THROW("Unsupported data type!"); @@ -37,10 +38,13 @@ PageConfig::PageConfig(Layout layout) PageConfig::PageConfig(Layout layout, const std::optional& tile) { if(layout == Layout::ROW_MAJOR) { - config_ = RowMajorPageConfig(); + if (tile.has_value()) { + tt::log_warning("Specifying tile shape for a row major layout is deprecated, and will be removed soon"); + } + config_ = RowMajorPageConfig(tile.value_or(Tile())); } else { - config_ = TilePageConfig(tile.value_or(Tile())); + config_ = TilePageConfig(tile.value_or(Tile())); } } @@ -60,17 +64,16 @@ size_t PageConfig::get_page_size_bytes(const Size& page_shape, DataType dtype) c return std::visit([&](const auto& config) constexpr { return config.get_page_size_bytes(page_shape, dtype); }, config_); } -bool PageConfig::is_row_major() const { - return std::holds_alternative(config_); +Layout PageConfig::get_layout() const { + if (std::holds_alternative(config_)) { + return Layout::ROW_MAJOR; + } + return Layout::TILE; } -std::optional PageConfig::get_tile() const -{ - if(std::holds_alternative(config_)) { - return std::get(config_).get_tile(); - } - return std::nullopt; +Tile PageConfig::get_tile() const { + return std::visit([&](const auto& config) { return config.get_tile(); }, config_); } @@ -115,10 +118,10 @@ const Tile& TilePageConfig::get_tile() const { return tile_; } +RowMajorPageConfig::RowMajorPageConfig(const Tile& tile) : tile_(tile) {} + Alignment RowMajorPageConfig::create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const { { - TT_FATAL(dtype != DataType::BFLOAT4_B && dtype != DataType::BFLOAT8_B, "BFLOAT4_B and BFLOAT8_B data types are not supported for ROW_MAJOR layout"); - uint32_t width_alignment = 1; if (memory_config.shard_spec.has_value()) { const auto& shard_spec = memory_config.shard_spec.value(); @@ -151,7 +154,7 @@ void RowMajorPageConfig::validate_alignment(const Alignment& alignment, DataType Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const { if (physical_size.height() == 0 || physical_size.width() == 0) { - return Size(1, sizeof(uint32_t) / CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype)); + return Size(1, sizeof(uint32_t) / CMAKE_UNIQUE_NAMESPACE::rm_element_size_bytes(dtype)); } if(memory_config.memory_layout == TensorMemoryLayout::SINGLE_BANK) { @@ -168,8 +171,12 @@ Size RowMajorPageConfig::get_page_shape(const Size& physical_size, DataType dtyp } size_t RowMajorPageConfig::get_page_size_bytes(const Size& page_shape, DataType dtype) const { - const auto size = page_shape.height() * page_shape.width() * CMAKE_UNIQUE_NAMESPACE::element_size_bytes(dtype); + const auto size = page_shape.height() * page_shape.width() * CMAKE_UNIQUE_NAMESPACE::rm_element_size_bytes(dtype); return size; } +const Tile& RowMajorPageConfig::get_tile() const { + return tile_; +} + } // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/layout/page_config.hpp b/ttnn/cpp/ttnn/tensor/layout/page_config.hpp index 0338722bc03..fc43041e0e2 100644 --- a/ttnn/cpp/ttnn/tensor/layout/page_config.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/page_config.hpp @@ -20,11 +20,28 @@ namespace tt::tt_metal { class RowMajorPageConfig { public: + RowMajorPageConfig(const Tile& tile = Tile()); + Alignment create_default_alignment(DataType dtype, const MemoryConfig& memory_config) const; void validate_alignment(const Alignment& alignment, DataType dtype, const MemoryConfig& memory_config) const; Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const; size_t get_page_size_bytes(const Size& page_size, DataType dtype) const; + + const Tile& get_tile() const; + + bool operator==(const RowMajorPageConfig&) const = default; + bool operator!=(const RowMajorPageConfig&) const = default; + + static constexpr auto attribute_names = std::forward_as_tuple("tile"); + const auto attribute_values() const { + return std::forward_as_tuple(tile_); + } + +private: + // This is currently needed for compatibility reasons. + // Each time tile is specified, a warning will be issued. This should be removed soon. + Tile tile_; }; class TilePageConfig { @@ -39,6 +56,14 @@ class TilePageConfig { const Tile& get_tile() const; + bool operator==(const TilePageConfig&) const = default; + bool operator!=(const TilePageConfig&) const = default; + + static constexpr auto attribute_names = std::forward_as_tuple("tile"); + const auto attribute_values() const { + return std::forward_as_tuple(tile_); + } + private: Tile tile_; }; @@ -57,9 +82,17 @@ class PageConfig { Size get_page_shape(const Size& physical_size, DataType dtype, const MemoryConfig& memory_config, const std::optional& physical_shard_size) const; size_t get_page_size_bytes(const Size& page_size, DataType dtype) const; - std::optional get_tile() const; + Tile get_tile() const; + + Layout get_layout() const; + + bool operator==(const PageConfig&) const = default; + bool operator!=(const PageConfig&) const = default; - bool is_row_major() const; + static constexpr auto attribute_names = std::forward_as_tuple("config"); + const auto attribute_values() const { + return std::forward_as_tuple(config_); + } private: Config config_; diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index b12189faa04..78f1587004e 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -33,7 +33,7 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape, const PageConfig& pag // SHARDED if (memory_config.shard_spec.has_value()) { TT_FATAL(alignment_can_be_2D, "Tensor with shape {} cannot be sharded because alignment will have rank greater than 2!", shape); - if (page_config.is_row_major()) { + if (page_config.get_layout() == Layout::ROW_MAJOR) { const auto& shard_spec = memory_config.shard_spec.value(); if (shard_spec.physical_shard_shape.has_value()) { return Alignment{shard_spec.physical_shard_shape.value()[1]}; @@ -67,6 +67,12 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape, const PageConfig& pag values[i] = legacy_padded_shape[i] * values[i + 1]; } + for (auto& value : values) { + if (value == 0) { + value = 1; + } + } + Alignment result(std::move(values)); return result; } @@ -128,7 +134,6 @@ std::optional TensorLayout::compute_shard_spec_buffer(const ttn switch (shard_spec.mode) { case ShardMode::PHYSICAL: - TT_FATAL(shard_spec.shape[0] % alignment_[-2] == 0 and shard_spec.shape[1] % alignment_[-1] == 0, "In shard mode {}, physical shard shape {} is not compatible with alignment {}!", shard_spec.mode, shard_spec.shape, alignment_); break; case ShardMode::LOGICAL: { const auto& physical_shard_shape = compute_physical_shard_shape(shard_spec.shape); diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index e7dc7d8865d..0e0ab95e996 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -28,7 +28,7 @@ class TensorLayout { [[deprecated("Use of Legacy Padded Shape is deprecated")]] static TensorLayout fromLegacyPaddedShape(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape); - Layout get_layout() const { return page_config_.is_row_major() ? Layout::ROW_MAJOR : Layout::TILE; } + Layout get_layout() const { return page_config_.get_layout(); } PageConfig get_page_config() const { return page_config_; } DataType get_data_type() const { return dtype_; } const MemoryConfig& get_memory_config() const { return memory_config_; } @@ -51,6 +51,20 @@ class TensorLayout { // H is all dimensions except W multiplied and aligned to tile and shard height Size compute_physical_shape(const ttnn::SimpleShape& shape) const; + TensorLayout with_memory_config(MemoryConfig memory_config) const { + TensorLayout result = *this; + result.memory_config_ = std::move(memory_config); + return result; + } + + bool operator==(const TensorLayout&) const = default; + bool operator!=(const TensorLayout&) const = default; + + static constexpr auto attribute_names = std::forward_as_tuple("dtype", "page_config", "memory_config", "alignment"); + const auto attribute_values() const { + return std::forward_as_tuple(dtype_, page_config_, memory_config_, alignment_); + } + private: // Private to not expose alignment parameter to the public API TensorLayout(DataType dtype, const PageConfig& page_config, const MemoryConfig& memory_config, const Alignment& alignment); diff --git a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp index 89a6f1228c5..8ca389ba754 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp @@ -16,7 +16,7 @@ class ShapeBase { public: using Container = SmallVector; - ShapeBase() = default; + ShapeBase() { init(); }; explicit ShapeBase(const Container& shape) : value_(shape) { init(); } explicit ShapeBase(Container&& shape) : value_(std::move(shape)) { init(); } explicit ShapeBase(std::initializer_list ilist) : value_(ilist) { init(); } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index eb499ec8c4e..92b5f4b359d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -33,6 +33,33 @@ namespace tt { namespace tt_metal { +namespace { +namespace CMAKE_UNIQUE_NAMESPACE { +MemoryConfig extract_memory_config(const Storage& storage) { + return std::visit( + [](const auto &storage) -> MemoryConfig { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return storage.memory_config(); + } else if constexpr (std::is_same_v) { + return storage.memory_config(); + } else { + return MemoryConfig{}; + } + }, + storage); +} +} +} + +Tensor::TensorAttributes::TensorAttributes(): tensor_spec( + ttnn::SimpleShape(std::array{0xff, 0xff, 0xff, 0xff}), + TensorLayout(DataType::INVALID, PageConfig(Layout::INVALID), MemoryConfig{})) {} + +Tensor::TensorAttributes::TensorAttributes(Storage storage, TensorSpec tensor_spec) : + storage(std::move(storage)), tensor_spec(std::move(tensor_spec)), metadata_populated(true) { +} + void Tensor::TensorAttributes::increment_main_thread_ref_count(Device *worker) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and not tt::tt_metal::detail::InWorkerThread()) { main_thread_ref_count++; @@ -74,40 +101,46 @@ void Tensor::TensorAttributes::update_main_thread_ref_count(Device *worker, uint } } -Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, const std::optional& tile) : - tensor_id{std::nullopt}, - deallocate_through_destructor(false) { - +Tensor::Tensor(Storage storage, const ttnn::Shape& shape, DataType dtype, Layout layout, const std::optional& tile) { if (tile.has_value()) { - tensor_attributes = std::make_shared(storage, shape, dtype, layout, tile.value()); - if (tile->get_tile_shape()[0] != TILE_WIDTH or tile->get_tile_shape()[1] != TILE_HEIGHT) { tt::log_warning("only matmul op and ccl all-gather currently supports the customized tile shape: {}", tile->get_tile_shape()); } - } else { - tensor_attributes = std::make_shared(storage, shape, dtype, layout); } + auto memory_config = CMAKE_UNIQUE_NAMESPACE::extract_memory_config(storage); + init(std::move(storage), TensorSpec(shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(layout, tile), memory_config, shape))); +} + +Tensor::Tensor(Storage storage, TensorSpec tensor_spec) { + init(std::move(storage), std::move(tensor_spec)); +} + +void Tensor::init(Storage storage, TensorSpec tensor_spec) { + tensor_attributes = std::make_shared(std::move(storage), std::move(tensor_spec)); ZoneScoped; std::visit( [&](auto&& storage) { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = 1; + tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { TT_ASSERT(storage.buffer->device() != nullptr); workers = {storage.buffer->device()}; - tensor_impl::validate_on_device_dtype_and_layout(storage.buffer->device(), shape.padded_shape(), dtype, layout); + tensor_impl::validate_on_device_dtype_and_layout(storage.buffer->device(), + tensor_attributes->tensor_spec.padded_shape(), + tensor_attributes->tensor_spec.data_type(), + tensor_attributes->tensor_spec.layout()); // Increment main thread ref count for all tensors on device - this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); + tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly // deallocated inside the worker (composite ops do this). if (tt::tt_metal::detail::InWorkerThread()) { - this->tensor_attributes->main_thread_tensor = false; + tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->num_shards_to_be_populated = 1; + tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = 1; + tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { workers.reserve(storage.num_buffers()); for (int i = 0; i < storage.ordered_device_ids.size(); i++) { @@ -115,76 +148,78 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L auto buffer = storage.get_buffer_for_device_id(device_id); TT_ASSERT(buffer->device() != nullptr); TT_ASSERT(buffer->device()->id() == device_id); - tensor_impl::validate_on_device_dtype_and_layout(buffer->device(), shape.padded_shape(), dtype, layout); + tensor_impl::validate_on_device_dtype_and_layout(buffer->device(), + tensor_attributes->tensor_spec.padded_shape(), + tensor_attributes->tensor_spec.data_type(), + tensor_attributes->tensor_spec.layout()); workers.push_back(buffer->device()); } // Increment main thread ref count for all tensors on cluster - this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); + tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly // deallocated inside the worker (composite ops do this). if (tt::tt_metal::detail::InWorkerThread()) { - this->tensor_attributes->main_thread_tensor = false; + tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); + tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); } else if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); + tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); } else { raise_unsupported_storage(); } }, - storage); - this->tensor_attributes->num_workers_completed = this->tensor_attributes->num_shards_to_be_populated; - this->tensor_attributes->metadata_populated = true; + tensor_attributes->storage); + tensor_attributes->num_workers_completed = this->tensor_attributes->num_shards_to_be_populated; } -Tensor::Tensor( - const std::vector& workers, - uint32_t num_buffers, - std::optional distributed_tensor_config) : - tensor_id(std::nullopt), +Tensor::Tensor(const std::vector& workers): tensor_attributes(std::make_shared()), - workers(workers), - deallocate_through_destructor(false) { - // When creating a device tensor, specify workers. - // When creating a host tensor, specify num_buffers. - // If neither are specified, a dummy tensor is being created. Do nothing. - if (workers.size()) { - if (not tt::tt_metal::detail::InWorkerThread()) { - this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); - } else { - // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly - // deallocated inside the worker (composite ops do this). - this->tensor_attributes->main_thread_tensor = false; - } + workers(workers) { + if (workers.empty()) { + return; + } + + tensor_attributes->storage = [&](){ if (workers.size() == 1) { - this->tensor_attributes->storage = DeviceStorage(); - } else if (workers.size() > 1) { - this->tensor_attributes->storage = MultiDeviceStorage(); - std::transform( - workers.cbegin(), - workers.cend(), - std::back_inserter( - std::get(this->tensor_attributes->storage).ordered_device_ids), - [](const Device *worker) { return worker->id(); }); + return Storage(DeviceStorage()); } - this->tensor_attributes->num_shards_to_be_populated = workers.size(); - } else if (num_buffers) { + MultiDeviceStorage storage; + std::transform( + workers.cbegin(), + workers.cend(), + std::back_inserter(storage.ordered_device_ids), + [](const Device *worker) { return worker->id(); }); + return Storage(std::move(storage)); + }(); + tensor_attributes->num_shards_to_be_populated = workers.size(); + if (!tt::tt_metal::detail::InWorkerThread()) { + tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); + } else { + // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly + // deallocated inside the worker (composite ops do this). + tensor_attributes->main_thread_tensor = false; + } +} + +Tensor::Tensor(uint32_t num_buffers, std::optional distributed_tensor_config): + tensor_attributes(std::make_shared()) { + if(num_buffers == 0) { + return; + } + + tensor_attributes->storage = [&]() { if (num_buffers == 1) { - this->tensor_attributes->storage = OwnedStorage(); - } else { - this->tensor_attributes->storage = MultiDeviceHostStorage(); - // Preallocate buffer and shape vector for MultiDeviceHostStorage - if (distributed_tensor_config.has_value()) { - std::get(this->tensor_attributes->storage).strategy = - distributed_tensor_config.value(); - } - std::get(this->tensor_attributes->storage).buffers = - std::vector(num_buffers, OwnedBuffer()); - std::get(this->tensor_attributes->storage).shapes = - std::vector(num_buffers, this->tensor_attributes->shape); + return Storage(OwnedStorage()); } - this->tensor_attributes->num_shards_to_be_populated = num_buffers; - } + MultiDeviceHostStorage storage; + if (distributed_tensor_config.has_value()) { + storage.strategy = distributed_tensor_config.value(); + } + storage.buffers = std::vector(num_buffers, OwnedBuffer()); + storage.shapes = std::vector(num_buffers, ttnn::Shape{}); + return Storage(std::move(storage)); + }(); + tensor_attributes->num_shards_to_be_populated = num_buffers; } Tensor &Tensor::operator=(const Tensor &other) { @@ -391,13 +426,9 @@ void Tensor::deepcopy(const Tensor& other) { // Wait until the tensor being copied is populated other.wait_for_tensor_data_populated(); // Populate tensor metadata - this->set_shape(other.get_shape()); this->set_storage(other.get_storage()); - this->set_dtype(other.get_dtype()); - this->set_layout(other.get_layout()); - this->set_tile(other.get_tile()); + this->set_tensor_spec(other.get_tensor_spec()); // Set metadata populated flag for getters - this->tensor_attributes->metadata_populated = true; this->tensor_attributes->num_workers_completed++; } @@ -405,11 +436,7 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { ZoneScoped; // Similar to deepcopy, but to be applied on a tensor that has an empty storage // container initialized. Require tensor storage to be correctly initialized. - this->set_shape(other.get_shape()); - this->set_dtype(other.get_dtype()); - this->set_layout(other.get_layout()); - this->set_tile(other.get_tile()); - + this->set_tensor_spec(other.get_tensor_spec()); // Populate storage container with buffers + shapes std::visit( [this](auto&& storage) { @@ -425,7 +452,6 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { }, other.get_storage()); // Non blocking storage query, since this is done for tensors that get created inside the // worker thread - this->tensor_attributes->metadata_populated = true; this->tensor_attributes->num_workers_completed++; } @@ -482,26 +508,41 @@ std::vector Tensor::get_workers(bool blocking) const { } // Getters - Spin until tensor is populated before querying tensor metadata -const tt::tt_metal::LegacyShape& Tensor::get_legacy_shape() const { - this->wait_for_tensor_metadata_populated(); - return this->tensor_attributes->shape.value; +tt::tt_metal::LegacyShape Tensor::get_legacy_shape() const { + wait_for_tensor_metadata_populated(); + return legacy_shape(); } -const ttnn::Shape& Tensor::get_shape() const { - this->wait_for_tensor_metadata_populated(); - return this->tensor_attributes->shape; +ttnn::Shape Tensor::get_shape() const { + wait_for_tensor_metadata_populated(); + return shape(); +} +DataType Tensor::get_dtype() const { + wait_for_tensor_metadata_populated(); + return dtype(); } -const DataType& Tensor::get_dtype() const { - this->wait_for_tensor_metadata_populated(); - return this->tensor_attributes->dtype; +Layout Tensor::get_layout() const { + wait_for_tensor_metadata_populated(); + return layout(); +} + +const TensorSpec& Tensor::get_tensor_spec() const { + wait_for_tensor_metadata_populated(); + return tensor_spec(); } -const Layout& Tensor::get_layout() const { - this->wait_for_tensor_metadata_populated(); - return this->tensor_attributes->layout; + +const ttnn::SimpleShape& Tensor::get_logical_shape() const { + wait_for_tensor_metadata_populated(); + return logical_shape(); +} + +const ttnn::SimpleShape& Tensor::get_padded_shape() const { + wait_for_tensor_metadata_populated(); + return padded_shape(); } -const Tile& Tensor::get_tile() const { - this->wait_for_tensor_metadata_populated(); - return this->tensor_attributes->tile; + +tt::tt_metal::Padding Tensor::get_padding() const { + return get_legacy_shape().padding(); } const Storage& Tensor::get_storage() const { @@ -647,25 +688,17 @@ bool Tensor::is_scalar() const { return logical_shape.rank() == 0 || logical_shape.volume() == 1; } -ttnn::SimpleShape Tensor::get_logical_shape() const { - return this->get_shape().logical_shape(); -} - -ttnn::SimpleShape Tensor::get_padded_shape() const { - return this->get_shape().padded_shape(); -} - -tt::tt_metal::Padding Tensor::get_padding() const { - return this->get_legacy_shape().padding(); -} - -Tensor create_device_tensor( - const ttnn::SimpleShape& shape, const TensorLayout& tensor_layout, Device* device) { +Tensor create_device_tensor(const TensorSpec& tensor_spec, Device* device) { ZoneScoped; - GraphTracker::instance().track_function_start("tt::tt_metal::create_device_tensor", shape, tensor_layout.get_data_type(), tensor_layout.get_layout(), device, tensor_layout.get_memory_config()); - - auto device_buffer = tensor_impl::allocate_buffer_on_device(device, shape, tensor_layout); - auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(shape.view(), tensor_layout.compute_padded_shape(shape).view()), tensor_layout.get_data_type(), tensor_layout.get_layout(), tensor_layout.get_page_config().get_tile()); + GraphTracker::instance().track_function_start("tt::tt_metal::create_device_tensor", + tensor_spec.logical_shape(), + tensor_spec.tensor_layout().get_data_type(), + tensor_spec.tensor_layout().get_layout(), + device, + tensor_spec.tensor_layout().get_memory_config()); + + auto device_buffer = tensor_impl::allocate_buffer_on_device(device, tensor_spec); + auto output = Tensor(DeviceStorage{device_buffer}, tensor_spec); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); @@ -674,12 +707,12 @@ Tensor create_device_tensor( } Tensor create_device_tensor(const ttnn::SimpleShape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { - return create_device_tensor(shape, TensorLayout(data_type, PageConfig(layout, tile), memory_config), device); + return create_device_tensor(TensorSpec(shape, TensorLayout(data_type, PageConfig(layout, tile), memory_config)), device); } Tensor create_device_tensor( const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { - return create_device_tensor(shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape), device); + return create_device_tensor(TensorSpec(shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape)), device); } namespace detail { @@ -831,6 +864,7 @@ Tensor allocate_tensor_on_device( const std::optional& tile) { // Top level wrapper to asynchronously create a device tensor (multi-device) Tensor device_tensor = Tensor(mesh_device->get_devices()); + TensorSpec tensor_spec(shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape)); // Save the ref count to later re-set it: // 1. device_tensor is copied in the lambda by the main thread, which increments the ref count. @@ -841,20 +875,15 @@ Tensor allocate_tensor_on_device( for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work( - [shape, data_type, layout, worker, memory_config, tile, device_tensor, worker_index]() mutable { - auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config, tile); - insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); - - uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - device_tensor.set_shape(local_tensor.get_shape()); - device_tensor.set_dtype(local_tensor.get_dtype()); - device_tensor.set_layout(local_tensor.get_layout()); - device_tensor.set_tile(local_tensor.get_tile()); - device_tensor.tensor_attributes->metadata_populated = true; - } - }); + worker->push_work([worker, device_tensor, tensor_spec, worker_index]() mutable { + auto local_tensor = create_device_tensor(tensor_spec, worker); + insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); + + uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + device_tensor.set_tensor_spec(tensor_spec); + } + }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); return device_tensor; @@ -881,8 +910,7 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) { "write_tensor only supports host_tensor to device_tensor data transfer"); TT_FATAL(async_safe_tensor.get_shape() == device_tensor.get_shape(), "Error"); TT_FATAL(async_safe_tensor.get_dtype() == device_tensor.get_dtype(), "Error"); - TT_FATAL(async_safe_tensor.get_layout() == device_tensor.get_layout(), "Error"); - TT_FATAL(async_safe_tensor.get_tile() == device_tensor.get_tile(), "Error"); + TT_FATAL(async_safe_tensor.get_tensor_spec().page_config() == device_tensor.get_tensor_spec().page_config(), "Error"); std::visit( [worker_index, worker, cq_id, &async_safe_tensor](auto&& s) { void* host_data = nullptr; diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 788a3070805..1d627f4d449 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -17,6 +17,7 @@ #include "common/tt_backend_api_types.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/tensor/types.hpp" +#include "ttnn/tensor/tensor_spec.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" #include "tt_metal/impl/buffers/buffer.hpp" #include "tt_metal/impl/tile/tile.hpp" @@ -32,13 +33,11 @@ namespace tt_metal { namespace distributed { class MeshDevice; } + struct Tensor { struct TensorAttributes : public std::enable_shared_from_this { Storage storage; - ttnn::Shape shape; - DataType dtype; - Layout layout; - Tile tile; + TensorSpec tensor_spec; uint32_t num_shards_to_be_populated = 0; uint32_t main_thread_ref_count = 0; std::atomic num_sibling_workers_sharing_tensor = 0; @@ -48,10 +47,8 @@ struct Tensor { bool deallocated = false; // Set to true if device side storage was deallocated bool dynamic_storage = false; // Storage type can change, depending on op behaviour bool track_ref_count = false; - TensorAttributes(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, Tile tile = std::array{32, 32}) : - storage(storage), shape(shape), dtype(dtype), layout(layout), tile(tile) {} - TensorAttributes() : - shape(std::array{0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID), tile(std::array{32, 32}) {} + TensorAttributes(Storage storage, TensorSpec tensor_spec); + TensorAttributes(); ~TensorAttributes() = default; // Use these functions to manage the main_thread_ref_count for a tensor attr instance. @@ -86,21 +83,16 @@ struct Tensor { // ====================================================================================== // Hi Level APIs // ====================================================================================== - explicit Tensor() : - tensor_id(std::nullopt), - tensor_attributes(nullptr), - workers(std::vector{}), - deallocate_through_destructor(false) {} + explicit Tensor() = default; - Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); - Tensor(const Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor(Storage storage, const ttnn::Shape& shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor(Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor(Storage storage, TensorSpec tensor_spec); - // Constructor to initialize unpopulated tensor with workers and storage specified. Use this when creating tensor + // Constructors to initialize unpopulated tensor with workers and storage specified. Use this when creating tensor // handles in async mode. - Tensor( - const std::vector& workers, - uint32_t num_buffers = 0, - std::optional distributed_tensor_config = std::nullopt); + explicit Tensor(uint32_t num_buffers, std::optional distributed_tensor_config = std::nullopt); + explicit Tensor(const std::vector& workers); Tensor(const Tensor &other); @@ -183,35 +175,38 @@ struct Tensor { // Getters // ====================================================================================== const Storage &get_storage() const; + DataType get_dtype() const; + Layout get_layout() const; + const ttnn::SimpleShape& get_logical_shape() const; + const ttnn::SimpleShape& get_padded_shape() const; + const TensorSpec& get_tensor_spec() const; + // [[deprecated("Use get_shape() instead.")]] - const tt::tt_metal::LegacyShape &get_legacy_shape() const; - const ttnn::Shape &get_shape() const; - const DataType &get_dtype() const; - const Layout &get_layout() const; - const Tile &get_tile() const; - - ttnn::SimpleShape get_logical_shape() const; - ttnn::SimpleShape get_padded_shape() const; + tt::tt_metal::LegacyShape get_legacy_shape() const; + ttnn::Shape get_shape() const; tt::tt_metal::Padding get_padding() const; // ====================================================================================== // Non-Blocking Getters. Query attributes directly, without waiting for worker completion // ====================================================================================== inline const Storage &storage() const { return this->tensor_attributes->storage; }; - inline const tt::tt_metal::LegacyShape &legacy_shape() const { return this->tensor_attributes->shape.value; }; - inline const ttnn::Shape &shape() const { return this->tensor_attributes->shape; }; - inline const DataType &dtype() const { return this->tensor_attributes->dtype; }; - inline const Layout &layout() const { return this->tensor_attributes->layout; }; - inline const Tile &tile() const { return this->tensor_attributes->tile; }; + inline tt::tt_metal::LegacyShape legacy_shape() const { return this->tensor_attributes->tensor_spec.shape().value; }; + inline ttnn::Shape shape() const { return this->tensor_attributes->tensor_spec.shape(); }; + inline const ttnn::SimpleShape& logical_shape() const { return this->tensor_attributes->tensor_spec.logical_shape(); }; + inline const ttnn::SimpleShape& padded_shape() const { return this->tensor_attributes->tensor_spec.padded_shape(); }; + inline DataType dtype() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_data_type(); }; + inline Layout layout() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_layout(); }; + inline const TensorSpec& tensor_spec() const { return this->tensor_attributes->tensor_spec; } // ====================================================================================== // Setters // ====================================================================================== inline void set_storage(const Storage &storage) { this->tensor_attributes->storage = storage; } - inline void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } - inline void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } - inline void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } - inline void set_tile(const Tile &tile) { this->tensor_attributes->tile = tile; } + // We intend to remove this API once we migrate all ops to compute_output_specs, and provide TensorSpec at creation + inline void set_tensor_spec(const TensorSpec& tensor_spec) { + this->tensor_attributes->tensor_spec = tensor_spec; + this->tensor_attributes->metadata_populated = true; + } // ====================================================================================== // Extra Helper Functions // ====================================================================================== @@ -273,30 +268,17 @@ struct Tensor { } } - const MemoryConfig memory_config() const { - return std::visit( - [](const auto &storage) -> MemoryConfig { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return storage.memory_config(); - } else if constexpr (std::is_same_v) { - return storage.memory_config(); - } else { - TT_THROW("MemoryConfig can only be obtained for a tensor with DeviceStorage"); - } - }, - this->get_storage()); - } - const std::optional shard_spec() const { return this->memory_config().shard_spec; } + const MemoryConfig& memory_config() const { return get_tensor_spec().tensor_layout().get_memory_config(); } + const std::optional& shard_spec() const { return this->memory_config().shard_spec; } const bool is_sharded() const; // Size in bytes of a single element held in tensor uint32_t element_size() const; - static constexpr auto attribute_names = std::forward_as_tuple("storage", "shape", "dtype", "layout", "tile"); + static constexpr auto attribute_names = std::forward_as_tuple("storage", "tensor_spec"); const auto attribute_values() const { - return std::forward_as_tuple(this->tensor_attributes->storage, this->tensor_attributes->shape, this->tensor_attributes->dtype, this->tensor_attributes->layout, this->tensor_attributes->tile); + return std::forward_as_tuple(this->tensor_attributes->storage, this->tensor_attributes->tensor_spec); } std::vector host_page_ordering(); @@ -316,12 +298,12 @@ struct Tensor { while (not this->tensor_attributes->metadata_populated) { } } + +private: + void init(Storage storage, TensorSpec tensor_spec); }; -Tensor create_device_tensor( - const ttnn::SimpleShape &logical_shape, - const TensorLayout& layout, - Device *device); +Tensor create_device_tensor(const TensorSpec& tensor_spec, Device *device); [[deprecated]] Tensor create_device_tensor( @@ -333,7 +315,7 @@ Tensor create_device_tensor( const std::optional& tile = std::nullopt); // TODO: Remove once ALL ops switch over to return ttnn::SimpleShape in compute_output_shapes -[[deprecated("Use create_device_tensor(const ttnn::SimpleShape&, const TensorLayout&, Device*) instead")]] +[[deprecated("Use create_device_tensor(const TensorSpec&, Device*) instead")]] Tensor create_device_tensor( const ttnn::Shape &shape, DataType dtype, @@ -389,6 +371,6 @@ bool validate_worker_modes(const std::vector &workers); namespace ttnn { using Tensor = tt::tt_metal::Tensor; -using TensorSpec = std::pair; +using TensorSpec = tt::tt_metal::TensorSpec; } // namespace ttnn diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 1c587510215..df198691c5b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -123,11 +123,11 @@ void validate_sharded_buffer_allocation( } } -DeviceBuffer allocate_buffer_on_device(Device* device, const ttnn::SimpleShape& shape, const TensorLayout& layout) { - auto buffer_size_bytes = layout.compute_packed_buffer_size_bytes(shape); - auto page_size_bytes = layout.compute_page_size_bytes(shape); - auto shard_spec_buffer = layout.compute_shard_spec_buffer(shape); - auto memory_config = layout.get_memory_config(); +DeviceBuffer allocate_buffer_on_device(Device* device, const TensorSpec& tensor_spec) { + auto buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); + auto page_size_bytes = tensor_spec.compute_page_size_bytes(); + auto shard_spec_buffer = tensor_spec.compute_shard_spec_buffer(); + auto memory_config = tensor_spec.tensor_layout().get_memory_config(); return Buffer::create(device, buffer_size_bytes, page_size_bytes, memory_config.buffer_type, memory_config.memory_layout, shard_spec_buffer); } @@ -169,7 +169,7 @@ void validate_on_device_dtype_and_layout(Device* device, const ttnn::SimpleShape Tensor pad_bfloat8_b( const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tensor_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { - const auto& tile = tensor.get_tile(); + auto tile = tensor.get_tensor_spec().tile(); // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and pad @@ -178,7 +178,7 @@ Tensor pad_bfloat8_b( unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT8_B @@ -191,11 +191,11 @@ Tensor pad_bfloat8_b( float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, tensor.get_layout(), - tensor.get_tile()); + tile); } Tensor unpad_bfloat8_b(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, const ttnn::SimpleShape& output_tensor_end) { - const auto& tile = tensor.get_tile(); + auto tile = tensor.get_tensor_spec().tile(); // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and unpad @@ -204,7 +204,7 @@ Tensor unpad_bfloat8_b(const Tensor& tensor, const ttnn::SimpleShape& output_ten unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT8_B @@ -217,12 +217,12 @@ Tensor unpad_bfloat8_b(const Tensor& tensor, const ttnn::SimpleShape& output_ten float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, tensor.get_layout(), - tensor.get_tile()); + tile); } Tensor pad_bfloat4_b( const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tensor_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { - const auto& tile = tensor.get_tile(); + auto tile = tensor.get_tensor_spec().tile(); // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and pad @@ -231,7 +231,7 @@ Tensor pad_bfloat4_b( unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT4_B @@ -244,11 +244,11 @@ Tensor pad_bfloat4_b( float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, tensor.get_layout(), - tensor.get_tile()); + tile); } Tensor unpad_bfloat4_b(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, const ttnn::SimpleShape& output_tensor_end) { - const auto& tile = tensor.get_tile(); + auto tile = tensor.get_tensor_spec().tile(); // TODO(arakhmati): do not convert to FLOAT32 // Convert to FLOAT32 tensor and unpad @@ -257,7 +257,7 @@ Tensor unpad_bfloat4_b(const Tensor& tensor, const ttnn::SimpleShape& output_ten unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT4_B @@ -270,7 +270,7 @@ Tensor unpad_bfloat4_b(const Tensor& tensor, const ttnn::SimpleShape& output_ten float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, tensor.get_layout(), - tensor.get_tile()); + tile); } // ====================================================================================== @@ -446,7 +446,7 @@ std::string to_string(const BufferType& buffer, const tt::tt_metal::LegacyShape& template std::string to_string(const Tensor& tensor, std::optional original_dtype) { - const auto& tile = tensor.get_tile(); + const auto tile = tensor.get_tensor_spec().tile(); const auto shape = tensor.get_legacy_shape(); const auto dtype = original_dtype.value_or(tensor.get_dtype()); const auto layout = tensor.get_layout(); @@ -479,7 +479,7 @@ std::string to_string(const Tensor& tensor, std::optional original_dty tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), - tensor.get_tile()); + tile); return to_string(float_tensor, tensor.get_dtype()); } @@ -494,7 +494,7 @@ std::string to_string(const Tensor& tensor, std::optional original_dty tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), - tensor.get_tile()); + tile); return to_string(float_tensor, tensor.get_dtype()); } const auto buffer = owned_buffer::get_as(storage.buffer); @@ -562,7 +562,7 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id read_data_from_device_buffer(device_buffer, data_vec); } auto output_buffer = owned_buffer::create(std::move(data_vec)); - return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(OwnedStorage{output_buffer}, tensor.get_tensor_spec()); } template @@ -571,15 +571,12 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { return to_host_helper(tensor, blocking, cq_id); } else if (tensor.storage_type() == StorageType::MULTI_DEVICE) { auto devices = get_devices(tensor); - Tensor host_tensor({}, devices.size()); + Tensor host_tensor(devices.size()); + host_tensor.set_tensor_spec(tensor.get_tensor_spec()); for (int device_index = 0; device_index < devices.size(); ++device_index) { const auto& device = devices[device_index]; auto shard = get_shard_for_device(tensor, device); shard = to_host_helper(shard, blocking, cq_id); - host_tensor.set_shape(tensor.get_shape()); - host_tensor.set_dtype(tensor.get_dtype()); - host_tensor.set_layout(tensor.get_layout()); - host_tensor.set_tile(tensor.get_tile()); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); } return host_tensor; @@ -622,7 +619,7 @@ Tensor to_host_sharded(const Tensor& tensor) { } ::detail::ReadFromBuffer(*device_buffer, data_vec, true); auto output_buffer = owned_buffer::create(std::move(data_vec)); - return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(OwnedStorage{output_buffer}, tensor.get_tensor_spec()); } template Tensor to_host_sharded(const Tensor& tensor); @@ -682,14 +679,13 @@ template typename BufferType> DeviceBuffer initialize_data_on_device( BufferType& data_to_write, Device* device, - const ttnn::SimpleShape& shape, - const TensorLayout& tensor_layout, + const TensorSpec& tensor_spec, std::optional> queue = std::nullopt) { ZoneScoped; TT_ASSERT(device != nullptr); - auto device_buffer = allocate_buffer_on_device(device, shape, tensor_layout); + auto device_buffer = allocate_buffer_on_device(device, tensor_spec); const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { @@ -705,15 +701,14 @@ template DeviceBuffer to_device_buffer( const Storage& storage, Device* device, - const ttnn::SimpleShape& shape, - const TensorLayout& tensor_layout, + const TensorSpec& tensor_spec, std::optional> queue) { return std::visit( - [&device, &shape, &tensor_layout, &queue](auto&& storage) -> DeviceBuffer { + [&device, &tensor_spec, &queue](auto&& storage) -> DeviceBuffer { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { auto data_to_write = host_buffer::get_as(storage.buffer); - return initialize_data_on_device(data_to_write, device, shape, tensor_layout, queue); + return initialize_data_on_device(data_to_write, device, tensor_spec, queue); } else if constexpr (std::is_same_v) { TT_THROW("Device storage doesn't support to_device_buffer"); } else if constexpr (std::is_same_v) { @@ -742,16 +737,9 @@ Tensor to_device(const Tensor& tensor, Device* target_device, const MemoryConfig TT_FATAL(target_device != nullptr, "Need target device in order to move tensor to device!"); TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device"); - auto shape = tensor.get_shape(); - auto logical_shape = tensor.get_logical_shape(); - auto data_type = tensor.get_dtype(); - auto layout = tensor.get_layout(); - auto tile = tensor.get_tile(); - TensorLayout tensor_layout = TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape); - - auto device_buffer = tensor_impl::to_device_buffer(tensor.get_storage(), target_device, logical_shape, tensor_layout, queue); - - return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); + TensorSpec tensor_spec(tensor.get_logical_shape(), tensor.get_tensor_spec().tensor_layout().with_memory_config(memory_config)); + auto device_buffer = tensor_impl::to_device_buffer(tensor.get_storage(), target_device, tensor_spec, queue); + return Tensor(DeviceStorage{device_buffer}, tensor_spec); } template Tensor to_device( @@ -815,7 +803,7 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { auto shape = tensor.get_legacy_shape(); auto source_layout = tensor.get_layout(); - auto tile = tensor.tile(); + auto tile = tensor.tensor_spec().tile(); auto convert = [tile, &shape, source_layout, target_layout](const auto& input_data) -> std::vector { switch (source_layout) { case Layout::ROW_MAJOR: @@ -870,13 +858,10 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { return std::visit( [&tensor, &target_layout](auto&& storage) -> Tensor { using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tile()); - } else if constexpr (std::is_same_v) { - return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tile()); - } else { + if constexpr (!std::is_same_v && !std::is_same_v) { raise_unsupported_storage(); } + return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tensor_spec().tile()); }, output_storage); } @@ -931,14 +916,14 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { if (tensor.get_layout() == target_layout) { return tensor; } + auto tile = tensor.get_tensor_spec().tile(); return std::visit( - [&tensor, &target_layout](auto&& storage) -> Tensor { + [&tensor, &target_layout, &tile](auto&& storage) -> Tensor { using StorageType = std::decay_t; if constexpr (std::is_same_v) { std::vector output_buffers; for (int i = 0; i < storage.num_buffers(); i++) { // Convert to FLOAT32 tensor and change layout - const auto& tile = tensor.get_tile(); auto input_packed_data = owned_buffer::get_as(storage.get_buffer(i)).get(); auto input_float_data = unpack_bfloat_tiles_into_float_vec( T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); @@ -948,7 +933,7 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), - tensor.get_tile()) + tile) .to(target_layout); // Convert back to BFLOAT8_B @@ -963,11 +948,10 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { tensor.get_legacy_shape(), bfloat_enum::value, target_layout, - tensor.get_tile()); + tile); } else { // Convert to FLOAT32 tensor and change layout - const auto& tile = tensor.get_tile(); auto input_packed_data = owned_buffer::get_as(tensor).get(); auto input_float_data = unpack_bfloat_tiles_into_float_vec( T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); @@ -977,7 +961,7 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), - tensor.get_tile()) + tile) .to(target_layout); // Convert back to BFLOAT @@ -990,7 +974,7 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { tensor.get_legacy_shape(), bfloat_enum::value, target_layout, - tensor.get_tile()); + tile); } }, tensor.get_storage()); @@ -1097,7 +1081,7 @@ Tensor pad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_shape, } }, tensor.get_storage()); - return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tensor_spec().tile()); } template Tensor pad( @@ -1185,7 +1169,7 @@ Tensor unpad(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, } }, tensor.get_storage()); - return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tensor_spec().tile()); } template Tensor unpad(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, const ttnn::SimpleShape& output_tensor_end); @@ -1219,7 +1203,7 @@ Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { ::detail::ReadShard(*buffer, device_data, core_id); auto output_buffer = owned_buffer::create(std::move(device_data)); - return Tensor(OwnedStorage{output_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(OwnedStorage{output_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tensor_spec().tile()); } template Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index d7010255c35..ea4a2dfd4aa 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -141,7 +141,7 @@ void validate_sharded_buffer_allocation( // Data reader, writer, and initializers // ====================================================================================== -DeviceBuffer allocate_buffer_on_device(Device* device, const ttnn::SimpleShape& shape, const TensorLayout& layout); +DeviceBuffer allocate_buffer_on_device(Device* device, const TensorSpec& tensor_spec); template inline void read_data_from_device_buffer( diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 44ffb91cb52..bd06cf3370c 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -81,11 +81,8 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; if (not num_workers_completed) { - device_tensor.set_shape(input_tensor.get_shape()); - device_tensor.set_dtype(input_tensor.get_dtype()); - device_tensor.set_layout(input_tensor.get_layout()); - device_tensor.set_tile(input_tensor.get_tile()); - device_tensor.tensor_attributes->metadata_populated = true; + device_tensor.set_tensor_spec(TensorSpec(input_tensor.get_logical_shape(), + input_tensor.get_tensor_spec().tensor_layout().with_memory_config(mem_config))); } }); } @@ -110,11 +107,11 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { } TT_FATAL( validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); - Tensor host_tensor({}, workers.size()); + Tensor host_tensor(workers.size()); uint32_t original_tensor_ref_count = input_tensor.tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < workers.size(); worker_index++) { auto target_device = workers[worker_index]; - target_device->push_work([host_tensor, blocking, target_device, input_tensor, workers, worker_index, cq_id]() mutable { + target_device->push_work([host_tensor, blocking, target_device, input_tensor, worker_index, cq_id]() mutable { TT_ASSERT( input_tensor.storage_type() == StorageType::DEVICE or input_tensor.storage_type() == StorageType::MULTI_DEVICE, "Can only use worker queue for cpu call if tensor is on device."); @@ -123,11 +120,7 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; if (not num_workers_completed) { - host_tensor.set_shape(input_tensor.get_shape()); - host_tensor.set_dtype(input_tensor.get_dtype()); - host_tensor.set_layout(input_tensor.get_layout()); - host_tensor.set_tile(input_tensor.get_tile()); - host_tensor.tensor_attributes->metadata_populated = true; + host_tensor.set_tensor_spec(input_tensor.get_tensor_spec()); } }); } @@ -158,7 +151,7 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, Device* worke if (worker and worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. Tensor async_safe_tensor = copy_borrowed_tensor_in_async_mode(worker, input_tensor); - Tensor tensor_modified_layout = Tensor({}, 1); + Tensor tensor_modified_layout = Tensor(1); worker->push_work([async_safe_tensor, tensor_modified_layout, target_layout]() mutable { TT_ASSERT( async_safe_tensor.storage_type() == StorageType::OWNED or @@ -196,7 +189,7 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed:: auto& host_storage = std::get(input_tensor.get_storage()); distributed_config = host_storage.strategy; } - Tensor tensor_modified_layout = Tensor({}, workers.size(), distributed_config); + Tensor tensor_modified_layout = Tensor(workers.size(), distributed_config); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; worker->push_work([input_tensor, tensor_modified_layout, target_layout, worker, worker_index]() mutable { @@ -212,12 +205,10 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed:: insert_buffer_and_shape_for_device(worker, shard, tensor_modified_layout, worker_index); uint32_t num_workers_completed = (tensor_modified_layout.tensor_attributes->num_workers_completed)++; if (not num_workers_completed) { - tensor_modified_layout.set_shape(input_tensor.get_shape()); - tensor_modified_layout.set_dtype(input_tensor.get_dtype()); - tensor_modified_layout.set_layout(target_layout); - tensor_modified_layout.set_tile(input_tensor.get_tile()); - tensor_modified_layout.tensor_attributes->metadata_populated = true; - }; + auto orig_layout = input_tensor.get_tensor_spec().tensor_layout(); + auto upd_layout = TensorLayout(orig_layout.get_data_type(), PageConfig(target_layout), orig_layout.get_memory_config()); + tensor_modified_layout.set_tensor_spec(TensorSpec(input_tensor.get_logical_shape(), upd_layout)); + } }); } tensor_modified_layout = tt::tt_metal::set_tensor_id(tensor_modified_layout); @@ -337,7 +328,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) ZoneScoped; GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape); const auto& new_padded_shape = new_shape.padded_shape(); - const auto& tile = input_tensor.get_tile(); + const auto tile = input_tensor.get_tensor_spec().tile(); TT_ASSERT( input_tensor.volume() == new_padded_shape.volume(), "{} != {}", @@ -349,7 +340,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) "Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!"); } auto output = std::visit( - [&input_tensor, &new_shape](auto&& storage) -> Tensor { + [&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor { using T = std::decay_t; const auto& tensor = input_tensor; if constexpr (std::is_same_v) { @@ -357,7 +348,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) for (int i = 0; i < updated_storage.shapes.size(); i++) { updated_storage.shapes[i] = new_shape; } - return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } if constexpr (std::is_same_v) { MultiDeviceStorage updated_storage = std::get(tensor.get_storage()); @@ -367,7 +358,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) new_shapes.insert({device_id, new_shape}); } updated_storage.shapes = new_shapes; - return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } if constexpr (std::is_same_v) { if (input_tensor.get_layout() == Layout::ROW_MAJOR) { @@ -376,7 +367,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) DeviceBuffer device_buffer = device_storage.get_buffer(); device_buffer->set_page_size(new_shape[-1] * tensor.element_size()); device_storage.insert_buffer(device_buffer); - return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } else { DeviceStorage device_storage = std::get(tensor.get_storage()); DeviceBuffer device_buffer = device_storage.get_buffer(); @@ -398,13 +389,13 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) device_buffer->set_shard_spec(shard_spec_buffer); device_storage.insert_buffer(device_buffer); - return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } } else { - return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } } else { - return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); } }, input_tensor.get_storage()); diff --git a/ttnn/cpp/ttnn/tensor/tensor_spec.hpp b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp new file mode 100644 index 00000000000..bc4729ae04a --- /dev/null +++ b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/tensor/types.hpp" +#include "ttnn/tensor/layout/tensor_layout.hpp" + +namespace tt::tt_metal { + +class TensorSpec final { +public: + TensorSpec(ttnn::SimpleShape logical_shape, TensorLayout tensor_layout): + logical_shape_(std::move(logical_shape)), + tensor_layout_(std::move(tensor_layout)), + cached_padded_shape_(tensor_layout_.compute_padded_shape(logical_shape_)), + cached_physical_shape_(tensor_layout_.compute_physical_shape(logical_shape_)) {} + TensorSpec(TensorSpec&&) noexcept = default; + TensorSpec& operator=(TensorSpec&&) = default; + TensorSpec(const TensorSpec&) = default; + TensorSpec& operator=(const TensorSpec&) = default; + bool operator==(const TensorSpec&) const = default; + bool operator!=(const TensorSpec&) const = default; + + const ttnn::SimpleShape& logical_shape() const { + return logical_shape_; + } + const TensorLayout& tensor_layout() const { + return tensor_layout_; + } + DataType data_type() const { + return tensor_layout_.get_data_type(); + } + Layout layout() const { + return tensor_layout_.get_layout(); + } + PageConfig page_config() const { + return tensor_layout_.get_page_config(); + } + const ttnn::SimpleShape& padded_shape() const { + return cached_padded_shape_; + } + const Size& physical_shape() const { + return cached_physical_shape_; + } + ttnn::Shape shape() const { + return ttnn::Shape(logical_shape_.view(), cached_padded_shape_.view()); + } + + Tile tile() const { + return tensor_layout_.get_page_config().get_tile(); + } + + Strides compute_strides() const { + return tensor_layout_.compute_strides(logical_shape_); + } + std::optional compute_shard_spec_buffer() const { + return tensor_layout_.compute_shard_spec_buffer(logical_shape_); + } + size_t compute_packed_buffer_size_bytes() const { + return tensor_layout_.compute_packed_buffer_size_bytes(logical_shape_); + } + size_t compute_page_size_bytes() const { + return tensor_layout_.compute_page_size_bytes(logical_shape_); + } + + static constexpr auto attribute_names = std::forward_as_tuple("logical_shape", "tensor_layout"); + const auto attribute_values() const { + return std::forward_as_tuple(logical_shape_, tensor_layout_); + } + +private: + ttnn::SimpleShape logical_shape_; + TensorLayout tensor_layout_; + + ttnn::SimpleShape cached_padded_shape_; + Size cached_physical_shape_; +}; + +} diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index b22263e013d..2afd19567e3 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -562,7 +562,7 @@ Tensor get_shard_for_device(const Tensor& tensor, Device* target_device, std::op // Stalling reads for tensor data-type and layout are needed here // since some worker might have raced ahead to these lookups, while // another worker is populating this metadata. - const Tile tile = tensor.get_tile(); + const Tile tile = tensor.get_tensor_spec().tile(); if constexpr (std::is_same_v) { shard = Tensor{ DeviceStorage{s.get_buffer_for_device(target_device)}, @@ -599,12 +599,12 @@ void insert_buffer_and_shape_for_device( s.insert_buffer_and_shape_for_device( buffer_index.value(), std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape); + shard.tensor_attributes->tensor_spec.shape()); } else if constexpr (std::is_same_v) { s.insert_buffer_and_shape_for_device( target_device, std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape); + shard.tensor_attributes->tensor_spec.shape()); } else if constexpr (std::is_same_v) { s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); } else if constexpr (std::is_same_v) { @@ -634,12 +634,7 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) [&owned_tensor, &tensor](auto&& buffer) { using BorrowedStorageType = std::vector>; auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); - owned_tensor = Tensor( - OwnedStorage{owned_buf}, - tensor.get_shape(), - tensor.get_dtype(), - tensor.get_layout(), - tensor.get_tile()); + owned_tensor = Tensor(OwnedStorage{owned_buf}, tensor.get_tensor_spec()); }, borrowed_buffer); return owned_tensor;