-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmisc.py
70 lines (56 loc) · 2.54 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from typing import List
import ml_collections
import argparse
from core.diffusion.schedule import NamedSchedule
from core.diffusion.sde import VPSDE
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_schedule(schedule):
if schedule.startswith('linear') or schedule.startswith('cosine'):
typ, N = schedule.split('_')
N = int(N)
return NamedSchedule(typ, N)
elif schedule.startswith('vpsde'):
typ, N = schedule.split('_')
N = int(N)
return VPSDE().get_schedule(N)
else:
raise NotImplementedError
def parse_sde(sde):
if sde == 'vpsde':
return VPSDE()
else:
raise NotImplementedError
def sub_dict(dct: dict, *keys):
return {key: dct[key] for key in keys if key in dct}
def dict2str(dct):
pairs = []
for key, val in dct.items():
pairs.append("{}_{}".format(key, val))
return "_".join(pairs)
def create_sample_config(get_config_fn, workspace, ckpt: str, hparams: dict, keys: List[str], description=None):
description = description or dict2str({key: hparams[key] for key in keys if key in hparams}) # a description of the hparams
path = os.path.join(workspace, f'evaluate/evaluator/sample2dir/{ckpt}/{description}')
config = get_config_fn(path=path, task='sample2dir', **hparams)
config.workspace = workspace
config.backup_root = os.path.join(workspace, f'evaluate/evaluator/sample2dir/{ckpt}/reproducibility/{description}')
config.interact = interact = ml_collections.ConfigDict()
interact.fname_log = os.path.join(workspace, f'evaluate/evaluator/sample2dir/{ckpt}/{description}.log')
return config
def create_nll_config(get_config_fn, workspace, ckpt: str, hparams: dict, keys: List[str], description=None):
description = description or dict2str({key: hparams[key] for key in keys if key in hparams})
fname = os.path.join(workspace, f'evaluate/evaluator/nll/{ckpt}/{description}.pth')
config = get_config_fn(fname=fname, task='nll', **hparams)
config.workspace = workspace
config.backup_root = os.path.join(workspace, f'evaluate/evaluator/nll/{ckpt}/reproducibility/{description}')
config.interact = interact = ml_collections.ConfigDict()
interact.fname_log = os.path.join(workspace, f'evaluate/evaluator/nll/{ckpt}/{description}.log')
return config