Skip to content

Commit

Permalink
#13527: Update ttnn.clamp logic to match PyTorch API (#13530)
Browse files Browse the repository at this point in the history
* #13527: Cleanup clamp

* #13527: Update clamp to match PyTorch API
  • Loading branch information
VirdhatchaniKN authored Oct 16, 2024
1 parent abf2adb commit a73be8a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 26 deletions.
35 changes: 26 additions & 9 deletions tests/ttnn/unit_tests/operations/eltwise/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,33 @@ def test_unary_composite_cbrt_ttnn(input_shapes, device):
(torch.Size([1, 3, 320, 384])),
),
)
def test_unary_composite_clamp_ttnn(input_shapes, device):
@pytest.mark.parametrize(
"min, max",
[
(-10, 10),
(1, -1),
(0, 0),
(-1.0, None),
(None, 1.0),
(None, None),
(-0.5, None),
(None, -0.5),
(1.0, 0.0),
(0.0, 1.0),
],
)
def test_unary_composite_clamp_ttnn(input_shapes, min, max, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
min = -10
max = 10
output_tensor = ttnn.clamp(input_tensor1, min, max)
golden_function = ttnn.get_golden_function(ttnn.clamp)
golden_tensor = golden_function(in_data1, min, max)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
if min is None and max is None:
with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"):
ttnn.clamp(input_tensor1, min=min, max=max)
assert True
else:
output_tensor = ttnn.clamp(input_tensor1, min=min, max=max)
golden_function = ttnn.get_golden_function(ttnn.clamp)
golden_tensor = golden_function(in_data1, min=min, max=max)
comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,24 @@ Tensor _clip(const Tensor& a, float low, float high, const std::optional<MemoryC
}

// clamp
Tensor _clamp(const Tensor& a, float low, float high, const std::optional<MemoryConfig>& output_mem_config) {
return _clip(a, low, high, output_mem_config);
Tensor ExecuteUnaryCompositeClamp::invoke(const Tensor& a, std::optional<float> min, std::optional<float> max, const std::optional<MemoryConfig>& output_mem_config) {
auto output_memory_config = output_mem_config.value_or(a.memory_config());
TT_FATAL((max.has_value() || min.has_value()), "Only one of 'min' or 'max' can be None. Please provide one value");
if (!max.has_value()) {
return ttnn::where( ttnn::ge(a, min.value(), std::nullopt, output_memory_config), a, min.value(), output_memory_config);
}else if(!min.has_value()) {
return ttnn::where( ttnn::le(a, max.value(), std::nullopt, output_memory_config), a, max.value(), output_memory_config);
}else if(min.value() > max.value()){
return full_like(a, max.value());
}
const Tensor h_const = full_like(a, max.value());
Tensor a_max = ttnn::minimum(a, h_const, output_memory_config);
if (min.value() == 0.0f) {
return ttnn::relu(a_max, output_memory_config);
} else {
const Tensor l_const = full_like(a, min.value());
return ttnn::maximum(a_max, l_const, output_memory_config);
}
}

// hardtanh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ enum class UnaryCompositeOpType {
HARDSIGMOID,
HARDTANH,
CLIP,
CLAMP,
SELU,
THRESHOLD,
GLU,
Expand Down Expand Up @@ -86,7 +85,6 @@ Tensor _hardswish(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, c
Tensor _hardsigmoid(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _hardtanh(const Tensor&, float min = -1, float max = 1, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _clip(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _clamp(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _selu(const Tensor&, float scale = 1.0507, float alpha = 1.67326, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _threshold(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _glu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Expand Down Expand Up @@ -280,13 +278,6 @@ struct OpHandler<UnaryCompositeOpType::CLIP> {
}
};

template <>
struct OpHandler<UnaryCompositeOpType::CLAMP> {
static Tensor handle(const Tensor& t1, float low, float high, const std::optional<MemoryConfig>& mem_cfg ) {
return _clamp(t1, low, high, mem_cfg);
}
};

template <>
struct OpHandler<UnaryCompositeOpType::SELU> {
static Tensor handle(const Tensor& t1, float scale, float alpha, const std::optional<MemoryConfig>& mem_cfg ) {
Expand Down
10 changes: 9 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ struct ExecuteUnaryCompositeOpWithFloats {
}
};

struct ExecuteUnaryCompositeClamp {
static Tensor invoke(
const Tensor &input_tensor,
std::optional<float> min = std::nullopt,
std::optional<float> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

template <UnaryCompositeOpType unary_comp_op_type>
struct ExecuteUnaryCompositeOpWithInt {

Expand Down Expand Up @@ -265,7 +273,7 @@ constexpr auto clip = ttnn::register_operation_with_auto_launch_op<
operations::unary::ExecuteUnaryCompositeOpWithFloats<operations::unary::UnaryCompositeOpType::CLIP>>();
constexpr auto clamp = ttnn::register_operation_with_auto_launch_op<
"ttnn::clamp",
operations::unary::ExecuteUnaryCompositeOpWithFloats<operations::unary::UnaryCompositeOpType::CLAMP>>();
operations::unary::ExecuteUnaryCompositeClamp>();
constexpr auto selu = ttnn::register_operation_with_auto_launch_op<
"ttnn::selu",
operations::unary::ExecuteUnaryCompositeOpWithFloats<operations::unary::UnaryCompositeOpType::SELU>>();
Expand Down
58 changes: 54 additions & 4 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,56 @@ namespace unary {

namespace detail {

template <typename unary_operation_t>
void bind_unary_composite_optional_floats_with_default(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, std::optional<float> parameter_a_value, const std::string& parameter_name_b, const std::string& parameter_b_doc, std::optional<float> parameter_b_value, const std::string& description) {
auto doc = fmt::format(
R"doc(
{8}
Args:
input_tensor (ttnn.Tensor): the input tensor.
Keyword args:
{2} (float): {3}. Defaults to `{4}`.
{5} (float): {6}. Defaults to `{7}`.
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2} = {4}, {5} = {7})
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name_a,
parameter_a_doc,
parameter_a_value,
parameter_name_b,
parameter_b_doc,
parameter_b_value,
description);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const ttnn::Tensor& input_tensor,
std::optional<float> parameter_a,
std::optional<float> parameter_b,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, parameter_a, parameter_b, memory_config);
},
py::arg("input_tensor"),
py::kw_only(),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation(py::module& module, const unary_operation_t& operation, const std::string& math, const std::string& info_doc = "" ) {
auto doc = fmt::format(
Expand Down Expand Up @@ -1583,12 +1633,12 @@ void py_module(py::module& module) {
"low", "Low value",
"high", "High value",
R"doc(Performs clip function on :attr:`input_tensor`, :attr:`low`, :attr:`high`.)doc");
detail::bind_unary_composite_floats(
detail::bind_unary_composite_optional_floats_with_default(
module,
ttnn::clamp,
"low", "Low value",
"high", "High value",
R"doc(Performs clamp function on :attr:`input_tensor`, :attr:`low`, :attr:`high`.)doc");
"min", "Minimum value", std::nullopt,
"max", "Maximum value", std::nullopt,
R"doc(Performs clamp function on :attr:`input_tensor`, :attr:`min`, :attr:`max`. Only one of 'min' or 'max' value can be None.)doc");
detail::bind_unary_composite_floats_with_default(
module,
ttnn::selu,
Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _golden_function_polygamma(input_tensor_a, k, *args, **kwargs):
ttnn.attach_golden_function(ttnn.polygamma, golden_function=_golden_function_polygamma)


def _golden_function_clamp(input_tensor_a, min, max, *args, **kwargs):
def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs):
import torch

return torch.clamp(input=input_tensor_a, min=min, max=max)
Expand Down

0 comments on commit a73be8a

Please sign in to comment.