From 5ba112d3bd00e7f084c3cba32d82b389380893bd Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 3 Jul 2021 17:47:41 +0300 Subject: [PATCH 1/9] extract RGB images for ScanNet --- data/scannet/README.md | 4 +- data/scannet/extract_posed_images.py | 160 +++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 data/scannet/extract_posed_images.py diff --git a/data/scannet/README.md b/data/scannet/README.md index 855b9f38f6..de6f5c39b4 100644 --- a/data/scannet/README.md +++ b/data/scannet/README.md @@ -6,7 +6,9 @@ We follow the procedure in [votenet](https://github.com/facebookresearch/votenet 2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points. -3. Enter the project root directory, generate training data by running +3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. And after extraction .jpg images requires around 2 Tb of dist space. The recommended 300 images per scene requires less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene. + +4. Enter the project root directory, generate training data by running ```bash python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py new file mode 100644 index 0000000000..9d69bdfed7 --- /dev/null +++ b/data/scannet/extract_posed_images.py @@ -0,0 +1,160 @@ +# Modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py # noqa +import imageio +import numpy as np +import os +import struct +import zlib +from argparse import ArgumentParser + +COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'} + +COMPRESSION_TYPE_DEPTH = { + -1: 'unknown', + 0: 'raw_ushort', + 1: 'zlib_ushort', + 2: 'occi_ushort' +} + + +class RGBDFrame: + + def load(self, file_handle): + self.camera_to_world = np.asarray( + struct.unpack('f' * 16, file_handle.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] + self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] + self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.color_data = b''.join( + struct.unpack('c' * self.color_size_bytes, + file_handle.read(self.color_size_bytes))) + self.depth_data = b''.join( + struct.unpack('c' * self.depth_size_bytes, + file_handle.read(self.depth_size_bytes))) + + def decompress_depth(self, compression_type): + assert compression_type == 'zlib_ushort' + return zlib.decompress(self.depth_data) + + def decompress_color(self, compression_type): + assert compression_type == 'jpeg' + return imageio.imread(self.color_data) + + +class SensorData: + + def __init__(self, filename, limit): + self.version = 4 + self.load(filename, limit) + + def load(self, filename, limit): + with open(filename, 'rb') as f: + version = struct.unpack('I', f.read(4))[0] + assert self.version == version + strlen = struct.unpack('Q', f.read(8))[0] + self.sensor_name = b''.join( + struct.unpack('c' * strlen, f.read(strlen))) + self.intrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.intrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack( + 'i', f.read(4))[0]] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack( + 'i', f.read(4))[0]] + self.color_width = struct.unpack('I', f.read(4))[0] + self.color_height = struct.unpack('I', f.read(4))[0] + self.depth_width = struct.unpack('I', f.read(4))[0] + self.depth_height = struct.unpack('I', f.read(4))[0] + self.depth_shift = struct.unpack('f', f.read(4))[0] + num_frames = struct.unpack('Q', f.read(8))[0] + self.frames = [] + if limit > 0 and limit < num_frames: + index = np.random.choice( + np.arange(num_frames), limit, replace=False).tolist() + else: + index = list(range(num_frames)) + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + if i in index: + self.frames.append(frame) + + def export_depth_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + depth_data = self.frames[f].decompress_depth( + self.depth_compression_type) + depth = np.fromstring( + depth_data, dtype=np.uint16).reshape(self.depth_height, + self.depth_width) + imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) + + def export_color_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + color = self.frames[f].decompress_color( + self.color_compression_type) + imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) + + def save_mat_to_file(self, matrix, filename): + with open(filename, 'w') as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt='%f') + + def export_poses(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + self.save_mat_to_file(self.frames[f].camera_to_world, + os.path.join(output_path, + str(f) + '.txt')) + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + print('exporting camera intrinsics to', output_path) + self.save_mat_to_file(self.intrinsic_color, + os.path.join(output_path, 'intrinsic_color.txt')) + self.save_mat_to_file(self.extrinsic_color, + os.path.join(output_path, 'extrinsic_color.txt')) + self.save_mat_to_file(self.intrinsic_depth, + os.path.join(output_path, 'intrinsic_depth.txt')) + self.save_mat_to_file(self.extrinsic_depth, + os.path.join(output_path, 'extrinsic_depth.txt')) + + +def process_scene(path, idx, limit): + data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) + output_path = os.path.join('posed_images', idx) + data.export_color_images(output_path) + data.export_intrinsics(output_path) + data.export_poses(output_path) + + +def process_directory(path, limit): + scenes = os.listdir(path) + for idx in scenes: + print(f'extracting {path} {idx}') + process_scene(path, idx, limit) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--max-images-per-scene', type=int, default=300) + args = parser.parse_args() + if os.path.exists('scans'): + process_directory('scans', args.max_images_per_scene) + if os.path.exists('scans_test'): + process_directory('scans_test', args.max_images_per_scene) From bcc4500d55db568dca57cc2e49556c8a2de07593 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sun, 4 Jul 2021 10:04:34 +0300 Subject: [PATCH 2/9] update scannet_data_utils with rgb --- data/scannet/extract_posed_images.py | 11 ++++---- tools/data_converter/scannet_data_utils.py | 29 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py index 9d69bdfed7..c0922d36a4 100644 --- a/data/scannet/extract_posed_images.py +++ b/data/scannet/extract_posed_images.py @@ -5,6 +5,8 @@ import struct import zlib from argparse import ArgumentParser +from functools import partial +from multiprocessing import Pool COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'} @@ -124,7 +126,6 @@ def export_poses(self, output_path): def export_intrinsics(self, output_path): if not os.path.exists(output_path): os.makedirs(output_path) - print('exporting camera intrinsics to', output_path) self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) self.save_mat_to_file(self.extrinsic_color, @@ -135,7 +136,7 @@ def export_intrinsics(self, output_path): os.path.join(output_path, 'extrinsic_depth.txt')) -def process_scene(path, idx, limit): +def process_scene(path, limit, idx): data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) output_path = os.path.join('posed_images', idx) data.export_color_images(output_path) @@ -144,10 +145,8 @@ def process_scene(path, idx, limit): def process_directory(path, limit): - scenes = os.listdir(path) - for idx in scenes: - print(f'extracting {path} {idx}') - process_scene(path, idx, limit) + with Pool(8) as pool: + pool.map(partial(process_scene, path, limit), os.listdir(path)) if __name__ == '__main__': diff --git a/tools/data_converter/scannet_data_utils.py b/tools/data_converter/scannet_data_utils.py index baa8fb0a8b..adf607a4f3 100644 --- a/tools/data_converter/scannet_data_utils.py +++ b/tools/data_converter/scannet_data_utils.py @@ -1,5 +1,6 @@ import mmcv import numpy as np +import os from concurrent import futures as futures from os import path as osp @@ -60,6 +61,28 @@ def get_axis_align_matrix(self, idx): mmcv.check_file_exist(matrix_file) return np.load(matrix_file) + def get_images(self, idx): + paths = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.jpg'): + paths.append(osp.join(path, file)) + return paths + + def get_extrinsics(self, idx): + extrinsics = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.txt') and file[0].isdigit(): + extrinsics.append(np.loadtxt(osp.join(path, file))) + return extrinsics + + def get_intrinsics(self, idx): + matrix_file = osp.join(self.root_dir, 'posed_images', idx, + 'intrinsic_color.txt') + mmcv.check_file_exist(matrix_file) + return np.loadtxt(matrix_file) + def get_infos(self, num_workers=4, has_label=True, sample_id_list=None): """Get data infos. @@ -88,6 +111,12 @@ def process_single_scene(sample_idx): osp.join(self.root_dir, 'points', f'{sample_idx}.bin')) info['pts_path'] = osp.join('points', f'{sample_idx}.bin') + # update with RGB image paths if exist + if os.path.exists(osp.join(self.root_dir, 'posed_images')): + info['img_paths'] = self.get_images(sample_idx) + info['extrinsics'] = self.get_extrinsics(sample_idx) + info['intrinsics'] = self.get_intrinsics(sample_idx) + if not self.test_mode: pts_instance_mask_path = osp.join( self.root_dir, 'scannet_instance_data', From 4e123bf42543a070c3a3f3a99dcb0fe7b4864a19 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Mon, 12 Jul 2021 11:47:53 +0300 Subject: [PATCH 3/9] fix docs an tools; add use_camera to ScanNetDataset --- data/scannet/README.md | 8 +++- data/scannet/extract_posed_images.py | 36 +++++++++------ docs/datasets/scannet_det.md | 20 +++++++- mmdet3d/datasets/scannet_dataset.py | 53 +++++++++++++++++++++- tools/data_converter/scannet_data_utils.py | 6 +-- 5 files changed, 103 insertions(+), 20 deletions(-) diff --git a/data/scannet/README.md b/data/scannet/README.md index de6f5c39b4..5b072ba9ac 100644 --- a/data/scannet/README.md +++ b/data/scannet/README.md @@ -6,7 +6,7 @@ We follow the procedure in [votenet](https://github.com/facebookresearch/votenet 2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points. -3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. And after extraction .jpg images requires around 2 Tb of dist space. The recommended 300 images per scene requires less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene. +3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. After extraction, all the .jpg images require 2 Tb disk space. The recommended 300 images per scene require less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene. 4. Enter the project root directory, generate training data by running @@ -18,6 +18,7 @@ The overall process could be achieved through the following script ```bash python batch_load_scannet_data.py +python extract_posed_images.py cd ../.. python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet ``` @@ -45,6 +46,11 @@ scannet │ ├── train_resampled_scene_idxs.npy │ ├── val_label_weight.npy │ ├── val_resampled_scene_idxs.npy +├── posed_images +│ ├── scenexxxx_xx +│ │ ├── xxxxxx.txt +│ │ ├── xxxxxx.jpg +│ │ ├── intrinsic.txt ├── scannet_infos_train.pkl ├── scannet_infos_val.pkl ├── scannet_infos_test.pkl diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py index c0922d36a4..92451640e4 100644 --- a/data/scannet/extract_posed_images.py +++ b/data/scannet/extract_posed_images.py @@ -19,7 +19,7 @@ class RGBDFrame: - + """Class for single ScanNet RGB-D image processing.""" def load(self, file_handle): self.camera_to_world = np.asarray( struct.unpack('f' * 16, file_handle.read(16 * 4)), @@ -45,7 +45,9 @@ def decompress_color(self, compression_type): class SensorData: - + """Class for single ScanNet scene processing. + Single scene file contains multiple RGB-D images. + """ def __init__(self, filename, limit): self.version = 4 self.load(filename, limit) @@ -100,7 +102,8 @@ def export_depth_images(self, output_path): depth = np.fromstring( depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) - imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) + imageio.imwrite(os.path.join( + output_path, self.index_to_str(f) + '.png'), depth) def export_color_images(self, output_path): if not os.path.exists(output_path): @@ -108,9 +111,15 @@ def export_color_images(self, output_path): for f in range(len(self.frames)): color = self.frames[f].decompress_color( self.color_compression_type) - imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) + imageio.imwrite(os.path.join( + output_path, self.index_to_str(f) + '.jpg'), color) - def save_mat_to_file(self, matrix, filename): + @staticmethod + def index_to_str(index): + return str(index).zfill(5) + + @staticmethod + def save_mat_to_file(matrix, filename): with open(filename, 'w') as f: for line in matrix: np.savetxt(f, line[np.newaxis], fmt='%f') @@ -121,22 +130,20 @@ def export_poses(self, output_path): for f in range(len(self.frames)): self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, - str(f) + '.txt')) + self.index_to_str(f) + '.txt')) def export_intrinsics(self, output_path): if not os.path.exists(output_path): os.makedirs(output_path) self.save_mat_to_file(self.intrinsic_color, - os.path.join(output_path, 'intrinsic_color.txt')) - self.save_mat_to_file(self.extrinsic_color, - os.path.join(output_path, 'extrinsic_color.txt')) - self.save_mat_to_file(self.intrinsic_depth, - os.path.join(output_path, 'intrinsic_depth.txt')) - self.save_mat_to_file(self.extrinsic_depth, - os.path.join(output_path, 'extrinsic_depth.txt')) + os.path.join(output_path, 'intrinsic.txt')) def process_scene(path, limit, idx): + """Process single ScanNet scene. + Extract RGB images, poses and camera intrinsics. + """ + print(f'processing {idx}') data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) output_path = os.path.join('posed_images', idx) data.export_color_images(output_path) @@ -153,7 +160,10 @@ def process_directory(path, limit): parser = ArgumentParser() parser.add_argument('--max-images-per-scene', type=int, default=300) args = parser.parse_args() + + # process train and val scenes if os.path.exists('scans'): process_directory('scans', args.max_images_per_scene) + # process test scenes if os.path.exists('scans_test'): process_directory('scans_test', args.max_images_per_scene) diff --git a/docs/datasets/scannet_det.md b/docs/datasets/scannet_det.md index e3685cd431..b7421d6cb3 100644 --- a/docs/datasets/scannet_det.md +++ b/docs/datasets/scannet_det.md @@ -4,9 +4,9 @@ For the overall process, please refer to the [README](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) page for ScanNet. -### Export ScanNet data +### Export ScanNet point cloud data -By exporting ScanNet data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes. +By exporting ScanNet point cloud data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes. ```shell python batch_load_scannet_data.py @@ -127,6 +127,16 @@ def export(mesh_file, After exporting each scan, the raw point cloud could be downsampled, e.g. to 50000, if the number of points is too large. In addition, invalid semantic labels outside of `nyu40id` standard or optional `DONOT CARE` classes should be filtered. Finally, the point cloud data, semantic labels, instance labels and ground truth bounding boxes should be saved in `.npy` files. +### Export ScanNet RGB data + +By exporting ScanNet RGB data, for each scene we load a set of RGB images with corresponding 4x4 pose matrices, and a single 4x4 camera intrinsic matrix. Note, that this step is optional and can be skipped if multi-view detection is not planned to use. + +```shell +python extract_posed_images.py +``` + +Each of 1201 train, 312 validation and ??? test scenes contains a single `.sens` file. For instance, for scene `0001_01` we have `data/scannet/scans/scene0001_01/0001_01.sens`. For this scene all images and poses are extracted to `data/scannet/posed_images/scene0001_01`. Specifically, there will be 300 image files xxxxx.jpg, 300 camera pose files xxxxx.txt and a single `intrinsic.txt` file. Typically, single scene contains several thousand images. By default, we extract only 300 of them with resulting weight of <100 Gb. To extract more images use `--max-images-per-scene` parameter. + ### Create dataset ```shell @@ -201,6 +211,11 @@ scannet │ ├── train_resampled_scene_idxs.npy │ ├── val_label_weight.npy │ ├── val_resampled_scene_idxs.npy +├── posed_images +│ ├── scenexxxx_xx +│ │ ├── xxxxxx.txt +│ │ ├── xxxxxx.jpg +│ │ ├── intrinsic.txt ├── scannet_infos_train.pkl ├── scannet_infos_val.pkl ├── scannet_infos_test.pkl @@ -209,6 +224,7 @@ scannet - `points/xxxxx.bin`: The `axis-unaligned` point cloud data after downsample. Note: the point would be axis-aligned in pre-processing `GlobalAlignment` of 3d detection task. - `instance_mask/xxxxx.bin`: The instance label for each point, value range: [0, NUM_INSTANCES], 0: unannotated. - `semantic_mask/xxxxx.bin`: The semantic label for each point, value range: [1, 40], i.e. `nyu40id` standard. Note: the `nyu40id` id will be mapped to train id in train pipeline `PointSegClassMapping`. +- `posed_images/scenexxxx_xx`: The set of `.jpg` images with `.txt` 4x4 poses and the single `.txt` file with camera intrinsic matrix. - `scannet_infos_train.pkl`: The train data infos, the detailed info of each scan is as follows: - info['point_cloud']: {'num_features': 6, 'lidar_idx': sample_idx}. - info['pts_path']: The path of `points/xxxxx.bin`. diff --git a/mmdet3d/datasets/scannet_dataset.py b/mmdet3d/datasets/scannet_dataset.py index 7ffa397afa..91d73612d5 100644 --- a/mmdet3d/datasets/scannet_dataset.py +++ b/mmdet3d/datasets/scannet_dataset.py @@ -53,7 +53,7 @@ def __init__(self, ann_file, pipeline=None, classes=None, - modality=None, + modality=dict(use_camera=False, use_depth=True), box_type_3d='Depth', filter_empty_gt=True, test_mode=False): @@ -66,6 +66,57 @@ def __init__(self, box_type_3d=box_type_3d, filter_empty_gt=filter_empty_gt, test_mode=test_mode) + assert 'use_camera' in self.modality and \ + 'use_depth' in self.modality + assert self.modality['use_camera'] or self.modality['use_depth'] + + def get_data_info(self, index): + """Get data info according to the given index. + + Args: + index (int): Index of the sample data to get. + + Returns: + dict: Data information that will be passed to the data \ + preprocessing pipelines. It includes the following keys: + + - sample_idx (str): Sample index. + - pts_filename (str): Filename of point clouds. + - file_name (str): Filename of point clouds. + - img_prefix (str | None, optional): Prefix of image files. + - img_info (dict, optional): Image info. + - ann_info (dict): Annotation info. + """ + info = self.data_infos[index] + sample_idx = info['point_cloud']['lidar_idx'] + pts_filename = osp.join(self.data_root, info['pts_path']) + input_dict = dict(sample_idx=sample_idx) + + if self.modality['use_depth']: + input_dict['pts_filename'] = pts_filename + input_dict['file_name'] = pts_filename + + if self.modality['use_camera']: + img_filename = [] + for img_path in info['img_paths']: + img_filename.append(osp.join(self.data_root, img_path)) + intrinsic = info['intrinsics'] + axis_align_matrix = self._get_axis_align_matrix(info) + depth2img = [] + for extrinsic in info['extrinsics']: + depth2img.append( + intrinsic @ np.linalg.inv(axis_align_matrix @ extrinsic)) + + input_dict['img_prefix'] = None + input_dict['img_info'] = dict(filename=img_filename) + input_dict['depth2img'] = depth2img + + if not self.test_mode: + annos = self.get_ann_info(index) + input_dict['ann_info'] = annos + if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any(): + return None + return input_dict def get_ann_info(self, index): """Get annotation info according to the given index. diff --git a/tools/data_converter/scannet_data_utils.py b/tools/data_converter/scannet_data_utils.py index adf607a4f3..775dacc4b1 100644 --- a/tools/data_converter/scannet_data_utils.py +++ b/tools/data_converter/scannet_data_utils.py @@ -66,20 +66,20 @@ def get_images(self, idx): path = osp.join(self.root_dir, 'posed_images', idx) for file in sorted(os.listdir(path)): if file.endswith('.jpg'): - paths.append(osp.join(path, file)) + paths.append(osp.join('posed_images', idx, file)) return paths def get_extrinsics(self, idx): extrinsics = [] path = osp.join(self.root_dir, 'posed_images', idx) for file in sorted(os.listdir(path)): - if file.endswith('.txt') and file[0].isdigit(): + if file.endswith('.txt') and not file == 'intrinsic.txt': extrinsics.append(np.loadtxt(osp.join(path, file))) return extrinsics def get_intrinsics(self, idx): matrix_file = osp.join(self.root_dir, 'posed_images', idx, - 'intrinsic_color.txt') + 'intrinsic.txt') mmcv.check_file_exist(matrix_file) return np.loadtxt(matrix_file) From edbf3f6d58afa5adfc639f17ca55a347934cae7f Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Tue, 13 Jul 2021 13:27:08 +0300 Subject: [PATCH 4/9] fix typos is scannet doc --- docs/datasets/scannet_det.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/datasets/scannet_det.md b/docs/datasets/scannet_det.md index b7421d6cb3..b43b267b8c 100644 --- a/docs/datasets/scannet_det.md +++ b/docs/datasets/scannet_det.md @@ -135,7 +135,7 @@ By exporting ScanNet RGB data, for each scene we load a set of RGB images with c python extract_posed_images.py ``` -Each of 1201 train, 312 validation and ??? test scenes contains a single `.sens` file. For instance, for scene `0001_01` we have `data/scannet/scans/scene0001_01/0001_01.sens`. For this scene all images and poses are extracted to `data/scannet/posed_images/scene0001_01`. Specifically, there will be 300 image files xxxxx.jpg, 300 camera pose files xxxxx.txt and a single `intrinsic.txt` file. Typically, single scene contains several thousand images. By default, we extract only 300 of them with resulting weight of <100 Gb. To extract more images use `--max-images-per-scene` parameter. +Each of 1201 train, 312 validation and 100 test scenes contains a single `.sens` file. For instance, for scene `0001_01` we have `data/scannet/scans/scene0001_01/0001_01.sens`. For this scene all images and poses are extracted to `data/scannet/posed_images/scene0001_01`. Specifically, there will be 300 image files xxxxx.jpg, 300 camera pose files xxxxx.txt and a single `intrinsic.txt` file. Typically, single scene contains several thousand images. By default, we extract only 300 of them with resulting weight of <100 Gb. To extract more images, use `--max-images-per-scene` parameter. ### Create dataset From 2e25eb6b55e6a762a3210a50f7a06ad599ed5219 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Tue, 13 Jul 2021 16:27:28 +0300 Subject: [PATCH 5/9] fix very rare undefined poses for ScanNet --- tools/data_converter/scannet_data_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tools/data_converter/scannet_data_utils.py b/tools/data_converter/scannet_data_utils.py index 775dacc4b1..0b4bc64391 100644 --- a/tools/data_converter/scannet_data_utils.py +++ b/tools/data_converter/scannet_data_utils.py @@ -113,9 +113,17 @@ def process_single_scene(sample_idx): # update with RGB image paths if exist if os.path.exists(osp.join(self.root_dir, 'posed_images')): - info['img_paths'] = self.get_images(sample_idx) - info['extrinsics'] = self.get_extrinsics(sample_idx) info['intrinsics'] = self.get_intrinsics(sample_idx) + all_extrinsics = self.get_extrinsics(sample_idx) + all_img_paths = self.get_images(sample_idx) + # some poses in ScanNet are invalid + extrinsics, img_paths = [], [] + for extrinsic, img_path in zip(all_extrinsics, all_img_paths): + if np.all(np.isfinite(extrinsic)): + img_paths.append(img_path) + extrinsics.append(extrinsic) + info['extrinsics'] = extrinsics + info['img_paths'] = img_paths if not self.test_mode: pts_instance_mask_path = osp.join( From d6c0d44f952930edf096200d39acfec31b92888e Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Tue, 13 Jul 2021 17:41:20 +0300 Subject: [PATCH 6/9] update ScanNet dataset for more clear MultiViewPipeline --- mmdet3d/datasets/scannet_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmdet3d/datasets/scannet_dataset.py b/mmdet3d/datasets/scannet_dataset.py index 91d73612d5..8a109d835e 100644 --- a/mmdet3d/datasets/scannet_dataset.py +++ b/mmdet3d/datasets/scannet_dataset.py @@ -97,9 +97,10 @@ def get_data_info(self, index): input_dict['file_name'] = pts_filename if self.modality['use_camera']: - img_filename = [] + img_info = [] for img_path in info['img_paths']: - img_filename.append(osp.join(self.data_root, img_path)) + img_info.append( + dict(filename=osp.join(self.data_root, img_path))) intrinsic = info['intrinsics'] axis_align_matrix = self._get_axis_align_matrix(info) depth2img = [] @@ -108,7 +109,7 @@ def get_data_info(self, index): intrinsic @ np.linalg.inv(axis_align_matrix @ extrinsic)) input_dict['img_prefix'] = None - input_dict['img_info'] = dict(filename=img_filename) + input_dict['img_info'] = img_info input_dict['depth2img'] = depth2img if not self.test_mode: From a51645ac3d4b81fc692d13c8b6a051864f528171 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Wed, 14 Jul 2021 15:52:24 +0300 Subject: [PATCH 7/9] update compatibility for ScanNet --- docs/compatibility.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/compatibility.md b/docs/compatibility.md index cca2522bcb..ff8848bf96 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -2,6 +2,18 @@ This document provides detailed descriptions of the BC-breaking changes in MMDetection3D. +## MMDetection3D 0.16.0 + +### ScanNet dataset for ImVoxelNet + +We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: + +- Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data. +- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating. +- Add instructions in the documents on preparing image data. + +Please refer to the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) for more details. + ## MMDetection3D 0.15.0 ### MMCV Version From 054508697927e8133b7b32261c06050a7bd4411d Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Wed, 14 Jul 2021 17:04:27 +0300 Subject: [PATCH 8/9] fix typo in compatibility doc --- docs/compatibility.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index ff8848bf96..b29fd0b18e 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -6,7 +6,7 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet ### ScanNet dataset for ImVoxelNet -We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: +We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection and segmentation methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: - Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data. - Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating. From aea1385e2764c056076f6f2f37b82d1b3bc7c56f Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Fri, 16 Jul 2021 09:46:13 +0300 Subject: [PATCH 9/9] use mmcv.track_parallel_progress for scannet images --- data/scannet/extract_posed_images.py | 38 ++++++++++++++++++---------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py index 92451640e4..7018f32d11 100644 --- a/data/scannet/extract_posed_images.py +++ b/data/scannet/extract_posed_images.py @@ -1,12 +1,12 @@ # Modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py # noqa import imageio +import mmcv import numpy as np import os import struct import zlib from argparse import ArgumentParser from functools import partial -from multiprocessing import Pool COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'} @@ -20,6 +20,7 @@ class RGBDFrame: """Class for single ScanNet RGB-D image processing.""" + def load(self, file_handle): self.camera_to_world = np.asarray( struct.unpack('f' * 16, file_handle.read(16 * 4)), @@ -46,8 +47,10 @@ def decompress_color(self, compression_type): class SensorData: """Class for single ScanNet scene processing. + Single scene file contains multiple RGB-D images. """ + def __init__(self, filename, limit): self.version = 4 self.load(filename, limit) @@ -102,8 +105,9 @@ def export_depth_images(self, output_path): depth = np.fromstring( depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) - imageio.imwrite(os.path.join( - output_path, self.index_to_str(f) + '.png'), depth) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.png'), depth) def export_color_images(self, output_path): if not os.path.exists(output_path): @@ -111,8 +115,9 @@ def export_color_images(self, output_path): for f in range(len(self.frames)): color = self.frames[f].decompress_color( self.color_compression_type) - imageio.imwrite(os.path.join( - output_path, self.index_to_str(f) + '.jpg'), color) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.jpg'), color) @staticmethod def index_to_str(index): @@ -128,9 +133,10 @@ def export_poses(self, output_path): if not os.path.exists(output_path): os.makedirs(output_path) for f in range(len(self.frames)): - self.save_mat_to_file(self.frames[f].camera_to_world, - os.path.join(output_path, - self.index_to_str(f) + '.txt')) + self.save_mat_to_file( + self.frames[f].camera_to_world, + os.path.join(output_path, + self.index_to_str(f) + '.txt')) def export_intrinsics(self, output_path): if not os.path.exists(output_path): @@ -141,9 +147,9 @@ def export_intrinsics(self, output_path): def process_scene(path, limit, idx): """Process single ScanNet scene. + Extract RGB images, poses and camera intrinsics. """ - print(f'processing {idx}') data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) output_path = os.path.join('posed_images', idx) data.export_color_images(output_path) @@ -151,19 +157,23 @@ def process_scene(path, limit, idx): data.export_poses(output_path) -def process_directory(path, limit): - with Pool(8) as pool: - pool.map(partial(process_scene, path, limit), os.listdir(path)) +def process_directory(path, limit, nproc): + print(f'processing {path}') + mmcv.track_parallel_progress( + func=partial(process_scene, path, limit), + tasks=os.listdir(path), + nproc=nproc) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--max-images-per-scene', type=int, default=300) + parser.add_argument('--nproc', type=int, default=8) args = parser.parse_args() # process train and val scenes if os.path.exists('scans'): - process_directory('scans', args.max_images_per_scene) + process_directory('scans', args.max_images_per_scene, args.nproc) # process test scenes if os.path.exists('scans_test'): - process_directory('scans_test', args.max_images_per_scene) + process_directory('scans_test', args.max_images_per_scene, args.nproc)