Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Latest update to moreh_adamw has issue #14186

Closed
rfurko-tt opened this issue Oct 23, 2024 · 16 comments
Closed

[Bug Report] Latest update to moreh_adamw has issue #14186

rfurko-tt opened this issue Oct 23, 2024 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@rfurko-tt
Copy link
Contributor

rfurko-tt commented Oct 23, 2024

Describe the bug
After updating to latest tt-metal main branch, we see that training loss goes down way slower. After turning off program cache, training loss goes as expected. We expect that bug is located in latest update to moreh_adamw cache functions. One of the ideas, that we aren't properly passing step now.

Expected behavior
Speed of convergence should not be affected by program caching.

Please complete the following environment information:

  • OS: Ubuntu 20.04
  • Version of software 05ae26f
@rfurko-tt rfurko-tt added the bug Something isn't working label Oct 23, 2024
@o2buzzle
Copy link
Contributor

That is odd to say the least.

I would not expect that to happen either since the program cache should be faster. It being slower would suggests runtime argument replacement is taking longer than recompiling from scratch which makes little sense to me

@mrshaw01
Copy link
Contributor

Hi Roman,
'Goes way slower' here means it's slower when running the same number of steps. Or: with n steps, it takes the same amount of time to run, but the loss is larger when the program cache is turned on compared to when it's turned off.

@dmakoviichuk-tt
Copy link
Contributor

@mrshaw01 we mean training loss.
Feels like step is not updated.

@rfurko-tt
Copy link
Contributor Author

Hi everyone,
Sorry for confusion. Loss goes down way slower.

@dmakoviichuk-tt
Copy link
Contributor

But if talk about cache. Somehow I don't see significant perf difference in cached vs non-cached :)

@rfurko-tt
Copy link
Contributor Author

Before cache fix, we get loss equal to 3.0 at iteration 8.
After cache fix, we get to this value only after 90 iterations.

@mrshaw01
Copy link
Contributor

We’ve got it. o2buzzle will debug this issue.
Thank you for the information!

@o2buzzle
Copy link
Contributor

@rfurko-tt That would suggest an issue with override_runtime_arguments not overriding kernel runtime values, but so far I couldn't find anything obviously wrong there yet, so it's likely something else deeper in the framework.

DPRINT inside the kernel shows that both are updating as expected.
image

@dmakoviichuk-tt
Copy link
Contributor

@o2buzzle that was our idea because it kinda trained but like with 100x smaller lr.

@o2buzzle
Copy link
Contributor

@dmakoviichuk-tt Would you mind providing some sample code that can reproduce the issue for reference? As it is I am not able to find anything that would stands out as buggy and the unit tests seems to be agreeing with that.

@dmakoviichuk-tt
Copy link
Contributor

Hi @o2buzzle unfortunately it is very hard to show this code right now.
What we found:
15338d4 this commit works for us.
But when you introduced caching for learning rate it stopped working.

@dmakoviichuk-tt
Copy link
Contributor

dmakoviichuk-tt commented Oct 24, 2024

#14243 here is the fix.
There is a chance that you forgot to add some tensor args to the hash.
But I like current code a little bit more than manually add everything.

@o2buzzle
Copy link
Contributor

o2buzzle commented Oct 25, 2024

fyi this dude (@mrshaw01) told me to add everything instead of just setting it to zero (something something code convention)
smh, fml

@mrshaw01
Copy link
Contributor

Both approaches should work well.
The 'manual addition' approach has already been used in some operations, so I suggested it for consistency. However, it turned out to be buggy.
But nevermind, the current approach is OK.

@dmakoviichuk-tt
Copy link
Contributor

@mrshaw01 would be nice to find why it is buggy in fuutre. Is it hash or missed parameters?

@o2buzzle
Copy link
Contributor

o2buzzle commented Oct 25, 2024

I find the odd thing here being that the kernel itself is reporting (through DPRINT) that the value has correctly changed, when I wrote an unit test to modify it, so whatever has been happening, it must have somehow hit the cache but managed to bypass runtime argument replacements for this to happen. Maybe some quirk in the caching mechanism? Or the unit test didn’t have enough cover on it

I originally study and re-implemented the hashing code from the example at group_attn_matmul_device_operation.cpp, and that one seems to work fine.

Also, Shaw didn’t really reviewed how the hash is handled per se, he just checked if it worked correctly with the provided unit tests (which it did pass, oddly) and gave input on how I implement it, but otherwise didn’t really deal with that bit of change directly. Hashing details errors there are mine and mine only.

dmakoviichuk-tt added a commit that referenced this issue Oct 25, 2024
* #14186: Fixed moreh_adam

* #0: fixed adam too
ct-clmsn pushed a commit to ct-clmsn/tt-metal that referenced this issue Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants