Skip to content

Commit

Permalink
Use file_client in CompositeFg (#212)
Browse files Browse the repository at this point in the history
Signed-off-by: lizz <[email protected]>
  • Loading branch information
innerlee authored Feb 26, 2021
1 parent 0abd0fe commit 48e39a9
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions mmedit/datasets/pipelines/matting_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import cv2
import mmcv
import numpy as np
from mmcv.fileio import FileClient

from ..registry import PIPELINES
from .utils import adjust_gamma, random_choose_unknown
Expand Down Expand Up @@ -218,14 +219,22 @@ class CompositeFg:
the randomly loaded images.
"""

def __init__(self, fg_dirs, alpha_dirs, interpolation='nearest'):
def __init__(self,
fg_dirs,
alpha_dirs,
interpolation='nearest',
io_backend='disk',
**kwargs):
self.fg_dirs = fg_dirs if isinstance(fg_dirs, list) else [fg_dirs]
self.alpha_dirs = alpha_dirs if isinstance(alpha_dirs,
list) else [alpha_dirs]
self.interpolation = interpolation

self.fg_list, self.alpha_list = self._get_file_list(
self.fg_dirs, self.alpha_dirs)
self.io_backend = io_backend
self.file_client = None
self.kwargs = kwargs

def __call__(self, results):
"""Call function.
Expand All @@ -237,15 +246,19 @@ def __call__(self, results):
Returns:
dict: A dict containing the processed data and information.
"""
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
fg = results['fg']
alpha = results['alpha'].astype(np.float32) / 255.
h, w = results['fg'].shape[:2]

# randomly select fg
if np.random.rand() < 0.5:
idx = np.random.randint(len(self.fg_list))
fg2 = mmcv.imread(self.fg_list[idx])
alpha2 = mmcv.imread(self.alpha_list[idx], 'grayscale')
fg2_bytes = self.file_client.get(self.fg_list[idx])
fg2 = mmcv.imfrombytes(fg2_bytes)
alpha2_bytes = self.file_client.get(self.alpha_list[idx])
alpha2 = mmcv.imfrombytes(alpha2_bytes, flag='grayscale')
alpha2 = alpha2.astype(np.float32) / 255.

fg2 = mmcv.imresize(fg2, (w, h), interpolation=self.interpolation)
Expand Down

0 comments on commit 48e39a9

Please sign in to comment.