-
Notifications
You must be signed in to change notification settings - Fork 0
/
VanVAE.py
53 lines (44 loc) · 2.33 KB
/
VanVAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from models.VAE.VanillaVAE import *
from data.Dataloaders import *
import torch
from utils.util import parse_args_VanillaVAE
import wandb
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parse_args_VanillaVAE()
size = None
if args.train:
# train dataloader
train_loader, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size, num_workers=args.num_workers)
if not args.no_wandb:
wandb.init(project='VAE',
config={
'dataset': args.dataset,
'batch_size': args.batch_size,
'n_epochs': args.n_epochs,
'lr': args.lr,
'latent_dim': args.latent_dim,
'hidden_dims': args.hidden_dims,
'input_size': in_shape,
'channels': in_channels,
'loss_type': args.loss_type,
'kld_weight': args.kld_weight
},
name = 'VAE_{}'.format(args.dataset))
# create model
model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args)
# train model
model.train_model(train_loader, args.n_epochs)
elif args.sample:
_, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size)
model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args)
model.load_state_dict(torch.load(args.checkpoint))
model.sample(title="Sample", train = False)
elif args.outlier_detection:
in_loader, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size, mode='val')
out_loader, _, _ = pick_dataset(args.out_dataset, batch_size = args.batch_size, normalize=True, size = in_shape, mode='val')
model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args)
model.load_state_dict(torch.load(args.checkpoint))
model.outlier_detection(in_loader, out_loader)
else:
raise ValueError("Invalid mode. Please specify train or sample")