From b4f2e71e44c58d1134ec6810a1af84cf6269abbc Mon Sep 17 00:00:00 2001 From: theo-barfoot Date: Mon, 21 Oct 2024 18:04:23 +0000 Subject: [PATCH] jax[cuda] working --- .devcontainer/Dockerfile | 2 +- .devcontainer/environment.yml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3ef5303..7048eda 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -30,7 +30,7 @@ ARG CUDATOOLKIT_VERSION=12.4 RUN conda install pytorch=${PYTORCH_VERSION} pytorch-cuda=${CUDATOOLKIT_VERSION} -c pytorch -c nvidia # Handle environment.yml if it exists -RUN echo env_change_20241021 +RUN echo env_change_20241021_2 COPY environment.yml* noop.txt /tmp/conda-tmp/ RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then \ /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; \ diff --git a/.devcontainer/environment.yml b/.devcontainer/environment.yml index 6e30dbb..369bdc7 100644 --- a/.devcontainer/environment.yml +++ b/.devcontainer/environment.yml @@ -5,7 +5,6 @@ channels: dependencies: - numpy - cupy - - jax[cuda] - scipy - pre-commit==3.7.1 - black==24.4.2 @@ -18,3 +17,6 @@ dependencies: - libmamba - libmambapy - libarchive + - pip + - pip: + - "jax[cuda12]"