-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathlgan.py
35 lines (27 loc) · 931 Bytes
/
lgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import numpy as np
import h5py
from utils import ensure_dir
from config import ConfigLGAN
from trainer import TrainerLatentWGAN
from dataset.lgan_dataset import get_dataloader
cfg = ConfigLGAN()
print("data path:", cfg.data_root)
agent = TrainerLatentWGAN(cfg)
if not cfg.test:
# load from checkpoint if provided
if cfg.cont:
agent.load_ckpt(cfg.ckpt)
# create dataloader
train_loader = get_dataloader(cfg)
agent.train(train_loader)
else:
# load trained weights
agent.load_ckpt(cfg.ckpt)
# run generator
generated_shape_codes = agent.generate(cfg.n_samples)
# save generated z
save_path = os.path.join(cfg.exp_dir, "results/fake_z_ckpt{}_num{}.h5".format(cfg.ckpt, cfg.n_samples))
ensure_dir(os.path.dirname(save_path))
with h5py.File(save_path, 'w') as fp:
fp.create_dataset("zs", shape=generated_shape_codes.shape, data=generated_shape_codes)