From 9b0333517e467c4f46184c04a86761e028448e79 Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Sat, 14 Sep 2024 22:14:41 +0000 Subject: [PATCH] Slightly optimize converting `aten.t` - Leave illegal `aten.t` ops as is - Optimize out no-op 0-D and 1-D cases - Simplify the code --- torch_ttnn/passes/lowering/to_tt_pass.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 818d229e0..1e3efef9c 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -610,13 +610,11 @@ def rewrite_node(node): return new_nodes[-1] if node.target == torch.ops.aten.t.default: - permutation = list() rank = len(node.meta["val"].size()) - assert rank >= 0 and rank <= 2, "Input tensor can only be 0D, 1D or 2D" if rank == 2: - permutation = [1, 0] - return g.call_function(ttnn.permute, args=(args[0], permutation)) - return None + return g.call_function(ttnn.permute, (args[0], (1, 0))) + return args[0] if rank < 2 else None + if node.target == torch.ops.aten.constant_pad_nd.default: input, pad, value = args input_shape = input.meta["val"].size()