From 9ed0c9c7d24af7c7bb8a34bbd80ccd353fc138b4 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Wed, 27 Nov 2024 23:04:39 +0000 Subject: [PATCH] #0: Add n-dim'l untilize implementation. --- .../data_movement/tilize/tilize.cpp | 46 +++++++++++++++---- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp index a20a2711878..36abb379028 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp @@ -7,10 +7,33 @@ #include "device/tilize_op.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/run_operation.hpp" +#include "ttnn/operations/data_movement/common/common.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" using namespace tt::tt_metal; namespace ttnn::operations::data_movement { +using OwnedTilizeArgs = std::tuple; +using BaseTilizeType = std::function; + +using MassagedTilize = MassagedOperation; +using MassagedTilizeParams = MassagedOperationParams; + +MassagedTilize build_ndiml_tilize(BaseTilizeType base_tilize) { + auto original_shape = std::make_shared(ttnn::Shape{}); + return MassagedTilize(MassagedTilizeParams{ + .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, + .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeArgs { + *original_shape = input_tensor.get_shape(); + ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + return std::make_tuple(squeezed_tensor); + }, + .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { + auto unsqueezed_tensor = ttnn::reshape(output, *original_shape); + return unsqueezed_tensor; + }, + .operation = base_tilize}); +} ttnn::Tensor ExecuteTilize::invoke( uint8_t queue_id, @@ -18,16 +41,19 @@ ttnn::Tensor ExecuteTilize::invoke( const std::optional& memory_config, std::optional output_dtype, bool use_multicore) { - return operation::run( - Tilize{ - memory_config.value_or(input_tensor.memory_config()), - output_dtype.value_or(input_tensor.get_dtype()), - use_multicore}, - {input_tensor}, - {}, - {}, - queue_id) - .at(0); + auto base_tilize = [=](const ttnn::Tensor& input_tensor) { + return operation::run( + Tilize{ + memory_config.value_or(input_tensor.memory_config()), + output_dtype.value_or(input_tensor.get_dtype()), + use_multicore}, + {input_tensor}, + {}, + {}, + queue_id)[0]; + }; + + return build_ndiml_tilize(base_tilize)(input_tensor); } ttnn::Tensor ExecuteTilize::invoke(