-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Different inception scores as compared to BigGAN paper. #12
Comments
I think you need to shuffle the data as the score is calculated for 50 images each time. If all 50 images belong to the same class, the score will be much lower. |
Do you mean that I need to add shuffle=True at the original code for dataloader? |
@YLJALDC import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from torchvision.models import inception_v3
import numpy as np
from scipy.stats import entropy
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
"""Computes the inception score of the generated images imgs
imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
cuda -- whether or not to run on GPU
batch_size -- batch size for feeding into Inception v3
splits -- number of splits
"""
N = len(imgs)
assert batch_size > 0
assert N > batch_size
# Set up dataloader
dataloader = torch.utils.data.DataLoader(imgs, shuffle=True, batch_size=batch_size)
print ("INFO : Dataset ready ...")
# Load inception model
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = inception_model.cuda()
print ("INFO : Inception model ready ...")
inception_model.eval();
up = nn.Upsample(size=(299, 299), mode='bilinear')
def get_pred(x):
if resize:
x = up(x)
x = inception_model(x)
return F.softmax(x, dim=1).data.cpu().numpy()
# Get predictions
preds = np.zeros((N, 1000))
print ("INFO : Extra memory allocated ...")
for i, batch in enumerate(dataloader, 0):
batch = batch.cuda()
# batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i*batch_size:i*batch_size+batch_size_i] = get_pred(batch)
if i % 1000 == 0:
print('==> Processing ' + str(i) + 'th. batch.')
# Now compute the mean kl-div
scores = []
for k in range(splits):
part = preds[k * (N // splits): (k+1) * (N // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
if __name__ == '__main__':
class IgnoreLabelDataset(torch.utils.data.Dataset):
def __init__(self, orig):
self.orig = orig
def __getitem__(self, index):
return self.orig[index][0]
def __len__(self):
return len(self.orig)
import torchvision.datasets as dset
import torchvision.transforms as transforms
cifar = dset.CIFAR10(root='data/', download=True,
transform=transforms.Compose([
transforms.Scale(32),
transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
)
IgnoreLabelDataset(cifar)
print ("INFO : Calculating Inception Score ...")
print (inception_score(IgnoreLabelDataset(cifar), cuda=True, batch_size=8, resize=True, splits=10)) |
@HolmesShuan is 10.54 accurate enough score? I was reading that tensorflow implementation has something around 11, is the difference of 0.5 small enough? |
@vibss2397 I still recommend the TensorFlow implementation. The IS score on ImageNet seems to be more inconsistent with the reported results. I run the CIFAR experiments multiple times and notice that the score is always 0.5 smaller than the baseline. |
I believe the difference exists because the network weights are different
between torch and tensorflow
…On Mon, Nov 16, 2020 at 6:51 PM Shuan ***@***.***> wrote:
@vibss2397 <https://github.com/vibss2397> I still recommend the
TensorFlow implementation. The IS score on ImageNet seems to be more
inconsistent with the reported results. I run the CIFAR experiments
multiple times and notice that the score is always 0.5 smaller than the
baseline.
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#12 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB7LUGL7GPTA4ONF4XYOHUTSQHQKJANCNFSM4H5LLFEQ>
.
|
Hello,
I was going through your paper and found out that you reported the inception scores for ImageNet validation to be 63.702±7.869 (for 299x299 image size)but the BigGAN paper reports it to be 166.5 (for 128x128 image size). Can you comment on the discrepancy here?
The text was updated successfully, but these errors were encountered: