From e17c4c1ddc19f282b12b232e4332f27b3291ea49 Mon Sep 17 00:00:00 2001 From: Denis Prokopenko <22414094+denproc@users.noreply.github.com> Date: Thu, 30 Apr 2020 17:59:08 +0300 Subject: [PATCH] refactoring(fid): Pytorch implementation of square root of matrix. (#32) refactoring(fid): pytorch implementation of square root of matrix --- photosynthesis_metrics/fid.py | 118 +++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 24 deletions(-) diff --git a/photosynthesis_metrics/fid.py b/photosynthesis_metrics/fid.py index a037f1ac..c32f8101 100644 --- a/photosynthesis_metrics/fid.py +++ b/photosynthesis_metrics/fid.py @@ -9,17 +9,60 @@ """ from typing import Tuple +import torch +from photosynthesis_metrics.base import BaseFeatureMetric -import numpy as np -import torch -from scipy import linalg +def _approximation_error(A: torch.Tensor, sA: torch.Tensor) -> torch.Tensor: + normA = torch.norm(A) + error = A - torch.mm(sA, sA) + error = torch.norm(error) / normA + return error -from photosynthesis_metrics.base import BaseFeatureMetric +def _sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Square root of matrix using Newton-Schulz Iterative method + Source: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py + Args: + A: matrix or batch of matrices + num_iters: Number of iteration of the method -def __compute_fid(mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray, eps=1e-6) -> float: + Returns: + Square root of matrix + Error + """ + expected_num_dims = 2 + if A.dim() != expected_num_dims: + raise ValueError(f'Input dimension equals {A.dim()}, expected {expected_num_dims}') + + if num_iters <= 0: + raise ValueError(f'Number of iteration equals {num_iters}, expected greater than 0') + dtype = A.type() + dim = A.size(0) + normA = A.norm(p='fro') + Y = A.div(normA) + I = torch.eye(dim, dim, requires_grad=False).type(dtype) + Z = torch.eye(dim, dim, requires_grad=False).type(dtype) + + sA = torch.empty_like(A) + error = torch.empty(1) + + for i in range(num_iters): + T = 0.5 * (3.0 * I - Z.mm(Y)) + Y = Y.mm(T) + Z = T.mm(Z) + + sA = Y * torch.sqrt(normA) + error = _approximation_error(A, sA) + if torch.isclose(error, torch.tensor([0.]), atol=1e-5): + break + return sA, error + + +def __compute_fid(mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor, + eps=1e-6) -> torch.Tensor: r""" The Frechet Inception Distance between two multivariate Gaussians X_predicted ~ N(mu_1, sigm_1) and X_target ~ N(mu_2, sigm_2) is @@ -36,25 +79,51 @@ def __compute_fid(mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: Scalar value of the distance between sets. """ diff = mu1 - mu2 - covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + covmean, _ = _sqrtm_newton_schulz(sigma1.mm(sigma2)) # Product might be almost singular - if not np.isfinite(covmean).all(): + if not torch.isfinite(covmean).all(): print(f'FID calculation produces singular product; adding {eps} to diagonal of cov estimates') - offset = np.eye(sigma1.shape[0]) * eps - covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + offset = torch.eye(sigma1.size(0)) * eps + covmean, _ = _sqrtm_newton_schulz((sigma1 + offset).mm(sigma2 + offset)) + + tr_covmean = torch.trace(covmean) + return diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean - # Numerical error might give slight imaginary component - if np.iscomplexobj(covmean): - if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): - m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) - covmean = covmean.real +def _cov(m: torch.Tensor, rowvar: bool=True) -> torch.Tensor: + r"""Estimate a covariance matrix given data. - tr_covmean = np.trace(covmean) - return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. -def _compute_statistics(samples: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + Returns: + The covariance matrix of the variables. + """ + if m.dim() > 2: + raise ValueError('Tensor for covariance computations has more than 2 dimensions. ' + 'Only 1 or 2 dimensional arrays are allowed') + if m.dim() < 2: + m = m.view(1, -1) + if not rowvar and m.size(0) != 1: + m = m.t() + fact = 1.0 / (m.size(1) - 1) + m -= torch.mean(m, dim=1, keepdim=True) + mt = m.t() + return fact * m.matmul(mt).squeeze() + + +def _compute_statistics(samples: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculates the statistics used by FID Args: samples: Low-dimension representation of image set. @@ -63,12 +132,12 @@ def _compute_statistics(samples: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: mu: mean over all activations from the encoder. sigma: covariance matrix over all activations from the encoder. """ - mu = np.mean(samples, axis=0) - sigma = np.cov(samples, rowvar=False) + mu = torch.mean(samples, dim=0) + sigma = _cov(samples, rowvar=False) return mu, sigma -def compute_fid(x: torch.Tensor, y: torch.Tensor) -> float: +def compute_fid(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: r"""Numpy implementation of the Frechet Distance. Fits multivariate Gaussians: X ~ N(mu_1, sigm_1) and Y ~ N(mu_2, sigm_2) to image stacks. Then computes FID as d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)). @@ -80,8 +149,8 @@ def compute_fid(x: torch.Tensor, y: torch.Tensor) -> float: Returns: -- : The Frechet Distance. """ - m_pred, s_pred = _compute_statistics(x.numpy()) - m_targ, s_targ = _compute_statistics(y.numpy()) + m_pred, s_pred = _compute_statistics(x) + m_targ, s_targ = _compute_statistics(y) score = __compute_fid(m_pred, s_pred, m_targ, s_targ) return score @@ -91,11 +160,12 @@ class FID(BaseFeatureMetric): r"""Creates a criterion that measures Frechet Inception Distance score for two datasets of images See https://arxiv.org/abs/1706.08500 for reference. """ + def __init__(self): super(FID, self).__init__() self.compute = compute_fid - def forward(self, predicted_features: torch.Tensor, target_features: torch.Tensor) -> float: + def forward(self, predicted_features: torch.Tensor, target_features: torch.Tensor) -> torch.Tensor: r"""Interface of Frechet Inception Distance. It's computed for a whole set of data and uses features from encoder instead of images itself to decrease computation cost. FID can compare two data distributions with different number of samples.