diff --git a/metaworld-jax/requirements.txt b/metaworld-jax/requirements.txt index d802543..d6f8cd6 100644 --- a/metaworld-jax/requirements.txt +++ b/metaworld-jax/requirements.txt @@ -1,5 +1,5 @@ # Jax -jax[cuda12_pip]==0.4.31 +jax[cuda12]==0.4.31 flax==0.8.5 distrax==0.1.5