Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add MultiViewPipeline #748

Open
wants to merge 4 commits into
base: 1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mmdet3d/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromFile,
LoadPointsFromMultiSweeps, NormalizePointsColor,
PointSegClassMapping)
LoadPointsFromMultiSweeps, MultiViewPipeline,
NormalizePointsColor, PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
Expand All @@ -23,5 +23,5 @@
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter',
'RandomDropPointsColor', 'RandomJitterPoints'
'RandomDropPointsColor', 'RandomJitterPoints', 'MultiViewPipeline'
]
77 changes: 76 additions & 1 deletion mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import mmcv
import numpy as np
from collections import defaultdict

from mmdet3d.core.points import BasePoints, get_points_type
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
from mmdet.datasets.pipelines import (Compose, LoadAnnotations,
LoadImageFromFile)


@PIPELINES.register_module()
Expand Down Expand Up @@ -666,3 +668,76 @@
repr_str += f'{indent_str}with_bbox_depth={self.with_bbox_depth}, '
repr_str += f'{indent_str}poly2mask={self.poly2mask})'
return repr_str


@PIPELINES.register_module()
class MultiViewPipeline(object):
"""Load and transform multi-view images.
Copy link
Member

@Tai-Wang Tai-Wang Jul 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we will put Load into transforms, a comment is needed here or somewhere else, because it seems a little inconsistent with the name transforms itself. Or do you think is it will be better to move load out of the transform like MultiScaleFlipAug3D?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting to split my MultiViewPipeline to 2 separate classesMultiViewPipeline and LoadMultiViewImageFromFilesV2? Not sure that adding 2 new transforms is better then 1. But if you think it is better, I can do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we can move to_float32, color_type, and file_client_args from LoadImageFromFile.__init__ to MultiViewPipeline.__init__ and remove LoadImageFromFile from MultiViewPipeline.transforms. However, it is not beautiful too.

I prefer to add a comment and assert about the necessity of LoadImageFromFile in MultiViewPipeline.transforms.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZwwWayne Please have a look at this function and see whether there is a conflict with your multi-modality models on nuScenes.


Args:
transforms (list[dict]): Transforms to apply for each image.
The list of transforms for MultiViewPipeline and
MultiScaleFlipAug differs. Here LoadImageFromFile is
required as the first transform. Other transforms as usual
can include Normalize, Pad, Resize etc.
n_images (int): Number of images to sample. Defaults to -1.
pose_keys (list[str]): Keys to be used to sample simultaneously
with images. Defaults to ('lidar2img', 'depth2img', 'cam2img').
"""

def __init__(self,
transforms,
n_images=-1,
pose_keys=('lidar2img', 'depth2img', 'cam2img')):
self.transforms = Compose(transforms)
assert isinstance(self.transforms.transforms[0], LoadImageFromFile)
self.n_images = n_images
self.pose_keys = pose_keys

def __call__(self, results):
"""Call function to load multi-view image from files.

Args:
results (dict): Result dict containing multi-view image filenames.

Returns:
dict: The result dict containing the multi-view image data.
Added keys are deducted from all pipelines from
self.transforms.
"""
assert len(results['img_info'])
pose_keys = [k for k in results if k in self.pose_keys]
for key in pose_keys:
assert len(results[key]) == len(results['img_info'])
img_pose_dict = defaultdict(list)
ids = np.arange(len(results['img_info']))

# sample self.n_images from all images
if self.n_images > 0:
replace = True if self.n_images > len(ids) else False
ids = np.random.choice(ids, self.n_images, replace=replace)

# apply self.transforms to sampled images
for i in ids.tolist():
img_results = dict(
img_prefix=results['img_prefix'],
img_info=results['img_info'][i])
img_results = self.transforms(img_results)
img_pose_dict['img'].append(img_results['img'])
img_pose_dict['img_info'].append(img_results['img_info'])
for key in pose_keys:
img_pose_dict[key].append(results[key][i])

# copy image keys to results
for key in img_results.keys():
results[key] = img_results[key]
for key in img_pose_dict:
results[key] = img_pose_dict[key]
return results

def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'n_images={self.n_images})'
return repr_str

Check warning on line 743 in mmdet3d/datasets/pipelines/loading.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/pipelines/loading.py#L740-L743

Added lines #L740 - L743 were not covered by tests
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/scannet/scannet_infos.pkl
Binary file not shown.
46 changes: 46 additions & 0 deletions tests/test_data/test_datasets/test_scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,49 @@ def test_seg_format_results():
expected_txt_path = osp.join(tmp_dir.name, 'results', 'scene0000_00.txt')
assert np.all(result_files[0]['seg_mask'] == expected_label)
mmcv.check_file_exist(expected_txt_path)


def test_multiview_getitem():
np.random.seed(0)
root_path = './tests/data/scannet/'
ann_file = './tests/data/scannet/scannet_infos.pkl'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'garbagebin')
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
pipelines = [
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(
type='MultiViewPipeline',
n_images=2,
transforms=[
dict(type='LoadImageFromFile'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32)
]),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d']),
]

scannet_dataset = ScanNetDataset(
root_path,
ann_file,
pipelines,
modality=dict(use_camera=True, use_depth=False))
data = scannet_dataset[0]
assert data['img']._data.shape == torch.Size([2, 3, 992, 1312])
assert data['img_metas']._data['ori_shape'] == (968, 1296, 3)
assert data['img_metas']._data['img_shape'] == (968, 1296, 3)
assert data['img_metas']._data['pad_shape'] == (992, 1312, 3)
assert len(data['img_metas']._data['depth2img']) == 2
for depth2img in data['img_metas']._data['depth2img']:
assert depth2img.shape == (4, 4)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import os
import torch
from mmcv.parallel import DataContainer

from mmdet3d.datasets.pipelines import (DefaultFormatBundle,
LoadMultiViewImageFromFiles)
LoadMultiViewImageFromFiles,
MultiViewPipeline)


def test_load_multi_view_image_from_files():
Expand Down Expand Up @@ -43,3 +45,31 @@ def test_load_multi_view_image_from_files():

assert isinstance(img, DataContainer)
assert img._data.shape == torch.Size((num_views, 3, 1280, 1920))


def test_multi_view_pipeline():
file_names = ['00000.jpg', '00011.jpg', '00102.jpg']
input_dict = dict(
img_prefix=None,
img_info=[
dict(
filename=os.path.join(
'tests/data/scannet/posed_images/scene0000_00', file_name))
for file_name in file_names
],
depth2img=[np.eye(4), np.eye(4) + 0.1,
np.eye(4) - 0.1])

pipeline = MultiViewPipeline(
transforms=[dict(type='LoadImageFromFile')], n_images=2)
results = pipeline(input_dict)

assert len(results['img']) == 2
assert len(results['depth2img']) == 2
shape = (968, 1296, 3)
for img in results['img']:
assert img.shape == shape
file_names = set(img_info['filename'] for img_info in results['img_info'])
assert len(file_names) == 2
depth2img = set(str(x) for x in results['depth2img'])
assert len(depth2img) == 2