From da4c3af9db571eee5b3fc7cd059eaf4764a698ca Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Tue, 20 Jul 2021 12:09:32 +0300 Subject: [PATCH] [Enhance] Support RGB images on ScanNet for multi-view detector (#696) * extract RGB images for ScanNet * update scannet_data_utils with rgb * fix docs an tools; add use_camera to ScanNetDataset * fix typos is scannet doc * fix very rare undefined poses for ScanNet * update ScanNet dataset for more clear MultiViewPipeline * update compatibility for ScanNet * fix typo in compatibility doc * use mmcv.track_parallel_progress for scannet images --- data/scannet/README.md | 10 +- data/scannet/extract_posed_images.py | 179 +++++++++++++++++++++ docs/compatibility.md | 12 ++ docs/datasets/scannet_det.md | 20 ++- mmdet3d/datasets/scannet_dataset.py | 54 ++++++- tools/data_converter/scannet_data_utils.py | 37 +++++ 6 files changed, 308 insertions(+), 4 deletions(-) create mode 100644 data/scannet/extract_posed_images.py diff --git a/data/scannet/README.md b/data/scannet/README.md index 855b9f38f6..5b072ba9ac 100644 --- a/data/scannet/README.md +++ b/data/scannet/README.md @@ -6,7 +6,9 @@ We follow the procedure in [votenet](https://github.com/facebookresearch/votenet 2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points. -3. Enter the project root directory, generate training data by running +3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. After extraction, all the .jpg images require 2 Tb disk space. The recommended 300 images per scene require less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene. + +4. Enter the project root directory, generate training data by running ```bash python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet @@ -16,6 +18,7 @@ The overall process could be achieved through the following script ```bash python batch_load_scannet_data.py +python extract_posed_images.py cd ../.. python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet ``` @@ -43,6 +46,11 @@ scannet │ ├── train_resampled_scene_idxs.npy │ ├── val_label_weight.npy │ ├── val_resampled_scene_idxs.npy +├── posed_images +│ ├── scenexxxx_xx +│ │ ├── xxxxxx.txt +│ │ ├── xxxxxx.jpg +│ │ ├── intrinsic.txt ├── scannet_infos_train.pkl ├── scannet_infos_val.pkl ├── scannet_infos_test.pkl diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py new file mode 100644 index 0000000000..7018f32d11 --- /dev/null +++ b/data/scannet/extract_posed_images.py @@ -0,0 +1,179 @@ +# Modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py # noqa +import imageio +import mmcv +import numpy as np +import os +import struct +import zlib +from argparse import ArgumentParser +from functools import partial + +COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'} + +COMPRESSION_TYPE_DEPTH = { + -1: 'unknown', + 0: 'raw_ushort', + 1: 'zlib_ushort', + 2: 'occi_ushort' +} + + +class RGBDFrame: + """Class for single ScanNet RGB-D image processing.""" + + def load(self, file_handle): + self.camera_to_world = np.asarray( + struct.unpack('f' * 16, file_handle.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] + self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] + self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.color_data = b''.join( + struct.unpack('c' * self.color_size_bytes, + file_handle.read(self.color_size_bytes))) + self.depth_data = b''.join( + struct.unpack('c' * self.depth_size_bytes, + file_handle.read(self.depth_size_bytes))) + + def decompress_depth(self, compression_type): + assert compression_type == 'zlib_ushort' + return zlib.decompress(self.depth_data) + + def decompress_color(self, compression_type): + assert compression_type == 'jpeg' + return imageio.imread(self.color_data) + + +class SensorData: + """Class for single ScanNet scene processing. + + Single scene file contains multiple RGB-D images. + """ + + def __init__(self, filename, limit): + self.version = 4 + self.load(filename, limit) + + def load(self, filename, limit): + with open(filename, 'rb') as f: + version = struct.unpack('I', f.read(4))[0] + assert self.version == version + strlen = struct.unpack('Q', f.read(8))[0] + self.sensor_name = b''.join( + struct.unpack('c' * strlen, f.read(strlen))) + self.intrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.intrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack( + 'i', f.read(4))[0]] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack( + 'i', f.read(4))[0]] + self.color_width = struct.unpack('I', f.read(4))[0] + self.color_height = struct.unpack('I', f.read(4))[0] + self.depth_width = struct.unpack('I', f.read(4))[0] + self.depth_height = struct.unpack('I', f.read(4))[0] + self.depth_shift = struct.unpack('f', f.read(4))[0] + num_frames = struct.unpack('Q', f.read(8))[0] + self.frames = [] + if limit > 0 and limit < num_frames: + index = np.random.choice( + np.arange(num_frames), limit, replace=False).tolist() + else: + index = list(range(num_frames)) + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + if i in index: + self.frames.append(frame) + + def export_depth_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + depth_data = self.frames[f].decompress_depth( + self.depth_compression_type) + depth = np.fromstring( + depth_data, dtype=np.uint16).reshape(self.depth_height, + self.depth_width) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.png'), depth) + + def export_color_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + color = self.frames[f].decompress_color( + self.color_compression_type) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.jpg'), color) + + @staticmethod + def index_to_str(index): + return str(index).zfill(5) + + @staticmethod + def save_mat_to_file(matrix, filename): + with open(filename, 'w') as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt='%f') + + def export_poses(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + self.save_mat_to_file( + self.frames[f].camera_to_world, + os.path.join(output_path, + self.index_to_str(f) + '.txt')) + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + self.save_mat_to_file(self.intrinsic_color, + os.path.join(output_path, 'intrinsic.txt')) + + +def process_scene(path, limit, idx): + """Process single ScanNet scene. + + Extract RGB images, poses and camera intrinsics. + """ + data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) + output_path = os.path.join('posed_images', idx) + data.export_color_images(output_path) + data.export_intrinsics(output_path) + data.export_poses(output_path) + + +def process_directory(path, limit, nproc): + print(f'processing {path}') + mmcv.track_parallel_progress( + func=partial(process_scene, path, limit), + tasks=os.listdir(path), + nproc=nproc) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--max-images-per-scene', type=int, default=300) + parser.add_argument('--nproc', type=int, default=8) + args = parser.parse_args() + + # process train and val scenes + if os.path.exists('scans'): + process_directory('scans', args.max_images_per_scene, args.nproc) + # process test scenes + if os.path.exists('scans_test'): + process_directory('scans_test', args.max_images_per_scene, args.nproc) diff --git a/docs/compatibility.md b/docs/compatibility.md index cca2522bcb..b29fd0b18e 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -2,6 +2,18 @@ This document provides detailed descriptions of the BC-breaking changes in MMDetection3D. +## MMDetection3D 0.16.0 + +### ScanNet dataset for ImVoxelNet + +We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection and segmentation methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: + +- Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data. +- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating. +- Add instructions in the documents on preparing image data. + +Please refer to the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) for more details. + ## MMDetection3D 0.15.0 ### MMCV Version diff --git a/docs/datasets/scannet_det.md b/docs/datasets/scannet_det.md index e3685cd431..b43b267b8c 100644 --- a/docs/datasets/scannet_det.md +++ b/docs/datasets/scannet_det.md @@ -4,9 +4,9 @@ For the overall process, please refer to the [README](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) page for ScanNet. -### Export ScanNet data +### Export ScanNet point cloud data -By exporting ScanNet data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes. +By exporting ScanNet point cloud data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes. ```shell python batch_load_scannet_data.py @@ -127,6 +127,16 @@ def export(mesh_file, After exporting each scan, the raw point cloud could be downsampled, e.g. to 50000, if the number of points is too large. In addition, invalid semantic labels outside of `nyu40id` standard or optional `DONOT CARE` classes should be filtered. Finally, the point cloud data, semantic labels, instance labels and ground truth bounding boxes should be saved in `.npy` files. +### Export ScanNet RGB data + +By exporting ScanNet RGB data, for each scene we load a set of RGB images with corresponding 4x4 pose matrices, and a single 4x4 camera intrinsic matrix. Note, that this step is optional and can be skipped if multi-view detection is not planned to use. + +```shell +python extract_posed_images.py +``` + +Each of 1201 train, 312 validation and 100 test scenes contains a single `.sens` file. For instance, for scene `0001_01` we have `data/scannet/scans/scene0001_01/0001_01.sens`. For this scene all images and poses are extracted to `data/scannet/posed_images/scene0001_01`. Specifically, there will be 300 image files xxxxx.jpg, 300 camera pose files xxxxx.txt and a single `intrinsic.txt` file. Typically, single scene contains several thousand images. By default, we extract only 300 of them with resulting weight of <100 Gb. To extract more images, use `--max-images-per-scene` parameter. + ### Create dataset ```shell @@ -201,6 +211,11 @@ scannet │ ├── train_resampled_scene_idxs.npy │ ├── val_label_weight.npy │ ├── val_resampled_scene_idxs.npy +├── posed_images +│ ├── scenexxxx_xx +│ │ ├── xxxxxx.txt +│ │ ├── xxxxxx.jpg +│ │ ├── intrinsic.txt ├── scannet_infos_train.pkl ├── scannet_infos_val.pkl ├── scannet_infos_test.pkl @@ -209,6 +224,7 @@ scannet - `points/xxxxx.bin`: The `axis-unaligned` point cloud data after downsample. Note: the point would be axis-aligned in pre-processing `GlobalAlignment` of 3d detection task. - `instance_mask/xxxxx.bin`: The instance label for each point, value range: [0, NUM_INSTANCES], 0: unannotated. - `semantic_mask/xxxxx.bin`: The semantic label for each point, value range: [1, 40], i.e. `nyu40id` standard. Note: the `nyu40id` id will be mapped to train id in train pipeline `PointSegClassMapping`. +- `posed_images/scenexxxx_xx`: The set of `.jpg` images with `.txt` 4x4 poses and the single `.txt` file with camera intrinsic matrix. - `scannet_infos_train.pkl`: The train data infos, the detailed info of each scan is as follows: - info['point_cloud']: {'num_features': 6, 'lidar_idx': sample_idx}. - info['pts_path']: The path of `points/xxxxx.bin`. diff --git a/mmdet3d/datasets/scannet_dataset.py b/mmdet3d/datasets/scannet_dataset.py index 7ffa397afa..8a109d835e 100644 --- a/mmdet3d/datasets/scannet_dataset.py +++ b/mmdet3d/datasets/scannet_dataset.py @@ -53,7 +53,7 @@ def __init__(self, ann_file, pipeline=None, classes=None, - modality=None, + modality=dict(use_camera=False, use_depth=True), box_type_3d='Depth', filter_empty_gt=True, test_mode=False): @@ -66,6 +66,58 @@ def __init__(self, box_type_3d=box_type_3d, filter_empty_gt=filter_empty_gt, test_mode=test_mode) + assert 'use_camera' in self.modality and \ + 'use_depth' in self.modality + assert self.modality['use_camera'] or self.modality['use_depth'] + + def get_data_info(self, index): + """Get data info according to the given index. + + Args: + index (int): Index of the sample data to get. + + Returns: + dict: Data information that will be passed to the data \ + preprocessing pipelines. It includes the following keys: + + - sample_idx (str): Sample index. + - pts_filename (str): Filename of point clouds. + - file_name (str): Filename of point clouds. + - img_prefix (str | None, optional): Prefix of image files. + - img_info (dict, optional): Image info. + - ann_info (dict): Annotation info. + """ + info = self.data_infos[index] + sample_idx = info['point_cloud']['lidar_idx'] + pts_filename = osp.join(self.data_root, info['pts_path']) + input_dict = dict(sample_idx=sample_idx) + + if self.modality['use_depth']: + input_dict['pts_filename'] = pts_filename + input_dict['file_name'] = pts_filename + + if self.modality['use_camera']: + img_info = [] + for img_path in info['img_paths']: + img_info.append( + dict(filename=osp.join(self.data_root, img_path))) + intrinsic = info['intrinsics'] + axis_align_matrix = self._get_axis_align_matrix(info) + depth2img = [] + for extrinsic in info['extrinsics']: + depth2img.append( + intrinsic @ np.linalg.inv(axis_align_matrix @ extrinsic)) + + input_dict['img_prefix'] = None + input_dict['img_info'] = img_info + input_dict['depth2img'] = depth2img + + if not self.test_mode: + annos = self.get_ann_info(index) + input_dict['ann_info'] = annos + if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any(): + return None + return input_dict def get_ann_info(self, index): """Get annotation info according to the given index. diff --git a/tools/data_converter/scannet_data_utils.py b/tools/data_converter/scannet_data_utils.py index baa8fb0a8b..0b4bc64391 100644 --- a/tools/data_converter/scannet_data_utils.py +++ b/tools/data_converter/scannet_data_utils.py @@ -1,5 +1,6 @@ import mmcv import numpy as np +import os from concurrent import futures as futures from os import path as osp @@ -60,6 +61,28 @@ def get_axis_align_matrix(self, idx): mmcv.check_file_exist(matrix_file) return np.load(matrix_file) + def get_images(self, idx): + paths = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.jpg'): + paths.append(osp.join('posed_images', idx, file)) + return paths + + def get_extrinsics(self, idx): + extrinsics = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.txt') and not file == 'intrinsic.txt': + extrinsics.append(np.loadtxt(osp.join(path, file))) + return extrinsics + + def get_intrinsics(self, idx): + matrix_file = osp.join(self.root_dir, 'posed_images', idx, + 'intrinsic.txt') + mmcv.check_file_exist(matrix_file) + return np.loadtxt(matrix_file) + def get_infos(self, num_workers=4, has_label=True, sample_id_list=None): """Get data infos. @@ -88,6 +111,20 @@ def process_single_scene(sample_idx): osp.join(self.root_dir, 'points', f'{sample_idx}.bin')) info['pts_path'] = osp.join('points', f'{sample_idx}.bin') + # update with RGB image paths if exist + if os.path.exists(osp.join(self.root_dir, 'posed_images')): + info['intrinsics'] = self.get_intrinsics(sample_idx) + all_extrinsics = self.get_extrinsics(sample_idx) + all_img_paths = self.get_images(sample_idx) + # some poses in ScanNet are invalid + extrinsics, img_paths = [], [] + for extrinsic, img_path in zip(all_extrinsics, all_img_paths): + if np.all(np.isfinite(extrinsic)): + img_paths.append(img_path) + extrinsics.append(extrinsic) + info['extrinsics'] = extrinsics + info['img_paths'] = img_paths + if not self.test_mode: pts_instance_mask_path = osp.join( self.root_dir, 'scannet_instance_data',