Skip to content

Commit

Permalink
Hopefully fix JAX not seeing GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
rainx0r committed Aug 22, 2024
1 parent ac76843 commit c37f0cd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion metaworld-jax/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ RUN apt-get update && apt install -y --no-install-recommends git python3-pip lib
WORKDIR /usr/src/app
COPY requirements.txt ./
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir --upgrade -r requirements.txt
pip install --no-cache-dir --upgrade --pre -r requirements.txt

ENTRYPOINT ["python"]

4 changes: 3 additions & 1 deletion metaworld-jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Jax
jax[cuda12]==0.4.31
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
jax[cuda12]==0.4.32.dev20240822

flax==0.8.5
distrax==0.1.5

Expand Down

0 comments on commit c37f0cd

Please sign in to comment.