Skip to content

Commit

Permalink
#13670: Implement full_like operation (#13671)
Browse files Browse the repository at this point in the history
  • Loading branch information
thanhnguyen-moreh authored Oct 22, 2024
1 parent 219678c commit 46e2f9b
Show file tree
Hide file tree
Showing 10 changed files with 510 additions and 11 deletions.
122 changes: 122 additions & 0 deletions tests/ttnn/unit_tests/operations/test_full_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import copy
import torch
import torch.nn as nn
import ttnn
from models.utility_functions import comp_allclose
from loguru import logger

from tests.ttnn.utils_for_testing import assert_equal


@pytest.mark.parametrize(
"input_shape",
[
[32, 32], # single tile
[5, 96, 64], # multiple tiles
],
)
@pytest.mark.parametrize(
"fill_value",
[3, -1],
)
@pytest.mark.parametrize(
"layout",
[
ttnn.TILE_LAYOUT, # Currently only support tile layout
],
)
def test_full_like_int(device, input_shape, fill_value, layout):
torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32)
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32], # single tile
[5, 96, 64], # multiple tiles
[3, 91, 67, 77], # not multiple of 32
],
)
@pytest.mark.parametrize(
"fill_value",
[0.15, -1.2],
)
@pytest.mark.parametrize(
"dtype",
[
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"layout",
[
ttnn.TILE_LAYOUT, # Currently only support tile layout
],
)
def test_full_like_float(device, input_shape, fill_value, dtype, layout):
torch_input_tensor = torch.rand((input_shape), dtype=dtype)
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_equal(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32], # single tile
],
)
@pytest.mark.parametrize(
"fill_value",
[3],
)
@pytest.mark.parametrize(
"layout",
[
ttnn.TILE_LAYOUT, # Currently only support tile layout
],
)
def test_full_like_callback(device, input_shape, fill_value, layout, use_program_cache):
for i in range(2):
torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32)
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
if i == 0:
num_program_cache_entries = device.num_program_cache_entries()
assert num_program_cache_entries > 0
else:
assert device.num_program_cache_entries() == num_program_cache_entries
torch_dummy = torch.randn([32, 32])
tt_dummy = ttnn.from_torch(torch_dummy, device=device)

assert_equal(torch_output_tensor, output_tensor)
26 changes: 15 additions & 11 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/device/full_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full/full.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/full_like.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/full_like_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/device/full_like_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/loss/loss_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Expand Down Expand Up @@ -372,8 +376,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/device/moreh_adamw_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adamw/moreh_adamw.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp
Expand Down Expand Up @@ -430,12 +434,12 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear/moreh_linear_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_linear/moreh_linear.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward_pybind.cpp
Expand Down Expand Up @@ -475,14 +479,6 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sgd/moreh_sgd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp
Expand All @@ -492,12 +488,20 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax_backward/moreh_softmax_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_softmax/moreh_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "ttnn/operations/embedding/embedding_pybind.hpp"
#include "ttnn/operations/embedding_backward/embedding_backward_pybind.hpp"
#include "ttnn/operations/examples/examples_pybind.hpp"
#include "ttnn/operations/full_like/full_like_pybind.hpp"
#include "ttnn/operations/experimental/experimental_pybind.hpp"
#include "ttnn/operations/full/full_pybind.hpp"
#include "ttnn/operations/kv_cache/kv_cache_pybind.hpp"
Expand Down Expand Up @@ -140,6 +141,9 @@ void py_module(py::module& module) {

auto m_moreh = module.def_submodule("moreh", "moreh operations");
moreh::bind_moreh_operations(m_moreh);

auto m_full_like = module.def_submodule("full_like", "full_like operation");
full_like::bind_full_like_operation(m_full_like);
}
} // namespace operations

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "full_like_device_operation.hpp"

#include <optional>

#include "ttnn/tensor/tensor.hpp"

namespace ttnn::operations::full_like {

FullLikeOperation::program_factory_t FullLikeOperation::select_program_factory(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return ProgramFactory{};
}

void FullLikeOperation::validate(const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& input = tensor_args.input;
if (operation_attributes.dtype != input.get_dtype())
TT_FATAL(
input.get_layout() == Layout::TILE, "Full Like: Data type conversion is only supported with tile layout");
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Full Like: Input must be on device");
TT_FATAL(input.buffer() != nullptr, "Full Like: Input must be allocated in buffer on device");
TT_FATAL(
input.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"Full Like: Not currently supporting sharding");
TT_FATAL(
operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED,
"Full Like: Not currently supporting sharding");
TT_FATAL(operation_attributes.layout == Layout::TILE, "Full Like: Not currently supporting row major layout");
}

void FullLikeOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate(operation_attributes, tensor_args);
}

void FullLikeOperation::validate_on_program_cache_hit(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
validate(operation_attributes, tensor_args);
}

FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
}

FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input = tensor_args.input;
return create_device_tensor(
output_shape,
operation_attributes.dtype,
operation_attributes.layout,
input.device(),
operation_attributes.memory_config);
}

std::tuple<FullLikeOperation::operation_attributes_t, FullLikeOperation::tensor_args_t> FullLikeOperation::invoke(
const Tensor& input,
const std::variant<float, int> fill_value,
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<MemoryConfig>& memory_config) {
return {
operation_attributes_t{
fill_value,
dtype.value_or(input.tensor_attributes->dtype),
layout.value_or(input.tensor_attributes->layout),
memory_config.value_or(input.memory_config())},
tensor_args_t{input}};
}

} // namespace ttnn::operations::full_like
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <optional>
#include <variant>

#include "ttnn/decorators.hpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/types.hpp"

namespace ttnn::operations::full_like {

struct FullLikeOperation {
struct operation_attributes_t {
const std::variant<float, int> fill_value;
const DataType dtype;
const Layout layout;
const MemoryConfig memory_config;
};

struct tensor_args_t {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
struct shared_variables_t {
KernelHandle writer_kernel_id;
std::size_t num_cores;
std::size_t num_cores_y;
};
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;

static cached_program_t create(
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& output);

static void override_runtime_arguments(
cached_program_t& cached_program,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& output);
};

using program_factory_t = std::variant<ProgramFactory>;
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& input,
const std::variant<float, int> fill_value,
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<MemoryConfig>& memory_config);
};

} // namespace ttnn::operations::full_like

namespace ttnn::prim {
constexpr auto moreh_full_like =
ttnn::register_operation<"ttnn::prim::moreh_full_like", ttnn::operations::full_like::FullLikeOperation>();
} // namespace ttnn::prim
Loading

0 comments on commit 46e2f9b

Please sign in to comment.