Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inefficient loss calculation in cached losses #3107

Open
Marcel256 opened this issue Dec 2, 2024 · 1 comment
Open

Inefficient loss calculation in cached losses #3107

Marcel256 opened this issue Dec 2, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@Marcel256
Copy link
Contributor

Hello, during my tests I discovered a problem with my PR to enable the combination of the Matryoshka loss and the cached losses (#3068 ). I moved the loss.backward() call out of the minibatch loop. Therefore, the function produces one big computation graph containing the loss calculation of all mini batches, which defeats the purpose of doing this computation in mini-batches. This drastically increases the memory consumption and can easily lead to out of memory errors. I am already working on a fix for this issue.

I am sorry for any inconvenience this will cause.

@tomaarsen tomaarsen added the bug Something isn't working label Dec 2, 2024
@tomaarsen
Copy link
Collaborator

Well spotted, thank you @Marcel256. I didn't see this when I was reviewing, I thought the removed section in calculate_loss_and_cache_gradients was identical to the section in calculate_loss. Perhaps we can add a parameter to calculate_loss whether the backward should be called in minibatch or not, although I'll gladly await your fix.

Also, don't stress it - the faulty commit from #3068 hasn't been included in any release yet, so nothing bad happened yet. In my time with this project, I've probably introduced 20 bugs like this 😋

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants