Skip to content

Commit

Permalink
cudaPackagesGoogle: init, a package-set for jax and tf
Browse files Browse the repository at this point in the history
  • Loading branch information
SomeoneSerge committed Dec 4, 2023
1 parent 1e72cc2 commit 5bda2ec
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 18 deletions.
8 changes: 4 additions & 4 deletions pkgs/development/python-modules/jaxlib/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
, stdenv
# Options:
, cudaSupport ? config.cudaSupport
, cudaPackages ? {}
, cudaPackagesGoogle
}:

let
inherit (cudaPackages) cudatoolkit cudnn;
inherit (cudaPackagesGoogle) cudatoolkit cudnn;

version = "0.4.20";

Expand Down Expand Up @@ -210,8 +210,8 @@ buildPythonPackage {
maintainers = with maintainers; [ samuela ];
platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
broken =
!(cudaSupport -> (cudaPackages ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|| !(cudaSupport -> (cudaPackages ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
!(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|| !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
|| !(cudaSupport -> stdenv.isLinux);
};
}
4 changes: 2 additions & 2 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
, config
# CUDA flags:
, cudaSupport ? config.cudaSupport
, cudaPackages ? {}
, cudaPackagesGoogle

# MKL:
, mklSupport ? true
}:

let
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
inherit (cudaPackagesGoogle) backendStdenv cudatoolkit cudaFlags cudnn nccl;

pname = "jaxlib";
version = "0.4.20";
Expand Down
6 changes: 3 additions & 3 deletions pkgs/development/python-modules/tensorflow/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
, tensorboard
, config
, cudaSupport ? config.cudaSupport
, cudaPackages ? {}
, cudaPackagesGoogle
, zlib
, python
, keras-applications
Expand All @@ -43,7 +43,7 @@ assert ! (stdenv.isDarwin && cudaSupport);

let
packages = import ./binary-hashes.nix;
inherit (cudaPackages) cudatoolkit cudnn;
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
in buildPythonPackage {
pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
inherit (packages) version;
Expand Down Expand Up @@ -200,7 +200,7 @@ in buildPythonPackage {
];

passthru = {
inherit cudaPackages;
cudaPackages = cudaPackagesGoogle;
};

meta = with lib; {
Expand Down
16 changes: 8 additions & 8 deletions pkgs/development/python-modules/tensorflow/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
, config
, cudaSupport ? config.cudaSupport
, cudaPackages ? { }
, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities
, cudaPackagesGoogle
, cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities
, mklSupport ? false, mkl
, tensorboardSupport ? true
# XLA without CUDA is broken
Expand Down Expand Up @@ -50,15 +50,15 @@ let
# __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
# translation units, so the build fails at link time
stdenv =
if cudaSupport then cudaPackages.backendStdenv
if cudaSupport then cudaPackagesGoogle.backendStdenv
else if originalStdenv.isDarwin then llvmPackages_11.stdenv
else originalStdenv;
inherit (cudaPackages) cudatoolkit nccl;
inherit (cudaPackagesGoogle) cudatoolkit nccl;
# use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
# cudaPackages.cudnn led to this:
# https://github.com/tensorflow/tensorflow/issues/60398
cudnnAttribute = "cudnn_8_6";
cudnn = cudaPackages.${cudnnAttribute};
cudnn = cudaPackagesGoogle.${cudnnAttribute};
gentoo-patches = fetchzip {
url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
Expand Down Expand Up @@ -486,8 +486,8 @@ let
broken =
stdenv.isDarwin
|| !(xlaSupport -> cudaSupport)
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
|| !(cudaSupport -> cudaPackages ? cudatoolkit);
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle)
|| !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit);
} // lib.optionalAttrs stdenv.isDarwin {
timeout = 86400; # 24 hours
maxSilent = 14400; # 4h, double the default of 7200s
Expand Down Expand Up @@ -590,7 +590,7 @@ in buildPythonPackage {
# Regression test for #77626 removed because not more `tensorflow.contrib`.

passthru = {
inherit cudaPackages;
cudaPackages = cudaPackagesGoogle;
deps = bazel-build.deps;
libtensorflow = bazel-build.out;
};
Expand Down
4 changes: 4 additions & 0 deletions pkgs/top-level/all-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -7318,6 +7318,10 @@ with pkgs;
cudaPackages_12_2 = callPackage ./cuda-packages.nix { cudaVersion = "12.2"; };
cudaPackages_12 = cudaPackages_12_0;

# Use the older cudaPackages for tensorflow and jax, as determined by cudnn
# compatibility: https://www.tensorflow.org/install/source#gpu
cudaPackagesGoogle = cudaPackages_11;

# TODO: try upgrading once there is a cuDNN release supporting CUDA 12. No
# such cuDNN release as of 2023-01-10.
cudaPackages = recurseIntoAttrs cudaPackages_11;
Expand Down
1 change: 0 additions & 1 deletion pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -13924,7 +13924,6 @@ self: super: with self; {
callPackage ../development/python-modules/tensorflow {
inherit (pkgs.darwin) cctools;
inherit (pkgs.config) cudaSupport;
inherit (self.tensorflow-bin) cudaPackages;
inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security;
flatbuffers-core = pkgs.flatbuffers;
flatbuffers-python = self.flatbuffers;
Expand Down

0 comments on commit 5bda2ec

Please sign in to comment.