Skip to content

Commit

Permalink
refactoring(fid): Pytorch implementation of square root of matrix. (#32)
Browse files Browse the repository at this point in the history
refactoring(fid): pytorch implementation of square root of matrix
  • Loading branch information
denproc authored Apr 30, 2020
1 parent 651eb84 commit e17c4c1
Showing 1 changed file with 94 additions and 24 deletions.
118 changes: 94 additions & 24 deletions photosynthesis_metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,60 @@
"""

from typing import Tuple
import torch

from photosynthesis_metrics.base import BaseFeatureMetric

import numpy as np
import torch
from scipy import linalg

def _approximation_error(A: torch.Tensor, sA: torch.Tensor) -> torch.Tensor:
normA = torch.norm(A)
error = A - torch.mm(sA, sA)
error = torch.norm(error) / normA
return error

from photosynthesis_metrics.base import BaseFeatureMetric

def _sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Square root of matrix using Newton-Schulz Iterative method
Source: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py
Args:
A: matrix or batch of matrices
num_iters: Number of iteration of the method
def __compute_fid(mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray, eps=1e-6) -> float:
Returns:
Square root of matrix
Error
"""
expected_num_dims = 2
if A.dim() != expected_num_dims:
raise ValueError(f'Input dimension equals {A.dim()}, expected {expected_num_dims}')

if num_iters <= 0:
raise ValueError(f'Number of iteration equals {num_iters}, expected greater than 0')
dtype = A.type()
dim = A.size(0)
normA = A.norm(p='fro')
Y = A.div(normA)
I = torch.eye(dim, dim, requires_grad=False).type(dtype)
Z = torch.eye(dim, dim, requires_grad=False).type(dtype)

sA = torch.empty_like(A)
error = torch.empty(1)

for i in range(num_iters):
T = 0.5 * (3.0 * I - Z.mm(Y))
Y = Y.mm(T)
Z = T.mm(Z)

sA = Y * torch.sqrt(normA)
error = _approximation_error(A, sA)
if torch.isclose(error, torch.tensor([0.]), atol=1e-5):
break
return sA, error


def __compute_fid(mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor,
eps=1e-6) -> torch.Tensor:
r"""
The Frechet Inception Distance between two multivariate Gaussians X_predicted ~ N(mu_1, sigm_1)
and X_target ~ N(mu_2, sigm_2) is
Expand All @@ -36,25 +79,51 @@ def __compute_fid(mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2:
Scalar value of the distance between sets.
"""
diff = mu1 - mu2
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
covmean, _ = _sqrtm_newton_schulz(sigma1.mm(sigma2))
# Product might be almost singular
if not np.isfinite(covmean).all():
if not torch.isfinite(covmean).all():
print(f'FID calculation produces singular product; adding {eps} to diagonal of cov estimates')
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
offset = torch.eye(sigma1.size(0)) * eps
covmean, _ = _sqrtm_newton_schulz((sigma1 + offset).mm(sigma2 + offset))

tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean

# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))

covmean = covmean.real
def _cov(m: torch.Tensor, rowvar: bool=True) -> torch.Tensor:
r"""Estimate a covariance matrix given data.
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
Covariance indicates the level to which two variables vary together.
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
then the covariance matrix element `C_{ij}` is the covariance of
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
Args:
m: A 1-D or 2-D array containing multiple variables and observations.
Each row of `m` represents a variable, and each column a single
observation of all those variables.
rowvar: If `rowvar` is True, then each row represents a
variable, with observations in the columns. Otherwise, the
relationship is transposed: each column represents a variable,
while the rows contain observations.
def _compute_statistics(samples: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Returns:
The covariance matrix of the variables.
"""
if m.dim() > 2:
raise ValueError('Tensor for covariance computations has more than 2 dimensions. '
'Only 1 or 2 dimensional arrays are allowed')
if m.dim() < 2:
m = m.view(1, -1)
if not rowvar and m.size(0) != 1:
m = m.t()
fact = 1.0 / (m.size(1) - 1)
m -= torch.mean(m, dim=1, keepdim=True)
mt = m.t()
return fact * m.matmul(mt).squeeze()


def _compute_statistics(samples: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the statistics used by FID
Args:
samples: Low-dimension representation of image set.
Expand All @@ -63,12 +132,12 @@ def _compute_statistics(samples: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu: mean over all activations from the encoder.
sigma: covariance matrix over all activations from the encoder.
"""
mu = np.mean(samples, axis=0)
sigma = np.cov(samples, rowvar=False)
mu = torch.mean(samples, dim=0)
sigma = _cov(samples, rowvar=False)
return mu, sigma


def compute_fid(x: torch.Tensor, y: torch.Tensor) -> float:
def compute_fid(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""Numpy implementation of the Frechet Distance.
Fits multivariate Gaussians: X ~ N(mu_1, sigm_1) and Y ~ N(mu_2, sigm_2) to image stacks.
Then computes FID as d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)).
Expand All @@ -80,8 +149,8 @@ def compute_fid(x: torch.Tensor, y: torch.Tensor) -> float:
Returns:
-- : The Frechet Distance.
"""
m_pred, s_pred = _compute_statistics(x.numpy())
m_targ, s_targ = _compute_statistics(y.numpy())
m_pred, s_pred = _compute_statistics(x)
m_targ, s_targ = _compute_statistics(y)

score = __compute_fid(m_pred, s_pred, m_targ, s_targ)
return score
Expand All @@ -91,11 +160,12 @@ class FID(BaseFeatureMetric):
r"""Creates a criterion that measures Frechet Inception Distance score for two datasets of images
See https://arxiv.org/abs/1706.08500 for reference.
"""

def __init__(self):
super(FID, self).__init__()
self.compute = compute_fid

def forward(self, predicted_features: torch.Tensor, target_features: torch.Tensor) -> float:
def forward(self, predicted_features: torch.Tensor, target_features: torch.Tensor) -> torch.Tensor:
r"""Interface of Frechet Inception Distance.
It's computed for a whole set of data and uses features from encoder instead of images itself to decrease
computation cost. FID can compare two data distributions with different number of samples.
Expand Down

0 comments on commit e17c4c1

Please sign in to comment.