Skip to content

Commit

Permalink
#0: Add n-dim'l untilize implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt committed Nov 27, 2024
1 parent da0dd61 commit 9ed0c9c
Showing 1 changed file with 36 additions and 10 deletions.
46 changes: 36 additions & 10 deletions ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,53 @@
#include "device/tilize_op.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {
using OwnedTilizeArgs = std::tuple<ttnn::Tensor>;
using BaseTilizeType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedTilize = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedTilizeParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

MassagedTilize build_ndiml_tilize(BaseTilizeType base_tilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});
return MassagedTilize(MassagedTilizeParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = base_tilize});
}

ttnn::Tensor ExecuteTilize::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config,
std::optional<DataType> output_dtype,
bool use_multicore) {
return operation::run(
Tilize{
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)
.at(0);
auto base_tilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
Tilize{
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_tilize(base_tilize)(input_tensor);
}

ttnn::Tensor ExecuteTilize::invoke(
Expand Down

0 comments on commit 9ed0c9c

Please sign in to comment.