Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking Numba vs JAX #1266

Open
tomwhite opened this issue Oct 7, 2024 · 4 comments
Open

Benchmarking Numba vs JAX #1266

tomwhite opened this issue Oct 7, 2024 · 4 comments

Comments

@tomwhite
Copy link
Collaborator

tomwhite commented Oct 7, 2024

We have used Numba to accelerate CPU computations in sgkit for a long time - and to great effect. We have previously discussed trying JAX as an alternative to see how it compares.

I created a notebook to compare basic allele counting (count_call_alleles) using Numba with a JAX equivalent here: https://github.com/tomwhite/sgkit/blob/134f15c18558458f6b04607cca9915e5ca90acf6/benchmark-numba-vs-jax.ipynb

The JAX code I wrote is significantly slower than Numba in this case. It could be because I've used JAX's bincount function, rather than writing code that operates directly on the arrays - like Numba does - but I'm not sure how to do that with JAX.

@jeromekelleher
Copy link
Collaborator

Thanks @tomwhite that's very interesting. I think the limitations of having to use a numpy-like API rather than simple loops, and this significant performance drop on key functionality is a pretty strong argument for sticking with Numba. Numba has it's issues, but overall I'm still very positive about it.

@tomwhite
Copy link
Collaborator Author

tomwhite commented Oct 8, 2024

I did have a look at using JAX's lower-level fori_loop, but I couldn't see how it could loop over array dimensions (or if that is even possible).

Also, the approach I've used for count_call_alleles won't work forhardy_weinberg_test (as @eric-czech pointed out to me), since it is sufficiently complicated to not be expressible in NumPy-like operations, and would therefore need loops.

@tomwhite
Copy link
Collaborator Author

I did have a look at using JAX's lower-level fori_loop, but I couldn't see how it could loop over array dimensions (or if that is even possible).

I managed to express count_alleles using the lower-level jax.lax.scan primitive - but it's still an order of magnitude slower than the Numba equivalent.

Notebook at https://github.com/sgkit-dev/sgkit/blob/df5f7a146ba64a0654d00bc5310e7b64c329b19e/benchmark-numba-vs-jax.ipynb

@jeromekelleher
Copy link
Collaborator

It's pretty obscure code too. Numba FTW!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants