From 4e1fef96b1fc5543455911fef6fcf7664f04d9b3 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Thu, 17 Oct 2024 16:45:24 -0700 Subject: [PATCH] #13707: Port reshape op to SimpleShape, related refactoring (#13838) * #13707: Port reshape op to SimpleShape, related refactoring * #13707: Rework * #13707: Fix failing tests * #13707: Review fixes: simplification & better error messages * #13707: Trying to fix stable diffusion * #13707: Rebase fix --- ..._functional_resnet50_large_new_conv_api.py | 30 +- .../ttnn_functional_resnet50_new_conv_api.py | 18 +- ...functional_resnet50_xlarge_new_conv_api.py | 26 +- ...ctional_resnet50_xlarge_new_conv_api_24.py | 30 +- ...unctional_resnet50_xxlarge_new_conv_api.py | 30 +- ...tional_unet_2d_condition_model_new_conv.py | 2 +- ...test_tilize_zero_padding_channels_last.cpp | 4 +- .../tensors/test_async_tensor_apis.cpp | 4 +- .../unit_tests/gtests/test_multi_device.cpp | 4 +- ttnn/cpp/pybind11/pytensor.cpp | 18 +- ttnn/cpp/pybind11/types.hpp | 3 +- ttnn/cpp/ttnn/distributed/api.cpp | 16 +- .../ttnn/operations/conv/conv2d/conv2d.cpp | 2 +- ttnn/cpp/ttnn/operations/core/core.cpp | 11 +- ttnn/cpp/ttnn/operations/core/core.hpp | 23 -- .../core/to_layout/to_layout_op.cpp | 5 +- .../operations/data_movement/fold/fold.cpp | 8 +- .../data_movement/permute/permute.cpp | 4 +- .../reshape_on_device/device/reshape_op.cpp | 20 +- .../reshape_on_device/device/reshape_op.hpp | 4 +- .../device/reshape_program_factory.cpp | 162 +-------- .../device/reshape_program_factory.hpp | 5 +- .../reshape_on_device/reshape.cpp | 51 +-- .../reshape_on_device/reshape.hpp | 16 +- .../reshape_on_device/reshape_pybind.cpp | 2 +- .../data_movement/reshape_view/reshape.cpp | 65 +--- .../data_movement/reshape_view/reshape.hpp | 11 +- .../operations/data_movement/split/split.cpp | 4 +- .../data_movement/squeeze/squeeze.cpp | 7 +- .../data_movement/unsqueeze/unsqueeze.cpp | 11 +- .../binary/device/binary_composite_op.cpp | 4 +- .../unary/device/unary_composite_op.cpp | 4 +- .../ttnn/operations/embedding/embedding.hpp | 4 +- .../embedding_backward/embedding_backward.cpp | 2 +- .../device/moreh_getitem_rm_factory.cpp | 2 +- ...ple_bilinear_program_factory_multicore.cpp | 2 +- .../reduction/generic/generic_reductions.cpp | 2 +- .../sdpa_decode_gqa/sdpa_decode_gqa.cpp | 2 +- .../split_query_key_value_and_split_heads.cpp | 16 +- ttnn/cpp/ttnn/tensor/serialization.cpp | 27 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 8 +- ttnn/cpp/ttnn/tensor/tensor.hpp | 4 +- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 22 +- ttnn/cpp/ttnn/tensor/tensor_ops.hpp | 5 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 59 ++- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 4 +- ttnn/cpp/ttnn/tensor/types.cpp | 25 ++ ttnn/cpp/ttnn/tensor/types.hpp | 343 ++++++++---------- 49 files changed, 444 insertions(+), 689 deletions(-) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index dc3acd8a114..c1aef6d5f68 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -75,10 +75,8 @@ def ResnetLinear( """ matmul_config = hardcoded_matmul_config_linear[batch_size] - weight_shape = weight.shape.with_tile_padding() - weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1]) - bias_shape = bias.shape.with_tile_padding() - bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1]) + weight = weight.reshape(weight.shape.to_rank(4)) + bias = bias.reshape(bias.shape.to_rank(4)) def linear_(act): output = ttnn.linear( @@ -717,9 +715,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -806,9 +804,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) # for _, tensor in conv_op_cache["reader_patterns_cache"]["conv"].items(): @@ -948,9 +946,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1037,9 +1035,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 3cf1f60fd57..8ba3e0b112c 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -67,10 +67,8 @@ def ResnetLinear( """ matmul_config = hardcoded_matmul_config_linear[batch_size] - weight_shape = weight.shape.with_tile_padding() - weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1]) - bias_shape = bias.shape.with_tile_padding() - bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1]) + weight = weight.reshape(weight.shape.to_rank(4)) + bias = bias.reshape(bias.shape.to_rank(4)) def linear_(act): output = ttnn.linear( @@ -1136,9 +1134,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - x.shape.with_tile_padding()[2] // self.batch_size, - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1207,9 +1205,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py index 8cf40b334d2..dbb1bec2df9 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py @@ -75,10 +75,8 @@ def ResnetLinear( """ matmul_config = hardcoded_matmul_config_linear[batch_size] - weight_shape = weight.shape.with_tile_padding() - weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1]) - bias_shape = bias.shape.with_tile_padding() - bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1]) + weight = weight.reshape(weight.shape.to_rank(4)) + bias = bias.reshape(bias.shape.to_rank(4)) def linear_(act): output = ttnn.linear( @@ -713,9 +711,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] / self.batch_size, + x.shape[3], ), ) @@ -802,9 +800,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -930,9 +928,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1020,7 +1018,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c ( self.batch_size, x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), + x.shape.with_tile_padding()[2] // self.batch_size, x.shape.with_tile_padding()[3], ), ) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index f6afdd893ca..18a190e3aa1 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -76,10 +76,8 @@ def ResnetLinear( """ matmul_config = hardcoded_matmul_config_linear[batch_size] - weight_shape = weight.shape.with_tile_padding() - weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1]) - bias_shape = bias.shape.with_tile_padding() - bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1]) + weight = weight.reshape(weight.shape.to_rank(4)) + bias = bias.reshape(bias.shape.to_rank(4)) def linear_(act): output = ttnn.linear( @@ -743,9 +741,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -832,9 +830,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -956,9 +954,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1045,9 +1043,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py index d4a6e9e6e9a..7f741dc91e8 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py @@ -76,10 +76,8 @@ def ResnetLinear( """ matmul_config = hardcoded_matmul_config_linear[batch_size] - weight_shape = weight.shape.with_tile_padding() - weight = weight.reshape(1, 1, weight_shape[-2], weight_shape[-1]) - bias_shape = bias.shape.with_tile_padding() - bias = bias.reshape(1, 1, bias_shape[-2], bias_shape[-1]) + weight = weight.reshape(weight.shape.to_rank(4)) + bias = bias.reshape(bias.shape.to_rank(4)) def linear_(act): output = ttnn.linear( @@ -801,9 +799,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -897,9 +895,9 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1030,9 +1028,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) @@ -1126,9 +1124,9 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c x, ( self.batch_size, - x.shape.with_tile_padding()[1], - (int)(x.shape.with_tile_padding()[2] / self.batch_size), - x.shape.with_tile_padding()[3], + x.shape[1], + x.shape[2] // self.batch_size, + x.shape[3], ), ) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py index 553818f57c0..9cbdfff2f48 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py @@ -681,7 +681,7 @@ def __call__( self.batch_size, self.conv_out_input_height, self.conv_out_input_width, - 32, # Padded to tile dim + -1, ), ) sample = ttnn.permute(sample, (0, 3, 1, 2)) # permute from NHWC to NCHW diff --git a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp index 06daced3e04..9f2158720fd 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp @@ -39,9 +39,9 @@ int main(int argc, char **argv) { //////////////////////////////////////////////////////////////////////////// // Application Setup //////////////////////////////////////////////////////////////////////////// - tt::tt_metal::LegacyShape shape = {1, 32, 61, 32}; + ttnn::SimpleShape shape{1, 32, 61, 32}; // Allocates a DRAM buffer on device populated with values specified by initialize - Tensor a = ttnn::numpy::arange(0, tt_metal::compute_volume(shape), 1).reshape(shape).to(device); + Tensor a = ttnn::numpy::arange(0, shape.volume(), 1).reshape(shape).to(device); Tensor b = ttnn::tilize_with_zero_padding(a); Tensor c = b.cpu(); //////////////////////////////////////////////////////////////////////////// diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 372ddea5f43..4803efbf2f1 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -51,7 +51,7 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { }, host_tensor.get_storage()); // Send tensor to device, read it back and copy it to empty tensor initialized by main thread - Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, 128); + Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, 128}); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); readback_tensor.set_storage(thread_local_tensor.get_storage()); @@ -262,7 +262,7 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { }, host_tensor.get_storage()); - Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, tensor_stop / 32); + Tensor reshaped_tensor = host_tensor.reshape(ttnn::SimpleShape{1, 1, 32, tensor_stop / 32}); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); log_info(LogTest, "Worker populating empty host readback_tensor"); diff --git a/tests/ttnn/unit_tests/gtests/test_multi_device.cpp b/tests/ttnn/unit_tests/gtests/test_multi_device.cpp index 4ecd79b3311..6d6863305f6 100644 --- a/tests/ttnn/unit_tests/gtests/test_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multi_device.cpp @@ -13,11 +13,11 @@ using namespace tt::tt_metal; Tensor create_host_multi_device_tensor(const Tensor& tensor, const ReplicateTensor& strategy) { std::vector owned_buffers; - std::vector shapes; + std::vector shapes; for (int i = 0; i < strategy.replication_factor; i++) { owned_buffers.push_back(std::get(tensor.get_storage()).buffer); - shapes.push_back(tensor.get_legacy_shape()); + shapes.push_back(tensor.get_shape()); } return Tensor{ diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index e064b78a762..9a8f28d42bf 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -419,11 +419,11 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false)); } std::vector host_owned_buffers; - std::vector host_owned_shapes; + std::vector host_owned_shapes; for (const auto &shard : tt_shards) { TT_ASSERT(std::holds_alternative(shard.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(shard.get_storage())); host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); - host_owned_shapes.push_back(shard.get_legacy_shape()); + host_owned_shapes.push_back(shard.shape()); } auto distributed_tensor_config = get_distributed_tensor_config(strategy); auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes}; @@ -1593,7 +1593,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, int N, int C, int H, int W) { return self.reshape(N, C, H, W); }, + [](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, {N, C, H, W})); }, R"doc( Reshapes TT tensor @@ -1603,7 +1603,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, const tt::tt_metal::LegacyShape &shape) -> Tensor { return self.reshape(shape); }, + [](Tensor &self, const ttnn::Shape &shape) -> Tensor { return self.reshape(shape); }, R"doc( Reshapes TT tensor @@ -1611,6 +1611,16 @@ void pytensor_module(py::module &m_tensor) { reshaped_tensor = tt_tensor.reshape((4, 3, 32)) )doc") + .def( + "reshape", + [](Tensor &self, const std::vector &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); }, + R"doc( + Reshapes TT tensor + + .. code-block:: python + + reshaped_tensor = tt_tensor.reshape((4, -1, 32)) + )doc") .def_property( "tensor_id", [](const Tensor &self) { return self.tensor_id; }, diff --git a/ttnn/cpp/pybind11/types.hpp b/ttnn/cpp/pybind11/types.hpp index ba41f2fa636..1ae033ac867 100644 --- a/ttnn/cpp/pybind11/types.hpp +++ b/ttnn/cpp/pybind11/types.hpp @@ -56,7 +56,8 @@ void py_module(py::module& module) { return ss.str(); }) .def_property_readonly("rank", [](const Shape& self) -> std::size_t { return self.rank(); }) - .def("with_tile_padding", [](const Shape& self) { return self.with_tile_padding(); }); + .def("with_tile_padding", [](const Shape& self) { return self.with_tile_padding(); }) + .def("to_rank", [](const Shape& self, std::size_t rank) { return self.to_rank(rank); }); [&PyShape](std::index_sequence) { ( diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 2be19453d2d..e48c3ceca14 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -58,24 +58,24 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) // we want to use MultiDeviceHostStorage or MultiDeviceStorage StorageType storage_type = tensor_shards.at(0).storage_type(); if (storage_type == StorageType::OWNED) { - std::vector shapes; + std::vector shapes; std::vector host_owned_buffers; for (const auto &shard : tensor_shards) { host_owned_buffers.push_back(std::get(shard.get_storage()).buffer); - shapes.push_back(shard.get_legacy_shape()); + shapes.push_back(shard.get_shape()); } auto storage = MultiDeviceHostStorage{AllGatherTensor(), std::move(host_owned_buffers), shapes}; return Tensor(std::move(storage), tensor_shards.at(0).get_legacy_shape(), tensor_shards.at(0).get_dtype(), tensor_shards.at(0).get_layout()); } else { std::vector ordered_device_ids; - std::unordered_map shapes; + std::unordered_map shapes; std::unordered_map device_buffers; for (const auto &shard : tensor_shards) { Device* device = std::get(shard.get_storage()).buffer->device(); auto device_id = device->id(); ordered_device_ids.push_back(device_id); device_buffers.insert({device->id(), std::get(shard.get_storage()).buffer}); - shapes.insert({device->id(), shard.get_legacy_shape()}); + shapes.insert({device->id(), shard.get_shape()}); } auto storage = MultiDeviceStorage{AllGatherTensor(), ordered_device_ids, std::move(device_buffers), shapes}; return Tensor(std::move(storage), tensor_shards.at(0).get_legacy_shape(), tensor_shards.at(0).get_dtype(), tensor_shards.at(0).get_layout()); @@ -199,7 +199,7 @@ Tensor create_multi_device_tensor( if (storage_type == StorageType::MULTI_DEVICE) { std::vector ordered_device_ids; - std::unordered_map shapes; + std::unordered_map shapes; std::unordered_map device_buffers; for (const auto& tensor : tensors) { TT_ASSERT(std::holds_alternative(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); @@ -207,7 +207,7 @@ Tensor create_multi_device_tensor( auto device_id = device->id(); ordered_device_ids.push_back(device_id); device_buffers.insert({device_id, std::get(tensor.get_storage()).buffer}); - shapes.insert({device_id, tensor.get_legacy_shape()}); + shapes.insert({device_id, tensor.get_shape()}); } return Tensor{ MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, shapes}, @@ -216,11 +216,11 @@ Tensor create_multi_device_tensor( tensors.at(0).get_layout()}; } else if (storage_type == StorageType::MULTI_DEVICE_HOST) { std::vector owned_buffers; - std::vector shapes; + std::vector shapes; for (const auto& tensor : tensors) { TT_ASSERT(std::holds_alternative(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); owned_buffers.push_back(std::get(tensor.get_storage()).buffer); - shapes.push_back(tensor.get_legacy_shape()); + shapes.push_back(tensor.get_shape()); } return Tensor{ MultiDeviceHostStorage{strategy, owned_buffers, shapes}, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 736ee518c24..84c9edcd219 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -508,7 +508,7 @@ std::tuple shard_or_reshard_tensor_if_requir // reshape to [1, 1, N*H*W, C] input_tensor = ttnn::reshape( input_tensor, - ttnn::Shape(std::array{ + ttnn::SimpleShape(std::array{ 1, 1, input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2], diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index fc616733f24..b0523f77339 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -28,7 +28,7 @@ ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor) { TT_THROW("Tensor rank is greater than 4"); } - const auto tensor_shape_4D = tensor_shape.to_rank<4>(); + const auto tensor_shape_4D = tensor_shape.to_rank(4); return ttnn::reshape(tensor, tensor_shape_4D); } @@ -47,13 +47,10 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) { } } - switch (rank) { - case 1: return ttnn::reshape(tensor, shape.to_rank<1>()); - case 2: return ttnn::reshape(tensor, shape.to_rank<2>()); - case 3: return ttnn::reshape(tensor, shape.to_rank<3>()); - case 4: return tensor; - default: TT_THROW("Invalid choice!"); + if (rank == 4) { + return tensor; } + return ttnn::reshape(tensor, shape.to_rank(rank)); } ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config) { diff --git a/ttnn/cpp/ttnn/operations/core/core.hpp b/ttnn/cpp/ttnn/operations/core/core.hpp index fad0fb0c9d7..28aece80116 100644 --- a/ttnn/cpp/ttnn/operations/core/core.hpp +++ b/ttnn/cpp/ttnn/operations/core/core.hpp @@ -20,29 +20,6 @@ namespace ttnn { namespace operations { namespace core { -template -ttnn::Tensor reshape(const ttnn::Tensor& tensor, const std::array& shape) { - std::int64_t new_volume = 1; - std::int64_t index_of_negative_1 = -1; - for (auto index = 0; index < Rank; ++index) { - if (shape[index] == -1) { - if (index_of_negative_1 != -1) { - TT_THROW("Shape cannot have more than 1 elements that is set to -1!"); - } - index_of_negative_1 = index; - } - new_volume *= shape[index]; - } - - std::array new_shape{}; - std::copy(shape.begin(), shape.end(), new_shape.begin()); - if (new_volume < 0) { - const auto volume = tensor.volume(); - new_shape[index_of_negative_1] = volume / (-new_volume); - } - return ttnn::reshape(tensor, ttnn::Shape(new_shape)); -} - ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor); ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank); diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index 402ed54e73f..6b23191b0c3 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -151,7 +151,7 @@ Tensor to_layout_impl( tensor = ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize); - return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape})); + return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); } else if (layout == ttnn::TILE_LAYOUT) { std::vector padded_output_shape; @@ -190,7 +190,7 @@ Tensor to_layout_impl( } else if (layout == ttnn::ROW_MAJOR_LAYOUT) { tensor = device ? tensor.to(layout, device) : tensor.to(layout); tensor = tensor.unpad_from_tile(tensor.get_logical_shape()); - return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape})); + return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); } else if (layout == ttnn::TILE_LAYOUT) { std::vector padded_output_shape; std::vector padded_input_start; @@ -205,7 +205,6 @@ Tensor to_layout_impl( tensor = tensor.pad(padded_output_shape, ttnn::SimpleShape(std::move(padded_input_start)), 0); tensor = device ? tensor.to(layout, device) : tensor.to(layout); return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape})); - } else { TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp index 73d34468f44..22bc09a88b3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp @@ -67,7 +67,7 @@ std::vector fold_with_transpose_( // reshape n = transpose_hc_output.shape()[0], w = transpose_hc_output.shape()[1], c = transpose_hc_output.shape()[2], h = transpose_hc_output.shape()[3]; - auto reshape_hc_output = ttnn::reshape_on_device(transpose_hc_output, n, (w / stride_w), (c * stride_w), h, L1_mem_config); + auto reshape_hc_output = ttnn::reshape_on_device(transpose_hc_output, ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h}, L1_mem_config); tt::log_debug("reshape_hc_output: {}", reshape_hc_output.shape()); @@ -78,7 +78,7 @@ std::vector fold_with_transpose_( // reshape n = transpose_hw_output2.shape()[0], w = transpose_hw_output2.shape()[1], h = transpose_hw_output2.shape()[2], c = transpose_hw_output2.shape()[3]; - auto reshape_hw_output = ttnn::reshape_on_device(transpose_hw_output2, n, w, (h / stride_h), (c * stride_h), L1_mem_config); + auto reshape_hw_output = ttnn::reshape_on_device(transpose_hw_output2, ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)}, L1_mem_config); tt::log_debug("reshape_hw_output: {}", reshape_hw_output.shape()); @@ -212,7 +212,7 @@ std::vector fold_with_transpose_sharded_( // reshape n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], c = tt_output_tensor.shape()[2], h = tt_output_tensor.shape()[3]; - tt_output_tensor = tt_output_tensor.reshape(n, (w / stride_w), (c * stride_w), h); + tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h}); tt::log_debug("reshape_hc_output: {}", tt_output_tensor.shape()); @@ -228,7 +228,7 @@ std::vector fold_with_transpose_sharded_( // reshape n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], h = tt_output_tensor.shape()[2], c = tt_output_tensor.shape()[3]; - tt_output_tensor = tt_output_tensor.reshape(n, w, (h / stride_h), (c * stride_h)); + tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)}); tt::log_debug("reshape_hw_output: {}", tt_output_tensor.shape()); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 07c71c89d7c..29a07c7e71d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -195,8 +195,8 @@ ttnn::Tensor ExecutePermute::invoke( output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr); if (input_rank < 4) { - const auto shape = output_tensor.get_logical_shape(); - const auto full_shape = output_tensor.get_padded_shape(); + const auto shape = output_tensor.get_shape(); + const auto full_shape = output_tensor.get_shape().with_tile_padding(); std::vector shape_vec{}; std::vector full_shape_vec{}; int i = 0; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp index 93c71eaf236..3aff0667de1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.cpp @@ -21,44 +21,40 @@ void ReshapeDeviceOperation::validate(const std::vector &input_tensors) TT_FATAL(input_tensor_a.get_layout() == Layout::TILE || input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Only tile and row major reshape supported!"); - auto output_shape = infer_dims_for_reshape(this->N, this->C, this->H, this->W, input_tensor_a.volume()); - TT_FATAL(input_tensor_a.volume() == output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3], "New shape volume must match old shape volume"); - TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Reshape does not currently support sharding"); TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Reshape does not currently support sharding"); if (input_tensor_a.get_layout() == Layout::TILE) { TT_FATAL(input_tensor_a.volume() % TILE_HW == 0, "Error"); - TT_FATAL(output_shape[2] % TILE_HEIGHT == 0 && output_shape[3] % TILE_WIDTH == 0, "Expected a multiple of 32 for H, W (or -1 evaluating to such) for reshape!"); } else if (input_tensor_a.get_layout() == Layout::ROW_MAJOR) { uint32_t ROW_MAJOR_WIDTH = 8; - TT_FATAL(input_tensor_a.get_legacy_shape()[3] % ROW_MAJOR_WIDTH == 0 && output_shape[3] % ROW_MAJOR_WIDTH == 0, "Operand/target width must be a multiple of 8"); + auto padded_output_shape = output_shape.padded_shape(); + TT_FATAL(input_tensor_a.get_legacy_shape()[3] % ROW_MAJOR_WIDTH == 0 && padded_output_shape[3] % ROW_MAJOR_WIDTH == 0, "Operand/target width must be a multiple of 8"); uint32_t num_old_sticks = input_tensor_a.get_legacy_shape()[0] * input_tensor_a.get_legacy_shape()[1] * input_tensor_a.get_legacy_shape()[2]; - uint32_t num_new_sticks = output_shape[0] * output_shape[1] * output_shape[2]; + uint32_t num_new_sticks = padded_output_shape[0] * padded_output_shape[1] * padded_output_shape[2]; } else { TT_THROW("Unsupported layout for reshape"); } } -std::vector ReshapeDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - return {tt::tt_metal::infer_dims_for_reshape(this->N, this->C, this->H, this->W, input_tensor_a.volume())}; +std::vector ReshapeDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { + return {output_shape.logical_shape()}; } std::vector ReshapeDeviceOperation::create_output_tensors(const std::vector &input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); - return operation::generic_create_output_tensors(*this, input_tensors, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), this->output_mem_config); + return {create_device_tensor(output_shape, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), input_tensor_a.device(), this->output_mem_config, input_tensor_a.tile())}; } operation::ProgramWithCallbacks ReshapeDeviceOperation::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); if (input_tensor_a.get_layout() == Layout::ROW_MAJOR) { - return {detail::reshape_rm_multi_core(input_tensor_a, output_tensor, this->N, this->C, this->H, this->W)}; + return {detail::reshape_rm_multi_core(input_tensor_a, output_tensor)}; } else if (input_tensor_a.get_layout() == Layout::TILE) { - return {detail::reshape_tile_single_core(input_tensor_a, output_tensor, this->N, this->C, this->H, this->W)}; + return {detail::reshape_tile_single_core(input_tensor_a, output_tensor)}; } else { TT_ASSERT(false, "Unsupported layout for reshape"); return {}; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.hpp index 3a4e88fe7a1..6dc04e3519e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_op.hpp @@ -11,11 +11,11 @@ namespace ttnn::operations::data_movement { struct ReshapeDeviceOperation { - int N, C, H, W; + const ttnn::Shape output_shape; const MemoryConfig output_mem_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp index 63ae487fb9c..0efebdd2653 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.cpp @@ -11,7 +11,7 @@ namespace ttnn::operations::data_movement::detail { -operation::ProgramWithCallbacks reshape_tile_single_core(const Tensor &a, Tensor &output, int N, int C, int H, int W) { +operation::ProgramWithCallbacks reshape_tile_single_core(const Tensor &a, Tensor &output) { tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); @@ -105,156 +105,6 @@ operation::ProgramWithCallbacks reshape_tile_single_core(const Tensor &a, Tensor return {std::move(program), override_runtime_args_callback}; } -operation::ProgramWithCallbacks reshape_rm_single_core(const Tensor &a, Tensor& output, int N, int C, int H, int W) { - - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - CoreRange core({0, 0}, {0, 0}); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device *device = a.device(); - tt::tt_metal::LegacyShape output_shape = output.get_legacy_shape(); - tt::tt_metal::Buffer *src0_buffer = a.buffer(); - tt::tt_metal::Buffer *dst_buffer = output.buffer(); - - uint32_t num_old_sticks = a.get_legacy_shape()[0] * a.get_legacy_shape()[1] * a.get_legacy_shape()[2]; - uint32_t num_new_sticks = output_shape[0] * output_shape[1] * output_shape[2]; - - uint32_t old_stick_size = a.get_legacy_shape()[3] * 2; // Assuming bfloat16 data format - uint32_t new_stick_size = output_shape[3] * 2; // Assuming bfloat16 data format - - tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - uint32_t single_tile_size = tt::tt_metal::detail::TileSize(cb_data_format); - uint32_t src0_cb_index = 0; - uint32_t num_input_tiles = (a.get_legacy_shape()[1] * a.get_legacy_shape()[2] * a.get_legacy_shape()[3] / tt::constants::TILE_HW); - uint32_t num_output_tiles = (output_shape[1] * output_shape[2] * output_shape[3] / tt::constants::TILE_HW); - - // Currently added to support Bert large, TODO: Make op more generic, parallelize - uint32_t available_l1 = device->l1_size_per_core() - device->get_base_allocator_addr(HalMemType::L1); - if (num_input_tiles * single_tile_size + num_output_tiles * single_tile_size > available_l1) { - if (old_stick_size >= new_stick_size) { - if (old_stick_size % new_stick_size == 0) { - // Maximize L1 usage. Is this needed or do we just need to double buffer 32 sticks (64) - // Evenly divide L1 between input/output - uint32_t w_tiles = a.get_legacy_shape()[3] / tt::constants::TILE_WIDTH; - num_input_tiles = ((available_l1 / 2) / single_tile_size) / w_tiles * w_tiles; - num_output_tiles = num_input_tiles; - } else { - // Not needed for Bert large at the moment so will trigger L1 OOM assert - } - } else { - if (new_stick_size % old_stick_size == 0) { - // Maximize L1 usage. Is this needed or do we just need to double buffer 32 sticks (64) - // Evenly divide L1 between input/output - uint32_t w_tiles = (output_shape[3] / tt::constants::TILE_WIDTH); - num_output_tiles = ((available_l1 / 2) / single_tile_size) / w_tiles * w_tiles; - num_input_tiles = num_output_tiles; - } else { - // Not needed for Bert large at the moment so will trigger L1 OOM assert - } - } - TT_ASSERT(num_input_tiles > 0 && num_output_tiles > 0, "Cannot fit input/output rows into L1"); - } - - tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, single_tile_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - - uint32_t output_cb_index = 16; // output operands start at index 16 - tt::tt_metal::CircularBufferConfig cb_output_config = tt::tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}}) - .set_page_size(output_cb_index, single_tile_size); - auto cb_output = tt::tt_metal::CreateCircularBuffer(program, core, cb_output_config); - - // Reader compile-time args - bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool old_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(old_stick_size); - vector reader_kernel_args = {src0_buffer->address(), num_old_sticks, old_stick_size}; - std::vector reader_compile_time_args = {src0_is_dram}; - if (old_stick_size_is_power_of_two) { - reader_kernel_args.push_back(log2(old_stick_size)); - - // Use the fast stick size power of 2 path (get noc addr uses just shift operations, no slow multiply algorithm) - reader_compile_time_args.push_back(1); - } else { - reader_compile_time_args.push_back(0); - } - - // Writer compile-time args - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool new_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(new_stick_size); - vector writer_kernel_args = {dst_buffer->address(), num_new_sticks, new_stick_size}; - std::vector writer_compile_time_args {dst_is_dram}; - if (new_stick_size_is_power_of_two) { - writer_kernel_args.push_back(log2(new_stick_size)); - writer_compile_time_args.push_back(1); - } else { - writer_compile_time_args.push_back(0); - } - - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/kernels/dataflow/reader_unary_reshape_stick_layout_interleaved.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/kernels/dataflow/writer_unary_reshape_stick_layout_interleaved.cpp", - core, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - // No compute required, so using blank kernel - vector compute_args = { - uint(a.volume() / tt::constants::TILE_HW), // per_core_block_cnt - 1 // per_core_block_size - }; - - auto eltwise_unary_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/eltwise_copy.cpp", - core, - tt::tt_metal::ComputeConfig{.compile_args = compute_args} - ); - - tt::tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - core, - reader_kernel_args - ); - - tt::tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - writer_kernel_args - ); - - auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - - auto src_buffer = input_buffers.at(0); - - auto dst_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - std::vector, std::vector > > get_runtime_args_rm_multi_core(const Tensor &input_tensor, Tensor &output_tensor, uint32_t num_cores_total, @@ -358,8 +208,7 @@ std::vector, std::vector > > get_runti return ret_val; } -operation::ProgramWithCallbacks reshape_rm_multi_core(const Tensor &a, Tensor& output, int N, int C, int H, int W) { - +operation::ProgramWithCallbacks reshape_rm_multi_core(const Tensor &a, Tensor& output) { TT_FATAL(a.get_dtype() == output.get_dtype(), "Error"); tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); @@ -378,11 +227,8 @@ operation::ProgramWithCallbacks reshape_rm_multi_core(const Tensor &a, Tensor& o uint32_t old_stick_size = a.get_legacy_shape()[3] * a.element_size(); uint32_t new_stick_size = output_shape[3] * output.element_size(); - if (old_stick_size > new_stick_size) { - TT_FATAL(old_stick_size % new_stick_size == 0, "Error"); - } else { - TT_FATAL(new_stick_size % old_stick_size == 0, "Error"); - } + TT_FATAL(std::max(old_stick_size, new_stick_size) % std::min(old_stick_size, new_stick_size) == 0, + "Last dimension of the old shape ({}) should be divisible by the last dimension of the new shape ({}) or vice versa", old_stick_size, new_stick_size); auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); uint32_t num_cores_x = compute_with_storage_grid_size.x; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.hpp index 961ddabdfad..9f6fa7ebae4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/device/reshape_program_factory.hpp @@ -6,8 +6,7 @@ namespace ttnn::operations::data_movement::detail { -operation::ProgramWithCallbacks reshape_tile_single_core(const Tensor &a, Tensor &output, int N, int C, int H, int W); -operation::ProgramWithCallbacks reshape_rm_single_core(const Tensor &a, Tensor& output, int N, int C, int H, int W); -operation::ProgramWithCallbacks reshape_rm_multi_core(const Tensor &a, Tensor& output, int N, int C, int H, int W); +operation::ProgramWithCallbacks reshape_tile_single_core(const Tensor &a, Tensor &output); +operation::ProgramWithCallbacks reshape_rm_multi_core(const Tensor &a, Tensor& output); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp index e2c5b11e8ca..93532abcfb9 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp @@ -19,14 +19,13 @@ namespace detail { static Tensor manual_insertion( const Tensor& input_tensor, - const tt::tt_metal::LegacyShape& shape, + const ttnn::Shape& shape, Device* device, const MemoryConfig& output_mem_config ) { TT_ASSERT(input_tensor.get_layout() == Layout::ROW_MAJOR); - TT_ASSERT( - shape[0] * shape[1] * shape[2] * shape[3] == input_tensor.volume(), - "Required shape volume must match old shape volume"); + TT_ASSERT(shape.logical_shape().volume() == input_tensor.get_logical_volume(), + "Required shape volume ({}) must match old shape volume ({})", shape.logical_shape().volume(), input_tensor.get_logical_volume()); auto device_buffer = input_tensor.device_buffer(); uint32_t size_in_bytes = device_buffer->size(); std::vector data_vec; @@ -51,31 +50,28 @@ namespace detail { ttnn::Tensor ReshapeOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - int N, - int C, - int H, - int W, + ttnn::Shape output_shape, const std::optional& memory_config_arg) { using namespace tt::constants; auto output_mem_config = memory_config_arg.value_or(input_tensor.memory_config()); + auto padded_output_shape = output_shape.padded_shape(); // No-op (Will do a tensor copy) - tt::tt_metal::LegacyShape output_shape = tt::tt_metal::infer_dims_for_reshape(N, C, H, W, input_tensor.volume()); if ( - ((input_tensor.get_layout() == Layout::TILE or input_tensor.get_layout() == Layout::ROW_MAJOR) && output_shape[3] == input_tensor.get_legacy_shape()[3]) + ((input_tensor.get_layout() == Layout::TILE or input_tensor.get_layout() == Layout::ROW_MAJOR) && padded_output_shape[3] == input_tensor.get_padded_shape()[3]) ) { // Don't need to do a check here to see the H and W both divisible by 32 // since handled within the tensor reshape method - return input_tensor.reshape(N, C, H, W); + return input_tensor.reshape(output_shape); } - if (input_tensor.get_legacy_shape() == output_shape) { + if (input_tensor.get_padded_shape() == padded_output_shape) { return ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config(input_tensor, output_mem_config); } uint32_t ROW_MAJOR_WIDTH = 8; if (input_tensor.get_layout() == Layout::ROW_MAJOR && (input_tensor.get_legacy_shape()[3] % ROW_MAJOR_WIDTH != 0 || - output_shape[3] % ROW_MAJOR_WIDTH != 0) && - ((compute_volume(output_shape) / output_shape[-1]) % TILE_HEIGHT != 0 - || output_shape[-1] % TILE_WIDTH != 0 + padded_output_shape[3] % ROW_MAJOR_WIDTH != 0) && + ((padded_output_shape.volume() / padded_output_shape[-1]) % TILE_HEIGHT != 0 + || padded_output_shape[-1] % TILE_WIDTH != 0 || input_tensor.get_legacy_shape()[-1] % TILE_WIDTH != 0 || (input_tensor.volume() / input_tensor.get_legacy_shape()[-1]) % TILE_HEIGHT != 0)) { TT_FATAL(input_tensor.get_dtype()==DataType::BFLOAT16, "Error"); @@ -83,22 +79,31 @@ ttnn::Tensor ReshapeOperation::invoke( return detail::manual_insertion((tt::tt_metal::Tensor)input_tensor, output_shape, input_tensor.device(), output_mem_config); } std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; - return operation::run(ReshapeDeviceOperation{N, C, H, W, output_mem_config}, {input_tensor}).at(0); + return operation::run(ReshapeDeviceOperation{output_shape, output_mem_config}, {input_tensor}).at(0); } ttnn::Tensor ReshapeOperation::invoke( const ttnn::Tensor& input_tensor, - int N, - int C, - int H, - int W, + ttnn::Shape shape, const std::optional& memory_config) { - return invoke(DefaultQueueId, input_tensor, N, C, H, W, memory_config); + return invoke(DefaultQueueId, input_tensor, shape, memory_config); } -ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, int N, int C, int H, int W) { - return invoke(DefaultQueueId, input_tensor, N, C, H, W,std::nullopt); +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape) { + return invoke(DefaultQueueId, input_tensor, shape, std::nullopt); +} + +ttnn::Tensor ReshapeOperation::invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, const std::vector & shape_vector, const std::optional& memory_config_arg) { + return invoke(queue_id, input_tensor, ttnn::Shape(infer_dims_for_reshape(input_tensor, shape_vector).as_vector()), memory_config_arg); +} + +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg) { + return invoke(DefaultQueueId, input_tensor, shape_vector, memory_config_arg); +} + +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector) { + return invoke(input_tensor, shape_vector, std::nullopt); } } // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp index 276b039f0d0..e6044087170 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp @@ -14,21 +14,19 @@ struct ReshapeOperation { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - int N, - int C, - int H, - int W, + ttnn::Shape shape, const std::optional& memory_config_arg); static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - int N, - int C, - int H, - int W, + ttnn::Shape shape, const std::optional& memory_config); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, int N, int C, int H, int W); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); + + static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp index 196e814881d..b3e9f335567 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp @@ -30,7 +30,7 @@ void bind_reshape(pybind11::module& module, const data_movement_operation_t& ope int X, const std::optional& memory_config, uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, W, Z, Y, X, memory_config); + return self(queue_id, input_tensor, std::vector{W, Z, Y, X}, memory_config); }, py::arg("input_tensor"), py::arg("W"), diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 568fa286987..c19662e5b21 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -23,7 +23,7 @@ namespace detail { ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) { - return tensor.reshape(shape.value); + return tensor.reshape(shape); } auto tensor_shape = tensor.shape(); auto layout = tensor.layout(); @@ -44,7 +44,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) host_tensor_4d = ttnn::slice(host_tensor_4d, begins, ends, step, std::nullopt); host_tensor = squeeze_from_4D(host_tensor_4d, tensor_shape.rank()); } - auto host_reshape_tensor = rm_tensor.reshape(shape.value); + auto host_reshape_tensor = rm_tensor.reshape(shape); auto final_layout_tensor = ttnn::to_layout(host_reshape_tensor, layout, std::nullopt, std::nullopt, (Device *)nullptr); auto device_tensor = ttnn::data_transfer_to_device(final_layout_tensor, device, memory_config); return device_tensor; @@ -52,7 +52,9 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) ttnn::Tensor row_major_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { const auto layout = tensor.get_layout(); + auto shape_with_padding = shape.padded_shape(); auto tensor_shape = tensor.get_shape(); + auto tensor_shape_with_padding = tensor_shape.padded_shape(); //Constraint in device kernel uint32_t ROW_MAJOR_WIDTH = 8; @@ -61,24 +63,24 @@ ttnn::Tensor row_major_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& sh auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); if (rm_tensor.is_contiguous()) { // Page size depends on the width, so only modify the shape if the width is the same - if (tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1]) { - return rm_tensor.reshape(shape.value); + if (tensor_shape_with_padding[-1] == shape_with_padding[-1]) { + return rm_tensor.reshape(shape); } //Different page width, going to use device kernel that does transpose else { auto original_rank = shape.rank(); auto tensor_4d = unsqueeze_to_4D(rm_tensor); - const auto shape_4d = shape.to_rank<4>(); - auto reshaped_tensor = ttnn::reshape_on_device(tensor_4d, shape_4d[0], shape_4d[1], shape_4d[2], shape_4d[3], tensor.memory_config()); + const auto shape_4d = shape.to_rank(4); + auto reshaped_tensor = ttnn::reshape_on_device(tensor_4d, ttnn::SimpleShape{shape_4d[0], shape_4d[1], shape_4d[2], shape_4d[3]}, tensor.memory_config()); reshaped_rm_tensor = squeeze_from_4D(reshaped_tensor, original_rank); } } else if (tensor_shape.rank() >= 2 and shape.rank() >= 2) { // Handle the case when the tensor is not contiguous but the last two dimensions are the same and so reshape // is possible if (tensor_shape[-1] == shape[-1] and tensor_shape[-2] == shape[-2] and - tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1] and - tensor_shape.with_tile_padding()[-2] == shape.with_tile_padding()[-2]) { - reshaped_rm_tensor = rm_tensor.reshape(shape.value); + tensor_shape_with_padding[-1] == shape_with_padding[-1] and + tensor_shape_with_padding[-2] == shape_with_padding[-2]) { + reshaped_rm_tensor = rm_tensor.reshape(shape); } } else { reshaped_rm_tensor = host_reshape(tensor, shape); @@ -96,49 +98,15 @@ ttnn::Tensor row_major_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& sh else { return reshaped_rm_tensor; } - -} - -ttnn::Shape get_shape_from_vector_with_possible_negative_values(const ttnn::Tensor& tensor, const std::vector & shape) { - std::int64_t new_volume = 1; - std::int64_t index_of_negative_1 = -1; - - for (auto index = 0; index < shape.size(); ++index) { - if (shape[index] == -1) { - if (index_of_negative_1 != -1) { - std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: ("; - for(auto & s: shape) { - error_msg += std::to_string(s) + ","; - } - error_msg += ")"; - TT_THROW("{}", error_msg); - } - index_of_negative_1 = index; - } - new_volume *= shape[index]; - } - - std::vector new_shape(shape.size()); - std::copy(shape.begin(), shape.end(), new_shape.begin()); - if (new_volume < 0) { - const auto volume = tensor.volume(); - new_shape[index_of_negative_1] = volume / (-new_volume); - } - return ttnn::Shape(new_shape); } } -ttnn::Tensor ReshapeViewOperation::invoke( - const ttnn::Tensor& tensor, - const ttnn::Shape& shape - ) { - +ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { auto layout = tensor.get_layout(); auto tensor_shape = tensor.get_shape(); - // First Case, No reshape Required if (tensor_shape == shape) { return tensor; @@ -151,22 +119,23 @@ ttnn::Tensor ReshapeViewOperation::invoke( // For Tensors already on host we can do the tensor.reshape (changing of view) if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) or tile_tensor_view_reshape_possible) { - return tensor.reshape(shape.value); + return tensor.reshape(shape); } // Catch-all // Do the reshape in row-major return detail::row_major_reshape(tensor, shape); +} +ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { + return invoke(tensor, ttnn::Shape(shape.as_vector())); } ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, const std::vector & shape_vector ) { - - auto shape = detail::get_shape_from_vector_with_possible_negative_values(tensor, shape_vector); - return invoke(tensor, shape); + return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector)); } } // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp index 311055263a3..c6c1941d2a2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp @@ -11,14 +11,9 @@ namespace ttnn { namespace operations::data_movement { struct ReshapeViewOperation { - static ttnn::Tensor invoke( - const ttnn::Tensor& input_tensor, - const ttnn::Shape& shape - ); - static ttnn::Tensor invoke( - const ttnn::Tensor& input_tensor, - const std::vector & shape_vector - ); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector & shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index 0a9b034dc8c..eb2bbdbed53 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -102,13 +102,13 @@ namespace detail { } const int W = 1, Z = shape[0] * shape[1], Y = shape[2], X = shape[3]; - const Tensor &reshaped_tensor = ttnn::reshape_on_device(input_tensor, 1, -1, Y, X, mem_config); + const Tensor &reshaped_tensor = ttnn::reshape_on_device(input_tensor, std::vector{1, -1, Y, X}, mem_config); auto part_reshaped = impl_split_last_dim_two_chunks_tiled(reshaped_tensor, mem_config); std::vector results; results.reserve(part_reshaped.size()); - for (auto &part : part_reshaped) results.emplace_back(ttnn::reshape_on_device(part, -1, shape[1], Y, X / 2, mem_config)); + for (auto &part : part_reshaped) results.emplace_back(ttnn::reshape_on_device(part, std::vector{-1, (int32_t)shape[1], Y, X / 2}, mem_config)); return results; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp index 189e6ed9bfe..d0674b0bb8e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp @@ -8,11 +8,7 @@ namespace ttnn::operations::data_movement { -ttnn::Tensor SqueezeOperation::invoke( - const ttnn::Tensor& input_tensor, - const int dim - ) { - +ttnn::Tensor SqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const int dim) { const auto original_logical_shape = input_tensor.get_shape(); const auto padded_shape = input_tensor.get_shape().with_tile_padding(); const auto input_tensor_rank = original_logical_shape.rank(); @@ -40,7 +36,6 @@ ttnn::Tensor SqueezeOperation::invoke( } return ttnn::reshape(input_tensor, ttnn::Shape(original_logical_shape_vector, padded_shape_vector)); - } } // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp index 06fdb657d70..d1fbf47978f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp @@ -8,11 +8,7 @@ namespace ttnn::operations::data_movement { -ttnn::Tensor UnsqueezeOperation::invoke( - const ttnn::Tensor& input_tensor, - const int dim - ) { - +ttnn::Tensor UnsqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const int dim) { const auto tensor_shape = input_tensor.get_shape(); const auto rank = tensor_shape.rank(); std::vector output_shape_vector; @@ -38,10 +34,7 @@ ttnn::Tensor UnsqueezeOperation::invoke( output_shape_vector.push_back(1); } - ttnn::Shape output_shape(output_shape_vector); - return ttnn::reshape(input_tensor, output_shape); - - + return ttnn::reshape(input_tensor, ttnn::SimpleShape(std::move(output_shape_vector))); } } // ttnn::operations::data_movement namespace diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index b71c58f041b..1303233ab57 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -365,10 +365,10 @@ Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional< Tensor b_slim = input_b; if(!skip_reshape_a) { - a_slim = ttnn::reshape(input_a, ttnn::Shape{std::array{1, 1, input_a.volume(), 1}}); + a_slim = ttnn::reshape(input_a, ttnn::SimpleShape{std::array{1, 1, input_a.volume(), 1}}); } if(!skip_reshape_b) { - b_slim = ttnn::reshape(input_b, ttnn::Shape{std::array{1, 1, 1, input_b.volume()}}); + b_slim = ttnn::reshape(input_b, ttnn::SimpleShape{std::array{1, 1, 1, input_b.volume()}}); } a_slim = ttnn::to_layout(a_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); b_slim = ttnn::to_layout(b_slim, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 6489de03deb..2d703769419 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -791,7 +791,7 @@ Tensor _make_global_from_hw_impl(HWFunctionT fn, const Tensor& y, const std::op // format to HW Tensor y_hw = ttnn::reshape_on_device( - y, 1, 1, y.get_legacy_shape()[2], y.get_legacy_shape()[3] * y.get_legacy_shape()[1] * y.get_legacy_shape()[0]); + y, ttnn::SimpleShape{1, 1, y.get_legacy_shape()[2], y.get_legacy_shape()[3] * y.get_legacy_shape()[1] * y.get_legacy_shape()[0]}); // compute @fn Tensor z_0 = fn(y_hw, output_mem_config); @@ -800,7 +800,7 @@ Tensor _make_global_from_hw_impl(HWFunctionT fn, const Tensor& y, const std::op // reformat Tensor z_1 = ttnn::reshape_on_device( - z_0, y.get_legacy_shape()[0], y.get_legacy_shape()[1], y.get_legacy_shape()[2], y.get_legacy_shape()[3]); + z_0, ttnn::SimpleShape{y.get_legacy_shape()[0], y.get_legacy_shape()[1], y.get_legacy_shape()[2], y.get_legacy_shape()[3]}); z_0.deallocate(); return z_1; diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp index befdba2a43b..52439fd693d 100644 --- a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp @@ -38,7 +38,7 @@ struct EmbeddingOperation { auto batch_size = input_tensor_arg.get_shape()[0]; auto sentence_size = input_tensor_arg.get_shape()[-1]; auto input_tensor = - ttnn::reshape(input_tensor_arg, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); + ttnn::reshape(input_tensor_arg, ttnn::SimpleShape{std::array{batch_size, 1, 1, sentence_size}}); bool tilized = layout == ttnn::TILE_LAYOUT; auto embeddings = operation::run( @@ -51,7 +51,7 @@ struct EmbeddingOperation { {input_tensor, weight}) .at(0); embeddings = ttnn::reshape( - embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); + embeddings, ttnn::SimpleShape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); return embeddings; } diff --git a/ttnn/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp b/ttnn/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp index 074a736ab4a..7786461848f 100644 --- a/ttnn/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp +++ b/ttnn/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp @@ -24,7 +24,7 @@ Tensor EmbeddingBackwardOperation::invoke( auto batch_size = input_tensor_arg.get_shape()[0]; auto sentence_size = input_tensor_arg.get_shape()[-1]; auto input_tensor = - ttnn::reshape(input_tensor_arg, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); + ttnn::reshape(input_tensor_arg, ttnn::SimpleShape{std::array{batch_size, 1, 1, sentence_size}}); auto input_gradient = operation::run( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp index f87a96a0ff8..a6c5bddad33 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp @@ -55,7 +55,7 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera uint32_t index_end_dim = index_dims.back(); Tensor input_5d = input; - input_5d = input_5d.reshape(input_5d_shape.value); + input_5d = input_5d.reshape(input_5d_shape); auto input_5d_shape_without_padding = input_5d_shape.value.without_padding(); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp index f36f5c1efca..4864d6a34f1 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp @@ -48,7 +48,7 @@ Tensor HaloTensorCreation(const Tensor &input){ input_tensor = ttnn::reshape( input_tensor, - Shape(std::array{ + SimpleShape(std::array{ 1, 1, input.get_shape()[0] * input.get_shape()[1] * input.get_shape()[2], diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index b82fdddc49e..5cce6fa248b 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -173,7 +173,7 @@ static Tensor reduce_impl( } if (reshape) { - output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{tt::tt_metal::LegacyShape{output_shape, padded_output_shape}}); + output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape, padded_output_shape}); } return output_tensor; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp index 90fe0c05811..490f4ba68dd 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode_gqa/sdpa_decode_gqa.cpp @@ -60,7 +60,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionGQADecode::invoke( input_tensor_q_gqa = ttnn::to_layout(input_tensor_q, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); input_tensor_q_gqa = ttnn::transpose(input_tensor_q_gqa, 1, 2); - input_tensor_q_gqa = ttnn::reshape(input_tensor_q_gqa, ttnn::Shape{std::array{1, Bq, NQH, D}}); + input_tensor_q_gqa = ttnn::reshape(input_tensor_q_gqa, ttnn::SimpleShape{std::array{1, Bq, NQH, D}}); input_tensor_q_gqa = ttnn::to_layout(input_tensor_q_gqa, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); } diff --git a/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp b/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp index 533c8eeae23..2612e60d109 100644 --- a/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp @@ -97,11 +97,11 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv head_size, padded_head_size); - const auto input_4d = input_tensor.reshape( + const auto input_4d = input_tensor.reshape(ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], - input_shape.with_tile_padding()[2]); + input_shape.with_tile_padding()[2]}); auto outputs = ttnn::experimental::nlp_create_qkv_heads_falcon7b(input_4d, memory_config.value_or(input_tensor.memory_config())); return detail::reshape_outputs_of_split_query_key_value_and_split_heads( @@ -145,11 +145,11 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv TT_FATAL(!input_tensor_kv.has_value(), "Invalid operation: KV tensor should not be provided when the input tensor is sharded. Please ensure that the KV tensor is only used in non-sharded configurations."); - const auto input_tensor_4d = input_tensor.reshape( + const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], - input_shape.with_tile_padding()[2]); + input_shape.with_tile_padding()[2]}); return detail::reshape_outputs_of_split_query_key_value_and_split_heads( ttnn::experimental::create_qkv_heads( input_tensor_4d, @@ -161,16 +161,16 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv sequence_size_padded, transpose_key); } else { - const auto input_tensor_4d = input_tensor.reshape( + const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], - input_shape.with_tile_padding()[2]); + input_shape.with_tile_padding()[2]}); std::optional input_tensor_kv_4d = std::nullopt; if (input_tensor_kv.has_value()) { auto padded_input_shape_kv = input_tensor_kv.value().get_shape().with_tile_padding(); - input_tensor_kv_4d = input_tensor_kv.value().reshape( - padded_input_shape_kv[0], 1, padded_input_shape_kv[1], padded_input_shape_kv[2]); + input_tensor_kv_4d = input_tensor_kv.value().reshape(ttnn::SimpleShape{ + padded_input_shape_kv[0], 1, padded_input_shape_kv[1], padded_input_shape_kv[2]}); } const auto outputs = ttnn::experimental::nlp_create_qkv_heads( input_tensor_4d, diff --git a/ttnn/cpp/ttnn/tensor/serialization.cpp b/ttnn/cpp/ttnn/tensor/serialization.cpp index ca283f679d1..aaee6f406a1 100644 --- a/ttnn/cpp/ttnn/tensor/serialization.cpp +++ b/ttnn/cpp/ttnn/tensor/serialization.cpp @@ -65,8 +65,7 @@ void dump_multi_device_host_storage(std::ofstream& output_stream, const MultiDev output_stream.write(reinterpret_cast(buffer.begin()), sizeof(T) * size); }, storage.get_buffer(0) ); - output_stream.write(reinterpret_cast(&storage.shapes.at(0)), sizeof(tt::tt_metal::LegacyShape)); - + output_stream.write(reinterpret_cast(&storage.shapes.at(0)), sizeof(ttnn::Shape)); } else { for (int i = 0; i < num_buffers; i++) { std::visit( @@ -79,7 +78,7 @@ void dump_multi_device_host_storage(std::ofstream& output_stream, const MultiDev ); } for (const auto& shape : storage.shapes) { - output_stream.write(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + output_stream.write(reinterpret_cast(&shape), sizeof(ttnn::Shape)); } } } @@ -102,14 +101,14 @@ MultiDeviceHostStorage load_multi_device_host_storage(std::ifstream& input_strea input_stream.read(reinterpret_cast(&strategy), sizeof(DistributedTensorConfig)); std::vector buffers; - std::vector shapes; + std::vector shapes; if (std::holds_alternative(strategy)) { std::size_t size = 0; input_stream.read(reinterpret_cast(&size), sizeof(std::size_t)); auto buffer = owned_buffer::create(size); - auto shape = tt::tt_metal::LegacyShape{}; + auto shape = ttnn::Shape{}; input_stream.read(reinterpret_cast(buffer.begin()), sizeof(T) * size); - input_stream.read(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + input_stream.read(reinterpret_cast(&shape), sizeof(ttnn::Shape)); buffers.push_back(buffer); shapes.push_back(shape); @@ -129,8 +128,8 @@ MultiDeviceHostStorage load_multi_device_host_storage(std::ifstream& input_strea buffers.push_back(std::move(buffer)); } for (std::size_t i = 0; i < num_buffers; ++i) { - auto shape = tt::tt_metal::LegacyShape{}; - input_stream.read(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + auto shape = ttnn::Shape{}; + input_stream.read(reinterpret_cast(&shape), sizeof(ttnn::Shape)); shapes.push_back(shape); } } @@ -203,14 +202,14 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor, const std:: throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name)); } - auto shape = tensor.get_legacy_shape(); + auto shape = tensor.get_shape(); auto data_type = tensor.get_dtype(); auto layout = tensor.get_layout(); auto storage_type = tensor.storage_type(); output_stream.write(reinterpret_cast(&detail::SENTINEL_VALUE), sizeof(std::size_t)); output_stream.write(reinterpret_cast(&VERSION_ID), sizeof(std::uint8_t)); - output_stream.write(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + output_stream.write(reinterpret_cast(&shape), sizeof(ttnn::Shape)); output_stream.write(reinterpret_cast(&data_type), sizeof(DataType)); output_stream.write(reinterpret_cast(&layout), sizeof(Layout)); output_stream.write(reinterpret_cast(&storage_type), sizeof(StorageType)); @@ -273,11 +272,11 @@ Tensor load_tensor_helper(const std::string& file_name, T device) { if (version_id > VERSION_ID) { throw std::runtime_error(fmt::format("Serialized tensor with version_id: {}. Loader version: {}", version_id, VERSION_ID)); } - auto shape = tt::tt_metal::LegacyShape{}; + auto shape = ttnn::Shape{}; DataType data_type; Layout layout; StorageType storage_type; - input_stream.read(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + input_stream.read(reinterpret_cast(&shape), sizeof(ttnn::Shape)); input_stream.read(reinterpret_cast(&data_type), sizeof(DataType)); input_stream.read(reinterpret_cast(&layout), sizeof(Layout)); input_stream.read(reinterpret_cast(&storage_type), sizeof(StorageType)); @@ -306,10 +305,10 @@ Tensor load_tensor_helper(const std::string& file_name, T device) { } else { input_stream.seekg(0, std::ios::beg); // No sentinel found, assume it's an older format and rewind - auto shape = tt::tt_metal::LegacyShape{}; + auto shape = ttnn::Shape{}; DataType data_type; Layout layout; - input_stream.read(reinterpret_cast(&shape), sizeof(tt::tt_metal::LegacyShape)); + input_stream.read(reinterpret_cast(&shape), sizeof(ttnn::Shape)); input_stream.read(reinterpret_cast(&data_type), sizeof(DataType)); input_stream.read(reinterpret_cast(&layout), sizeof(Layout)); diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index b3cea279364..2e57a99e1b6 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -179,7 +179,7 @@ Tensor::Tensor( std::get(this->tensor_attributes->storage).buffers = std::vector(num_buffers, OwnedBuffer()); std::get(this->tensor_attributes->storage).shapes = - std::vector(num_buffers, this->tensor_attributes->shape.value); + std::vector(num_buffers, this->tensor_attributes->shape); } this->tensor_attributes->num_shards_to_be_populated = num_buffers; } @@ -578,11 +578,11 @@ const bool Tensor::is_sharded() const { uint32_t Tensor::element_size() const { return tensor_impl::element_size_bytes(this->get_dtype()); } -Tensor Tensor::reshape(int N, int C, int H, int W) const { - return tensor_ops::tensor_reshape(*this, N, C, H, W); +Tensor Tensor::reshape(const ttnn::SimpleShape& new_shape) const { + return tensor_ops::tensor_reshape(*this, new_shape); } -Tensor Tensor::reshape(const tt::tt_metal::LegacyShape& new_shape) const { +Tensor Tensor::reshape(const ttnn::Shape& new_shape) const { return tensor_ops::tensor_reshape(*this, new_shape); } diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index cbc97f7a237..e4a8849ef2e 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -175,8 +175,8 @@ struct Tensor { // ====================================================================================== // Low Level APIs // ====================================================================================== - Tensor reshape(int N, int C, int H, int W) const; - Tensor reshape(const tt::tt_metal::LegacyShape &new_shape) const; + Tensor reshape(const ttnn::SimpleShape &new_shape) const; + Tensor reshape(const ttnn::Shape &new_shape) const; // ====================================================================================== // Getters diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index cd2a6fb6090..b3a013e1b38 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -1004,7 +1004,7 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { return OwnedStorage{output_buffer}; } else if constexpr (std::is_same_v) { std::vector output_buffers; - std::vector output_shapes; + std::vector output_shapes; for (int i = 0; i < storage.num_buffers(); i++) { const auto input_data = owned_buffer::get_as(storage.get_buffer(i)); auto output_buffer = owned_buffer::create(std::move(convert(input_data))); diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 9d0f9d8f373..8c9364e77ac 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -348,24 +348,14 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap return output; } -Tensor tensor_reshape(const Tensor& input_tensor, int N, int C, int H, int W) { - ZoneScoped; - GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, N, C, H, W); - auto new_shape = infer_dims_for_reshape(N, C, H, W, input_tensor.volume()); - auto output = input_tensor.reshape(new_shape); - output = tt::tt_metal::set_tensor_id(output); - GraphTracker::instance().track_function_end(output); - return output; -} - -Tensor tensor_reshape(const Tensor& input_tensor, const tt::tt_metal::LegacyShape& new_shape) { +Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape); TT_ASSERT( - input_tensor.volume() == tt::tt_metal::compute_volume(new_shape), + input_tensor.volume() == new_shape.padded_shape().volume(), "{} != {}", input_tensor.volume(), - tt::tt_metal::compute_volume(new_shape)); + new_shape.padded_shape().volume()); if (input_tensor.get_layout() == Layout::TILE) { TT_ASSERT( new_shape[-2] % constants::TILE_HEIGHT == 0 && new_shape[-1] % constants::TILE_WIDTH == 0 && @@ -384,7 +374,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const tt::tt_metal::LegacyShap } if constexpr (std::is_same_v) { MultiDeviceStorage updated_storage = std::get(tensor.get_storage()); - std::unordered_map new_shapes; + std::unordered_map new_shapes; for (auto device_id : updated_storage.ordered_device_ids) { new_shapes.insert({device_id, new_shape}); @@ -436,4 +426,8 @@ Tensor tensor_reshape(const Tensor& input_tensor, const tt::tt_metal::LegacyShap return output; } +Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape) { + return tensor_reshape(input_tensor, ttnn::Shape(new_shape.as_vector())); +} + } diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 7cfa92e3fd9..4d7cacc13b4 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -42,8 +42,7 @@ Tensor tensor_pad_to_tile(const Tensor& input_tensor, float pad_value); Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShape& output_tensor_shape); -Tensor tensor_reshape(const Tensor& input_tensor, int N, int C, int H, int W); - -Tensor tensor_reshape(const Tensor& input_tensor, const tt::tt_metal::LegacyShape& new_shape); +Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape); +Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape); } diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index be238446c19..3ff6eb4452b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -456,42 +456,37 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); } -const tt::tt_metal::LegacyShape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volume) { - vector ns{N, C, H, W}; - int neg_idx = -1; - for (int i = 0; i < ns.size(); i++) { - if (ns[i] == -1) { - TT_ASSERT(neg_idx == -1, "Only one -1 is allowed in reshape"); - neg_idx = i; +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std::vector& shape) { + int64_t old_volume = tensor.get_logical_volume(); + int64_t new_volume = 1; + int64_t index_of_negative_1 = -1; + for (auto index = 0; index < shape.size(); ++index) { + if (shape[index] == -1) { + if (index_of_negative_1 != -1) { + std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: ("; + for(auto & s: shape) { + error_msg += std::to_string(s) + ","; + } + error_msg += ")"; + TT_THROW("{}", error_msg); + } + index_of_negative_1 = index; } else { - TT_ASSERT(ns[i] > 0, "New shape entries can only have -1 or positive values"); + TT_FATAL(shape[index] > 0, "New shape entries can only have -1 or positive values"); + new_volume *= shape[index]; } } - switch (neg_idx) { - case 0: - TT_ASSERT(old_volume % C * H * W == 0); - N = old_volume / (C * H * W); - break; - case 1: - TT_ASSERT(old_volume % N * H * W == 0); - C = old_volume / (N * H * W); - break; - case 2: - TT_ASSERT(old_volume % N * C * W == 0); - H = old_volume / (N * C * W); - break; - case 3: - TT_ASSERT(old_volume % N * C * H == 0); - W = old_volume / (N * C * H); - break; - case -1: // In case where there is no negative value in ns - TT_ASSERT(N * C * H * W == old_volume); - break; - default: TT_ASSERT(false && "Unexpected neg_idx in reshape!"); + std::vector new_shape(shape.size()); + std::copy(shape.begin(), shape.end(), new_shape.begin()); + if (index_of_negative_1 == -1) { + TT_FATAL(new_volume == old_volume, "Invalid arguments to reshape"); + } else { + TT_FATAL(old_volume % new_volume == 0, "Invalid arguments to reshape"); + new_shape[index_of_negative_1] = old_volume / new_volume; } - return {(uint32_t)N, (uint32_t)C, (uint32_t)H, (uint32_t)W}; + return ttnn::SimpleShape(std::move(new_shape)); } bool is_arch_gs(const tt::ARCH& arch) { return arch == tt::ARCH::GRAYSKULL; } @@ -597,12 +592,12 @@ void insert_buffer_and_shape_for_device( s.insert_buffer_and_shape_for_device( buffer_index.value(), std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape.value); + shard.tensor_attributes->shape); } else if constexpr (std::is_same_v) { s.insert_buffer_and_shape_for_device( target_device, std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape.value); + shard.tensor_attributes->shape); } else if constexpr (std::is_same_v) { s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); } else if constexpr (std::is_same_v) { diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 239d3f2e4ec..692e5b361fa 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -34,9 +34,7 @@ Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, u // Converts convolution weights to depthwise layout with broadcasted weights Tensor convert_conv_weight_tensor_to_depthwise_layout(Tensor conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); -const tt::tt_metal::LegacyShape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volume); - -const tt::tt_metal::LegacyShape infer_dims_for_reshape_RM(int N, int C, int H, int W, uint32_t old_volume); +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std::vector& shape); // TODO: Remove this once we switch to SimpleShape .volume() static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) { diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index 38ef1b6ed0b..14d89db8a9b 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -36,6 +36,31 @@ SimpleShape get_physical_shape(const SimpleShape& logical_shape, DataType data_t return physical_shape; } +namespace types { + +const Shape Shape::to_rank(size_t new_rank) const { + auto padded_shape = value; + auto shape = value.without_padding(); + + std::vector new_shape(new_rank, 1); + std::vector new_padded_shape(new_rank, 1); + + int cur_idx = static_cast(rank()) - 1; + int new_idx = static_cast(new_rank) - 1; + for(;cur_idx >= 0 && new_idx >= 0; cur_idx--, new_idx--) { + new_shape[new_idx] = shape[cur_idx]; + new_padded_shape[new_idx] = padded_shape[cur_idx]; + } + for(;cur_idx >= 0; cur_idx--) { + TT_FATAL(shape[cur_idx] == 1, "Can't convert shape rank"); + TT_FATAL(padded_shape[cur_idx] == 1, "Can't convert shape rank"); + } + + return Shape(std::move(new_shape), std::move(new_padded_shape)); +} + +} + } namespace tt { diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 33c497dd522..7c62798c2dd 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -465,10 +465,161 @@ struct BorrowedStorage { inline bool is_allocated() const { return true; } }; +} // namespace tt_metal +} // namespace tt + + +namespace ttnn { +namespace types { + +namespace detail { +template +static tt::tt_metal::LegacyShape compute_ttl_shape( + const std::array &shape, const std::array, Rank> &padding) { + auto ttl_shape = std::array{}; + for (auto index = 0; index < Rank; index++) { + ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; + } + return tt::tt_metal::LegacyShape{ttl_shape, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; +} + +} // namespace detail + + +struct Shape { + // ttnn::Shape is a wrapper around tt::tt_metal::LegacyShape + // It is used to flip the default value of operator[] to return the shape without padding + tt::tt_metal::LegacyShape value; + + Shape(const std::initializer_list dimensions) : value{dimensions} {} + + Shape(const tt::tt_metal::LegacyShape &shape) : value{shape} {} + + template + Shape(const std::array &shape) : value{shape} {} + + template + explicit Shape(const std::array &shape, const std::array &shape_with_tile_padding) : + value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} + + template + explicit Shape( + const std::array &shape, const std::array, Rank> &tile_padding) : + value{detail::compute_ttl_shape(shape, tile_padding)} {} + + Shape(const std::vector &shape) : value{tt::tt_metal::LegacyShape{shape}} {} + + explicit Shape(const std::vector &shape, const std::vector &shape_with_tile_padding) : + value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} + + explicit Shape(const std::vector &shape, const Padding &padding) : + value{tt::tt_metal::LegacyShape{shape, padding}} {} + + explicit Shape(const Shape &shape, const Padding &padding) : + value{tt::tt_metal::LegacyShape{shape.value, padding}} {} + + Shape(const SimpleShape& shape): value{shape.as_vector()} {} + + const auto rank() const { return this->value.rank(); } + + const size_t size() const { return this->rank(); } + + // Returns the padded shape, padding information is stripped + [[deprecated("Replaced by padded_shape()")]] + const tt::tt_metal::Padding &padding() const { return this->value.padding(); } + + const uint32_t get_normalized_index(std::int64_t index) const { return this->value.get_normalized_index(index); } + + Shape with_tile_padding() const { + return Shape{tt::tt_metal::LegacyShape{this->value, tt::tt_metal::Padding{this->value.rank()}}}; + } + + SimpleShape padded_shape() const { + std::vector values(rank()); + for (size_t i = 0; i < values.size(); i++) { + values[i] = this->value[i]; // value stored LegacyShape, its operator[] returns padded value + } + return SimpleShape(std::move(values)); + } + + // Returns the shape without padding, padding information is stripped + SimpleShape logical_shape() const { + std::vector values(this->rank()); + for (size_t i = 0; i < values.size(); i++) { + values[i] = this->operator[](i); // operator[] returns the shape without padding + } + return SimpleShape(std::move(values)); + } + + bool has_tile_padding() const { + auto rank = this->rank(); + for (auto index = 0; index < rank; index++) { + if (this->has_tile_padding(index)) { + return true; + } + } + return false; + } + + bool has_tile_padding(int dim) const { + return this->value.padding()[dim].front > 0 or this->value.padding()[dim].back > 0; + } + + bool operator==(const Shape &other) const { + const auto &shape_a = this->value; + const auto &shape_b = other.value; + // tt::tt_metal::LegacyShape comparison doesn't take padding into account + return (shape_a == shape_b and shape_a.without_padding() == shape_b.without_padding()); + } + + template + bool operator==(const std::array &other) const { + return Shape{this->value.without_padding()} == Shape{other}; + } + + bool operator!=(const Shape &other) const { return not(*this == other); } + + // Returns value without padding + uint32_t operator[](std::int64_t index) const; + + const Shape to_rank(size_t new_rank) const; + + static constexpr auto attribute_names = std::forward_as_tuple("value"); + const auto attribute_values() const { return std::forward_as_tuple(this->value); } +}; + +static std::ostream &operator<<(std::ostream &os, const Shape &shape) { + const auto shape_with_tile_padding = shape.with_tile_padding(); + const auto &padding = shape.value.padding(); + os << "ttnn.Shape(["; + for (auto i = 0; i < shape.rank(); ++i) { + if (i > 0) { + os << ", "; + } + os << shape[i]; + if (padding[i].back > 0) { + os << "[" << shape_with_tile_padding[i] << "]"; + } + } + os << "])"; + return os; +} + +} // namespace types + +using types::Shape; + +SimpleShape get_physical_shape(const SimpleShape& logical_shape, DataType data_type, Layout layout, const std::optional& tile = std::nullopt); + +} // namespace ttnn + +namespace tt { +namespace tt_metal { + struct MultiDeviceHostStorage { DistributedTensorConfig strategy; std::vector buffers; - std::vector shapes; + std::vector shapes; mutable std::mutex mtx; friend void swap(MultiDeviceHostStorage &first, MultiDeviceHostStorage &second) { @@ -483,7 +634,7 @@ struct MultiDeviceHostStorage { MultiDeviceHostStorage() = default; MultiDeviceHostStorage( - DistributedTensorConfig strategy_, std::vector buffers_, std::vector shapes_) : + DistributedTensorConfig strategy_, std::vector buffers_, std::vector shapes_) : strategy(strategy_), buffers(buffers_), shapes(shapes_) {} MultiDeviceHostStorage(MultiDeviceHostStorage &&other) { swap(*this, other); } // unfotunately we need to have this code written manually. @@ -514,7 +665,7 @@ struct MultiDeviceHostStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. - void insert_buffer_and_shape_for_device(int buffer_index, const OwnedBuffer &buffer, const LegacyShape shape) { + void insert_buffer_and_shape_for_device(int buffer_index, const OwnedBuffer &buffer, const ttnn::Shape shape) { std::lock_guard lock(mtx); buffers[buffer_index] = buffer; shapes[buffer_index] = shape; @@ -532,7 +683,7 @@ struct MultiDeviceHostStorage { return buffers[buffer_index]; } - LegacyShape get_tensor_shape(int shape_index) const { + ttnn::Shape get_tensor_shape(int shape_index) const { std::lock_guard lock(mtx); TT_ASSERT(shape_index < shapes.size(), "Buffer not found for device {}", shape_index); return shapes[shape_index]; @@ -558,7 +709,7 @@ struct MultiDeviceStorage { DistributedTensorConfig strategy; std::vector ordered_device_ids; std::unordered_map buffers; - std::unordered_map shapes; + std::unordered_map shapes; mutable std::mutex buffer_mtx; mutable std::mutex shape_mtx; MultiDeviceStorage() = default; @@ -576,7 +727,7 @@ struct MultiDeviceStorage { DistributedTensorConfig strategy_, std::vector ordered_device_ids_, std::unordered_map buffers_, - std::unordered_map shapes_) : + std::unordered_map shapes_) : strategy(std::move(strategy_)), ordered_device_ids(std::move(ordered_device_ids_)), buffers(std::move(buffers_)), @@ -631,7 +782,7 @@ struct MultiDeviceStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. - inline void insert_buffer_and_shape_for_device(Device *device, const DeviceBuffer buffer, const LegacyShape shape) { + inline void insert_buffer_and_shape_for_device(Device *device, const DeviceBuffer buffer, const ttnn::Shape shape) { std::scoped_lock lock(buffer_mtx, shape_mtx); TT_ASSERT( device == buffer->device(), @@ -665,7 +816,7 @@ struct MultiDeviceStorage { return buffers.at(device_id); } - inline LegacyShape get_tensor_shape_for_device(Device *device) const { + inline ttnn::Shape get_tensor_shape_for_device(Device *device) const { std::lock_guard lock(shape_mtx); TT_ASSERT( shapes.find(device->id()) != shapes.end(), "Shape not found for device {}", device->id()); @@ -708,179 +859,3 @@ constexpr void raise_unsupported_storage() { } // namespace tt_metal } // namespace tt - -namespace ttnn { -namespace types { - -namespace detail { -template -static tt::tt_metal::LegacyShape compute_ttl_shape( - const std::array &shape, const std::array, Rank> &padding) { - auto ttl_shape = std::array{}; - for (auto index = 0; index < Rank; index++) { - ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; - } - return tt::tt_metal::LegacyShape{ttl_shape, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; -} - -} // namespace detail - - -struct Shape { - // ttnn::Shape is a wrapper around tt::tt_metal::LegacyShape - // It is used to flip the default value of operator[] to return the shape without padding - tt::tt_metal::LegacyShape value; - - Shape(const std::initializer_list dimensions) : value{dimensions} {} - - Shape(const tt::tt_metal::LegacyShape &shape) : value{shape} {} - - template - Shape(const std::array &shape) : value{shape} {} - - template - explicit Shape(const std::array &shape, const std::array &shape_with_tile_padding) : - value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} - - template - explicit Shape( - const std::array &shape, const std::array, Rank> &tile_padding) : - value{detail::compute_ttl_shape(shape, tile_padding)} {} - - Shape(const std::vector &shape) : value{tt::tt_metal::LegacyShape{shape}} {} - - explicit Shape(const std::vector &shape, const std::vector &shape_with_tile_padding) : - value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} - - explicit Shape(const std::vector &shape, const Padding &padding) : - value{tt::tt_metal::LegacyShape{shape, padding}} {} - - explicit Shape(const Shape &shape, const Padding &padding) : - value{tt::tt_metal::LegacyShape{shape.value, padding}} {} - - const auto rank() const { return this->value.rank(); } - - const size_t size() const { return this->rank(); } - - // Returns the padded shape, padding information is stripped - [[deprecated("Replaced by padded_shape()")]] - const tt::tt_metal::Padding &padding() const { return this->value.padding(); } - - const uint32_t get_normalized_index(std::int64_t index) const { return this->value.get_normalized_index(index); } - - Shape with_tile_padding() const { - return Shape{tt::tt_metal::LegacyShape{this->value, tt::tt_metal::Padding{this->value.rank()}}}; - } - - SimpleShape padded_shape() const { - std::vector values(rank()); - for (size_t i = 0; i < values.size(); i++) { - values[i] = this->value[i]; // value stored LegacyShape, its operator[] returns padded value - } - return SimpleShape(std::move(values)); - } - - // Returns the shape without padding, padding information is stripped - SimpleShape logical_shape() const { - std::vector values(this->rank()); - for (size_t i = 0; i < values.size(); i++) { - values[i] = this->operator[](i); // operator[] returns the shape without padding - } - return SimpleShape(std::move(values)); - } - - bool has_tile_padding() const { - auto rank = this->rank(); - for (auto index = 0; index < rank; index++) { - if (this->has_tile_padding(index)) { - return true; - } - } - return false; - } - - bool has_tile_padding(int dim) const { - return this->value.padding()[dim].front > 0 or this->value.padding()[dim].back > 0; - } - - bool operator==(const Shape &other) const { - const auto &shape_a = this->value; - const auto &shape_b = other.value; - // tt::tt_metal::LegacyShape comparison doesn't take padding into account - return (shape_a == shape_b and shape_a.without_padding() == shape_b.without_padding()); - } - - template - bool operator==(const std::array &other) const { - return Shape{this->value.without_padding()} == Shape{other}; - } - - bool operator!=(const Shape &other) const { return not(*this == other); } - - // Returns value without padding - uint32_t operator[](std::int64_t index) const; - - template - const Shape to_rank() const { - auto rank = this->rank(); - auto &shape = *this; - auto shape_with_tile_padding = shape.with_tile_padding(); - - std::array new_shape{}; - std::array new_padded_shape{}; - if (rank == NewRank) { - return Shape(shape); - } else if (rank > NewRank) { - auto num_extra_dims = rank - NewRank; - - for (auto index = 0; index < num_extra_dims; index++) { - TT_ASSERT(shape[index] == 1); - TT_ASSERT(shape_with_tile_padding[index] == 1); - } - - for (auto index = 0; index < NewRank; index++) { - new_shape[index] = shape[index + num_extra_dims]; - new_padded_shape[index] = shape_with_tile_padding[index + num_extra_dims]; - } - } else { - auto num_missing_dims = NewRank - rank; - - new_shape.fill(1); - new_padded_shape.fill(1); - - for (auto index = 0; index < rank; index++) { - new_shape[index + num_missing_dims] = shape[index]; - new_padded_shape[index + num_missing_dims] = shape_with_tile_padding[index]; - } - } - return Shape(new_shape, new_padded_shape); - } - - static constexpr auto attribute_names = std::forward_as_tuple("value"); - const auto attribute_values() const { return std::forward_as_tuple(this->value); } -}; - -static std::ostream &operator<<(std::ostream &os, const Shape &shape) { - const auto shape_with_tile_padding = shape.with_tile_padding(); - const auto &padding = shape.value.padding(); - os << "ttnn.Shape(["; - for (auto i = 0; i < shape.rank(); ++i) { - if (i > 0) { - os << ", "; - } - os << shape[i]; - if (padding[i].back > 0) { - os << "[" << shape_with_tile_padding[i] << "]"; - } - } - os << "])"; - return os; -} - -} // namespace types - -using types::Shape; - -SimpleShape get_physical_shape(const SimpleShape& logical_shape, DataType data_type, Layout layout, const std::optional& tile = std::nullopt); - -} // namespace ttnn