-
Notifications
You must be signed in to change notification settings - Fork 8
/
test.py
85 lines (64 loc) · 2.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torchvision
def test_timeembedding():
from models_dis import timestep_embedding, DisModel
times_steps = torch.randint(1, 100, (1,))
print(timestep_embedding(times_steps, 1000))
def test_dismodel():
from models_dis import timestep_embedding, DiS_models
from thop import profile
for k, v in DiS_models.items():
print(k)
model = v(img_size=32).cuda()
input_image = torch.randn(1, 3, 32, 32).cuda()
times_steps = torch.randint(1, 100, (1,)).cuda()
flops, _ = profile(model, inputs=(input_image, times_steps ))
# out = model(x=input_image, timesteps=times_steps)
#print(out.size())
print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
parameters_sum = sum(x.numel() for x in model.parameters())
print(parameters_sum / 1000000.0, "M")
def test_cifar10():
data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data"
cifar10 = torchvision.datasets.CIFAR10(
root=data_path,
train=True,
download=False
)
cifar10_test = torchvision.datasets.CIFAR10(
root=data_path,
train=False,
download=False
)
print(cifar10)
print(cifar10_test[0])
def test_imagenet1k():
data_path = '/TrainData/Multimodal/public/datasets/ImageNet/train'
import torchvision.datasets as datasets
dataset_train = datasets.ImageFolder(data_path)
print(dataset_train[0])
def test_celeba():
from datasets import load_dataset
data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data/CelebA"
dataset = load_dataset(data_path)
# dataset = dataset['train']
# dataset = dataset.map(lambda e: e['image'].convert('RGB'), batched=True)
#print(dataset[0])
print(dataset['train'][0].keys())
#print(dataset['train'][0]['image'].convert("RGB"))
# print(len(dataset['train']))
def test_fid_score():
from tools.fid_score import calculate_fid_given_paths
path1 = '/TrainData/Multimodal/zhengcong.fei/dis/results/cond_cifar10_small/his'
path2 = '/TrainData/Multimodal/zhengcong.fei/dis/results/uncond_cifar10_small/his'
fid = calculate_fid_given_paths((path1, path2))
def test_vae():
from diffusers.models import AutoencoderKL
vae_path = '/TrainData/Multimodal/zhengcong.fei/dis/vae'
vae = AutoencoderKL.from_pretrained(vae_path)
# test_dismodel()
# test_cifar10()
# test_imagenet1k()
# test_celeba()
# test_fid_score()
# test_vae()