From c37f0cda31ed217bb9679dc44a48bc16df59dad6 Mon Sep 17 00:00:00 2001 From: rainx0r Date: Thu, 22 Aug 2024 20:56:50 +0100 Subject: [PATCH] Hopefully fix JAX not seeing GPUs --- metaworld-jax/Dockerfile | 2 +- metaworld-jax/requirements.txt | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/metaworld-jax/Dockerfile b/metaworld-jax/Dockerfile index 9c95316..32af31e 100644 --- a/metaworld-jax/Dockerfile +++ b/metaworld-jax/Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && apt install -y --no-install-recommends git python3-pip lib WORKDIR /usr/src/app COPY requirements.txt ./ RUN pip install --no-cache-dir --upgrade pip && \ - pip install --no-cache-dir --upgrade -r requirements.txt + pip install --no-cache-dir --upgrade --pre -r requirements.txt ENTRYPOINT ["python"] diff --git a/metaworld-jax/requirements.txt b/metaworld-jax/requirements.txt index 172961c..8ac9bf2 100644 --- a/metaworld-jax/requirements.txt +++ b/metaworld-jax/requirements.txt @@ -1,5 +1,7 @@ # Jax -jax[cuda12]==0.4.31 +--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +jax[cuda12]==0.4.32.dev20240822 + flax==0.8.5 distrax==0.1.5