Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle typecasting aten._to_copy with torch.bool types #351

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions tests/lowering/creation/test_masked_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def forward(self, input, mask, fill_value):
return input.masked_fill(mask, fill_value)


class ToCopyMaskedFill(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, mask, fill_value):
_to_copy = torch.ops.aten._to_copy.default(mask, dtype=torch.bool)
return input.masked_fill(_to_copy, fill_value)


@pytest.mark.parametrize(
"input_shape, mask_shape, fill_value",
[
Expand Down Expand Up @@ -87,3 +96,16 @@ def forward(self, input, mask, fill_value):
)
def test_masked_fill(device, input_shape, mask_shape, fill_value):
_test_masked_fill_common(device, MaskedFillModule(), input_shape, mask_shape, fill_value)


@pytest.mark.parametrize(
"input_shape, mask_shape, fill_value",
[
((1, 1, 32, 32), (1, 1, 32, 32), -3.3895313892515355e38),
((1, 16, 32, 32), (1, 1, 32, 32), -3.3895313892515355e38),
],
)
def test_masked_fill_to_copy(device, input_shape, mask_shape, fill_value):
target = _test_masked_fill_common(device, ToCopyMaskedFill(), input_shape, mask_shape, fill_value)
# Check the graph has be rewritten and contain ttnn ops
assert target.count(torch.ops.aten._to_copy.default) == 0
5 changes: 5 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
],
]

aten_masked_fill_scalar_blocklist = [
["Tensor<[2, 1, 7, 7]> self = ?", "Tensor<[2, 1, 7, 7]> mask = ?", "number value = -3.3895313892515355e+38"],
]

# Need to remove this from the blocklist so that yolos can pass
aten_view_default_blocklist.remove(["Tensor<[1, 192, 32, 42]> self = ?", "List[int] size = [1, 192, 1344]"])

Expand Down Expand Up @@ -187,6 +191,7 @@
GUARD[torch.ops.aten._to_copy.default] = partial(guard_aten, aten__to_copy_default_blocklist)
GUARD[torch.ops.aten.unsqueeze.default] = partial(guard_aten, aten_unsqueeze_default_blocklist)
GUARD[torch.ops.aten.squeeze.dim] = partial(guard_aten, aten_squeeze_dim_blocklist)
GUARD[torch.ops.aten.masked_fill.Scalar] = partial(guard_aten, aten_masked_fill_scalar_blocklist)


def can_lowering_to_ttnn(node):
Expand Down
5 changes: 0 additions & 5 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,6 @@
["Tensor<[1, 40, 28, 28]> self = ?", "List[int] size = [1, 40, 28, 28]", "List[int] stride = [31360, 784, 28, 1]"],
]
aten_mm_default_blocklist = [["Tensor<[1, 21843]> self = ?", "Tensor<[21843, 768]> mat2 = ?"]]
aten_masked_fill_scalar_blocklist = [
["Tensor<[2, 1, 7, 7]> self = ?", "Tensor<[2, 1, 7, 7]> mask = ?", "number value = -3.3895313892515355e+38"],
]


def get_inputs(node):
Expand Down Expand Up @@ -677,7 +674,6 @@ def guard_aten(blocklist, node):
torch.ops.aten.native_dropout.default: partial(guard_aten, aten_native_dropout_default_blocklist),
torch.ops.aten.new_empty_strided.default: partial(guard_aten, aten_new_empty_strided_default_blocklist),
torch.ops.aten.mm.default: partial(guard_aten, aten_mm_default_blocklist),
torch.ops.aten.masked_fill.Scalar: partial(guard_aten, aten_masked_fill_scalar_blocklist),
}

guard_ops = [
Expand Down Expand Up @@ -715,5 +711,4 @@ def guard_aten(blocklist, node):
"torch.ops.aten.native_dropout.default",
"torch.ops.aten.new_empty_strided.default",
"torch.ops.aten.mm.default",
"torch.ops.aten.masked_fill.Scalar",
]
7 changes: 6 additions & 1 deletion torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,13 @@ def rewrite_node(node):
if node.target == torch.ops.aten._to_copy.default:
target_users_ops = [user.target for user in node.users.keys()]
# Float and int types can be converted to ttnn.bfloat16, but bool may be problematic
# Can be removed if bool is only used for the following ops that have lowering:
ops_safe_to_remove = set([torch.ops.aten.masked_fill.Scalar])
can_remove_bool = kwargs["dtype"] != bool or (
kwargs["dtype"] == torch.bool and ops_safe_to_remove.intersection(target_users_ops)
)
# Skip if type casting from bool and if the graph output uses this op
if kwargs["dtype"] not in [torch.bool] and "output" not in target_users_ops:
if can_remove_bool and "output" not in target_users_ops:
# Essentially remove this op because it's used as a typecast
return node.args[0]
else:
Expand Down
Loading