Skip to content

Commit

Permalink
basic sample wise metric used in GenLoop
Browse files Browse the repository at this point in the history
  • Loading branch information
VongolaWu authored and LeoXing1996 committed Jan 13, 2023
1 parent 6de38d0 commit b2fc76d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mmedit/evaluation/metrics/base_sample_wise_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from typing import List, Optional, Sequence

import torch.nn as nn
from mmengine.evaluator import BaseMetric
from torch.utils.data.dataloader import DataLoader

from mmedit.registry import METRICS
from .metrics_utils import average, obtain_data
Expand Down Expand Up @@ -109,3 +111,27 @@ def process(self, data_batch: Sequence[dict],

def process_image(self, gt, pred, mask):
return 0

def evaluate(self, size=None) -> dict:
if size is None:
size = self.size
return super().evaluate(size)

def prepare(self, module: nn.Module, dataloader: DataLoader):
self.SAMPLER_MODE = 'normal'
self.sample_model = 'orig'
self.size = dataloader.dataset.__len__()

def get_metric_sampler(self, model: nn.Module, dataloader: DataLoader,
metrics) -> DataLoader:
"""Get sampler for normal metrics. Directly returns the dataloader.
Args:
model (nn.Module): Model to evaluate.
dataloader (DataLoader): Dataloader for real images.
metrics (List['GenMetric']): Metrics with the same sample mode.
Returns:
DataLoader: Default sampler for normal metrics.
"""
return dataloader
24 changes: 24 additions & 0 deletions tests/test_evaluation/test_metrics/test_base_sample_wise_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader

from mmedit.datasets import BasicImageDataset
from mmedit.evaluation.metrics import base_sample_wise_metric


Expand Down Expand Up @@ -52,3 +54,25 @@ def test_process():
metric.process(data_batch, predictions)
assert len(metric.results) == 2
assert metric.results[0]['metric'] == 0


def test_prepare():
data_root = 'tests/data/image/'
dataset = BasicImageDataset(
metainfo=dict(dataset_type='sr_annotation_dataset', task_name='sisr'),
data_root=data_root,
data_prefix=dict(img='lq', gt='gt'),
filename_tmpl=dict(img='{}_x4'),
pipeline=[])
dataloader = DataLoader(dataset)

metric = base_sample_wise_metric.BaseSampleWiseMetric()
metric.metric = 'metric'

metric.prepare(None, dataloader)
assert metric.SAMPLER_MODE == 'normal'
assert metric.sample_model == 'orig'
assert metric.size == 1

metric.get_metric_sampler(None, dataloader, [])
assert dataloader == dataloader

0 comments on commit b2fc76d

Please sign in to comment.