Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support dataset initialization with file_client #1402

Merged
merged 3 commits into from
Mar 28, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class CustomDataset(Dataset):
Default: None
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
load gt for evaluation, load from disk by default. Default: None.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""

CLASSES = None
Expand All @@ -87,7 +90,8 @@ def __init__(self,
reduce_zero_label=False,
classes=None,
palette=None,
gt_seg_map_loader_cfg=None):
gt_seg_map_loader_cfg=None,
file_client_args=dict(backend='disk')):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
Expand All @@ -105,6 +109,9 @@ def __init__(self,
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)

self.file_client_args = file_client_args
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)

if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
Expand Down Expand Up @@ -146,16 +153,21 @@ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,

img_infos = []
if split is not None:
with open(split) as f:
for line in f:
img_name = line.strip()
img_info = dict(filename=img_name + img_suffix)
if ann_dir is not None:
seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
lines = mmcv.list_from_file(
split, file_client_args=self.file_client_args)
for line in lines:
img_name = line.strip()
img_info = dict(filename=img_name + img_suffix)
if ann_dir is not None:
seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
for img in self.file_client.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
suffix=img_suffix,
recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_map = img.replace(img_suffix, seg_map_suffix)
Expand Down