forked from ROCm/AITemplate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
default.nix
50 lines (48 loc) · 995 Bytes
/
default.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
{ pkgs ? import <nixpkgs> {
config = {
allowUnfree = true;
cudaSupport = true;
};
}}:
let
ait-deps = ps: with ps; [
pytorch-bin
pip
wheel
click
unidecode
inflect
librosa
jinja2
sympy
einops
parameterized
transformers
# (
# buildPythonPackage rec {
# pname = "cuda_python";
# version = "12.1.0";
# format = "wheel";
# src = fetchPypi {
# inherit pname version format;
# sha256 = "94506d730baade1744767e2c05d5ddd84d7fbe4c9b6f694a54a3f376f7ffa525";
# abi = "cp39";
# python = "cp39";
# platform = "manylinux_2_17_x86_64.manylinux2014_x86_64";
# };
# doCheck = false;
# }
# )
];
in
pkgs.mkShell {
buildInputs = [
pkgs.cmake
pkgs.cudatoolkit
(pkgs.python310.withPackages ait-deps)
];
shellHook = ''
export CUDA_PATH=${pkgs.cudatoolkit}
echo "You are now using a NIX environment"
'';
}