-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
41 lines (36 loc) · 1.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# utilities for CVP-MVSNet
# by: Jiayu
# Date: 2019-08-13
# Note: Part of the code is modified from the MVSNet_pytorch project by xy-guo
# Link: https://github.com/xy-guo/MVSNet_pytorch
# Thanks the author for such great code!
import numpy as np
import torchvision.utils as vutils
import torch
import torch.nn.functional as F
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars):
if isinstance(vars, list):
return [wrapper(x) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v) for k, v in vars.items()}
else:
return func(vars)
return wrapper
@make_recursive_func
def tocuda(vars):
if isinstance(vars, torch.Tensor):
return vars.cuda()
elif isinstance(vars, str):
return vars
elif isinstance(vars, float):
return torch.tensor(vars).cuda()
elif isinstance(vars, np.ndarray):
return torch.tensor(vars).cuda()
elif isinstance(vars, int):
return torch.tensor(vars).cuda()
else:
raise NotImplementedError("invalid input type {}".format(type(vars)))