Skip to content

Commit

Permalink
#15483: initial setup - binary sfpu
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Nov 28, 2024
1 parent 18de632 commit 6a3008b
Show file tree
Hide file tree
Showing 15 changed files with 928 additions and 60 deletions.
126 changes: 126 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

import pytest
from models.utility_functions import skip_for_grayskull
from tests.ttnn.utils_for_testing import assert_with_pcc


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.sub,
],
)
def test_sub_fp32(device, ttnn_function):
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
x_torch = torch.tensor([[1]], dtype=torch.float32)
y_torch = torch.tensor([[0.00030171126]], dtype=torch.float32)
z_torch = x_torch - y_torch
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn.subtract(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_sub)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.add,
],
)
def test_add_fp32(device, ttnn_function):
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
x_torch = torch.tensor([[1]], dtype=torch.float32)
y_torch = torch.tensor([[0.00030171126]], dtype=torch.float32)
z_torch = x_torch + y_torch
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_add = ttnn.add(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_add)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.mul,
],
)
def test_mul_fp32(device, ttnn_function):
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
x_torch = torch.tensor([[2]], dtype=torch.float32)
y_torch = torch.tensor([[0.00030171126]], dtype=torch.float32)
z_torch = x_torch * y_torch
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_mul = ttnn.mul(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_mul)

status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status

# currently failing as div sfpu tile is performing multiplication


# @skip_for_grayskull("Unsupported dtype for Grayskull")
# @pytest.mark.parametrize(
# "ttnn_function",
# [
# ttnn.div,
# ],
# )
# def test_div_fp32(device, ttnn_function):
# torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
# x_torch = torch.tensor([[1.00030171126]], dtype=torch.float32)
# y_torch = torch.tensor([[2]], dtype=torch.float32)
# z_torch = x_torch / y_torch
# x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
# z_tt_div = ttnn.divide(x_tt, y_tt)
# tt_out = ttnn.to_torch(z_tt_div)
# print("inputs a, b", x_torch, y_torch)
# print(z_torch, ttnn.to_torch(z_tt), tt_out)
# # print("torch out", z_torch, )
# print("torch out in ttnn", ttnn.to_torch(z_tt))
# print("tt out in torch", tt_out)
# status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
# assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.pow,
],
)
def test_pow_fp32(device, ttnn_function):
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
x_torch = torch.tensor([[1.55, 2.25]], dtype=torch.float32)
y_torch = torch.tensor([[2, 3]], dtype=torch.float32)
z_torch = torch.pow(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_pow = ttnn.pow(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_pow)

status = ttnn.ttnn.pearson_correlation_coefficient(z_torch, tt_out) >= 0.999
assert status
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_sfpu_pgm_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_optimized_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp
Expand Down
54 changes: 54 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,57 @@ Tensor InplaceBinaryOperation<binary_op_type>::invoke(
input_tensor_a, scalar, std::nullopt, std::nullopt, input_tensor_a, activations, input_tensor_a_activation);
}

template <BinaryOpType binary_op_type>
Tensor BinaryOperationSfpu<binary_op_type>::invoke(
uint8_t queue_id,
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor,
const std::optional<unary::FusedActivations>& activations,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation) {
auto [input_tensor_a, input_tensor_b] =
detail::preprocess_inputs<binary_op_type>(input_tensor_a_arg, input_tensor_b_arg);

auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config());
DataType dtype = output_dtype.value_or(input_tensor_a.get_dtype());
if (optional_output_tensor.has_value()) {
dtype = optional_output_tensor.value().get_dtype();
}

return ttnn::prim::binary(
queue_id,
input_tensor_a,
input_tensor_b,
binary_op_type,
output_dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation);
}

template <BinaryOpType binary_op_type>
Tensor BinaryOperationSfpu<binary_op_type>::invoke(
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return invoke(
DefaultQueueId,
input_tensor_a_arg,
input_tensor_b_arg,
output_dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation);
}

template struct BinaryOperation<BinaryOpType::ADD>;
template struct InplaceBinaryOperation<BinaryOpType::ADD>;
template struct BinaryOperation<BinaryOpType::SUB>;
Expand Down Expand Up @@ -422,4 +473,7 @@ template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_AND>;
template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_OR>;
template struct InplaceLogicalBinary<BinaryOpType::LOGICAL_XOR>;

template struct BinaryOperationSfpu<BinaryOpType::RSUB>;
template struct BinaryOperationSfpu<BinaryOpType::POWER>;

} // namespace ttnn::operations::binary
29 changes: 29 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ struct InplaceBinaryOperation {
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);
};

template <BinaryOpType binary_op_type>
struct BinaryOperationSfpu {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt,
const std::optional<unary::FusedActivations>& activations = std::nullopt,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor_a_arg,
const Tensor& input_tensor_b_arg,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);
};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -227,6 +249,13 @@ constexpr auto ne_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::ne_",
operations::binary::InplaceRelationalBinary<operations::binary::BinaryOpType::NE>>();

constexpr auto rsub_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::rsub_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::RSUB>>();
constexpr auto power_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::power_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::POWER>>();

template <typename InputBType>
ttnn::Tensor operator+(const ttnn::Tensor& input_tensor_a, InputBType scalar) {
return add(input_tensor_a, scalar);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ enum class BinaryOpType {
LOGICAL_XOR,
LDEXP,
LOGADDEXP2,
DIV_FAST
DIV_FAST,
RSUB,
POWER
};
}
26 changes: 26 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,30 @@ std::map<std::string, std::string> get_defines(
return defines;
}

std::map<std::string, std::string> get_defines_fp32(
BinaryOpType op_type,
const std::optional<tt::tt_metal::DataType> input_dtype,
const std::optional<tt::tt_metal::DataType> output_dtype,
const std::optional<std::vector<UnaryWithParam>> fused_activations,
const std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
std::map<std::string, std::string> new_defines;
std::string op_name = "sub_binary_tile";
std::string idst1 = "i*2";
std::string idst2 = "i*2+1";

switch (op_type) {
case BinaryOpType::ADD: op_name = "add_binary_tile"; break;
case BinaryOpType::SUB: op_name = "sub_binary_tile"; break;
case BinaryOpType::MUL: op_name = "mul_binary_tile"; break;
case BinaryOpType::DIV_FAST: op_name = "div_binary_tile"; break;
case BinaryOpType::RSUB: op_name = "rsub_binary_tile"; break;
case BinaryOpType::POWER: op_name = "power_binary_tile"; break;
default: TT_ASSERT(false && "Undefined op type");
}

new_defines.insert({"BINARY_SFPU_OP", fmt::format("{}({}, {});", op_name, idst1, idst2)});

return new_defines;
}

} // namespace ttnn::operations::binary::utils
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ std::map<std::string, std::string> get_defines(
const std::optional<ttnn::operations::unary::FusedActivations>& fused_activations = std::nullopt,
const std::optional<ttnn::operations::unary::UnaryWithParam>& input_tensor_a_activation = std::nullopt);

std::map<std::string, std::string> get_defines_fp32(
BinaryOpType op_type,
const std::optional<tt::tt_metal::DataType> in_dtype = std::nullopt,
const std::optional<tt::tt_metal::DataType> out_dtype = std::nullopt,
const std::optional<ttnn::operations::unary::FusedActivations> fused_activations = std::nullopt,
const std::optional<ttnn::operations::unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f
auto width_b = input_shape_b[-1];

if (height_a == height_b and width_a == width_b) {
return ElementWiseMultiCore{};
if(tensor_args.input_tensor_a.get_dtype() == DataType::FLOAT32 && tensor_args.input_tensor_b->get_dtype() == DataType::FLOAT32){
return ElementWiseMultiCoreSfpu{};
} else {
return ElementWiseMultiCore{};
}
}
if (height_b == 1 or width_b == 1) {
if (height_b == 1 and width_b == 1) {
Expand Down Expand Up @@ -192,7 +196,7 @@ BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output
auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b);

auto program_factory = select_program_factory(operation_attributes, tensor_args);
if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) {
if (std::holds_alternative<ElementWiseMultiCore>(program_factory) or std::holds_alternative<ElementWiseMultiCoreSfpu>(program_factory)) {
const auto& input_tensor_b = *tensor_args.input_tensor_b;
if (operation_attributes.memory_config.is_sharded()) {
ShardSpec shard_spec{CoreRangeSet(), {0, 0}};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ struct BinaryDeviceOperation {
tensor_return_value_t& tensor_return_value);
};

struct ElementWiseMultiCoreSfpu {
struct shared_variables_t {
KernelHandle binary_reader_kernel_id;
KernelHandle unary_writer_kernel_id;
KernelHandle eltwise_binary_kernel_id;
CBHandle cb_src0;
CBHandle cb_src1;
CBHandle cb_output;
CoreCoord compute_with_storage_grid_size;
uint32_t src0_single_tile_size;
uint32_t src1_single_tile_size;
uint32_t dst_single_tile_size;
};
using cached_program_t = ttnn::device_operation::CachedProgram<shared_variables_t>;

static cached_program_t create(
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value);

static void override_runtime_arguments(
cached_program_t& cached_program,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& tensor_return_value);
};
struct BroadcastWidthMultiCore {
struct shared_variables_t {
KernelHandle binary_reader_kernel_id;
Expand Down Expand Up @@ -192,6 +218,7 @@ struct BinaryDeviceOperation {

using program_factory_t = std::variant<
ElementWiseMultiCore,
ElementWiseMultiCoreSfpu,
BroadcastWidthMultiCore,
BroadcastHeightMultiCore,
BroadcastHeightAndWidthMultiCore,
Expand Down
Loading

0 comments on commit 6a3008b

Please sign in to comment.