Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch_size in PointCloud is ineffective when reverse-mode differentiation is used #417

Open
jaschau opened this issue Aug 23, 2023 · 9 comments

Comments

@jaschau
Copy link

jaschau commented Aug 23, 2023

I have a large scale optimal transport problem between two PointClouds that I want to differentiate. The cost matrix does not fit into memory so I was quite happy to see the support for batch_size parameter in PointCloud.
Unfortunately, since reverse-mode differentiation needs to store all the intermediate results, I still run out of memory when calculating the gradients of the optimal transport problem. I believe the issue is related to the one discussed here jax-ml/jax#3186. There, the authors suggest to use @jax.remat decorator to disable checkpointing of the relevant code snippets and instead opt for re-compuation during the backwards pass.

Here's a minimum example for reproduction. You might need to play around with the problem size (or reduce the batch size) depending on the size of your GPU.

from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
from ott.geometry import pointcloud
from ott.tools import sinkhorn_divergence


# %%
def sample_points_uniformly_from_disc(r, n, key):
    key_r, key_phi = jrandom.split(key, 2)
    # sqrt is necessary to achieve uniform distribution, c.f. https://stats.stackexchange.com/questions/481543/generating-random-points-uniformly-on-a-disk
    r_vals = r * jnp.sqrt(jrandom.uniform(key_r, shape=(n,)))
    phi_vals = 2 * jnp.pi * jrandom.uniform(key_phi, shape=(n,))
    y = jnp.stack(
        (r_vals * jnp.cos(phi_vals), r_vals * jnp.sin(phi_vals)),
        axis=1
    )
    return y



# %% [markdown]
# # Sample points uniformly from discs with different radii

# %%
key = jrandom.PRNGKey(seed=42)
key, key_x, key_y = jrandom.split(key, 3)
n = 50000
x = sample_points_uniformly_from_disc(5, n, key_x)
y = sample_points_uniformly_from_disc(10, n, key_x)
plt.plot(x[:, 0], x[:, 1], ".")
plt.plot(y[:, 0], y[:, 1], ".")


# %% [markdown]
# # Run forward sinkhorn divergence with batch size

# %%
@partial(jax.jit, static_argnames=["batch_size"])
def f(x, y, a, b, batch_size=None):
    out = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud, x, y, a=a, b=b,
        batch_size=batch_size,
        sinkhorn_kwargs={"use_danskin": True}
    )
    return out.divergence, out

batch_size = 10000
div, div_res = f(x, y, None, None, batch_size=batch_size)
print("div:", div)
# outputs 12.471696

# %% [markdown]
# # Run backward pass (fails with OOM)

# %%
df = jax.value_and_grad(f, has_aux=True)
(div, div_res), grad = df(x, y, None, None, batch_size=batch_size)
print("div=", div)

When I run the last cell on my machine, I get

2023-08-23 12:06:41.435311: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 17.77GiB (19078594560 bytes) by rematerialization; only reduced to 29.82GiB (32015643581 bytes), down from 29.82GiB (32015644425 bytes) originally
2023-08-23 12:06:52.591406: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.31GiB (rounded to 10000000000)requested by op 
2023-08-23 12:06:52.591535: W external/tsl/tsl/framework/bfc_allocator.cc:497] *****************************************************_______________________________________________
2023-08-23 12:06:52.591823: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   781.2KiB
              constant allocation:       328B
        maybe_live_out allocation:   18.63GiB
     preallocated temp allocation:    9.32GiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:   27.95GiB
Peak buffers:
	Buffer 1:
		Size: 9.31GiB
		Operator: op_name="jit(f)/jit(main)/broadcast_in_dim[shape=(5, 10000, 50000) broadcast_dimensions=()]" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3
		XLA Label: broadcast
		Shape: f32[5,10000,50000]
		==========================

	Buffer 2:
		Size: 9.31GiB
		Operator: op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3 deduplicated_name="fusion.145"
		XLA Label: fusion
		Shape: f32[5,10000,50000]
		==========================

	Buffer 3:
		Size: 9.31GiB
		Operator: op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3 deduplicated_name="fusion.145"
		XLA Label: fusion
		Shape: f32[5,10000,50000]
		==========================

Any help in addressing this would be really appreciated!

@jaschau
Copy link
Author

jaschau commented Aug 23, 2023

I now realized that adding @jax.remat to

def body0(carry, i: int):
and to
def body1(carry, i: int):
solves the issue.
Would it make sense to add this permanently to the code? This shouldn't affect the forward pass and would only affect the backward pass when online mode is used (i.e. batch_size != None). What do you think?
Are there any other places where @jax.remat should be added that I missed?

@michalk8
Copy link
Collaborator

michalk8 commented Aug 23, 2023

Hi @jaschau , thanks a lot, I will also take a closer look at this; in the meantime, could you please try running

df = jax.jit(jax.value_and_grad(f, has_aux=True), static_argnames=["batch_size"])

I suspect there might be also some issues with the grad function not being jitted.
UPDATE: just tried it myself, doesn't seem any additional optimizations from jitting helped.

@jaschau
Copy link
Author

jaschau commented Aug 23, 2023

Hi @michalk8, thanks a lot for looking into this. Just also tested it and also found that jitting doesn't seem to help here.

@marcocuturi
Copy link
Contributor

Hi @jaschau

thanks! just to clarify,

Unfortunately, since reverse-mode differentiation needs to store all the intermediate results, I still run out of memory when calculating the gradients of the optimal transport problem.

the problem does not come from storing many iterations, because you are using Danskin. Here when differentiating the divergence, you are differentiating 3 terms that are all reg_ot_cost, and this simply means differentiating the result of 3 large <vector, matrix vector> values, where the matrix has a Jacobian w.r.t. argument you want to differentiate.

These are all computed using compute_kl_reg_cost, where the two first f,g arguments are frozen (and outputted by Sinkhorn), and only the ot_prob.geom stores everything and is differentiated.

I suspect (I am sure Michal will think about this) that the issue might be differentiating the geometry application (geom.marginal_from_potentials(f,g)) materializes the entire derivative w.r.t the entire matrix. One possible way to do this would be to have a custom_vjp for geometries with a batch_size that would follow the same splitting strategy. But the remat approach seems very nice!

Also seems to be related to this

@michalk8
Copy link
Collaborator

I suspect (I am sure Michal will think about this) that the issue might be differentiating the geometry application (geom.marginal_from_potentials(f,g)) materializes the entire derivative w.r.t the entire matrix.

Based on the traceback, it just materializes the 50000,10000 array, which is correct (when batch_size=10000).

@marcocuturi
Copy link
Contributor

Thanks! then maybe isn't this a problem of having a batch_size that is too large? and that remat handles this better, more adaptively?

@jaschau
Copy link
Author

jaschau commented Aug 24, 2023

Hi, at least according to the OOM I posted above, it seems that's it's not just the 50000, 10000 array that is being created, but rather an array corresponding to all 5 slices of batches of size 10000.

Peak buffers:
	Buffer 1:
		Size: 9.31GiB
		Operator: op_name="jit(f)/jit(main)/broadcast_in_dim[shape=(5, 10000, 50000) broadcast_dimensions=()]" source_file="/tmp/ipykernel_2433426/603878625.py" source_line=3
		XLA Label: broadcast
		Shape: f32[5,10000,50000]
		==========================

I also tested what happens when I use the @jax.remat as indicated above with even larger point clouds. With n=300000 and batch_size=100, I also obtain a OOM with the following peak buffer

Peak buffers:
	Buffer 1:
		Size: 3.35GiB
		XLA Label: copy
		Shape: f32[3000,300000]
		==========================

So with the remat indicated above we go down from a buffer of size (3000, 100, 300000) to a buffer of size (3000, 300000) which still leads to OOM. I can also remedy this by rematting instead of only the scan body the entire scan call, i.e. by replacing the scan call in

_, (h_res, h_sign) = jax.lax.scan(
with

    @jax.remat
    def compute_h_res_h_sign(f, g, eps, vec):
        _, (h_res, h_sign) = jax.lax.scan(
            fun, init=(f, g, eps, vec), xs=jnp.arange(n)
        )
        return h_res, h_sign

    h_res, h_sign = compute_h_res_h_sign(f, g, eps, vec)

@marcocuturi
Copy link
Contributor

Sorry for not answering in a while @jaschau, I was off for a few days.

Your option to add a jax.remat decorator sounds very elegant. Please don't hesitate to push it in a PR if you have found that it answers your needs (along with adequate tests)

@jaschau
Copy link
Author

jaschau commented Sep 4, 2023

No worries. I'll be happy to provide a PR. I will have to think about how to properly test it, though, so it might take some days/weeks until I get around to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants