You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, during my tests I discovered a problem with my PR to enable the combination of the Matryoshka loss and the cached losses (#3068 ). I moved the loss.backward() call out of the minibatch loop. Therefore, the function produces one big computation graph containing the loss calculation of all mini batches, which defeats the purpose of doing this computation in mini-batches. This drastically increases the memory consumption and can easily lead to out of memory errors. I am already working on a fix for this issue.
I am sorry for any inconvenience this will cause.
The text was updated successfully, but these errors were encountered:
Well spotted, thank you @Marcel256. I didn't see this when I was reviewing, I thought the removed section in calculate_loss_and_cache_gradients was identical to the section in calculate_loss. Perhaps we can add a parameter to calculate_loss whether the backward should be called in minibatch or not, although I'll gladly await your fix.
Also, don't stress it - the faulty commit from #3068 hasn't been included in any release yet, so nothing bad happened yet. In my time with this project, I've probably introduced 20 bugs like this 😋
Hello, during my tests I discovered a problem with my PR to enable the combination of the Matryoshka loss and the cached losses (#3068 ). I moved the loss.backward() call out of the minibatch loop. Therefore, the function produces one big computation graph containing the loss calculation of all mini batches, which defeats the purpose of doing this computation in mini-batches. This drastically increases the memory consumption and can easily lead to out of memory errors. I am already working on a fix for this issue.
I am sorry for any inconvenience this will cause.
The text was updated successfully, but these errors were encountered: