-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
batch_size in PointCloud is ineffective when reverse-mode differentiation is used #417
Comments
I now realized that adding ott/src/ott/geometry/pointcloud.py Line 203 in 137fd3a
ott/src/ott/geometry/pointcloud.py Line 221 in 137fd3a
Would it make sense to add this permanently to the code? This shouldn't affect the forward pass and would only affect the backward pass when online mode is used (i.e. batch_size != None ). What do you think?Are there any other places where @jax.remat should be added that I missed?
|
Hi @jaschau , thanks a lot, I will also take a closer look at this; in the meantime, could you please try running df = jax.jit(jax.value_and_grad(f, has_aux=True), static_argnames=["batch_size"]) I suspect there might be also some issues with the grad function not being jitted. |
Hi @michalk8, thanks a lot for looking into this. Just also tested it and also found that jitting doesn't seem to help here. |
Hi @jaschau thanks! just to clarify,
the problem does not come from storing many iterations, because you are using Danskin. Here when differentiating the divergence, you are differentiating 3 terms that are all These are all computed using compute_kl_reg_cost, where the two first I suspect (I am sure Michal will think about this) that the issue might be differentiating the geometry application ( Also seems to be related to this |
Based on the traceback, it just materializes the |
Thanks! then maybe isn't this a problem of having a |
Hi, at least according to the OOM I posted above, it seems that's it's not just the
I also tested what happens when I use the
So with the remat indicated above we go down from a buffer of size ott/src/ott/geometry/pointcloud.py Line 271 in 137fd3a
|
Sorry for not answering in a while @jaschau, I was off for a few days. Your option to add a |
No worries. I'll be happy to provide a PR. I will have to think about how to properly test it, though, so it might take some days/weeks until I get around to it. |
I have a large scale optimal transport problem between two PointClouds that I want to differentiate. The cost matrix does not fit into memory so I was quite happy to see the support for
batch_size
parameter inPointCloud
.Unfortunately, since reverse-mode differentiation needs to store all the intermediate results, I still run out of memory when calculating the gradients of the optimal transport problem. I believe the issue is related to the one discussed here jax-ml/jax#3186. There, the authors suggest to use
@jax.remat
decorator to disable checkpointing of the relevant code snippets and instead opt for re-compuation during the backwards pass.Here's a minimum example for reproduction. You might need to play around with the problem size (or reduce the batch size) depending on the size of your GPU.
When I run the last cell on my machine, I get
Any help in addressing this would be really appreciated!
The text was updated successfully, but these errors were encountered: