Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support RGB images on ScanNet for multi-view detector #696

Merged
merged 9 commits into from
Jul 20, 2021
10 changes: 9 additions & 1 deletion data/scannet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ We follow the procedure in [votenet](https://github.com/facebookresearch/votenet

2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points.

3. Enter the project root directory, generate training data by running
3. In this directory, extract RGB image with poses by running `python extract_posed_images.py`. This step is optional. Skip it if you don't plan to use multi-view RGB images. Add `--max-images-per-scene -1` to disable limiting number of images per scene. ScanNet scenes contain up to 5000+ frames per each. After extraction, all the .jpg images require 2 Tb disk space. The recommended 300 images per scene require less then 100 Gb. For example multi-view 3d detector ImVoxelNet samples 50 and 100 images per training and test scene.

4. Enter the project root directory, generate training data by running

```bash
python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet
Expand All @@ -16,6 +18,7 @@ The overall process could be achieved through the following script

```bash
python batch_load_scannet_data.py
python extract_posed_images.py
cd ../..
python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet
```
Expand Down Expand Up @@ -43,6 +46,11 @@ scannet
│ ├── train_resampled_scene_idxs.npy
│ ├── val_label_weight.npy
│ ├── val_resampled_scene_idxs.npy
├── posed_images
│ ├── scenexxxx_xx
│ │ ├── xxxxxx.txt
│ │ ├── xxxxxx.jpg
│ │ ├── intrinsic.txt
├── scannet_infos_train.pkl
├── scannet_infos_val.pkl
├── scannet_infos_test.pkl
Expand Down
179 changes: 179 additions & 0 deletions data/scannet/extract_posed_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/SensorData.py # noqa
import imageio
import mmcv
import numpy as np
import os
import struct
import zlib
from argparse import ArgumentParser
from functools import partial

COMPRESSION_TYPE_COLOR = {-1: 'unknown', 0: 'raw', 1: 'png', 2: 'jpeg'}

COMPRESSION_TYPE_DEPTH = {
-1: 'unknown',
0: 'raw_ushort',
1: 'zlib_ushort',
2: 'occi_ushort'
}


class RGBDFrame:
"""Class for single ScanNet RGB-D image processing."""

def load(self, file_handle):
self.camera_to_world = np.asarray(
struct.unpack('f' * 16, file_handle.read(16 * 4)),
dtype=np.float32).reshape(4, 4)
self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0]
self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0]
self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
self.color_data = b''.join(
struct.unpack('c' * self.color_size_bytes,
file_handle.read(self.color_size_bytes)))
self.depth_data = b''.join(
struct.unpack('c' * self.depth_size_bytes,
file_handle.read(self.depth_size_bytes)))

def decompress_depth(self, compression_type):
assert compression_type == 'zlib_ushort'
return zlib.decompress(self.depth_data)

def decompress_color(self, compression_type):
assert compression_type == 'jpeg'
return imageio.imread(self.color_data)


class SensorData:
"""Class for single ScanNet scene processing.

Single scene file contains multiple RGB-D images.
"""

def __init__(self, filename, limit):
self.version = 4
self.load(filename, limit)

def load(self, filename, limit):
with open(filename, 'rb') as f:
version = struct.unpack('I', f.read(4))[0]
assert self.version == version
strlen = struct.unpack('Q', f.read(8))[0]
self.sensor_name = b''.join(
struct.unpack('c' * strlen, f.read(strlen)))
self.intrinsic_color = np.asarray(
struct.unpack('f' * 16, f.read(16 * 4)),
dtype=np.float32).reshape(4, 4)
self.extrinsic_color = np.asarray(
struct.unpack('f' * 16, f.read(16 * 4)),
dtype=np.float32).reshape(4, 4)
self.intrinsic_depth = np.asarray(
struct.unpack('f' * 16, f.read(16 * 4)),
dtype=np.float32).reshape(4, 4)
self.extrinsic_depth = np.asarray(
struct.unpack('f' * 16, f.read(16 * 4)),
dtype=np.float32).reshape(4, 4)
self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack(
'i', f.read(4))[0]]
self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack(
'i', f.read(4))[0]]
self.color_width = struct.unpack('I', f.read(4))[0]
self.color_height = struct.unpack('I', f.read(4))[0]
self.depth_width = struct.unpack('I', f.read(4))[0]
self.depth_height = struct.unpack('I', f.read(4))[0]
self.depth_shift = struct.unpack('f', f.read(4))[0]
num_frames = struct.unpack('Q', f.read(8))[0]
self.frames = []
if limit > 0 and limit < num_frames:
index = np.random.choice(
np.arange(num_frames), limit, replace=False).tolist()
else:
index = list(range(num_frames))
for i in range(num_frames):
frame = RGBDFrame()
frame.load(f)
if i in index:
self.frames.append(frame)

def export_depth_images(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
for f in range(len(self.frames)):
depth_data = self.frames[f].decompress_depth(
self.depth_compression_type)
depth = np.fromstring(
depth_data, dtype=np.uint16).reshape(self.depth_height,
self.depth_width)
imageio.imwrite(
os.path.join(output_path,
self.index_to_str(f) + '.png'), depth)

def export_color_images(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
for f in range(len(self.frames)):
color = self.frames[f].decompress_color(
self.color_compression_type)
imageio.imwrite(
os.path.join(output_path,
self.index_to_str(f) + '.jpg'), color)

@staticmethod
def index_to_str(index):
return str(index).zfill(5)

@staticmethod
def save_mat_to_file(matrix, filename):
with open(filename, 'w') as f:
for line in matrix:
np.savetxt(f, line[np.newaxis], fmt='%f')

def export_poses(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
for f in range(len(self.frames)):
self.save_mat_to_file(
self.frames[f].camera_to_world,
os.path.join(output_path,
self.index_to_str(f) + '.txt'))

def export_intrinsics(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
self.save_mat_to_file(self.intrinsic_color,
os.path.join(output_path, 'intrinsic.txt'))


def process_scene(path, limit, idx):
"""Process single ScanNet scene.

Extract RGB images, poses and camera intrinsics.
"""
data = SensorData(os.path.join(path, idx, f'{idx}.sens'), limit)
output_path = os.path.join('posed_images', idx)
data.export_color_images(output_path)
data.export_intrinsics(output_path)
data.export_poses(output_path)


def process_directory(path, limit, nproc):
print(f'processing {path}')
mmcv.track_parallel_progress(
func=partial(process_scene, path, limit),
tasks=os.listdir(path),
nproc=nproc)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--max-images-per-scene', type=int, default=300)
parser.add_argument('--nproc', type=int, default=8)
args = parser.parse_args()

# process train and val scenes
if os.path.exists('scans'):
process_directory('scans', args.max_images_per_scene, args.nproc)
# process test scenes
if os.path.exists('scans_test'):
process_directory('scans_test', args.max_images_per_scene, args.nproc)
12 changes: 12 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

This document provides detailed descriptions of the BC-breaking changes in MMDetection3D.

## MMDetection3D 0.16.0

### ScanNet dataset for ImVoxelNet

We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection and segmentation methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes:

- Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data.
- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating.
- Add instructions in the documents on preparing image data.

Please refer to the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) for more details.

## MMDetection3D 0.15.0

### MMCV Version
Expand Down
20 changes: 18 additions & 2 deletions docs/datasets/scannet_det.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

For the overall process, please refer to the [README](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) page for ScanNet.

### Export ScanNet data
### Export ScanNet point cloud data

By exporting ScanNet data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes.
By exporting ScanNet point cloud data, we load the raw point cloud data and generate the relevant annotations including semantic label, instance label and ground truth bounding boxes.

```shell
python batch_load_scannet_data.py
Expand Down Expand Up @@ -127,6 +127,16 @@ def export(mesh_file,

After exporting each scan, the raw point cloud could be downsampled, e.g. to 50000, if the number of points is too large. In addition, invalid semantic labels outside of `nyu40id` standard or optional `DONOT CARE` classes should be filtered. Finally, the point cloud data, semantic labels, instance labels and ground truth bounding boxes should be saved in `.npy` files.

### Export ScanNet RGB data

By exporting ScanNet RGB data, for each scene we load a set of RGB images with corresponding 4x4 pose matrices, and a single 4x4 camera intrinsic matrix. Note, that this step is optional and can be skipped if multi-view detection is not planned to use.

```shell
python extract_posed_images.py
```

Each of 1201 train, 312 validation and 100 test scenes contains a single `.sens` file. For instance, for scene `0001_01` we have `data/scannet/scans/scene0001_01/0001_01.sens`. For this scene all images and poses are extracted to `data/scannet/posed_images/scene0001_01`. Specifically, there will be 300 image files xxxxx.jpg, 300 camera pose files xxxxx.txt and a single `intrinsic.txt` file. Typically, single scene contains several thousand images. By default, we extract only 300 of them with resulting weight of <100 Gb. To extract more images, use `--max-images-per-scene` parameter.

### Create dataset

```shell
Expand Down Expand Up @@ -201,6 +211,11 @@ scannet
│ ├── train_resampled_scene_idxs.npy
│ ├── val_label_weight.npy
│ ├── val_resampled_scene_idxs.npy
├── posed_images
│ ├── scenexxxx_xx
│ │ ├── xxxxxx.txt
│ │ ├── xxxxxx.jpg
│ │ ├── intrinsic.txt
├── scannet_infos_train.pkl
├── scannet_infos_val.pkl
├── scannet_infos_test.pkl
Expand All @@ -209,6 +224,7 @@ scannet
- `points/xxxxx.bin`: The `axis-unaligned` point cloud data after downsample. Note: the point would be axis-aligned in pre-processing `GlobalAlignment` of 3d detection task.
- `instance_mask/xxxxx.bin`: The instance label for each point, value range: [0, NUM_INSTANCES], 0: unannotated.
- `semantic_mask/xxxxx.bin`: The semantic label for each point, value range: [1, 40], i.e. `nyu40id` standard. Note: the `nyu40id` id will be mapped to train id in train pipeline `PointSegClassMapping`.
- `posed_images/scenexxxx_xx`: The set of `.jpg` images with `.txt` 4x4 poses and the single `.txt` file with camera intrinsic matrix.
- `scannet_infos_train.pkl`: The train data infos, the detailed info of each scan is as follows:
- info['point_cloud']: {'num_features': 6, 'lidar_idx': sample_idx}.
- info['pts_path']: The path of `points/xxxxx.bin`.
Expand Down
54 changes: 53 additions & 1 deletion mmdet3d/datasets/scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,
ann_file,
pipeline=None,
classes=None,
modality=None,
modality=dict(use_camera=False, use_depth=True),
box_type_3d='Depth',
filter_empty_gt=True,
test_mode=False):
Expand All @@ -66,6 +66,58 @@ def __init__(self,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)
assert 'use_camera' in self.modality and \
'use_depth' in self.modality
assert self.modality['use_camera'] or self.modality['use_depth']

def get_data_info(self, index):
"""Get data info according to the given index.

Args:
index (int): Index of the sample data to get.

Returns:
dict: Data information that will be passed to the data \
preprocessing pipelines. It includes the following keys:

- sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- img_prefix (str | None, optional): Prefix of image files.
- img_info (dict, optional): Image info.
- ann_info (dict): Annotation info.
"""
info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx']
pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict(sample_idx=sample_idx)

if self.modality['use_depth']:
input_dict['pts_filename'] = pts_filename
input_dict['file_name'] = pts_filename

if self.modality['use_camera']:
img_info = []
for img_path in info['img_paths']:
img_info.append(
dict(filename=osp.join(self.data_root, img_path)))
intrinsic = info['intrinsics']
axis_align_matrix = self._get_axis_align_matrix(info)
depth2img = []
for extrinsic in info['extrinsics']:
depth2img.append(
intrinsic @ np.linalg.inv(axis_align_matrix @ extrinsic))

input_dict['img_prefix'] = None
input_dict['img_info'] = img_info
input_dict['depth2img'] = depth2img

if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
return None
return input_dict

def get_ann_info(self, index):
"""Get annotation info according to the given index.
Expand Down
Loading