Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aten.constant_pad_nd.default not support rank > 4 #514

Open
swimdi opened this issue Nov 27, 2024 · 2 comments
Open

aten.constant_pad_nd.default not support rank > 4 #514

swimdi opened this issue Nov 27, 2024 · 2 comments
Assignees

Comments

@swimdi
Copy link
Collaborator

swimdi commented Nov 27, 2024

aten.constant_pad_nd.default currently not support lowering to ttnn.pad if its input rank > 4

or it will TT_THROW Tensor rank is greater than 4

btw, there's no rank > 4 case in models, so I think can lower its priority

@ayerofieiev-tt
Copy link
Member

@swimdi can you explain what this op does? do we lower it to some TT-NN operation?

@swimdi
Copy link
Collaborator Author

swimdi commented Nov 28, 2024

aten.constant_pad_nd.default is do the padding on the op and it current partially lowering to ttnn.pad

if node.target == torch.ops.aten.constant_pad_nd.default:
input, pad = args[0], args[1]
if len(args) > 2:
value = args[2]
else:
value = 0
# TODO(#516)
if any(p < 0 for p in pad):
return None
input_shape = input.meta["val"].size()
output_shape = node.meta["val"].size()
rank = len(input_shape)
full_pad = [(0, 0)] * (rank - len(pad))
# The order of pad from pytorch is reversed
full_pad += [(pad[i], pad[i + 1]) for i in range(0, len(pad), 2)][::-1]
# TODO(#514)
if rank > 4:
return None
# TODO(#192): Front padding isn't well supported so skip for now
if not all(f == 0 for f, _ in full_pad):
return None
# Change layout to row-major for non-tile-size-aligned tensor
if (
rank < 2
or input_shape[-1] % ttnn.TILE_SIZE != 0
or input_shape[-2] % ttnn.TILE_SIZE != 0
or full_pad[-1][1] % ttnn.TILE_SIZE != 0
or full_pad[-2][1] % ttnn.TILE_SIZE != 0
):
input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout()))
# TODO(#515)
if output_shape[-1] % 2 != 0:
return None
return g.call_function(ttnn.pad, args=(input, full_pad, value))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

No branches or pull requests

2 participants