-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benchmarking Numba vs JAX #1266
Comments
Thanks @tomwhite that's very interesting. I think the limitations of having to use a numpy-like API rather than simple loops, and this significant performance drop on key functionality is a pretty strong argument for sticking with Numba. Numba has it's issues, but overall I'm still very positive about it. |
I did have a look at using JAX's lower-level fori_loop, but I couldn't see how it could loop over array dimensions (or if that is even possible). Also, the approach I've used for |
I managed to express |
It's pretty obscure code too. Numba FTW! |
We have used Numba to accelerate CPU computations in sgkit for a long time - and to great effect. We have previously discussed trying JAX as an alternative to see how it compares.
I created a notebook to compare basic allele counting (
count_call_alleles
) using Numba with a JAX equivalent here: https://github.com/tomwhite/sgkit/blob/134f15c18558458f6b04607cca9915e5ca90acf6/benchmark-numba-vs-jax.ipynbThe JAX code I wrote is significantly slower than Numba in this case. It could be because I've used JAX's
bincount
function, rather than writing code that operates directly on the arrays - like Numba does - but I'm not sure how to do that with JAX.The text was updated successfully, but these errors were encountered: