Skip to content

Commit

Permalink
Fix tests not passing after #66
Browse files Browse the repository at this point in the history
  • Loading branch information
mseitzer committed Oct 10, 2021
1 parent b5c8527 commit 41cb138
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/test_fid_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_calculate_fid_given_statistics(mocker, tmp_path, device):
m1, m2 = np.zeros((dim,)), np.ones((dim,))
sigma = np.eye(dim)

def dummy_statistics(path, model, batch_size, dims, device):
def dummy_statistics(path, model, batch_size, dims, device, num_workers):
if path.endswith('1'):
return m1, sigma
elif path.endswith('2'):
Expand All @@ -37,7 +37,8 @@ def dummy_statistics(path, model, batch_size, dims, device):
fid_value = fid_score.calculate_fid_given_paths(paths,
batch_size=dim,
device=device,
dims=dim)
dims=dim,
num_workers=0)

# Given equal covariance, FID is just the squared norm of difference
assert fid_value == np.sum((m1 - m2)**2)
Expand All @@ -59,7 +60,8 @@ def test_compute_statistics_of_path(mocker, tmp_path, device):
stats = fid_score.compute_statistics_of_path(str(tmp_path), model,
batch_size=len(images),
dims=3,
device=device)
device=device,
num_workers=0)

assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
Expand All @@ -78,7 +80,8 @@ def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
stats = fid_score.compute_statistics_of_path(str(path), model,
batch_size=1,
dims=5,
device=device)
device=device,
num_workers=0)

assert np.allclose(stats[0], mu)
assert np.allclose(stats[1], sigma)
Expand Down

0 comments on commit 41cb138

Please sign in to comment.