-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathcustom_types.py
45 lines (35 loc) · 1 KB
/
custom_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized
from enum import Enum
IS_WINDOWS = sys.platform == 'win32'
get_trace = getattr(sys, 'gettrace', None)
DEBUG = get_trace is not None and get_trace() is not None
# if DEBUG:
# seed = 99
# torch.manual_seed(seed)
# np.random.seed(seed)
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]
D = torch.device
CPU = torch.device('cpu')
def get_device(device_id: int) -> D:
if not torch.cuda.is_available():
return CPU
device_id = min(torch.cuda.device_count() - 1, device_id)
return torch.device(f'cuda:{device_id}')
CUDA = get_device