Skip to content

Commit

Permalink
Merge pull request #325222 from SomeoneSerge/fix/gpu-access/torch-bin
Browse files Browse the repository at this point in the history
python3Packages.torch-bin: gpuChecks -> tests.tester-<name>.gpuCheck
  • Loading branch information
SomeoneSerge authored Jul 7, 2024
2 parents a4f437f + 3eff201 commit 4fbd433
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
5 changes: 4 additions & 1 deletion pkgs/development/python-modules/torch/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ buildPythonPackage {

pythonImportsCheck = [ "torch" ];

passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
passthru.tests = callPackage ./tests.nix {
torchWithCuda = torch-bin;
torchWithRocm = torch-bin;
};

meta = {
description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
Expand Down
40 changes: 0 additions & 40 deletions pkgs/development/python-modules/torch/gpu-checks.nix

This file was deleted.

19 changes: 19 additions & 0 deletions pkgs/development/python-modules/torch/mk-runtime-check.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
cudaPackages,
feature,
torch,
versionAttr,
}:

cudaPackages.writeGpuTestPython
{
inherit feature;
libraries = [ torch ];
name = "${feature}Available";
}
''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
''
22 changes: 20 additions & 2 deletions pkgs/development/python-modules/torch/tests.nix
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
{ callPackage }:
{
callPackage,
torchWithCuda,
torchWithRocm,
}:

callPackage ./gpu-checks.nix { }
{
# To perform the runtime check use either
# `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
# `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}

0 comments on commit 4fbd433

Please sign in to comment.