-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-implement Nx.LinAlg.eigh as defn #1027
Comments
Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy |
Do you have a basic test I could use to compare them? |
@christianjgreen in short, you can just compare the execution time of SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood. |
Thanks for the info! |
Update: After adding some optimizations to the QR algorithm, I got it down to ~8 seconds which is still twice as slow as the jacobi method, which makes me think there is something else that can be optimized. What would be best for the library owners? Starting work on a QDWH-eigh with the jacobi method, or try to optimize the QR code so that it falls within its big O predictions? |
Last update and sorry for all the pings! Even those QR-eigh decomposition is supposed to be much faster than Jacobi on large matrices, I can't seem to at least get it to match the performance of the jacobi method, which leads me to believe something is amiss with the way the QR algorithm gets compiled down. I've tried a few things like wilkinson shifts, deflating, and only checking subdiags but the iterations still grow too high before converging. Current testing on my naïve implementation with the default 1000 iterations and an eps of I'll defer to @polvalente and @josevalim for next steps as I'm a complete newbie here. |
Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA.
However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).
Especially since we already have QDWH implemented in
Nx.LinAlg.SVD.qdwh
, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (likeNx.LinAlg.svd
)The text was updated successfully, but these errors were encountered: