-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert
aten.bitwise_not
to ttnn.bitwise_not
`ttnn.bitwise_not` takes an extraneous `value: int`. I'm fixing this in the kernel.
- Loading branch information
Showing
3 changed files
with
38 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class BitwiseNotModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.bitwise_not(input) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape", | ||
[(4, 4)], | ||
) | ||
def test_bitwise_not(device, input_shape): | ||
m = BitwiseNotModule() | ||
input = torch.randint(-256, 256, input_shape, dtype=torch.int32) | ||
result_before = m.forward(input) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(input) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has been rewritten and contains ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.bitwise_not) == 1 | ||
# Check inference result | ||
assert torch.equal(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters