Skip to content

Commit

Permalink
[Enhancement] Mock LPIPS module in PPL metric (#1490)
Browse files Browse the repository at this point in the history
* mock LPIPS module in PPL metric

* mock cpu test
  • Loading branch information
LeoXing1996 authored Dec 5, 2022
1 parent fe7ac22 commit b1ed317
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions tests/test_evaluation/test_metrics/test_ppl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn
from mmengine.runner import Runner

from mmedit.datasets import BasicImageDataset
from mmedit.datasets.transforms import PackEditInputs
from mmedit.evaluation import PerceptualPathLength
from mmedit.models import LSGAN, GenDataPreprocessor
from mmedit.models.editors.stylegan2 import StyleGAN2Generator
from mmedit.utils import register_all_modules

register_all_modules()


def process_fn(data_batch, predictions):
Expand All @@ -22,29 +24,17 @@ def process_fn(data_batch, predictions):
return data_batch, _predictions


class vgg_pytorch_classifier(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x):
return torch.randn(x.shape[0], 4096)
class LPIPS_mock:

def __init__(self, *args, **kwargs):
pass

class vgg_mock(nn.Module):
def to(self, *args, **kwargs):
return self

def __init__(self, style):
super().__init__()
self.classifier = nn.Sequential(nn.Identity(), nn.Identity(),
nn.Identity(),
vgg_pytorch_classifier())
self.style = style

def forward(self, x, *args, **kwargs):
if self.style.upper() == 'STYLEGAN':
return torch.randn(x.shape[0], 4096)
else: # torch
return torch.randn(x.shape[0], 7 * 7 * 512)
def __call__(self, x1, x2, *args, **kwargs):
num_batche = x1.shape[0]
return torch.rand(num_batche, 1, 1, 1)


class TestPPL:
Expand All @@ -69,9 +59,7 @@ def setup_class(cls):
generator = StyleGAN2Generator(64, 8)
cls.module = LSGAN(generator, data_preprocessor=gan_data_preprocessor)

cls.mock_vgg_pytorch = MagicMock(
return_value=(vgg_mock('PyTorch'), 'False'))

@patch('lpips.LPIPS', LPIPS_mock)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ppl_cuda(self):
ppl = PerceptualPathLength(
Expand Down Expand Up @@ -104,6 +92,7 @@ def test_ppl_cuda(self):
ppl_res = ppl.compute_metrics(ppl.fake_results)
assert ppl_res['ppl_score'] >= 0

@patch('lpips.LPIPS', LPIPS_mock)
@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
Expand Down

0 comments on commit b1ed317

Please sign in to comment.