From 46e2f9b0be39eb380c567b135486109a70f32a43 Mon Sep 17 00:00:00 2001 From: Nguyen Truong Thanh Date: Tue, 22 Oct 2024 14:02:04 +0700 Subject: [PATCH] #13670: Implement full_like operation (#13671) --- .../unit_tests/operations/test_full_like.py | 122 ++++++++++++++++++ ttnn/CMakeLists.txt | 26 ++-- ttnn/cpp/pybind11/operations/__init__.hpp | 4 + .../device/full_like_device_operation.cpp | 76 +++++++++++ .../device/full_like_device_operation.hpp | 71 ++++++++++ .../full_like/device/full_like_factory.cpp | 118 +++++++++++++++++ .../ttnn/operations/full_like/full_like.cpp | 20 +++ .../ttnn/operations/full_like/full_like.hpp | 25 ++++ .../operations/full_like/full_like_pybind.cpp | 46 +++++++ .../operations/full_like/full_like_pybind.hpp | 13 ++ 10 files changed, 510 insertions(+), 11 deletions(-) create mode 100644 tests/ttnn/unit_tests/operations/test_full_like.py create mode 100644 ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/device/full_like_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/full_like.cpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/full_like.hpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/full_like_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/full_like/full_like_pybind.hpp diff --git a/tests/ttnn/unit_tests/operations/test_full_like.py b/tests/ttnn/unit_tests/operations/test_full_like.py new file mode 100644 index 00000000000..92aa32fd5e0 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_full_like.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import copy +import torch +import torch.nn as nn +import ttnn +from models.utility_functions import comp_allclose +from loguru import logger + +from tests.ttnn.utils_for_testing import assert_equal + + +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], # single tile + [5, 96, 64], # multiple tiles + ], +) +@pytest.mark.parametrize( + "fill_value", + [3, -1], +) +@pytest.mark.parametrize( + "layout", + [ + ttnn.TILE_LAYOUT, # Currently only support tile layout + ], +) +def test_full_like_int(device, input_shape, fill_value, layout): + torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32) + torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) + input_tensor = ttnn.to_device(input_tensor, device) + output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) + assert ttnn.is_tensor_storage_on_device(output_tensor) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_equal(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], # single tile + [5, 96, 64], # multiple tiles + [3, 91, 67, 77], # not multiple of 32 + ], +) +@pytest.mark.parametrize( + "fill_value", + [0.15, -1.2], +) +@pytest.mark.parametrize( + "dtype", + [ + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "layout", + [ + ttnn.TILE_LAYOUT, # Currently only support tile layout + ], +) +def test_full_like_float(device, input_shape, fill_value, dtype, layout): + torch_input_tensor = torch.rand((input_shape), dtype=dtype) + torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) + input_tensor = ttnn.to_device(input_tensor, device) + output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) + assert ttnn.is_tensor_storage_on_device(output_tensor) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_equal(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], # single tile + ], +) +@pytest.mark.parametrize( + "fill_value", + [3], +) +@pytest.mark.parametrize( + "layout", + [ + ttnn.TILE_LAYOUT, # Currently only support tile layout + ], +) +def test_full_like_callback(device, input_shape, fill_value, layout, use_program_cache): + for i in range(2): + torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32) + torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) + input_tensor = ttnn.to_device(input_tensor, device) + output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) + assert ttnn.is_tensor_storage_on_device(output_tensor) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + if i == 0: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + else: + assert device.num_program_cache_entries() == num_program_cache_entries + torch_dummy = torch.randn([32, 32]) + tt_dummy = ttnn.from_torch(torch_dummy, device=device) + + assert_equal(torch_output_tensor, output_tensor) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 754049b6dfb..8c301e0b035 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -256,6 +256,10 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/device/full_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/full_like.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/full_like_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/device/full_like_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -372,8 +376,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp @@ -430,12 +434,12 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear/moreh_linear_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear/moreh_linear.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward_pybind.cpp @@ -475,14 +479,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -492,12 +488,20 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 30795ff56c4..361a6c71867 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -27,6 +27,7 @@ #include "ttnn/operations/embedding/embedding_pybind.hpp" #include "ttnn/operations/embedding_backward/embedding_backward_pybind.hpp" #include "ttnn/operations/examples/examples_pybind.hpp" +#include "ttnn/operations/full_like/full_like_pybind.hpp" #include "ttnn/operations/experimental/experimental_pybind.hpp" #include "ttnn/operations/full/full_pybind.hpp" #include "ttnn/operations/kv_cache/kv_cache_pybind.hpp" @@ -140,6 +141,9 @@ void py_module(py::module& module) { auto m_moreh = module.def_submodule("moreh", "moreh operations"); moreh::bind_moreh_operations(m_moreh); + + auto m_full_like = module.def_submodule("full_like", "full_like operation"); + full_like::bind_full_like_operation(m_full_like); } } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp new file mode 100644 index 00000000000..7ae59a16637 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "full_like_device_operation.hpp" + +#include + +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::full_like { + +FullLikeOperation::program_factory_t FullLikeOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return ProgramFactory{}; +} + +void FullLikeOperation::validate(const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input = tensor_args.input; + if (operation_attributes.dtype != input.get_dtype()) + TT_FATAL( + input.get_layout() == Layout::TILE, "Full Like: Data type conversion is only supported with tile layout"); + TT_FATAL(input.storage_type() == StorageType::DEVICE, "Full Like: Input must be on device"); + TT_FATAL(input.buffer() != nullptr, "Full Like: Input must be allocated in buffer on device"); + TT_FATAL( + input.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "Full Like: Not currently supporting sharding"); + TT_FATAL( + operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, + "Full Like: Not currently supporting sharding"); + TT_FATAL(operation_attributes.layout == Layout::TILE, "Full Like: Not currently supporting row major layout"); +} + +void FullLikeOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate(operation_attributes, tensor_args); +} + +void FullLikeOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate(operation_attributes, tensor_args); +} + +FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return tensor_args.input.get_logical_shape(); +} + +FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); + const auto& input = tensor_args.input; + return create_device_tensor( + output_shape, + operation_attributes.dtype, + operation_attributes.layout, + input.device(), + operation_attributes.memory_config); +} + +std::tuple FullLikeOperation::invoke( + const Tensor& input, + const std::variant fill_value, + const std::optional& dtype, + const std::optional& layout, + const std::optional& memory_config) { + return { + operation_attributes_t{ + fill_value, + dtype.value_or(input.tensor_attributes->dtype), + layout.value_or(input.tensor_attributes->layout), + memory_config.value_or(input.memory_config())}, + tensor_args_t{input}}; +} + +} // namespace ttnn::operations::full_like diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp new file mode 100644 index 00000000000..c4c59a724d1 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" + +namespace ttnn::operations::full_like { + +struct FullLikeOperation { + struct operation_attributes_t { + const std::variant fill_value; + const DataType dtype; + const Layout layout; + const MemoryConfig memory_config; + }; + + struct tensor_args_t { + const Tensor& input; + }; + + using shape_return_value_t = SimpleShape; + using tensor_return_value_t = Tensor; + + struct ProgramFactory { + struct shared_variables_t { + KernelHandle writer_kernel_id; + std::size_t num_cores; + std::size_t num_cores_y; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + }; + + using program_factory_t = std::variant; + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static void validate(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + static std::tuple invoke( + const Tensor& input, + const std::variant fill_value, + const std::optional& dtype, + const std::optional& layout, + const std::optional& memory_config); +}; + +} // namespace ttnn::operations::full_like + +namespace ttnn::prim { +constexpr auto moreh_full_like = + ttnn::register_operation<"ttnn::prim::moreh_full_like", ttnn::operations::full_like::FullLikeOperation>(); +} // namespace ttnn::prim diff --git a/ttnn/cpp/ttnn/operations/full_like/device/full_like_factory.cpp b/ttnn/cpp/ttnn/operations/full_like/device/full_like_factory.cpp new file mode 100644 index 00000000000..4ae5bbbf264 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/device/full_like_factory.cpp @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "common/constants.hpp" +#include "full_like_device_operation.hpp" +#include "host_api.hpp" +#include "impl/buffers/circular_buffer_types.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/tensor/types.hpp" + +namespace ttnn::operations::full_like { + +using namespace tt; +using namespace tt::tt_metal; +using namespace tt::constants; +using namespace tt::tt_metal::detail; + +union datatype { + uint32_t u32; + float f32; +} u; + +FullLikeOperation::ProgramFactory::cached_program_t FullLikeOperation::ProgramFactory::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto input = tensor_args.input; + auto fill_value = operation_attributes.fill_value; + if (std::holds_alternative(fill_value)) { + u.u32 = std::get(fill_value); + } else if (std::holds_alternative(fill_value)) { + u.f32 = std::get(fill_value); + } + DataType dtype{operation_attributes.dtype}; + Layout layout{operation_attributes.layout}; + Device* device = input.device(); + MemoryConfig memory_config{operation_attributes.memory_config}; + + auto num_tiles = input.volume() / TILE_HW; + + Program program{}; + + auto data_format = datatype_to_dataformat_converter(dtype); + uint32_t single_tile_size = TileSize(data_format); + + const auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + const uint32_t num_cores_x = compute_with_storage_grid_size.x; + const uint32_t num_cores_y = compute_with_storage_grid_size.y; + + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles); + + constexpr CB cb_fill_value_id = CB::c_intermed0; + + CircularBufferConfig cb_value_config = CircularBufferConfig(single_tile_size, {{cb_fill_value_id, data_format}}) + .set_page_size(cb_fill_value_id, single_tile_size); + auto cb_fill_value = CreateCircularBuffer(program, all_cores, cb_value_config); + std::map writer_defines; + + switch (dtype) { + case DataType::BFLOAT16: writer_defines["OUTPUT_DTYPE_BFLOAT16"] = "1"; break; + case DataType::INT32: writer_defines["OUTPUT_DTYPE_INT32"] = "1"; break; + case DataType::FLOAT32: writer_defines["OUTPUT_DTYPE_FLOAT32"] = "1"; break; + default: break; + } + + std::vector writer_compile_time_args = {(uint32_t) cb_fill_value_id}; + + auto writer_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/full/device/kernels/writer_full.cpp", + all_cores, + WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + + uint32_t tiles_offset = 0; + for (uint32_t i = 0; i < num_cores; i++) { + const CoreCoord core(i / num_cores_y, i % num_cores_y); + + uint32_t num_tiles_per_core = 0; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_tiles_per_core_group_2; + } else { + TT_ASSERT(false, "Core not in specified core ranges"); + } + SetRuntimeArgs(program, writer_id, core, {u.u32, output.buffer()->address(), num_tiles_per_core, tiles_offset}); + + tiles_offset += num_tiles_per_core; + } + + return {std::move(program), {writer_id, num_cores, num_cores_y}}; +} + +void FullLikeOperation::ProgramFactory::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto& program = cached_program.program; + auto& writer_kernel_id = cached_program.shared_variables.writer_kernel_id; + auto& num_cores = cached_program.shared_variables.num_cores; + auto& num_cores_y = cached_program.shared_variables.num_cores_y; + + auto output_buffer_address = output.buffer()->address(); + for (uint32_t i = 0; i < num_cores; i++) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[1] = output_buffer_address; + } + } +} + +} // namespace ttnn::operations::full_like diff --git a/ttnn/cpp/ttnn/operations/full_like/full_like.cpp b/ttnn/cpp/ttnn/operations/full_like/full_like.cpp new file mode 100644 index 00000000000..03b5595ac04 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/full_like.cpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "full_like.hpp" + +#include "ttnn/operations/full_like/device/full_like_device_operation.hpp" + +namespace ttnn::operations::full_like { + +Tensor FullLike::invoke( + const Tensor &input, + const std::variant fill_value, + const std::optional &dtype, + const std::optional &layout, + const std::optional &memory_config) { + return ttnn::prim::moreh_full_like(input, fill_value, dtype, layout, memory_config); +} + +} // namespace ttnn::operations::full_like diff --git a/ttnn/cpp/ttnn/operations/full_like/full_like.hpp b/ttnn/cpp/ttnn/operations/full_like/full_like.hpp new file mode 100644 index 00000000000..29167c52f5a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/full_like.hpp @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "ttnn/decorators.hpp" + +namespace ttnn::operations::full_like { + +struct FullLike { + static Tensor invoke( + const Tensor &input, + const std::variant fill_value, + const std::optional &dtype, + const std::optional &layout, + const std::optional &memory_config); +}; +} // namespace ttnn::operations::full_like + +namespace ttnn { +constexpr auto moreh_full_like = + ttnn::register_operation_with_auto_launch_op<"ttnn::moreh_full_like", ttnn::operations::full_like::FullLike>(); +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.cpp b/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.cpp new file mode 100644 index 00000000000..1d2fcb1346d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "full_like_pybind.hpp" + +#include +#include + +#include "full_like.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/full_like/device/full_like_device_operation.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::full_like { + +void bind_full_like_operation(py::module& module) { + auto doc = + R"doc(full_like(tensor: Tensor, fill_value: float or value, dtype: DataType, layout: Layout, memory_config: MemoryConfig) -> Tensor + + Create a tensor with the same shape of the given tensor and filled with given fill_value, with the specified `memory_config` and converting its data type to `dtype`. + This operation only supports TILE_LAYOUT for now. + + Args: + * :attr:`input`: The tensor has shape which will be based on to make the output tensor + * :attr:`fill_value`: The value which will be used to fill the output tensor + * :attr:`dtype`: The target data type of the output tensor. + * :attr:`layout`: The target layout of the output tensor. + * :attr:`memory_config`: The memory configuration for the output tensor. + )doc"; + + bind_registered_operation( + module, + ttnn::moreh_full_like, + doc, + ttnn::pybind_arguments_t{ + py::arg("input"), + py::arg("fill_value"), + py::kw_only(), + py::arg("dtype") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace ttnn::operations::full_like diff --git a/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.hpp b/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.hpp new file mode 100644 index 00000000000..805755d7b06 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/full_like/full_like_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::full_like { +void bind_full_like_operation(py::module& module); +} // namespace ttnn::operations::full_like