-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcuda.nix
49 lines (47 loc) · 1.58 KB
/
cuda.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
topLevel@{ flake-parts-lib, inputs, ... }: {
imports = [
inputs.flake-parts.flakeModules.flakeModules
./common.nix
./link-nvidia-drivers.nix
];
flake.flakeModules.cuda = {
imports = [
topLevel.config.flake.flakeModules.common
topLevel.config.flake.flakeModules.linkNvidiaDrivers
];
options.perSystem = flake-parts-lib.mkPerSystemOption ({ lib, pkgs, system, ... }: {
config = lib.mkIf (system != "aarch64-darwin") {
nixpkgs.config.allowUnfree = true;
nixpkgs.config.cudaSupport = true;
ml-ops.common = { config, ... }: {
config.LD_LIBRARY_PATH = lib.mkMerge [
"/run/opengl-driver/lib"
# bitsandbytes need to search for CUDA libraries
"${config.environmentVariables.CUDA_HOME}/lib"
];
config.devenvShellModule.packages = [
config.cuda.home
];
config.environmentVariables.CUDA_HOME = toString (config.cuda.home);
options.cuda.home = lib.mkOption {
type = lib.types.package;
default = pkgs.symlinkJoin {
name = "cuda-home";
paths = config.cuda.packages;
};
};
options.cuda.packages = lib.mkOption {
type = lib.types.listOf lib.types.package;
};
config.cuda.packages = [
pkgs.cudaPackages.cuda_nvcc
pkgs.cudaPackages.cudatoolkit
pkgs.cudaPackages.cuda_cudart
pkgs.cudaPackages.cudnn
pkgs.cudaPackages.libcublas
];
};
};
});
};
}