Skip to content

Commit

Permalink
add cuda to jax install
Browse files Browse the repository at this point in the history
  • Loading branch information
smburbach committed May 17, 2024
1 parent 812a833 commit 6c2eb72
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
13 changes: 5 additions & 8 deletions jupyterhub/deeplearning/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ RUN python3 -m pip install \
RUN python3 -m pip cache purge

# jax
RUN python3 -m pip install jax jaxlib
# RUN wget --directory-prefix=${REQUIREMENTS_DIR} https://raw.githubusercontent.com/briney/containers/main/requirements/jax_pip.txt \
# && fix-permissions ${REQUIREMENTS_DIR}
# RUN python3 -m pip install \
# --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
# -r ${REQUIREMENTS_DIR}/jax_pip.txt
RUN wget --directory-prefix=${REQUIREMENTS_DIR} https://raw.githubusercontent.com/briney/containers/main/requirements/jax_pip.txt \
&& fix-permissions ${REQUIREMENTS_DIR}
RUN python3 -m pip install \
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
-r ${REQUIREMENTS_DIR}/jax_pip.txt
RUN python3 -m pip cache purge

# AI/ML packages (including 🤗)
Expand All @@ -57,8 +56,6 @@ RUN wget --directory-prefix=${REQUIREMENTS_DIR} https://raw.githubusercontent.co
RUN python3 -m pip install -r ${REQUIREMENTS_DIR}/ai-ml_pip.txt
RUN python3 -m pip cache purge

# jupyterlab extensions

# the stable version of jupyterlab_nvdashboard doesn't work with jupyterlab 4.0 yet, so we need to install the pre-release version
RUN python3 -m pip install --extra-index-url https://pypi.anaconda.org/rapidsai-wheels-nightly/simple --pre jupyterlab_nvdashboard
# RUN wget --directory-prefix=${REQUIREMENTS_DIR} https://raw.githubusercontent.com/briney/containers/main/requirements/nvidia_pip.txt \
Expand Down
1 change: 0 additions & 1 deletion kubeflow/deeplearning/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ RUN python3 -m pip install \
--find-links https://download.pytorch.org/whl/torch_stable.html \
-r ${REQUIREMENTS_DIR}/torch_pip.txt
RUN python3 -m pip cache purge
# RUN python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118

# jax
RUN wget --directory-prefix=${REQUIREMENTS_DIR} https://raw.githubusercontent.com/briney/containers/main/requirements/jax_pip.txt \
Expand Down
3 changes: 2 additions & 1 deletion requirements/jax_pip.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
jax[cuda11_pip] # removed to save space
jax==0.4.28
jaxlib==0.4.28+cuda12.cudnn89

0 comments on commit 6c2eb72

Please sign in to comment.