-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathpytorch_std_mean.py
41 lines (30 loc) · 1.2 KB
/
pytorch_std_mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
Code for calculating the mean and standard deviation of a dataset.
This is useful for normalizing the dataset to obtain mean 0, std 1.
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
* 2020-05-09 Initial coding
* 2022-12-16 Updated comments, code revision, and checked code still works with latest PyTorch.
"""
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_set = datasets.CIFAR10(
root="dataset/", transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)
def get_mean_std(loader):
# var[X] = E[X**2] - E[X]**2
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
for data, _ in tqdm(loader):
channels_sum += torch.mean(data, dim=[0, 2, 3])
channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3])
num_batches += 1
mean = channels_sum / num_batches
std = (channels_sqrd_sum / num_batches - mean**2) ** 0.5
return mean, std
mean, std = get_mean_std(train_loader)
print(mean)
print(std)