Skip to content

Commit

Permalink
Merge pull request #183051 from mcwitt/squashed/upgrade-jax
Browse files Browse the repository at this point in the history
python310Packages.jax: 0.3.6 -> 0.3.16
  • Loading branch information
samuela authored Aug 30, 2022
2 parents 78e892c + 6bff360 commit 2307799
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 43 deletions.
12 changes: 0 additions & 12 deletions pkgs/development/python-modules/jax/cache-fix.patch

This file was deleted.

22 changes: 7 additions & 15 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
, absl-py
, blas
, buildPythonPackage
, etils
, fetchFromGitHub
, fetchpatch
, jaxlib
, lapack
, matplotlib
, numpy
, opt-einsum
, pytestCheckHook
Expand All @@ -20,7 +21,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.3.6";
version = "0.3.16";
format = "setuptools";

disabled = pythonOlder "3.7";
Expand All @@ -29,34 +30,25 @@ buildPythonPackage rec {
owner = "google";
repo = pname;
rev = "jax-v${version}";
hash = "sha256-eGdAEZFHadNTHgciP4KMYHdwksz9g6un0Ar+A/KV5TE=";
hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I=";
};

patches = [
# See https://github.com/google/jax/issues/7944
./cache-fix.patch

# See https://github.com/google/jax/issues/10292
(fetchpatch {
url = "https://github.com/google/jax/commit/cadc8046d56e0c1433cf48a2f106947d5f4ecbfd.patch";
hash = "sha256-jrpIqt4LzWAswt/Cpwtfa5d1Yn31HcXkVH3ETmaigA0=";
})
];

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [
absl-py
etils
numpy
opt-einsum
scipy
typing-extensions
];
] ++ etils.optional-dependencies.epath;

checkInputs = [
jaxlib
matplotlib
pytestCheckHook
pytest-xdist
];
Expand Down
39 changes: 23 additions & 16 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
, buildBazelPackage
, buildPythonPackage
, cctools
, curl
, cython
, fetchFromGitHub
, git
, IOKit
, jsoncpp
, nsync
, openssl
, pybind11
, setuptools
, symlinkJoin
Expand Down Expand Up @@ -50,7 +53,7 @@ let
inherit (cudaPackages) cudatoolkit cudnn nccl;

pname = "jaxlib";
version = "0.3.0";
version = "0.3.15";

meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
Expand Down Expand Up @@ -93,7 +96,7 @@ let
owner = "google";
repo = "jax";
rev = "${pname}-v${version}";
sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
};

nativeBuildInputs = [
Expand All @@ -103,15 +106,19 @@ let
setuptools
wheel
which
] ++ lib.optionals stdenv.isDarwin [
cctools
];

buildInputs = [
curl
double-conversion
giflib
grpc
jsoncpp
libjpeg_turbo
numpy
openssl
pkgs.flatbuffers
pkgs.protobuf
pybind11
Expand All @@ -124,6 +131,8 @@ let
cudnn
] ++ lib.optionals stdenv.isDarwin [
IOKit
] ++ lib.optionals (!stdenv.isDarwin) [
nsync
];

postPatch = ''
Expand All @@ -149,6 +158,7 @@ let
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
'' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
Expand All @@ -163,7 +173,7 @@ let
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
TF_SYSTEM_LIBS = lib.concatStringsSep "," [
TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
"absl_py"
"astor_archive"
"astunparse_archive"
Expand All @@ -179,7 +189,6 @@ let
"cython"
"dill_archive"
"double_conversion"
"enum34_archive"
"flatbuffers"
"functools32_archive"
"gast_archive"
Expand All @@ -190,11 +199,9 @@ let
"libjpeg_turbo"
"lmdb"
"nasm"
# "nsync" # not packaged in nixpkgs
"opt_einsum_archive"
"org_sqlite"
"pasta"
"pcre"
"png"
"pybind11"
"six_archive"
Expand All @@ -204,7 +211,9 @@ let
"typing_extensions_archive"
"wrapt"
"zlib"
];
] ++ lib.optionals (!stdenv.isDarwin) [
"nsync" # fails to build on darwin
]);

# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
Expand All @@ -226,9 +235,11 @@ let
fetchAttrs = {
sha256 =
if cudaSupport then
"sha256-Ald+vplRx/DDG/7TfHAqD4Gktb1BGnf7FSCCJzSI0eo="
"sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
else if stdenv.isDarwin then
"sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
else
"sha256-eK5IjTAncDarkWYKnXrEo7kw7J7iOH7in2L2GabnFYo=";
"sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
};

buildAttrs = {
Expand All @@ -239,15 +250,10 @@ let
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
# 4) Patch tensorflow sources to work with later versions of protobuf. See
# https://github.com/google/jax/issues/9534. Note that this should be
# removed on the next release after 0.3.0.
preBuild = ''
for src in ./jaxlib/*.{cc,h}; do
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
--replace "status.message()" "std::string{status.message()}"
'' + lib.optionalString cudaSupport ''
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin ''
Expand Down Expand Up @@ -275,7 +281,7 @@ let
};
platformTag =
if stdenv.targetPlatform.isLinux then
"manylinux2010_${stdenv.targetPlatform.linuxArch}"
"manylinux2014_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "x86_64-darwin" then
"macosx_10_9_${stdenv.targetPlatform.linuxArch}"
else if stdenv.system == "aarch64-darwin" then
Expand Down Expand Up @@ -306,6 +312,7 @@ buildPythonPackage {

propagatedBuildInputs = [
absl-py
curl
double-conversion
flatbuffers
giflib
Expand Down

0 comments on commit 2307799

Please sign in to comment.