diff --git a/tests/lowering/eltwise/unary/test_bitwise_not.py b/tests/lowering/eltwise/unary/test_bitwise_not.py new file mode 100644 index 000000000..ddb60bfb0 --- /dev/null +++ b/tests/lowering/eltwise/unary/test_bitwise_not.py @@ -0,0 +1,34 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class BitwiseNotModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.bitwise_not(input) + + +@pytest.mark.parametrize( + "input_shape", + [(4, 4)], +) +def test_bitwise_not(device, input_shape): + m = BitwiseNotModule() + input = torch.randint(-256, 256, input_shape, dtype=torch.int32) + result_before = m.forward(input) + option = torch_ttnn.TorchTtnnOption(device=device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has been rewritten and contains ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.bitwise_not) == 1 + # Check inference result + assert torch.equal(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 0050ff3a3..8b38fa63b 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -35,7 +35,7 @@ def is_function_call(node) -> bool: ttnn.atan, ttnn.atan2, # binary ttnn.atanh, - # ttnn.clone, in target_wrappers + ttnn.bitwise_not, ttnn.cos, ttnn.cosh, ttnn.elu, diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 906f0cc9b..78654afbe 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -190,6 +190,9 @@ def __init__(self, target, args, kwargs): if target == torch.ops.aten.atanh.default: return self.call_function_prop_meta(ttnn.atanh, args, kwargs) + if target == torch.ops.aten.bitwise_not.default: + return self.call_function_prop_meta(ttnn.bitwise_not, args, kwargs) + if target == torch.ops.aten.clamp.default: # aten.clamp args are positional but ttnn.clip uses kw args new_kwargs = map_args_to_kwargs(args, ((1, "min"), (2, "max")), default_none=True)