Skip to content

Commit

Permalink
Fix unsqueeze for 4D inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Nov 11, 2024
1 parent 6979151 commit 449f170
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
9 changes: 8 additions & 1 deletion tests/lowering/tensor_manipulation/test_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ def forward(self, x, y):

@pytest.mark.parametrize(
"input_shape, dim",
[((5, 2, 4, 3), 1)],
[
((5, 2, 4, 3), 1),
pytest.param(
(50, 1, 3, 1024),
0,
marks=pytest.mark.xfail(reason="Fails if ouput is > 4D, using TILE_LAYOUT, and W dim is >= 32."),
),
],
)
def test_unsqueeze1(device, input_shape, dim):
mod = UnsqueezeModule()
Expand Down
13 changes: 8 additions & 5 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ def rewrite_node(node):
new_kwargs = {
"fill_value": args[1],
"device": TtnnDevice(),
"layout": TtnnTileLayout(),
}
full_node = g.call_function(ttnn.full, args=(arg_metadata.size(),), kwargs=new_kwargs)
full_node = g.call_function(ttnn.to_layout, (full_node, TtnnTileLayout()), {})
return g.call_function(
relational_scalar_ops[node.target],
args=(args[0], full_node),
Expand All @@ -482,9 +482,9 @@ def rewrite_node(node):
new_kwargs = {
"fill_value": args[1],
"device": TtnnDevice(),
"layout": TtnnTileLayout(),
}
return g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs)
full = g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs)
return g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {})
# Replace op with scalar for eltwise ops
# TODO: Generalize this to support all eltwise ops
node_users = list(node.users.keys())
Expand Down Expand Up @@ -638,7 +638,9 @@ def rewrite_node(node):

output_size = list(node.meta["val"].size())

if output_size[-1] == input_size[-1] and len(output_size) <= 4:
# FIXME: Cannot reshape a 4D tensor if size[-1] >= 32.
can_unsqueeze_4d = (len(input_size) >= 4 and input_size[-1] < 32) or len(input_size) < 4
if output_size[-1] == input_size[-1] and can_unsqueeze_4d:
return g.call_function(ttnn.reshape, args=(args[0], output_size))
return None

Expand Down Expand Up @@ -768,12 +770,13 @@ def rewrite_node(node):
multiplier = np_tensor_shp // np_mask_shp
mask_bcst = g.call_function(target_wrappers.repeat, args=(mask, multiplier.tolist()))

kwargs = {"dtype": TtnnBfloat16(), "layout": TtnnTileLayout(), "device": TtnnDevice()}
kwargs = {"dtype": TtnnBfloat16(), "device": TtnnDevice()}
ones = g.call_function(ttnn.ones, (tensor_shape,), kwargs)
mask_flip = g.call_function(ttnn.subtract, (ones, mask_bcst))
tensor_masked = g.call_function(ttnn.multiply, (tensor, mask_flip))

full = g.call_function(ttnn.full, (tensor_shape, fill_value), kwargs)
full = g.call_function(ttnn.to_layout, (full, TtnnTileLayout()), {})
full_masked = g.call_function(ttnn.multiply, (mask_bcst, full))

masked_fill = g.call_function(ttnn.add, (tensor_masked, full_masked))
Expand Down

0 comments on commit 449f170

Please sign in to comment.