From 0f35bd1b2e61952ccf2330363c6ff91819d0fd5b Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 30 Nov 2022 19:22:35 +0800 Subject: [PATCH 1/2] mock LPIPS module in PPL metric --- .../test_evaluation/test_metrics/test_ppl.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/tests/test_evaluation/test_metrics/test_ppl.py b/tests/test_evaluation/test_metrics/test_ppl.py index fb82e40606..50834459e8 100644 --- a/tests/test_evaluation/test_metrics/test_ppl.py +++ b/tests/test_evaluation/test_metrics/test_ppl.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import platform -from unittest.mock import MagicMock +from unittest.mock import patch import pytest import torch -import torch.nn as nn from mmengine.runner import Runner from mmedit.datasets import BasicImageDataset @@ -12,6 +11,9 @@ from mmedit.evaluation import PerceptualPathLength from mmedit.models import LSGAN, GenDataPreprocessor from mmedit.models.editors.stylegan2 import StyleGAN2Generator +from mmedit.utils import register_all_modules + +register_all_modules() def process_fn(data_batch, predictions): @@ -22,29 +24,17 @@ def process_fn(data_batch, predictions): return data_batch, _predictions -class vgg_pytorch_classifier(nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.randn(x.shape[0], 4096) +class LPIPS_mock: + def __init__(self, *args, **kwargs): + pass -class vgg_mock(nn.Module): + def to(self, *args, **kwargs): + return self - def __init__(self, style): - super().__init__() - self.classifier = nn.Sequential(nn.Identity(), nn.Identity(), - nn.Identity(), - vgg_pytorch_classifier()) - self.style = style - - def forward(self, x, *args, **kwargs): - if self.style.upper() == 'STYLEGAN': - return torch.randn(x.shape[0], 4096) - else: # torch - return torch.randn(x.shape[0], 7 * 7 * 512) + def __call__(self, x1, x2, *args, **kwargs): + num_batche = x1.shape[0] + return torch.rand(num_batche, 1, 1, 1) class TestPPL: @@ -69,9 +59,7 @@ def setup_class(cls): generator = StyleGAN2Generator(64, 8) cls.module = LSGAN(generator, data_preprocessor=gan_data_preprocessor) - cls.mock_vgg_pytorch = MagicMock( - return_value=(vgg_mock('PyTorch'), 'False')) - + @patch('lpips.LPIPS', LPIPS_mock) @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_ppl_cuda(self): ppl = PerceptualPathLength( @@ -104,6 +92,7 @@ def test_ppl_cuda(self): ppl_res = ppl.compute_metrics(ppl.fake_results) assert ppl_res['ppl_score'] >= 0 + # @patch('lpips.LPIPS', LPIPS_mock) @pytest.mark.skipif( 'win' in platform.system().lower() and 'cu' in torch.__version__, reason='skip on windows-cuda due to limited RAM.') From 9665daed9f6861edd18b5f16febe029188312b17 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 30 Nov 2022 19:26:25 +0800 Subject: [PATCH 2/2] mock cpu test --- tests/test_evaluation/test_metrics/test_ppl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_evaluation/test_metrics/test_ppl.py b/tests/test_evaluation/test_metrics/test_ppl.py index 50834459e8..0025526940 100644 --- a/tests/test_evaluation/test_metrics/test_ppl.py +++ b/tests/test_evaluation/test_metrics/test_ppl.py @@ -92,7 +92,7 @@ def test_ppl_cuda(self): ppl_res = ppl.compute_metrics(ppl.fake_results) assert ppl_res['ppl_score'] >= 0 - # @patch('lpips.LPIPS', LPIPS_mock) + @patch('lpips.LPIPS', LPIPS_mock) @pytest.mark.skipif( 'win' in platform.system().lower() and 'cu' in torch.__version__, reason='skip on windows-cuda due to limited RAM.')