Skip to content

Commit

Permalink
Slightly optimize converting aten.t
Browse files Browse the repository at this point in the history
- Leave illegal `aten.t` ops as is
- Optimize out no-op 0-D and 1-D cases
- Simplify the code
  • Loading branch information
jdh8 committed Sep 15, 2024
1 parent ccc66e5 commit 9b03335
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,11 @@ def rewrite_node(node):
return new_nodes[-1]

if node.target == torch.ops.aten.t.default:
permutation = list()
rank = len(node.meta["val"].size())
assert rank >= 0 and rank <= 2, "Input tensor can only be 0D, 1D or 2D"
if rank == 2:
permutation = [1, 0]
return g.call_function(ttnn.permute, args=(args[0], permutation))
return None
return g.call_function(ttnn.permute, (args[0], (1, 0)))
return args[0] if rank < 2 else None

if node.target == torch.ops.aten.constant_pad_nd.default:
input, pad, value = args
input_shape = input.meta["val"].size()
Expand Down

0 comments on commit 9b03335

Please sign in to comment.