Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reshape -> Transpose gives bad PCC #13745

Open
yieldthought opened this issue Oct 11, 2024 · 18 comments
Open

Reshape -> Transpose gives bad PCC #13745

yieldthought opened this issue Oct 11, 2024 · 18 comments
Assignees
Labels
bug Something isn't working llama3 LLM_bug Op Generalization Generalization and relaxations of requirements in Ops P1

Comments

@yieldthought
Copy link
Contributor

yieldthought commented Oct 11, 2024

Describe the bug

    tt_input = tt_input.reshape(1, 2048, 4, 128)
    tt_output= ttnn.transpose(tt_input, 1, 2)

gives 0.0 PCC compared to Torch:

    torch_ref = torch_input.view(1, 2048, 4, 128)
    torch_ref = torch_ref.transpose(1, 2)

To Reproduce

import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc

def test_transpose_with_reshape(device):
    # Create input tensor
    torch_input = torch.rand((1, 1, 2048, 512), dtype=torch.bfloat16)
    
    # TT operations
    tt_input = ttnn.from_torch(
        torch_input,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )
    tt_input = tt_input.reshape(1, 2048, 4, 128)
    tt_output= ttnn.transpose(tt_input, 1, 2)
    
    # Convert back to PyTorch for comparison
    tt_result = ttnn.to_torch(tt_output)
    
    # PyTorch reference operations
    torch_ref = torch_input.view(1, 2048, 4, 128)
    torch_ref = torch_ref.transpose(1, 2)
    
    # Compare results
    assert_with_pcc(torch_ref, tt_result, 0.9999)

Expected behavior
Expected this to match torch behaviour.

Please complete the following environment information:
Internal IRD-supplied N150

Additional context
Seen bringing up Llama 3.2

@yieldthought yieldthought added bug Something isn't working LLM_bug Op Generalization Generalization and relaxations of requirements in Ops llama3 P1 labels Oct 11, 2024
@sjameelTT
Copy link
Contributor

Have you tried ttnn.reshape(tt_input, (1, 2048, 4, 128))? The tensor.op apis are unreliable atm since they haven't been updated. We will get to updating those once we're done getting full functionality.

@sjameelTT
Copy link
Contributor

sjameelTT commented Oct 11, 2024

I tested on my branch sjameel/transpose_pad and it worked with ttnn.reshape I will aim to get that into main soon.

@sjameelTT
Copy link
Contributor

@yieldthought is this still an issue on main?

@ntarafdar
Copy link
Contributor

@yieldthought can you comment if this is fixed? @sjameelTT just pushed a fix to transpose

@yieldthought
Copy link
Contributor Author

The test as provided in this issue still fails on main:

expected_pytorch_result = tensor([[[[0.0508, 0.3359, 0.5977,  ..., 0.5312, 0.7422, 0.4688],
          [0.5156, 0.5273, 0.5117,  ..., 0.5898, 0.7...406, 0.1172, 0.8828],
          [0.4883, 0.9336, 0.3477,  ..., 0.6133, 0.1484, 0.0469]]]],
       dtype=torch.bfloat16)
actual_pytorch_result = TorchTensor([[[[0.0508, 0.3359, 0.5977,  ..., 0.5312, 0.7422, 0.4688],
               [0.6914, 0.9375, 0.6523,  ..., 0...9, 0.6562],
               [0.4883, 0.9336, 0.3477,  ..., 0.6133, 0.1484, 0.0469]]]],
            dtype=torch.bfloat16)
pcc = 0.9999

    def assert_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.9999):
        assert list(expected_pytorch_result.shape) == list(
            actual_pytorch_result.shape
        ), f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}"
        pcc_passed, pcc_message = comp_pcc(expected_pytorch_result, actual_pytorch_result, pcc)
>       assert pcc_passed, construct_pcc_assert_message(pcc_message, expected_pytorch_result, actual_pytorch_result)
E       AssertionError: 0.015757835260159045

tests/ttnn/utils_for_testing.py:30: AssertionError

@cglagovichTT
Copy link
Contributor

cglagovichTT commented Nov 25, 2024

@ntarafdar can you add this test and arbitrary variants of it to your sweeps? The shapes in this test can be parametrized nicely to give lots of coverage for reshape -> transpose.

In addition, please test {single device, multidevice} x {L1 interleaved, DRAM interleaved} x {RM, TILE}. If any of these configurations is not supported yet, it should assert that.

@sjameelTT
Copy link
Contributor

This is a reshape issue, not a transpose issue:

replace:
tt_input = tt_input.reshape(1, 2048, 4, 128)
with
tt_input = ttnn.reshape(tt_input, (1, 2048, 4, 128))

and it will pass. @ntarafdar we should probably make sure that tensor.reshape pybind is the same as the ttnn.reshape...

@ntarafdar
Copy link
Contributor

ahh CC: @jvegaTT

tt_input.reshape does a view change and ignores padding...

@ntarafdar
Copy link
Contributor

I'm closing this issue, since this is now working with @sjameelTT's suggestion. @cglagovichTT please file a separate issue for additional testing.

@cglagovichTT
Copy link
Contributor

I think the issue should stay open since tt_input.reshape(...) still has unexpected behavior. ttnn.reshape(...) is a workaround, so it can stay at P1

@ntarafdar
Copy link
Contributor

tt_input.reshape is not the API to use. we are using ttnn api for transpose and a tensor API for reshape.
I agree this is a problem but this is not the problem for this issue

@cglagovichTT
Copy link
Contributor

In that case tt_input.reshape should not be a valid call, so that pybind should be removed as part of resolving this issue.
The important thing here is that there's unexpected behavior from the user perspective which is unresolved

@ntarafdar
Copy link
Contributor

ntarafdar commented Nov 25, 2024

this might be a more involved change.
A lot of API uses this reshape because it abuses the data format underneath the hood to incorporate padding, which is inappropriately used by developers everywhere.
We can see what needs to be done, but it might be a lot of work for little reward.

@ntarafdar ntarafdar reopened this Nov 25, 2024
@cglagovichTT
Copy link
Contributor

I'd say that the reward is high. Unexpected API behavior like this cost developers lots of debug time

@uaydonat
Copy link
Contributor

Simple bugs like these are very time consuming and annoying to the customers.
There is no way for a customer/developer to know not to use tt_input.reshape. So, either it should be removed completely or maybe only allow specific use cases (for padding?) we know works, and assert other cases and point to ttnn.reshape.

@jvegaTT
Copy link
Contributor

jvegaTT commented Nov 26, 2024

tt_input.reshape just updates the Metadata of the tensor without changing the data. This is what we want sometimes as some reshapes have no changes in the physical locations of the data or if we want to just include some of the padding data from the padded shape into the logical shape. ttnn.reshape will quickly check if it is valid to call tt_input.reshape and does so if possible. If not then it will do the actual reshape with the required data movements.

There are use cases for both of these functions. Would renaming one of them be the solution? Both are extensively utilized all over our code base, it is a very involved change.

@cglagovichTT
Copy link
Contributor

It sounds like tensor.reshape is doing a view without validating its inputs. I think a solution would be to change its name to something that indicates that it's an internal unsafe view op, or add input validation and call it view. What do you think?

Also someone can check this, but in Pytorch doesn't tensor.op(...) and torch.op(tensor, ...) usually have the same behavior?

@ntarafdar
Copy link
Contributor

ntarafdar commented Nov 26, 2024

@nardoTT is on it! She is going to make sure the <tensor>.reshape(new_shape) and ttnn.reshape(<tensor>, new_shape) behave the same way, we will also introduce an unsafe view for cases that need it (off the top of my head I know slice needs the unsafe version but there could be more that show up).

She is bright but she is new so I would give her a couple weeks to ramp up and get this in as this is part of her ramp up task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working llama3 LLM_bug Op Generalization Generalization and relaxations of requirements in Ops P1
Projects
None yet
Development

No branches or pull requests

7 participants