-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
219678c
commit 46e2f9b
Showing
10 changed files
with
510 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import copy | ||
import torch | ||
import torch.nn as nn | ||
import ttnn | ||
from models.utility_functions import comp_allclose | ||
from loguru import logger | ||
|
||
from tests.ttnn.utils_for_testing import assert_equal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[32, 32], # single tile | ||
[5, 96, 64], # multiple tiles | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[3, -1], | ||
) | ||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.TILE_LAYOUT, # Currently only support tile layout | ||
], | ||
) | ||
def test_full_like_int(device, input_shape, fill_value, layout): | ||
torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32) | ||
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) | ||
|
||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) | ||
input_tensor = ttnn.to_device(input_tensor, device) | ||
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.from_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
assert_equal(torch_output_tensor, output_tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[32, 32], # single tile | ||
[5, 96, 64], # multiple tiles | ||
[3, 91, 67, 77], # not multiple of 32 | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[0.15, -1.2], | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.bfloat16, | ||
torch.float32, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.TILE_LAYOUT, # Currently only support tile layout | ||
], | ||
) | ||
def test_full_like_float(device, input_shape, fill_value, dtype, layout): | ||
torch_input_tensor = torch.rand((input_shape), dtype=dtype) | ||
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) | ||
|
||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) | ||
input_tensor = ttnn.to_device(input_tensor, device) | ||
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.from_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
assert_equal(torch_output_tensor, output_tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[ | ||
[32, 32], # single tile | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"fill_value", | ||
[3], | ||
) | ||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.TILE_LAYOUT, # Currently only support tile layout | ||
], | ||
) | ||
def test_full_like_callback(device, input_shape, fill_value, layout, use_program_cache): | ||
for i in range(2): | ||
torch_input_tensor = torch.randint(0, 100, (input_shape), dtype=torch.int32) | ||
torch_output_tensor = torch.full_like(torch_input_tensor, fill_value) | ||
|
||
input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) | ||
input_tensor = ttnn.to_device(input_tensor, device) | ||
output_tensor = ttnn.moreh_full_like(input_tensor, fill_value) | ||
assert ttnn.is_tensor_storage_on_device(output_tensor) | ||
output_tensor = ttnn.from_device(output_tensor) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
if i == 0: | ||
num_program_cache_entries = device.num_program_cache_entries() | ||
assert num_program_cache_entries > 0 | ||
else: | ||
assert device.num_program_cache_entries() == num_program_cache_entries | ||
torch_dummy = torch.randn([32, 32]) | ||
tt_dummy = ttnn.from_torch(torch_dummy, device=device) | ||
|
||
assert_equal(torch_output_tensor, output_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "full_like_device_operation.hpp" | ||
|
||
#include <optional> | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
|
||
namespace ttnn::operations::full_like { | ||
|
||
FullLikeOperation::program_factory_t FullLikeOperation::select_program_factory( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
return ProgramFactory{}; | ||
} | ||
|
||
void FullLikeOperation::validate(const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
const auto& input = tensor_args.input; | ||
if (operation_attributes.dtype != input.get_dtype()) | ||
TT_FATAL( | ||
input.get_layout() == Layout::TILE, "Full Like: Data type conversion is only supported with tile layout"); | ||
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Full Like: Input must be on device"); | ||
TT_FATAL(input.buffer() != nullptr, "Full Like: Input must be allocated in buffer on device"); | ||
TT_FATAL( | ||
input.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, | ||
"Full Like: Not currently supporting sharding"); | ||
TT_FATAL( | ||
operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, | ||
"Full Like: Not currently supporting sharding"); | ||
TT_FATAL(operation_attributes.layout == Layout::TILE, "Full Like: Not currently supporting row major layout"); | ||
} | ||
|
||
void FullLikeOperation::validate_on_program_cache_miss( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
validate(operation_attributes, tensor_args); | ||
} | ||
|
||
void FullLikeOperation::validate_on_program_cache_hit( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
validate(operation_attributes, tensor_args); | ||
} | ||
|
||
FullLikeOperation::shape_return_value_t FullLikeOperation::compute_output_shapes( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
return tensor_args.input.get_logical_shape(); | ||
} | ||
|
||
FullLikeOperation::tensor_return_value_t FullLikeOperation::create_output_tensors( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
const auto output_shape = compute_output_shapes(operation_attributes, tensor_args); | ||
const auto& input = tensor_args.input; | ||
return create_device_tensor( | ||
output_shape, | ||
operation_attributes.dtype, | ||
operation_attributes.layout, | ||
input.device(), | ||
operation_attributes.memory_config); | ||
} | ||
|
||
std::tuple<FullLikeOperation::operation_attributes_t, FullLikeOperation::tensor_args_t> FullLikeOperation::invoke( | ||
const Tensor& input, | ||
const std::variant<float, int> fill_value, | ||
const std::optional<DataType>& dtype, | ||
const std::optional<Layout>& layout, | ||
const std::optional<MemoryConfig>& memory_config) { | ||
return { | ||
operation_attributes_t{ | ||
fill_value, | ||
dtype.value_or(input.tensor_attributes->dtype), | ||
layout.value_or(input.tensor_attributes->layout), | ||
memory_config.value_or(input.memory_config())}, | ||
tensor_args_t{input}}; | ||
} | ||
|
||
} // namespace ttnn::operations::full_like |
71 changes: 71 additions & 0 deletions
71
ttnn/cpp/ttnn/operations/full_like/device/full_like_device_operation.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <optional> | ||
#include <variant> | ||
|
||
#include "ttnn/decorators.hpp" | ||
#include "ttnn/device_operation.hpp" | ||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace ttnn::operations::full_like { | ||
|
||
struct FullLikeOperation { | ||
struct operation_attributes_t { | ||
const std::variant<float, int> fill_value; | ||
const DataType dtype; | ||
const Layout layout; | ||
const MemoryConfig memory_config; | ||
}; | ||
|
||
struct tensor_args_t { | ||
const Tensor& input; | ||
}; | ||
|
||
using shape_return_value_t = SimpleShape; | ||
using tensor_return_value_t = Tensor; | ||
|
||
struct ProgramFactory { | ||
struct shared_variables_t { | ||
KernelHandle writer_kernel_id; | ||
std::size_t num_cores; | ||
std::size_t num_cores_y; | ||
}; | ||
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>; | ||
|
||
static cached_program_t create( | ||
const operation_attributes_t& operation_attributes, | ||
const tensor_args_t& tensor_args, | ||
tensor_return_value_t& output); | ||
|
||
static void override_runtime_arguments( | ||
cached_program_t& cached_program, | ||
const operation_attributes_t& operation_attributes, | ||
const tensor_args_t& tensor_args, | ||
tensor_return_value_t& output); | ||
}; | ||
|
||
using program_factory_t = std::variant<ProgramFactory>; | ||
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); | ||
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); | ||
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); | ||
static void validate(const operation_attributes_t&, const tensor_args_t&); | ||
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); | ||
|
||
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); | ||
static std::tuple<operation_attributes_t, tensor_args_t> invoke( | ||
const Tensor& input, | ||
const std::variant<float, int> fill_value, | ||
const std::optional<DataType>& dtype, | ||
const std::optional<Layout>& layout, | ||
const std::optional<MemoryConfig>& memory_config); | ||
}; | ||
|
||
} // namespace ttnn::operations::full_like | ||
|
||
namespace ttnn::prim { | ||
constexpr auto moreh_full_like = | ||
ttnn::register_operation<"ttnn::prim::moreh_full_like", ttnn::operations::full_like::FullLikeOperation>(); | ||
} // namespace ttnn::prim |
Oops, something went wrong.