From a99e6cde56c08df4abfbded37844b4c0c773670c Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Thu, 19 Sep 2024 18:31:57 +0000 Subject: [PATCH] Convert `aten.bitwise_not` to `ttnn.bitwise_not` `ttnn.bitwise_not` takes an extraneous `value: int`. I'm fixing this in the kernel. --- .../eltwise/unary/test_bitwise_not.py | 34 +++++++++++++++++++ .../passes/lowering/add_data_move_pass.py | 2 +- torch_ttnn/passes/lowering/to_tt_pass.py | 3 ++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 tests/lowering/eltwise/unary/test_bitwise_not.py diff --git a/tests/lowering/eltwise/unary/test_bitwise_not.py b/tests/lowering/eltwise/unary/test_bitwise_not.py new file mode 100644 index 000000000..ddb60bfb0 --- /dev/null +++ b/tests/lowering/eltwise/unary/test_bitwise_not.py @@ -0,0 +1,34 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class BitwiseNotModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.bitwise_not(input) + + +@pytest.mark.parametrize( + "input_shape", + [(4, 4)], +) +def test_bitwise_not(device, input_shape): + m = BitwiseNotModule() + input = torch.randint(-256, 256, input_shape, dtype=torch.int32) + result_before = m.forward(input) + option = torch_ttnn.TorchTtnnOption(device=device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has been rewritten and contains ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.bitwise_not) == 1 + # Check inference result + assert torch.equal(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 7add48c36..40d48caa7 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -33,7 +33,7 @@ def is_function_call(node) -> bool: ttnn.atan, ttnn.atan2, # binary ttnn.atanh, - # ttnn.clone, in target_wrappers + ttnn.bitwise_not, ttnn.cos, ttnn.cosh, ttnn.erf, diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b0bb2c8f2..ff514e70e 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -134,6 +134,9 @@ def call_function(self, target, args, kwargs): if target == torch.ops.aten.atanh.default: return self.call_function_prop_meta(ttnn.atanh, args, kwargs) + if target == torch.ops.aten.bitwise_not.default: + return self.call_function_prop_meta(ttnn.bitwise_not, args, kwargs) + if target == torch.ops.aten.clamp.default: return self.call_function_prop_meta(ttnn.clip, args, kwargs)