diff --git a/mmedit/datasets/pipelines/matting_aug.py b/mmedit/datasets/pipelines/matting_aug.py index eb8f29b79f..8d83aaeaf9 100644 --- a/mmedit/datasets/pipelines/matting_aug.py +++ b/mmedit/datasets/pipelines/matting_aug.py @@ -4,6 +4,7 @@ import cv2 import mmcv import numpy as np +from mmcv.fileio import FileClient from ..registry import PIPELINES from .utils import adjust_gamma, random_choose_unknown @@ -218,7 +219,12 @@ class CompositeFg: the randomly loaded images. """ - def __init__(self, fg_dirs, alpha_dirs, interpolation='nearest'): + def __init__(self, + fg_dirs, + alpha_dirs, + interpolation='nearest', + io_backend='disk', + **kwargs): self.fg_dirs = fg_dirs if isinstance(fg_dirs, list) else [fg_dirs] self.alpha_dirs = alpha_dirs if isinstance(alpha_dirs, list) else [alpha_dirs] @@ -226,6 +232,9 @@ def __init__(self, fg_dirs, alpha_dirs, interpolation='nearest'): self.fg_list, self.alpha_list = self._get_file_list( self.fg_dirs, self.alpha_dirs) + self.io_backend = io_backend + self.file_client = None + self.kwargs = kwargs def __call__(self, results): """Call function. @@ -237,6 +246,8 @@ def __call__(self, results): Returns: dict: A dict containing the processed data and information. """ + if self.file_client is None: + self.file_client = FileClient(self.io_backend, **self.kwargs) fg = results['fg'] alpha = results['alpha'].astype(np.float32) / 255. h, w = results['fg'].shape[:2] @@ -244,8 +255,10 @@ def __call__(self, results): # randomly select fg if np.random.rand() < 0.5: idx = np.random.randint(len(self.fg_list)) - fg2 = mmcv.imread(self.fg_list[idx]) - alpha2 = mmcv.imread(self.alpha_list[idx], 'grayscale') + fg2_bytes = self.file_client.get(self.fg_list[idx]) + fg2 = mmcv.imfrombytes(fg2_bytes) + alpha2_bytes = self.file_client.get(self.alpha_list[idx]) + alpha2 = mmcv.imfrombytes(alpha2_bytes, flag='grayscale') alpha2 = alpha2.astype(np.float32) / 255. fg2 = mmcv.imresize(fg2, (w, h), interpolation=self.interpolation)