From 449f17038868328a3f1c1677e519116b13312c1c Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Mon, 11 Nov 2024 19:44:54 +0000 Subject: [PATCH] Fix unsqueeze for 4D inputs --- .../lowering/tensor_manipulation/test_unsqueeze.py | 9 ++++++++- torch_ttnn/passes/lowering/to_tt_pass.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/lowering/tensor_manipulation/test_unsqueeze.py b/tests/lowering/tensor_manipulation/test_unsqueeze.py index cf825c4dd..06387d4e9 100644 --- a/tests/lowering/tensor_manipulation/test_unsqueeze.py +++ b/tests/lowering/tensor_manipulation/test_unsqueeze.py @@ -17,7 +17,14 @@ def forward(self, x, y): @pytest.mark.parametrize( "input_shape, dim", - [((5, 2, 4, 3), 1)], + [ + ((5, 2, 4, 3), 1), + pytest.param( + (50, 1, 3, 1024), + 0, + marks=pytest.mark.xfail(reason="Fails if ouput is > 4D, using TILE_LAYOUT, and W dim is >= 32."), + ), + ], ) def test_unsqueeze1(device, input_shape, dim): mod = UnsqueezeModule() diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 3e3472828..1700ae385 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -456,9 +456,9 @@ def rewrite_node(node): new_kwargs = { "fill_value": args[1], "device": TtnnDevice(), - "layout": TtnnTileLayout(), } full_node = g.call_function(ttnn.full, args=(arg_metadata.size(),), kwargs=new_kwargs) + full_node = g.call_function(ttnn.to_layout, (full_node, TtnnTileLayout()), {}) return g.call_function( relational_scalar_ops[node.target], args=(args[0], full_node), @@ -482,9 +482,9 @@ def rewrite_node(node): new_kwargs = { "fill_value": args[1], "device": TtnnDevice(), - "layout": TtnnTileLayout(), } - return g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs) + full = g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs) + return g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {}) # Replace op with scalar for eltwise ops # TODO: Generalize this to support all eltwise ops node_users = list(node.users.keys()) @@ -638,7 +638,9 @@ def rewrite_node(node): output_size = list(node.meta["val"].size()) - if output_size[-1] == input_size[-1] and len(output_size) <= 4: + # FIXME: Cannot reshape a 4D tensor if size[-1] >= 32. + can_unsqueeze_4d = (len(input_size) >= 4 and input_size[-1] < 32) or len(input_size) < 4 + if output_size[-1] == input_size[-1] and can_unsqueeze_4d: return g.call_function(ttnn.reshape, args=(args[0], output_size)) return None @@ -768,12 +770,13 @@ def rewrite_node(node): multiplier = np_tensor_shp // np_mask_shp mask_bcst = g.call_function(target_wrappers.repeat, args=(mask, multiplier.tolist())) - kwargs = {"dtype": TtnnBfloat16(), "layout": TtnnTileLayout(), "device": TtnnDevice()} + kwargs = {"dtype": TtnnBfloat16(), "device": TtnnDevice()} ones = g.call_function(ttnn.ones, (tensor_shape,), kwargs) mask_flip = g.call_function(ttnn.subtract, (ones, mask_bcst)) tensor_masked = g.call_function(ttnn.multiply, (tensor, mask_flip)) full = g.call_function(ttnn.full, (tensor_shape, fill_value), kwargs) + full = g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {}) full_masked = g.call_function(ttnn.multiply, (mask_bcst, full)) masked_fill = g.call_function(ttnn.add, (tensor_masked, full_masked))