diff --git a/lpips/__init__.py b/lpips/__init__.py index a3d19013..83f44154 100755 --- a/lpips/__init__.py +++ b/lpips/__init__.py @@ -10,35 +10,6 @@ from lpips.trainer import * from lpips.lpips import * -# class PerceptualLoss(torch.nn.Module): -# def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) -# # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss -# super(PerceptualLoss, self).__init__() -# print('Setting up Perceptual loss...') -# self.use_gpu = use_gpu -# self.spatial = spatial -# self.gpu_ids = gpu_ids -# self.model = dist_model.DistModel() -# self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) -# print('...[%s] initialized'%self.model.name()) -# print('...Done') - -# def forward(self, pred, target, normalize=False): -# """ -# Pred and target are Variables. -# If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] -# If normalize is False, assumes the images are already between [-1,+1] - -# Inputs pred and target are Nx3xHxW -# Output pytorch Variable N long -# """ - -# if normalize: -# target = 2 * target - 1 -# pred = 2 * pred - 1 - -# return self.model.forward(target, pred) - def normalize_tensor(in_feat,eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) return in_feat/(norm_factor+eps) diff --git a/lpips/lpips.py b/lpips/lpips.py index 9b979c08..b5a3937c 100755 --- a/lpips/lpips.py +++ b/lpips/lpips.py @@ -22,8 +22,40 @@ def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): - # lpips - [True] means with linear calibration on top of base network - # pretrained - [True] means load linear weights + """ Initializes a perceptual loss torch.nn.Module + + Parameters (default listed first) + --------------------------------- + lpips : bool + [True] use linear layers on top of base/trunk network + [False] means no linear layers; each layer is averaged together + pretrained : bool + This flag controls the linear layers, which are only in effect when lpips=True above + [True] means linear layers are calibrated with human perceptual judgments + [False] means linear layers are randomly initialized + pnet_rand : bool + [False] means trunk loaded with ImageNet classification weights + [True] means randomly initialized trunk + net : str + ['alex','vgg','squeeze'] are the base/trunk networks available + version : str + ['v0.1'] is the default and latest + ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) + model_path : 'str' + [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 + + The following parameters should only be changed if training the network + + eval_mode : bool + [True] is for test mode (default) + [False] is for training mode + pnet_tune + [False] tune the base/trunk network + [True] keep base/trunk frozen + use_dropout : bool + [True] to use dropout when training linear layers + [False] for no dropout when training linear layers + """ super(LPIPS, self).__init__() if(verbose): @@ -102,19 +134,9 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False): else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] - val = res[0] - for l in range(1,self.L): + val = 0 + for l in range(self.L): val += res[l] - - # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) - # b = torch.max(self.lins[kk](feats0[kk]**2)) - # for kk in range(self.L): - # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) - # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) - # a = a/self.L - # from IPython import embed - # embed() - # return 10*torch.log10(b/a) if(retPerLayer): return (val, res) diff --git a/lpips_2imgs.py b/lpips_2imgs.py index 19d8a9ae..74ca61eb 100644 --- a/lpips_2imgs.py +++ b/lpips_2imgs.py @@ -24,5 +24,8 @@ img1 = img1.cuda() # Compute distance -dist01 = loss_fn.forward(img0,img1) +dist01 = loss_fn.forward(img0, img1) print('Distance: %.3f'%dist01) + +dist01 = loss_fn.forward(img0, img1, retPerLayer=True) +print(dist01)