-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replicating jax's export feature to cache filter_jitted functions #879
Comments
This should definitely be possible, we'd just have to write an API that wraps the existing Under-the-hood I'd be happy to take a PR on this. |
When I try
I get this error.
I'm guessing all the extra stuff that filter_jit turns into static args wouldn't be supported for serialization? Do you think there is any way around that? |
Just following up with the minimal example to replicate the problem.
|
Right, so this is because the internal JIT'd function ( This is kind of the whole point of This is the function that's actually JIT'd: Line 43 in d9b3ffd
and here is where they are split up in this way: Lines 220 to 222 in d9b3ffd
|
Jax allows you to serialize and deserialize a jitted function as described here.
https://jax.readthedocs.io/en/latest/_autosummary/jax.export.export.html#jax.export.export
I tried this for a filter_jitted function, but received this error.
Is it possible to replicate serialization like this for
filter_jit
? I wasn't sure if that was even theoretically possible, but figured it was at least worth asking.The text was updated successfully, but these errors were encountered: