diff --git a/README.md b/README.md index 4eeac34..b49de52 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ # ImVoxelNet: Image to Voxels Projection for Monocular and Multi-View General-Purpose 3D Object Detection **News**: + * :fire: July, 2021. We update `ScanNet` image preprocessing both [here](https://github.com/saic-vul/imvoxelnet/pull/21) and in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/pull/696). * :fire: June, 2021. `ImVoxelNet` for `KITTI` is now [supported](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/imvoxelnet) in [mmdetection3d](https://github.com/open-mmlab/mmdetection3d). This repository contains implementation of the monocular/multi-view 3D object detector ImVoxelNet, introduced in our paper: @@ -38,7 +39,7 @@ We support three benchmarks based on the **SUN RGB-D** dataset. you should follow the instructions in [sunrgbd](data/sunrgbd). * For the [PerspectiveNet](https://papers.nips.cc/paper/2019/hash/b87517992f7dce71b674976b280257d2-Abstract.html) benchmark with 30 object categories, the same instructions can be applied; - you only need to pass `--dataset sunrgbd_monocular` when running `create_data.py`. + you only need to set `dataset` argument to `sunrgbd_monocular` when running `create_data.py`. * The [Total3DUnderstanding](https://github.com/yinyunie/Total3DUnderstanding) benchmark implies detecting objects of 37 categories along with camera pose and room layout estimation. Download the preprocessed data as @@ -49,38 +50,9 @@ We support three benchmarks based on the **SUN RGB-D** dataset. python tools/data_converter/sunrgbd_total.py ``` -**ScanNet.** Please follow instructions in [scannet](data/scannet). -Note that `create_data.py` works with point clouds, not RGB images; thus, you should do some preprocessing before running `create_data.py`. -1. First, you should obtain RGB images. We recommend using a script from [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python). -2. Then, copy the camera pose `.txt` files and `.jpg` images to the `scannet/sens_reader` folder. -3. Copy axis alignment matrix `.txt` files to the `scannet/txts` folder. -4. Move the results of `batch_load_scannet_data.py` to the `scannet/mmdetection3d` folder. Final directory structure: -``` -scannet -├── sens_reader -│ ├── scans -│ │ ├── scene0000_00 -│ │ │ ├── out -│ │ │ │ ├── frame-000001.color.jpg -│ │ │ │ ├── frame-000001.pose.txt -│ │ │ │ ├── frame-000002.color.jpg -│ │ │ │ ├── ... -│ │ ├── ... -├── txts -│ ├── scene0000_00.txt -│ ├── ... -├── mmdetection3d -│ ├── scene0000_00_bbox.npy -│ ├── scene0000_00_ins_label.npy -│ ├── scene0000_00_sem_label.npy -│ ├── scene0000_00_vert.npy -│ ├── scene0000_01_bbox.npy -│ ├── ... -``` -Now, you may run `create_data.py` with `--dataset scannet_monocular`. - +For **ScanNet** please follow instructions in [scannet](data/scannet). For **KITTI** and **nuScenes**, please follow instructions in [getting_started.md](docs/getting_started.md). -For `nuScenes`, set `--dataset nuscenes_monocular`. +For `nuScenes`, set `dataset` argument to `nuscenes_monocular`. ### Getting Started diff --git a/data/scannet/README.md b/data/scannet/README.md index ccb6e30..5b072ba 100644 --- a/data/scannet/README.md +++ b/data/scannet/README.md @@ -1,23 +1,30 @@ -### Prepare ScanNet Data +### Prepare ScanNet Data for Indoor Detection or Segmentation Task + We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/). -1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. +1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. If you are performing segmentation tasks and want to upload the results to its official [benchmark](http://kaldir.vc.in.tum.de/scannet_benchmark/), please also link or move the 'scans_test' folder to this directory. + +2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points. + +3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. After extraction, all the .jpg images require 2 Tb disk space. The recommended 300 images per scene require less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene. -2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. +4. Enter the project root directory, generate training data by running -3. Enter the project root directory, generate training data by running ```bash python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet ``` The overall process could be achieved through the following script + ```bash python batch_load_scannet_data.py +python extract_posed_images.py cd ../.. python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet ``` The directory structure after pre-processing should be as below + ``` scannet ├── scannet_utils.py @@ -26,11 +33,26 @@ scannet ├── scannet_utils.py ├── README.md ├── scans -├── scannet_train_instance_data +├── scans_test +├── scannet_instance_data ├── points +│ ├── xxxxx.bin ├── instance_mask +│ ├── xxxxx.bin ├── semantic_mask +│ ├── xxxxx.bin +├── seg_info +│ ├── train_label_weight.npy +│ ├── train_resampled_scene_idxs.npy +│ ├── val_label_weight.npy +│ ├── val_resampled_scene_idxs.npy +├── posed_images +│ ├── scenexxxx_xx +│ │ ├── xxxxxx.txt +│ │ ├── xxxxxx.jpg +│ │ ├── intrinsic.txt ├── scannet_infos_train.pkl ├── scannet_infos_val.pkl +├── scannet_infos_test.pkl ``` diff --git a/data/scannet/batch_load_scannet_data.py b/data/scannet/batch_load_scannet_data.py index 38ca1de..60b53b3 100644 --- a/data/scannet/batch_load_scannet_data.py +++ b/data/scannet/batch_load_scannet_data.py @@ -16,14 +16,17 @@ from load_scannet_data import export from os import path as osp -SCANNET_DIR = 'scans' DONOTCARE_CLASS_IDS = np.array([]) OBJ_CLASS_IDS = np.array( [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) -def export_one_scan(scan_name, output_filename_prefix, max_num_point, - label_map_file, scannet_dir): +def export_one_scan(scan_name, + output_filename_prefix, + max_num_point, + label_map_file, + scannet_dir, + test_mode=False): mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply') agg_file = osp.join(scannet_dir, scan_name, scan_name + '.aggregation.json') @@ -31,43 +34,63 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point, scan_name + '_vh_clean_2.0.010000.segs.json') # includes axisAlignment info for the train set scans. meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt') - mesh_vertices, semantic_labels, instance_labels, instance_bboxes, \ - instance2semantic = export(mesh_file, agg_file, seg_file, - meta_file, label_map_file, None) - - mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS)) - mesh_vertices = mesh_vertices[mask, :] - semantic_labels = semantic_labels[mask] - instance_labels = instance_labels[mask] - - num_instances = len(np.unique(instance_labels)) - print(f'Num of instances: {num_instances}') - - bbox_mask = np.in1d(instance_bboxes[:, -1], OBJ_CLASS_IDS) - instance_bboxes = instance_bboxes[bbox_mask, :] - print(f'Num of care instances: {instance_bboxes.shape[0]}') - - N = mesh_vertices.shape[0] - if N > max_num_point: - choices = np.random.choice(N, max_num_point, replace=False) - mesh_vertices = mesh_vertices[choices, :] - semantic_labels = semantic_labels[choices] - instance_labels = instance_labels[choices] + mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \ + aligned_bboxes, instance2semantic, axis_align_matrix = export( + mesh_file, agg_file, seg_file, meta_file, label_map_file, None, + test_mode) + + if not test_mode: + mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS)) + mesh_vertices = mesh_vertices[mask, :] + semantic_labels = semantic_labels[mask] + instance_labels = instance_labels[mask] + + num_instances = len(np.unique(instance_labels)) + print(f'Num of instances: {num_instances}') + + bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS) + unaligned_bboxes = unaligned_bboxes[bbox_mask, :] + bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS) + aligned_bboxes = aligned_bboxes[bbox_mask, :] + assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0] + print(f'Num of care instances: {unaligned_bboxes.shape[0]}') + + if max_num_point is not None: + max_num_point = int(max_num_point) + N = mesh_vertices.shape[0] + if N > max_num_point: + choices = np.random.choice(N, max_num_point, replace=False) + mesh_vertices = mesh_vertices[choices, :] + if not test_mode: + semantic_labels = semantic_labels[choices] + instance_labels = instance_labels[choices] np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices) - np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels) - np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels) - np.save(f'{output_filename_prefix}_bbox.npy', instance_bboxes) - - -def batch_export(max_num_point, output_folder, train_scan_names_file, - label_map_file, scannet_dir): + if not test_mode: + np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels) + np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels) + np.save(f'{output_filename_prefix}_unaligned_bbox.npy', + unaligned_bboxes) + np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes) + np.save(f'{output_filename_prefix}_axis_align_matrix.npy', + axis_align_matrix) + + +def batch_export(max_num_point, + output_folder, + scan_names_file, + label_map_file, + scannet_dir, + test_mode=False): + if test_mode and not os.path.exists(scannet_dir): + # test data preparation is optional + return if not os.path.exists(output_folder): print(f'Creating new data folder: {output_folder}') os.mkdir(output_folder) - train_scan_names = [line.rstrip() for line in open(train_scan_names_file)] - for scan_name in train_scan_names: + scan_names = [line.rstrip() for line in open(scan_names_file)] + for scan_name in scan_names: print('-' * 20 + 'begin') print(datetime.datetime.now()) print(scan_name) @@ -78,7 +101,7 @@ def batch_export(max_num_point, output_folder, train_scan_names_file, continue try: export_one_scan(scan_name, output_filename_prefix, max_num_point, - label_map_file, scannet_dir) + label_map_file, scannet_dir, test_mode) except Exception: print(f'Failed export scan: {scan_name}') print('-' * 20 + 'done') @@ -88,14 +111,18 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( '--max_num_point', - default=50000, + default=None, help='The maximum number of the points.') parser.add_argument( '--output_folder', - default='./scannet_train_instance_data', + default='./scannet_instance_data', help='output folder of the result.') parser.add_argument( - '--scannet_dir', default='scans', help='scannet data directory.') + '--train_scannet_dir', default='scans', help='scannet data directory.') + parser.add_argument( + '--test_scannet_dir', + default='scans_test', + help='scannet data directory.') parser.add_argument( '--label_map_file', default='meta_data/scannetv2-labels.combined.tsv', @@ -104,10 +131,25 @@ def main(): '--train_scan_names_file', default='meta_data/scannet_train.txt', help='The path of the file that stores the scan names.') + parser.add_argument( + '--test_scan_names_file', + default='meta_data/scannetv2_test.txt', + help='The path of the file that stores the scan names.') args = parser.parse_args() - batch_export(args.max_num_point, args.output_folder, - args.train_scan_names_file, args.label_map_file, - args.scannet_dir) + batch_export( + args.max_num_point, + args.output_folder, + args.train_scan_names_file, + args.label_map_file, + args.train_scannet_dir, + test_mode=False) + batch_export( + args.max_num_point, + args.output_folder, + args.test_scan_names_file, + args.label_map_file, + args.test_scannet_dir, + test_mode=True) if __name__ == '__main__': diff --git a/data/scannet/extract_posed_images.py b/data/scannet/extract_posed_images.py new file mode 100644 index 0000000..7018f32 --- /dev/null +++ b/data/scannet/extract_posed_images.py @@ -0,0 +1,179 @@ +# Modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py # noqa +import imageio +import mmcv +import numpy as np +import os +import struct +import zlib +from argparse import ArgumentParser +from functools import partial + +COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'} + +COMPRESSION_TYPE_DEPTH = { + -1: 'unknown', + 0: 'raw_ushort', + 1: 'zlib_ushort', + 2: 'occi_ushort' +} + + +class RGBDFrame: + """Class for single ScanNet RGB-D image processing.""" + + def load(self, file_handle): + self.camera_to_world = np.asarray( + struct.unpack('f' * 16, file_handle.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] + self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] + self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.color_data = b''.join( + struct.unpack('c' * self.color_size_bytes, + file_handle.read(self.color_size_bytes))) + self.depth_data = b''.join( + struct.unpack('c' * self.depth_size_bytes, + file_handle.read(self.depth_size_bytes))) + + def decompress_depth(self, compression_type): + assert compression_type == 'zlib_ushort' + return zlib.decompress(self.depth_data) + + def decompress_color(self, compression_type): + assert compression_type == 'jpeg' + return imageio.imread(self.color_data) + + +class SensorData: + """Class for single ScanNet scene processing. + + Single scene file contains multiple RGB-D images. + """ + + def __init__(self, filename, limit): + self.version = 4 + self.load(filename, limit) + + def load(self, filename, limit): + with open(filename, 'rb') as f: + version = struct.unpack('I', f.read(4))[0] + assert self.version == version + strlen = struct.unpack('Q', f.read(8))[0] + self.sensor_name = b''.join( + struct.unpack('c' * strlen, f.read(strlen))) + self.intrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_color = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.intrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.extrinsic_depth = np.asarray( + struct.unpack('f' * 16, f.read(16 * 4)), + dtype=np.float32).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack( + 'i', f.read(4))[0]] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack( + 'i', f.read(4))[0]] + self.color_width = struct.unpack('I', f.read(4))[0] + self.color_height = struct.unpack('I', f.read(4))[0] + self.depth_width = struct.unpack('I', f.read(4))[0] + self.depth_height = struct.unpack('I', f.read(4))[0] + self.depth_shift = struct.unpack('f', f.read(4))[0] + num_frames = struct.unpack('Q', f.read(8))[0] + self.frames = [] + if limit > 0 and limit < num_frames: + index = np.random.choice( + np.arange(num_frames), limit, replace=False).tolist() + else: + index = list(range(num_frames)) + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + if i in index: + self.frames.append(frame) + + def export_depth_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + depth_data = self.frames[f].decompress_depth( + self.depth_compression_type) + depth = np.fromstring( + depth_data, dtype=np.uint16).reshape(self.depth_height, + self.depth_width) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.png'), depth) + + def export_color_images(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + color = self.frames[f].decompress_color( + self.color_compression_type) + imageio.imwrite( + os.path.join(output_path, + self.index_to_str(f) + '.jpg'), color) + + @staticmethod + def index_to_str(index): + return str(index).zfill(5) + + @staticmethod + def save_mat_to_file(matrix, filename): + with open(filename, 'w') as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt='%f') + + def export_poses(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + for f in range(len(self.frames)): + self.save_mat_to_file( + self.frames[f].camera_to_world, + os.path.join(output_path, + self.index_to_str(f) + '.txt')) + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + self.save_mat_to_file(self.intrinsic_color, + os.path.join(output_path, 'intrinsic.txt')) + + +def process_scene(path, limit, idx): + """Process single ScanNet scene. + + Extract RGB images, poses and camera intrinsics. + """ + data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit) + output_path = os.path.join('posed_images', idx) + data.export_color_images(output_path) + data.export_intrinsics(output_path) + data.export_poses(output_path) + + +def process_directory(path, limit, nproc): + print(f'processing {path}') + mmcv.track_parallel_progress( + func=partial(process_scene, path, limit), + tasks=os.listdir(path), + nproc=nproc) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--max-images-per-scene', type=int, default=300) + parser.add_argument('--nproc', type=int, default=8) + args = parser.parse_args() + + # process train and val scenes + if os.path.exists('scans'): + process_directory('scans', args.max_images_per_scene, args.nproc) + # process test scenes + if os.path.exists('scans_test'): + process_directory('scans_test', args.max_images_per_scene, args.nproc) diff --git a/data/scannet/load_scannet_data.py b/data/scannet/load_scannet_data.py index 80d417b..911bb4c 100644 --- a/data/scannet/load_scannet_data.py +++ b/data/scannet/load_scannet_data.py @@ -52,12 +52,31 @@ def read_segmentation(filename): return seg_to_verts, num_verts +def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id, + instance_ids): + num_instances = len(np.unique(list(object_id_to_segs.keys()))) + instance_bboxes = np.zeros((num_instances, 7)) + for obj_id in object_id_to_segs: + label_id = object_id_to_label_id[obj_id] + obj_pc = mesh_vertices[instance_ids == obj_id, 0:3] + if len(obj_pc) == 0: + continue + xyz_min = np.min(obj_pc, axis=0) + xyz_max = np.max(obj_pc, axis=0) + bbox = np.concatenate([(xyz_min + xyz_max) / 2.0, xyz_max - xyz_min, + np.array([label_id])]) + # NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES + instance_bboxes[obj_id - 1, :] = bbox + return instance_bboxes + + def export(mesh_file, agg_file, seg_file, meta_file, label_map_file, - output_file=None): + output_file=None, + test_mode=False): """Export original files to vert, ins_label, sem_label and bbox file. Args: @@ -68,6 +87,8 @@ def export(mesh_file, label_map_file (str): Path of the label_map_file. output_file (str): Path of the output folder. Default: None. + test_mode (bool): Whether is generating test data without labels. + Default: False. It returns a tuple, which containts the the following things: np.ndarray: Vertices of points data. @@ -83,6 +104,8 @@ def export(mesh_file, # Load scene axis alignment matrix lines = open(meta_file).readlines() + # test set data doesn't have align_matrix + axis_align_matrix = np.eye(4) for line in lines: if 'axisAlignment' in line: axis_align_matrix = [ @@ -91,56 +114,55 @@ def export(mesh_file, ] break axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) + + # perform global alignment of mesh vertices pts = np.ones((mesh_vertices.shape[0], 4)) pts[:, 0:3] = mesh_vertices[:, 0:3] pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4 - mesh_vertices[:, 0:3] = pts[:, 0:3] + aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]], + axis=1) # Load semantic and instance labels - object_id_to_segs, label_to_segs = read_aggregation(agg_file) - seg_to_verts, num_verts = read_segmentation(seg_file) - label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) - object_id_to_label_id = {} - for label, segs in label_to_segs.items(): - label_id = label_map[label] - for seg in segs: - verts = seg_to_verts[seg] - label_ids[verts] = label_id - instance_ids = np.zeros( - shape=(num_verts), dtype=np.uint32) # 0: unannotated - num_instances = len(np.unique(list(object_id_to_segs.keys()))) - for object_id, segs in object_id_to_segs.items(): - for seg in segs: - verts = seg_to_verts[seg] - instance_ids[verts] = object_id - if object_id not in object_id_to_label_id: - object_id_to_label_id[object_id] = label_ids[verts][0] - instance_bboxes = np.zeros((num_instances, 7)) - for obj_id in object_id_to_segs: - label_id = object_id_to_label_id[obj_id] - obj_pc = mesh_vertices[instance_ids == obj_id, 0:3] - if len(obj_pc) == 0: - continue - xmin = np.min(obj_pc[:, 0]) - ymin = np.min(obj_pc[:, 1]) - zmin = np.min(obj_pc[:, 2]) - xmax = np.max(obj_pc[:, 0]) - ymax = np.max(obj_pc[:, 1]) - zmax = np.max(obj_pc[:, 2]) - bbox = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2, - (zmin + zmax) / 2, xmax - xmin, ymax - ymin, - zmax - zmin, label_id]) - # NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES - instance_bboxes[obj_id - 1, :] = bbox + if not test_mode: + object_id_to_segs, label_to_segs = read_aggregation(agg_file) + seg_to_verts, num_verts = read_segmentation(seg_file) + label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) + object_id_to_label_id = {} + for label, segs in label_to_segs.items(): + label_id = label_map[label] + for seg in segs: + verts = seg_to_verts[seg] + label_ids[verts] = label_id + instance_ids = np.zeros( + shape=(num_verts), dtype=np.uint32) # 0: unannotated + for object_id, segs in object_id_to_segs.items(): + for seg in segs: + verts = seg_to_verts[seg] + instance_ids[verts] = object_id + if object_id not in object_id_to_label_id: + object_id_to_label_id[object_id] = label_ids[verts][0] + unaligned_bboxes = extract_bbox(mesh_vertices, object_id_to_segs, + object_id_to_label_id, instance_ids) + aligned_bboxes = extract_bbox(aligned_mesh_vertices, object_id_to_segs, + object_id_to_label_id, instance_ids) + else: + label_ids = None + instance_ids = None + unaligned_bboxes = None + aligned_bboxes = None + object_id_to_label_id = None if output_file is not None: np.save(output_file + '_vert.npy', mesh_vertices) - np.save(output_file + '_sem_label.npy', label_ids) - np.save(output_file + '_ins_label.npy', instance_ids) - np.save(output_file + '_bbox.npy', instance_bboxes) - - return mesh_vertices, label_ids, instance_ids,\ - instance_bboxes, object_id_to_label_id + if not test_mode: + np.save(output_file + '_sem_label.npy', label_ids) + np.save(output_file + '_ins_label.npy', instance_ids) + np.save(output_file + '_unaligned_bbox.npy', unaligned_bboxes) + np.save(output_file + '_aligned_bbox.npy', aligned_bboxes) + np.save(output_file + '_axis_align_matrix.npy', axis_align_matrix) + + return mesh_vertices, label_ids, instance_ids, unaligned_bboxes, \ + aligned_bboxes, object_id_to_label_id, axis_align_matrix def main(): diff --git a/mmdet3d/datasets/scannet_monocular_dataset.py b/mmdet3d/datasets/scannet_monocular_dataset.py index 48e730f..f00d9f8 100644 --- a/mmdet3d/datasets/scannet_monocular_dataset.py +++ b/mmdet3d/datasets/scannet_monocular_dataset.py @@ -16,21 +16,18 @@ class ScanNetMultiViewDataset(MultiViewMixin, Custom3DDataset): def get_data_info(self, index): info = self.data_infos[index] input_dict = defaultdict(list) - for i in range(len(info['image_paths'])): - img_filename = osp.join(self.data_root, info['image_paths'][i]) + axis_align_matrix = info['annos']['axis_align_matrix'].astype(np.float32) + for i in range(len(info['img_paths'])): + img_filename = osp.join(self.data_root, info['img_paths'][i]) input_dict['img_prefix'].append(None) input_dict['img_info'].append(dict(filename=img_filename)) - extrinsic = np.linalg.inv(info['axis_align_matrix'] @ info['pose'][i]) + extrinsic = np.linalg.inv(axis_align_matrix @ info['extrinsics'][i]) input_dict['lidar2img'].append(extrinsic.astype(np.float32)) input_dict = dict(input_dict) - # if info['annos']['gt_num'] != 0: - # origin = np.mean(info['annos']['gt_boxes_upright_depth'][:, :3], axis=0) - # else: - # origin = np.array([0, 0, 0]) origin = np.array([.0, .0, .5]) input_dict['lidar2img'] = dict( extrinsic=input_dict['lidar2img'], - intrinsic=info['intrinsic'].astype(np.float32), + intrinsic=info['intrinsics'].astype(np.float32), origin=origin.astype(np.float32) ) diff --git a/tools/create_data.py b/tools/create_data.py index bc67b0d..9d14369 100644 --- a/tools/create_data.py +++ b/tools/create_data.py @@ -114,7 +114,7 @@ def lyft_data_prep(root_path, root_path, info_val_path, version=version) -def scannet_data_prep(root_path, info_prefix, out_dir, workers, monocular): +def scannet_data_prep(root_path, info_prefix, out_dir, workers): """Prepare the info file for scannet dataset. Args: @@ -124,7 +124,7 @@ def scannet_data_prep(root_path, info_prefix, out_dir, workers, monocular): workers (int): Number of threads to be used. """ indoor.create_indoor_info_file( - root_path, info_prefix, out_dir, workers=workers, monocular=monocular) + root_path, info_prefix, out_dir, workers=workers) def sunrgbd_data_prep(root_path, info_prefix, out_dir, workers, monocular): @@ -207,7 +207,7 @@ def waymo_data_prep(root_path, '--out-dir', type=str, default='./data/kitti', - required='False', + required=False, help='name of info pkl') parser.add_argument('--extra-tag', type=str, default='kitti') parser.add_argument( @@ -281,16 +281,7 @@ def waymo_data_prep(root_path, root_path=args.root_path, info_prefix=args.extra_tag, out_dir=args.out_dir, - workers=args.workers, - monocular=False - ) - elif args.dataset == 'scannet_monocular': - scannet_data_prep( - root_path=args.root_path, - info_prefix=args.extra_tag, - out_dir=args.out_dir, - workers=args.workers, - monocular=True + workers=args.workers ) elif args.dataset == 'sunrgbd': sunrgbd_data_prep( diff --git a/tools/data_converter/indoor_converter.py b/tools/data_converter/indoor_converter.py index 4c5ce24..986f5d9 100644 --- a/tools/data_converter/indoor_converter.py +++ b/tools/data_converter/indoor_converter.py @@ -1,7 +1,7 @@ import mmcv import os -from tools.data_converter.scannet_data_utils import ScanNetData, ScanNetMonocularData +from tools.data_converter.scannet_data_utils import ScanNetData from tools.data_converter.sunrgbd_data_utils import SUNRGBDData @@ -35,7 +35,7 @@ def create_indoor_info_file(data_path, val_dataset = SUNRGBDData( root_path=data_path, split='val', use_v1=use_v1, monocular=monocular) else: - dataset = ScanNetMonocularData if monocular else ScanNetData + dataset = ScanNetData train_dataset = dataset(root_path=data_path, split='train') val_dataset = dataset(root_path=data_path, split='val') diff --git a/tools/data_converter/scannet_data_utils.py b/tools/data_converter/scannet_data_utils.py index 68fef04..55a1010 100644 --- a/tools/data_converter/scannet_data_utils.py +++ b/tools/data_converter/scannet_data_utils.py @@ -1,15 +1,13 @@ -import os import mmcv import numpy as np +import os from concurrent import futures as futures from os import path as osp class ScanNetData(object): """ScanNet data. - Generate scannet infos for scannet_converter. - Args: root_path (str): Root path of the raw data. split (str): Set split type of the data. Default: 'train'. @@ -38,27 +36,59 @@ def __init__(self, root_path, split='train'): f'scannetv2_{split}.txt') mmcv.check_file_exist(split_file) self.sample_id_list = mmcv.list_from_file(split_file) + self.test_mode = (split == 'test') def __len__(self): return len(self.sample_id_list) - def get_box_label(self, idx): - box_file = osp.join(self.root_dir, 'scannet_train_instance_data', - f'{idx}_bbox.npy') + def get_aligned_box_label(self, idx): + box_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_aligned_bbox.npy') mmcv.check_file_exist(box_file) return np.load(box_file) + def get_unaligned_box_label(self, idx): + box_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_unaligned_bbox.npy') + mmcv.check_file_exist(box_file) + return np.load(box_file) + + def get_axis_align_matrix(self, idx): + matrix_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_axis_align_matrix.npy') + mmcv.check_file_exist(matrix_file) + return np.load(matrix_file) + + def get_images(self, idx): + paths = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.jpg'): + paths.append(osp.join('posed_images', idx, file)) + return paths + + def get_extrinsics(self, idx): + extrinsics = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.txt') and not file == 'intrinsic.txt': + extrinsics.append(np.loadtxt(osp.join(path, file))) + return extrinsics + + def get_intrinsics(self, idx): + matrix_file = osp.join(self.root_dir, 'posed_images', idx, + 'intrinsic.txt') + mmcv.check_file_exist(matrix_file) + return np.loadtxt(matrix_file) + def get_infos(self, num_workers=4, has_label=True, sample_id_list=None): """Get data infos. - This method gets information from the raw data. - Args: num_workers (int): Number of threads to be used. Default: 4. has_label (bool): Whether the data has label. Default: True. sample_id_list (list[int]): Index list of the sample. Default: None. - Returns: infos (list[dict]): Information of the raw data. """ @@ -68,58 +98,87 @@ def process_single_scene(sample_idx): info = dict() pc_info = {'num_features': 6, 'lidar_idx': sample_idx} info['point_cloud'] = pc_info - pts_filename = osp.join(self.root_dir, - 'scannet_train_instance_data', + pts_filename = osp.join(self.root_dir, 'scannet_instance_data', f'{sample_idx}_vert.npy') - pts_instance_mask_path = osp.join(self.root_dir, - 'scannet_train_instance_data', - f'{sample_idx}_ins_label.npy') - pts_semantic_mask_path = osp.join(self.root_dir, - 'scannet_train_instance_data', - f'{sample_idx}_sem_label.npy') - points = np.load(pts_filename) - pts_instance_mask = np.load(pts_instance_mask_path).astype(np.long) - pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.long) - mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points')) - mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask')) - mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask')) - points.tofile( osp.join(self.root_dir, 'points', f'{sample_idx}.bin')) - pts_instance_mask.tofile( - osp.join(self.root_dir, 'instance_mask', f'{sample_idx}.bin')) - pts_semantic_mask.tofile( - osp.join(self.root_dir, 'semantic_mask', f'{sample_idx}.bin')) - info['pts_path'] = osp.join('points', f'{sample_idx}.bin') - info['pts_instance_mask_path'] = osp.join('instance_mask', - f'{sample_idx}.bin') - info['pts_semantic_mask_path'] = osp.join('semantic_mask', - f'{sample_idx}.bin') + + # update with RGB image paths if exist + if os.path.exists(osp.join(self.root_dir, 'posed_images')): + info['intrinsics'] = self.get_intrinsics(sample_idx) + all_extrinsics = self.get_extrinsics(sample_idx) + all_img_paths = self.get_images(sample_idx) + # some poses in ScanNet are invalid + extrinsics, img_paths = [], [] + for extrinsic, img_path in zip(all_extrinsics, all_img_paths): + if np.all(np.isfinite(extrinsic)): + img_paths.append(img_path) + extrinsics.append(extrinsic) + info['extrinsics'] = extrinsics + info['img_paths'] = img_paths + + if not self.test_mode: + pts_instance_mask_path = osp.join( + self.root_dir, 'scannet_instance_data', + f'{sample_idx}_ins_label.npy') + pts_semantic_mask_path = osp.join( + self.root_dir, 'scannet_instance_data', + f'{sample_idx}_sem_label.npy') + + pts_instance_mask = np.load(pts_instance_mask_path).astype( + np.long) + pts_semantic_mask = np.load(pts_semantic_mask_path).astype( + np.long) + + mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask')) + mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask')) + + pts_instance_mask.tofile( + osp.join(self.root_dir, 'instance_mask', + f'{sample_idx}.bin')) + pts_semantic_mask.tofile( + osp.join(self.root_dir, 'semantic_mask', + f'{sample_idx}.bin')) + + info['pts_instance_mask_path'] = osp.join( + 'instance_mask', f'{sample_idx}.bin') + info['pts_semantic_mask_path'] = osp.join( + 'semantic_mask', f'{sample_idx}.bin') if has_label: annotations = {} - boxes_with_classes = self.get_box_label( - sample_idx) # k, 6 + class - annotations['gt_num'] = boxes_with_classes.shape[0] + # box is of shape [k, 6 + class] + aligned_box_label = self.get_aligned_box_label(sample_idx) + unaligned_box_label = self.get_unaligned_box_label(sample_idx) + annotations['gt_num'] = aligned_box_label.shape[0] if annotations['gt_num'] != 0: - minmax_boxes3d = boxes_with_classes[:, :-1] # k, 6 - classes = boxes_with_classes[:, -1] # k, 1 + aligned_box = aligned_box_label[:, :-1] # k, 6 + unaligned_box = unaligned_box_label[:, :-1] + classes = aligned_box_label[:, -1] # k annotations['name'] = np.array([ self.label2cat[self.cat_ids2class[classes[i]]] for i in range(annotations['gt_num']) ]) - annotations['location'] = minmax_boxes3d[:, :3] - annotations['dimensions'] = minmax_boxes3d[:, 3:6] - annotations['gt_boxes_upright_depth'] = minmax_boxes3d + # default names are given to aligned bbox for compatibility + # we also save unaligned bbox info with marked names + annotations['location'] = aligned_box[:, :3] + annotations['dimensions'] = aligned_box[:, 3:6] + annotations['gt_boxes_upright_depth'] = aligned_box + annotations['unaligned_location'] = unaligned_box[:, :3] + annotations['unaligned_dimensions'] = unaligned_box[:, 3:6] + annotations[ + 'unaligned_gt_boxes_upright_depth'] = unaligned_box annotations['index'] = np.arange( annotations['gt_num'], dtype=np.int32) annotations['class'] = np.array([ self.cat_ids2class[classes[i]] for i in range(annotations['gt_num']) ]) + axis_align_matrix = self.get_axis_align_matrix(sample_idx) + annotations['axis_align_matrix'] = axis_align_matrix # 4x4 info['annos'] = annotations return info @@ -128,56 +187,3 @@ def process_single_scene(sample_idx): with futures.ThreadPoolExecutor(num_workers) as executor: infos = executor.map(process_single_scene, sample_id_list) return list(infos) - - -class ScanNetMonocularData(ScanNetData): - def process_single_scene(self, sample_idx, has_label): - info = dict(image_paths=[], pose=[]) - - with open(os.path.join(self.root_dir, 'txts', f'{sample_idx}.txt')) as file: - for line in file.readlines(): - splits = line.split(' = ') - if splits[0] == 'axisAlignment': - axis_align_matrix = np.fromstring(splits[1], sep=' ').reshape(4, 4) - break - info['axis_align_matrix'] = axis_align_matrix - - frame_sub_path = f'sens_reader/scans/{sample_idx}/out' - frame_path = osp.join(self.root_dir, frame_sub_path) - base_file_names = {x.split('.')[0] for x in os.listdir(frame_path)} - base_file_names.remove('_info') - for base_file_name in base_file_names: - pose = np.loadtxt(osp.join(frame_path, f'{base_file_name}.pose.txt')) - if np.all(np.isfinite(pose)): - info['image_paths'].append(osp.join(frame_sub_path, f'{base_file_name}.color.jpg')) - info['pose'].append(pose) - - with open(osp.join(frame_path, '_info.txt')) as file: - splits = file.readlines()[7].split(' = ') - assert splits[0] == 'm_calibrationColorIntrinsic' - info['intrinsic'] = np.fromstring(splits[1], sep=' ').reshape(4, 4) - - if has_label: - annotations = {} - bbox_path = osp.join(self.root_dir, 'mmdetection3d', f'{sample_idx}_bbox.npy') - boxes_with_classes = np.load(bbox_path) - annotations['gt_num'] = boxes_with_classes.shape[0] - if annotations['gt_num'] != 0: - minmax_boxes3d = boxes_with_classes[:, :-1] # k, 6 - classes = boxes_with_classes[:, -1] # k, 1 - annotations['gt_boxes_upright_depth'] = minmax_boxes3d - annotations['class'] = np.array([ - self.cat_ids2class[classes[i]] - for i in range(annotations['gt_num']) - ]) - info['annos'] = annotations - return info - - def get_infos(self, num_workers=4, has_label=True, sample_id_list=None): - sample_id_list = sample_id_list if sample_id_list is not None \ - else self.sample_id_list - infos = [] - for i, sample_idx in enumerate(sample_id_list): - print(f'{self.split} sample_idx: {sample_idx} {i}/{len(sample_id_list)}') - infos.append(self.process_single_scene(sample_idx, has_label)) - return infos diff --git a/tools/dist_test.py b/tools/dist_test.py index b4180d9..73723e9 100644 --- a/tools/dist_test.py +++ b/tools/dist_test.py @@ -1,3 +1,3 @@ import os -os.system('bash tools/dist_test.sh configs/imvoxelnet/imvoxelnet_kitti.py work_dirs/atlas_kitti/20210503_214214.pth 2 --eval mAP') +os.system('bash tools/dist_test.sh configs/imvoxelnet/imvoxelnet_kitti.py work_dirs/imvoxelnet_kitti/20210503_214214.pth 2 --eval mAP') diff --git a/tools/dist_train.py b/tools/dist_train.py index 32479c5..d9049ff 100644 --- a/tools/dist_train.py +++ b/tools/dist_train.py @@ -1,3 +1,3 @@ import os -os.system('bash tools/dist_train.sh configs/imvoxelnet/imvoxelnet_scannet.py 2') \ No newline at end of file +os.system('bash tools/dist_train.sh configs/imvoxelnet/imvoxelnet_kitti.py 2')