-
Notifications
You must be signed in to change notification settings - Fork 0
/
call_methods.py
262 lines (200 loc) · 7.63 KB
/
call_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import argparse
from typing import Union
import torch
from data.datasets import BaseDataset
from model.models import BaseModel
from option.enums import ModelNames, NetworkNames, DatasetNames
def make_model(model_name: str, *args, **kwargs) -> Union[BaseModel, BaseModel]:
"""
Creates a model from the given model name
Parameters
----------
model_name: str
The name of the model to create
*args: list
The arguments to pass to the model constructor
**kwargs: dict
The keyword arguments to pass to the model constructor
Returns
-------
model: BaseModel
The created model
"""
model = None
# ========= Make Model Instance for all DDPM Variants without Condition =========
# --- DDPM ---
if model_name.lower() == ModelNames.DDPM:
from model.ddpm import DDPM
model = DDPM(*args, **kwargs)
model.train_method = model.train
print("DDPM 35M will be trained using the 'DDPM.train method'...")
# --- DDPM with EMA ---
elif model_name.lower() == ModelNames.DDPMwithEMA:
from model.ddpm import DDPM
opt = kwargs["opt"]
opt.ema_apply = True
model = DDPM(*args, **kwargs)
print(f"Final DDPM ema_apply setting: {opt.ema_apply}")
model.train_method = model.train
print("DDPM will be trained using the 'DDPM.train method' with EMA...")
# --- DDPM with Power Law EMA ---
elif model_name.lower() == ModelNames.DDPMwithPowerLawEMA:
from model.ddpm import DDPM
opt = kwargs["opt"]
opt.power_ema_apply = True
model = DDPM(*args, **kwargs)
print(f"Final DDPM power_law_ema_apply setting: {opt.power_ema_apply}")
model.train_method = model.train
print(
"DDPM will be trained using the 'DDPM.train method' with Power Law EMA..."
)
# ========= Make Model Instance for all DDPM Classifier Free Guidance Variants =========
# --- DDPM with CFG ---
elif model_name.lower() == ModelNames.CFG_DDPM:
from model.ddpm import DDPMCFG
model = DDPMCFG(*args, **kwargs)
model.train_method = model.cfg_train
print("DDPM CFG will be trained using the 'DDPMCFG.cfg_train method'...")
# --- DDPM with CFG and EMA ---
elif model_name.lower() == ModelNames.CFG_DDPM_EMA:
from model.ddpm import DDPMCFG
opt = kwargs["opt"]
opt.ema_apply = True
model = DDPMCFG(*args, **kwargs)
print(f"Final CFG DDPM ema_apply setting: {opt.ema_apply}")
model.train_method = model.cfg_train
print(
"DDPM CFG will be trained using the 'DDPMCFG.cfg_train method' with EMA..."
)
# --- DDPM with CFG and Power Law EMA ---
elif model_name.lower() == ModelNames.CFG_DDPM_PowerLawEMA:
from model.ddpm import DDPMCFG
opt = kwargs["opt"]
opt.power_ema_apply = True
model = DDPMCFG(*args, **kwargs)
print(f"Final DDPM CFG power_law_ema_apply setting: {opt.power_ema_apply}")
model.train_method = model.cfg_train
print(
"DDPM CFG will be trained using the 'DDPMCFG.cfg_train method' with Power Law EMA..."
)
# ========= Make Model Instance for all DDPM Classifier Free Guidance ++ Variants =========
# --- DDPM with CFG ++ ---
elif model_name.lower() == ModelNames.CFG_Plus_DDPM:
from model.ddpm import DDPMCFG
model = DDPMCFG(*args, **kwargs)
model.train_method = model.cfg_plus_train
print(
"DDPM CFG ++ will be trained using the 'DDPMCFG.cfg_plus_train method'..."
)
# --- DDPM with CFG ++ and EMA ---
elif model_name.lower() == ModelNames.CFG_Plus_DDPM_EMA:
from model.ddpm import DDPMCFG
opt = kwargs["opt"]
opt.ema_apply = True
model = DDPMCFG(*args, **kwargs)
print(f"Final DDPM CFG ++ ema_apply setting: {opt.ema_apply}")
model.train_method = model.cfg_plus_train
print(
"DDPM CFG ++ will be trained using the 'DDPMCFG.cfg_plus_train method' with EMA..."
)
# --- DDPM with CFG ++ and Power Law EMA ---
elif model_name.lower() == ModelNames.CFG_Plus_DDPM_PowerLawEMA:
from model.ddpm import DDPMCFG
opt = kwargs["opt"]
opt.power_ema_apply = True
model = DDPMCFG(*args, **kwargs)
print(f"Final DDPM CFG ++ power_law_ema_apply setting: {opt.power_ema_apply}")
model.train_method = model.cfg_plus_train
print(
"DDPM CFG ++ will be trained using the 'DDPMCFG.cfg_plus_train method' with Power Law EMA..."
)
else:
raise ValueError(f"Invalid model name: {model_name}")
print(f"Model {model_name} was created")
return model
def make_network(network_name: str, *args, **kwargs) -> torch.nn.Module:
"""
Creates a network from the given network name
Parameters
----------
network_name: str
The name of the network to create
*args: list
The arguments to pass to the network constructor
**kwargs: dict
The keyword arguments to pass to the network constructor
Returns
-------
network: torch.nn.Module
The created network
"""
network = None
# ------ Network instance for DDPM ------
if network_name.lower() == NetworkNames.DDPM_Unet:
from model.unet import UNet
network = UNet(*args, **kwargs)
# ------ Network instance for DDPM with CFG and CFG ++ ------
elif network_name.lower() == NetworkNames.CFG_Unet:
from model.unet import UnetCFG
network = UnetCFG(*args, **kwargs)
else:
raise ValueError(f"Invalid network name: {network_name}")
print(f"Network {network_name} was created")
return network
def make_dataset(dataset_name: str, opt: argparse.Namespace, *args, **kwargs):
"""
Creates a dataset from the given dataset name
Parameters
----------
dataset_name: str
The name of the dataset to create
opt: argparse.Namespace
The training options
*args: list
The arguments to pass to the dataset constructor
**kwargs: dict
The keyword arguments to pass to the dataset constructor
Returns
-------
dataset: BaseDataset
The created dataset
"""
dataset = None
if dataset_name.lower() == DatasetNames.MNIST:
from data.mnist import MNISTDataset, MNISTTest
train_dataset = MNISTDataset(opt, *args, **kwargs)
test_dataset = MNISTTest(opt, *args, **kwargs)
dataset = (train_dataset, test_dataset)
elif dataset_name.lower() == DatasetNames.BIOLOGICAL:
from data.topographies import BiologicalObservation
train_dataset = BiologicalObservation(opt, *args, **kwargs)
dataset = (train_dataset,)
else:
raise ValueError(f"Invalid dataset name: {dataset_name}")
for d in dataset:
make_dataloader(
d,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers,
pin_memory=True,
)
d.print_dataloader_info()
print(f"Dataset {dataset_name} was created")
return dataset
def make_dataloader(dataset: BaseDataset, *args, **kwargs) -> None:
"""
Creates a dataloader from the given dataset
Parameters
----------
dataset: torch.utils.data.Dataset
The dataset to create the dataloader from
*args: list
The arguments to pass to the dataloader constructor
**kwargs: dict
The keyword arguments to pass to the dataloader constructor
Returns
-------
None
"""
dataset.dataloader = torch.utils.data.DataLoader(dataset, *args, **kwargs)