From 5bda2ec626a4107f8adc3a7e58c16c0594c40d2c Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Fri, 24 Nov 2023 13:47:03 +0000 Subject: [PATCH] cudaPackagesGoogle: init, a package-set for jax and tf --- pkgs/development/python-modules/jaxlib/bin.nix | 8 ++++---- .../python-modules/jaxlib/default.nix | 4 ++-- .../python-modules/tensorflow/bin.nix | 6 +++--- .../python-modules/tensorflow/default.nix | 16 ++++++++-------- pkgs/top-level/all-packages.nix | 4 ++++ pkgs/top-level/python-packages.nix | 1 - 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index d80cbc2a60183..e35b4759bd64f 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -29,11 +29,11 @@ , stdenv # Options: , cudaSupport ? config.cudaSupport -, cudaPackages ? {} +, cudaPackagesGoogle }: let - inherit (cudaPackages) cudatoolkit cudnn; + inherit (cudaPackagesGoogle) cudatoolkit cudnn; version = "0.4.20"; @@ -210,8 +210,8 @@ buildPythonPackage { maintainers = with maintainers; [ samuela ]; platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; broken = - !(cudaSupport -> (cudaPackages ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1") - || !(cudaSupport -> (cudaPackages ? cudnn) && lib.versionAtLeast cudnn.version "8.2") + !(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1") + || !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2") || !(cudaSupport -> stdenv.isLinux); }; } diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index c70ab0ac2b327..a04d6973ca4be 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -44,14 +44,14 @@ , config # CUDA flags: , cudaSupport ? config.cudaSupport -, cudaPackages ? {} +, cudaPackagesGoogle # MKL: , mklSupport ? true }: let - inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; + inherit (cudaPackagesGoogle) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; version = "0.4.20"; diff --git a/pkgs/development/python-modules/tensorflow/bin.nix b/pkgs/development/python-modules/tensorflow/bin.nix index fa70e4cc4a305..1040023619262 100644 --- a/pkgs/development/python-modules/tensorflow/bin.nix +++ b/pkgs/development/python-modules/tensorflow/bin.nix @@ -22,7 +22,7 @@ , tensorboard , config , cudaSupport ? config.cudaSupport -, cudaPackages ? {} +, cudaPackagesGoogle , zlib , python , keras-applications @@ -43,7 +43,7 @@ assert ! (stdenv.isDarwin && cudaSupport); let packages = import ./binary-hashes.nix; - inherit (cudaPackages) cudatoolkit cudnn; + inherit (cudaPackagesGoogle) cudatoolkit cudnn; in buildPythonPackage { pname = "tensorflow" + lib.optionalString cudaSupport "-gpu"; inherit (packages) version; @@ -200,7 +200,7 @@ in buildPythonPackage { ]; passthru = { - inherit cudaPackages; + cudaPackages = cudaPackagesGoogle; }; meta = with lib; { diff --git a/pkgs/development/python-modules/tensorflow/default.nix b/pkgs/development/python-modules/tensorflow/default.nix index c8e292e316744..be8b26f3d0e99 100644 --- a/pkgs/development/python-modules/tensorflow/default.nix +++ b/pkgs/development/python-modules/tensorflow/default.nix @@ -19,8 +19,8 @@ # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0 , config , cudaSupport ? config.cudaSupport -, cudaPackages ? { } -, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities +, cudaPackagesGoogle +, cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities , mklSupport ? false, mkl , tensorboardSupport ? true # XLA without CUDA is broken @@ -50,15 +50,15 @@ let # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the # translation units, so the build fails at link time stdenv = - if cudaSupport then cudaPackages.backendStdenv + if cudaSupport then cudaPackagesGoogle.backendStdenv else if originalStdenv.isDarwin then llvmPackages_11.stdenv else originalStdenv; - inherit (cudaPackages) cudatoolkit nccl; + inherit (cudaPackagesGoogle) cudatoolkit nccl; # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu) # cudaPackages.cudnn led to this: # https://github.com/tensorflow/tensorflow/issues/60398 cudnnAttribute = "cudnn_8_6"; - cudnn = cudaPackages.${cudnnAttribute}; + cudnn = cudaPackagesGoogle.${cudnnAttribute}; gentoo-patches = fetchzip { url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2"; hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs="; @@ -486,8 +486,8 @@ let broken = stdenv.isDarwin || !(xlaSupport -> cudaSupport) - || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages) - || !(cudaSupport -> cudaPackages ? cudatoolkit); + || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle) + || !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit); } // lib.optionalAttrs stdenv.isDarwin { timeout = 86400; # 24 hours maxSilent = 14400; # 4h, double the default of 7200s @@ -590,7 +590,7 @@ in buildPythonPackage { # Regression test for #77626 removed because not more `tensorflow.contrib`. passthru = { - inherit cudaPackages; + cudaPackages = cudaPackagesGoogle; deps = bazel-build.deps; libtensorflow = bazel-build.out; }; diff --git a/pkgs/top-level/all-packages.nix b/pkgs/top-level/all-packages.nix index e32f668e0d653..e7ee7c01292bb 100644 --- a/pkgs/top-level/all-packages.nix +++ b/pkgs/top-level/all-packages.nix @@ -7318,6 +7318,10 @@ with pkgs; cudaPackages_12_2 = callPackage ./cuda-packages.nix { cudaVersion = "12.2"; }; cudaPackages_12 = cudaPackages_12_0; + # Use the older cudaPackages for tensorflow and jax, as determined by cudnn + # compatibility: https://www.tensorflow.org/install/source#gpu + cudaPackagesGoogle = cudaPackages_11; + # TODO: try upgrading once there is a cuDNN release supporting CUDA 12. No # such cuDNN release as of 2023-01-10. cudaPackages = recurseIntoAttrs cudaPackages_11; diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index ebf97327f53fc..d2abc3c70698f 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -13924,7 +13924,6 @@ self: super: with self; { callPackage ../development/python-modules/tensorflow { inherit (pkgs.darwin) cctools; inherit (pkgs.config) cudaSupport; - inherit (self.tensorflow-bin) cudaPackages; inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security; flatbuffers-core = pkgs.flatbuffers; flatbuffers-python = self.flatbuffers;