diff --git a/pkgs/development/r-modules/default.nix b/pkgs/development/r-modules/default.nix index 678604f04af2e..8937f8ce4a52f 100644 --- a/pkgs/development/r-modules/default.nix +++ b/pkgs/development/r-modules/default.nix @@ -1803,10 +1803,23 @@ let ''; }); - torch = old.torch.overrideAttrs (attrs: { - preConfigure = '' - patchShebangs configure - ''; + torch = let + # Sets the correct string to download either the binary for + # the cpu version of the torch package, or the + # the cuda-enabled version. + # To use gpu acceleration, set `config.cudaSupport = true;` + # when importing nixpkgs in your shell + accel = if pkgs.config.cudaSupport then "cu118" else "cpu"; + + binary_sha = if pkgs.config.cudaSupport then + "sha256-a80sG89C0svZzkjNRpY0rTR2P1JdvKAbWDGIIghsv2Y=" else + "sha256-qUn8Rot6ME7iTvtNd52iw3ebqMnpLz7kwl/9GoPHD+I="; + in + old.torch.overrideAttrs (attrs: { + src = pkgs.fetchzip { + url = "https://torch-cdn.mlverse.org/packages/${accel}/0.13.0/src/contrib/torch_0.13.0_R_x86_64-pc-linux-gnu.tar.gz"; + sha256 = binary_sha; + }; }); pak = old.pak.overrideAttrs (attrs: {