From 8a550337a9897539449c2458f20072793c9cd389 Mon Sep 17 00:00:00 2001 From: rainx0r Date: Tue, 20 Aug 2024 09:57:45 +0100 Subject: [PATCH] Switch to nvidia's JAX container --- metaworld-jax/Dockerfile | 6 +++--- metaworld-jax/requirements.txt | 12 +----------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/metaworld-jax/Dockerfile b/metaworld-jax/Dockerfile index dd2bb07..437a7a4 100644 --- a/metaworld-jax/Dockerfile +++ b/metaworld-jax/Dockerfile @@ -1,11 +1,11 @@ -FROM python:3.12.5-slim +FROM ghcr.io/nvidia/jax:jax-2024-08-19 LABEL maintainer="me@evangelos.ai" -RUN apt-get update && apt install -y --no-install-recommends git python3-pip libglfw3 libglfw3-dev +RUN apt-get update && apt install -y --no-install-recommends git libglfw3 libglfw3-dev WORKDIR /usr/src/app COPY requirements.txt ./ RUN pip install --no-cache-dir --upgrade pip && \ - pip install --no-cache-dir -r requirements.txt + pip install --no-cache-dir --upgrade -r requirements.txt ENTRYPOINT ["python"] diff --git a/metaworld-jax/requirements.txt b/metaworld-jax/requirements.txt index 42f1c69..f269b78 100644 --- a/metaworld-jax/requirements.txt +++ b/metaworld-jax/requirements.txt @@ -1,14 +1,4 @@ -# Cuda -nvidia-cublas-cu12~=12.2.0 -nvidia-cuda-cupti-cu12~=12.2.0 -nvidia-cuda-nvcc-cu12~=12.2.0 -nvidia-cuda-runtime-cu12~=12.2.0 -nvidia-cusparse-cu12~=12.2.0 -nvidia-nvjitlink-cu12~=12.2.0 - # Jax -jax[cuda12]==0.4.31 -flax==0.8.5 distrax==0.1.5 # Metaworld @@ -21,5 +11,5 @@ torch==2.4.0 # Logging wandb==0.17.6 tensorboard==2.17.1 -orbax-checkpoint @ git+https://github.com/google/orbax/@2ce2fb27f9786442b08ba14c8767c460dd6e8a0a#subdirectory=checkpoint +orbax-checkpoint==0.6.0