Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store npz file #80

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

# FID score for PyTorch

This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow.

FID is a measure of similarity between two datasets of images.
FID is a measure of similarity between two datasets of images.
It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.

Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337).

Expand Down Expand Up @@ -36,24 +36,33 @@ To compute the FID score between two datasets, where images of each dataset are
python -m pytorch_fid path/to/dataset1 path/to/dataset2
```

To run the evaluation on GPU, use the flag `--device cuda:N`, where `N` is the index of the GPU to use.
To run the evaluation on GPU, use the flag `--device cuda:N`, where `N` is the index of the GPU to use.

### Using different layers for feature maps

In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance.

This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
The resulting scores might also no longer correlate with visual quality.

You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
The choices are:
- 64: first max pooling features
- 192: second max pooling features
- 768: pre-aux classifier features
- 2048: final average pooling features (this is the default)

## Generating a compatible `.npz` archive from a dataset
A frequent use case will be to compare multiple models against an original dataset.
To save training multiple times on the original dataset, there is also the ability to generate a compatible `.npz` archive from a dataset. This is done using any combination of the previously mentioned arguments with the addition of the `--save-stats` flag. For example:
```
python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile
```

The output file may then be used in place of the path to the original dataset for further comparisons.

## Citing

If you use this repository in your research, consider citing it using the following Bibtex entry:
Expand Down
27 changes: 27 additions & 0 deletions src/pytorch_fid/fid_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def tqdm(x):
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help=('Dimensionality of Inception features to use. '
'By default, uses pool3 features'))
parser.add_argument('--save-stats', action='store_true',
help=('Generate an npz archive from a directory of samples. '
'The first path is used as input and the second as output.'))
parser.add_argument('path', type=str, nargs=2,
help=('Paths to the generated images or '
'to .npz statistic files'))
Expand Down Expand Up @@ -262,6 +265,26 @@ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
return fid_value


def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
"""Calculates the FID of two paths"""
if not os.path.exists(paths[0]):
raise RuntimeError('Invalid path: %s' % paths[0])

if os.path.exists(paths[1]):
raise RuntimeError('Existing output file: %s' % paths[1])

block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

model = InceptionV3([block_idx]).to(device)

print(f"Saving statistics for {paths[0]}")

m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
dims, device, num_workers)

np.savez_compressed(paths[1], mu=m1, sigma=s1)


def main():
args = parser.parse_args()

Expand All @@ -276,6 +299,10 @@ def main():
else:
num_workers = args.num_workers

if args.save_stats:
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
return

fid_value = calculate_fid_given_paths(args.path,
args.batch_size,
device,
Expand Down