Skip to content

Commit

Permalink
Convert aten.bitwise_not to ttnn.bitwise_not
Browse files Browse the repository at this point in the history
`ttnn.bitwise_not` takes an extraneous `value: int`.  I'm fixing this in the kernel.
  • Loading branch information
jdh8 committed Oct 20, 2024
1 parent 9e4b61a commit b2f9acb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
34 changes: 34 additions & 0 deletions tests/lowering/eltwise/unary/test_bitwise_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch_ttnn
import pytest
import ttnn


class BitwiseNotModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.bitwise_not(input)


@pytest.mark.parametrize(
"input_shape",
[(4, 4)],
)
def test_bitwise_not(device, input_shape):
m = BitwiseNotModule()
input = torch.randint(-256, 256, input_shape, dtype=torch.int32)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has been rewritten and contains ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.bitwise_not) == 1
# Check inference result
assert torch.equal(result_before, result_after)
2 changes: 1 addition & 1 deletion torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_function_call(node) -> bool:
ttnn.atan,
ttnn.atan2, # binary
ttnn.atanh,
# ttnn.clone, in target_wrappers
ttnn.bitwise_not,
ttnn.cos,
ttnn.cosh,
ttnn.erf,
Expand Down
3 changes: 3 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def __init__(self, target, args, kwargs):
if target == torch.ops.aten.atanh.default:
return self.call_function_prop_meta(ttnn.atanh, args, kwargs)

if target == torch.ops.aten.bitwise_not.default:
return self.call_function_prop_meta(ttnn.bitwise_not, args, kwargs)

if target == torch.ops.aten.clamp.default:
return self.call_function_prop_meta(ttnn.clip, args, kwargs)

Expand Down

0 comments on commit b2f9acb

Please sign in to comment.