Skip to content

Commit

Permalink
tensorrt: dont break eval for unrelated packages
Browse files Browse the repository at this point in the history
  • Loading branch information
SomeoneSerge committed Dec 4, 2023
1 parent 5bda2ec commit 3ee37e4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
28 changes: 22 additions & 6 deletions pkgs/development/libraries/science/math/tensorrt/extension.nix
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,32 @@ final: prev: let
isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
# Return the first file that is supported. In practice there should only ever be one anyway.
supportedFile = files: findFirst isSupported null files;
# Supported versions with versions as keys and file as value
supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);

# Compute versioned attribute name to be used in this package set
computeName = version: "tensorrt_${toUnderscore version}";

# Supported versions with versions as keys and file as value
supportedVersions = lib.recursiveUpdate
{
tensorrt = {
enable = false;
fileVersionCuda = null;
fileVersionCudnn = null;
fullVersion = "0.0.0";
sha256 = null;
tarball = null;
supportedCudaVersions = [ ];
};
}
(mapAttrs' (version: attrs: nameValuePair (computeName version) attrs)
(filterAttrs (version: file: file != null) (mapAttrs (version: files: supportedFile files) tensorRTVersions)));

# Add all supported builds as attributes
allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
allBuilds = mapAttrs (name: file: buildTensorRTPackage (removeAttrs file ["fileVersionCuda"])) supportedVersions;

# Set the default attributes, e.g. tensorrt = tensorrt_8_4;
defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion}
then allBuilds.${computeName tensorRTDefaultVersion}
else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
defaultName = computeName tensorRTDefaultVersion;
defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
in {
inherit buildTensorRTPackage;
} // allBuilds // defaultBuild;
Expand Down
15 changes: 9 additions & 6 deletions pkgs/development/libraries/science/math/tensorrt/generic.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@
, cudnn
}:

{ fullVersion
{ enable ? true
, fullVersion
, fileVersionCudnn ? null
, tarball
, sha256
, supportedCudaVersions ? [ ]
}:

assert fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
assert !enable || fileVersionCudnn == null || lib.assertMsg (lib.strings.versionAtLeast cudnn.version fileVersionCudnn)
"This version of TensorRT requires at least cuDNN ${fileVersionCudnn} (current version is ${cudnn.version})";

backendStdenv.mkDerivation rec {
pname = "cudatoolkit-${cudatoolkit.majorVersion}-tensorrt";
version = fullVersion;
src = requireFile rec {
src = if !enable then null else
requireFile rec {
name = tarball;
inherit sha256;
message = ''
Expand All @@ -38,13 +40,13 @@ backendStdenv.mkDerivation rec {

outputs = [ "out" "dev" ];

nativeBuildInputs = [
nativeBuildInputs = lib.optionals enable [
autoPatchelfHook
autoAddOpenGLRunpathHook
];

# Used by autoPatchelfHook
buildInputs = [
buildInputs = lib.optionals enable [
backendStdenv.cc.cc.lib # libstdc++
cudatoolkit
cudnn
Expand Down Expand Up @@ -75,14 +77,15 @@ backendStdenv.mkDerivation rec {
'';

passthru.stdenv = backendStdenv;
passthru.enable = enable;

meta = with lib; {
# Check that the cudatoolkit version satisfies our min/max constraints (both
# inclusive). We mark the package as broken if it fails to satisfies the
# official version constraints (as recorded in default.nix). In some cases
# you _may_ be able to smudge version constraints, just know that you're
# embarking into unknown and unsupported territory when doing so.
broken = !(elem cudaVersion supportedCudaVersions);
broken = !enable || !(elem cudaVersion supportedCudaVersions);
description = "TensorRT: a high-performance deep learning interface";
homepage = "https://developer.nvidia.com/tensorrt";
license = licenses.unfree;
Expand Down
2 changes: 1 addition & 1 deletion pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -13956,7 +13956,7 @@ self: super: with self; {

tensorly = callPackage ../development/python-modules/tensorly { };

tensorrt = callPackage ../development/python-modules/tensorrt { };
tensorrt = callPackage ../development/python-modules/tensorrt { cudaPackages = pkgs.cudaPackages_11; };

tensorstore = callPackage ../development/python-modules/tensorstore { };

Expand Down

0 comments on commit 3ee37e4

Please sign in to comment.