-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"JAX array is set a static" warning is raised unwantedly #863
Comments
A simple fix would be to replace Line 585 in 804d82e
but I'm not sure this is in line with the intended use. |
I would say this behavior is expected (whether not it is wanted maybe another question). Since in general numpy arrays are not hashable, and making things static is to set them as aux data in the pytree (https://github.com/patrick-kidger/equinox/blob/main/equinox/_module.py#L946), which expects hashability since "JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons." (https://jax.readthedocs.io/en/latest/pytrees.html). This can sometimes cause silent or (speaking from personal experience) confusing errors, which is why I wanted to add the warning. I agree, at the least, that the warning message is wrong because a JAX array isn't being set static (a numpy array is), so matching the message to the check should be done. As for numpy arrays, I think they were included originally because of their hash problems (to quote @/ jakevdp "Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data. If you do include such values in aux_data, you'll get unsupported, poorly-defined behavior."). That being said, there definitely are cases where using static arrays can be fine and correct (which is why the warning can be ignored as opposed to error), and if these cases are very common then the warning could be a burden. WDYT? |
But in our example, the numpy array is just an intermediary for the computation. The actual computation is starting with a static tuple, and returning a static tuple, hence why I don't find that this should be an expected warning. Also, I don't really see a way around it. The warning is being raised upon the class creation, so I don't see how we could filter this warning. In the case of our library, this will be raised everytime we make an operation on our class, and we really cannot make this class attribute not static. |
Actually, investigating more, the following works without warning: import numpy as np
import equinox as eqx
class Foo(eqx.Module):
x: tuple[int, int] = eqx.field(static=True)
def add_one(self):
x_as_np = np.asarray(self.x)
x_as_np += 1
x = tuple([i.item() for i in x_as_np])
return Foo(x)
x = (3, 2)
foo = Foo(x)
foo.add_one()
# no warning So what's being detected in the first example is that the tuple elements are of type |
Hmmm I see it yea I misread it, it's a int64 class from numpy. That would be an mis usage of the |
As of Equinox 0.11.6 and #800, the following MWE raises a
UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method
__mul__
of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally usenumpy
instead ofjax.numpy
to have "static" logic.The text was updated successfully, but these errors were encountered: