Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] adapt baseinferencer from engine. #1575

Merged
merged 9 commits into from
Jan 30, 2023
97 changes: 50 additions & 47 deletions mmedit/apis/inferencers/base_mmedit_inferencer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose
from mmengine.infer import BaseInferencer
from mmengine.runner import load_checkpoint
from mmengine.structures import BaseDataElement
from torchvision import utils

from mmedit.registry import MODELS
from mmedit.utils import ConfigType, SampleList
Expand All @@ -19,7 +22,7 @@
ResType = Union[Dict, List[Dict], BaseDataElement, List[BaseDataElement]]


class BaseMMEditInferencer:
class BaseMMEditInferencer(BaseInferencer):
"""Base inferencer.

Args:
Expand Down Expand Up @@ -47,21 +50,12 @@ def __init__(self,
seed: int = 2022,
**kwargs) -> None:
# Load config to cfg
if isinstance(config, str):
config = Config.fromfile(config)
elif not isinstance(config, (ConfigDict, Config)):
raise TypeError('config must be a filename or any ConfigType'
f'object, but got {type(config)}')
self.cfg = config

if config.model.get('pretrained'):
config.model.pretrained = None

if device is None:
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self._init_model(config, ckpt, device)
super().__init__(config, ckpt, device)

self._init_extra_parameters(extra_parameters)
self.base_params = self._dispatch_kwargs(**kwargs)
self.seed = seed
Expand All @@ -77,7 +71,16 @@ def _init_model(self, cfg: Union[ConfigType, str], ckpt: Optional[str],
model.cfg = cfg
model.to(device)
model.eval()
self.model = model
return model

def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
if 'test_dataloader' in cfg and \
'dataset' in cfg.test_dataloader and \
'pipeline' in cfg.test_dataloader.dataset:
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
return Compose(pipeline_cfg)
return None

def _init_extra_parameters(self, extra_parameters: Dict) -> None:
"""Initialize extra_parameters of each kind of inferencer."""
Expand Down Expand Up @@ -149,37 +152,6 @@ def get_extra_parameters(self) -> List[str]:
"""
return list(self.extra_parameters.keys())

@abstractmethod
def preprocess(self, inputs: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.

Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.

Returns:
Dict: Result of preprocess
"""

@abstractmethod
def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""

@abstractmethod
def visualize(self,
inputs: InputsType,
preds: PredType,
result_out_dir: str = None) -> List[np.ndarray]:
"""Visualize predictions.

Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
result_out_dir (str): Output directory of images. Defaults to ''.

Returns:
List[np.ndarray]: Result of visualize
"""

def postprocess(
self,
preds: PredType,
Expand Down Expand Up @@ -226,3 +198,34 @@ def _pred2dict(self, pred_tensor: torch.Tensor) -> Dict:
result = {}
result['infer_results'] = pred_tensor
return result

def visualize(self,
inputs: list,
preds: Any,
show: bool = False,
result_out_dir: str = '',
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.

Customize your visualization by overriding this method. visualize
should return visualization results, which could be np.ndarray or any
other objects.

Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
result_out_dir (str): Output directory of images. Defaults to ''.

Returns:
List[np.ndarray]: Visualization results.
"""
results = (preds[:, [2, 1, 0]] + 1.) / 2.

# save images
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
utils.save_image(results, result_out_dir)

return results
4 changes: 4 additions & 0 deletions mmedit/apis/inferencers/inpainting_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class InpaintingInferencer(BaseMMEditInferencer):
visualize=['result_out_dir'],
postprocess=[])

def _init_pipeline(self, cfg) -> Compose:
"""Initialize the test pipeline."""
return None

def preprocess(self, img: InputsType, mask: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.

Expand Down
3 changes: 1 addition & 2 deletions mmedit/apis/inferencers/mmedit_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from mmedit.utils import ConfigType
from .base_mmedit_inferencer import BaseMMEditInferencer
from .conditional_inferencer import ConditionalInferencer
from .eg3d_inferencer import EG3DInferencer
from .inpainting_inferencer import InpaintingInferencer
Expand All @@ -17,7 +16,7 @@
from .video_restoration_inferencer import VideoRestorationInferencer


class MMEditInferencer(BaseMMEditInferencer):
class MMEditInferencer:
"""Class to assign task to different inferencers.

Args:
Expand Down