Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pandaset class #610

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ml3d/configs/randlanet_pandaset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
dataset:
name: Pandaset
dataset_path: path_to_dataset
cache_dir: ./logs/cache
test_result_folder: './logs/test'
training_split: [ '001', '002', '003', '005', '011', '013', '015', '016',
'017', '019', '021', '023', '024', '027', '028', '029',
'030', '032', '033', '034', '035', '037', '038', '039',
'040', '041', '042', '043', '044', '046', '052', '053',
'054', '056', '057', '058', '064', '065', '066', '067',
'070', '071', '072', '073', '077', '078', '080', '084',
'088', '089', '090', '094', '095', '097', '098', '101',
'102', '103', '105', '106', '109', '110', '112', '113'
]
test_split: ['115', '116', '117', '119', '120', '124', '139', '149', '158']
validation_split: ['122', '123']
use_cache: true
sampler:
name: 'SemSegRandomSampler'
model:
name: RandLANet
batcher: DefaultBatcher
num_classes: 39
num_points: 81920
num_neighbors: 16
framework: torch
num_layers: 4
ignored_label_inds: [0]
sub_sampling_ratio: [4, 4, 4, 4]
in_channels: 3
dim_features: 8
dim_output: [16, 64, 128, 256]
grid_size: 0.06
pipeline:
name: SemanticSegmentation
max_epoch: 50
save_ckpt_freq: 5
device: gpu
optimizer:
lr: 0.001
batch_size: 4
main_log_dir: './logs'
logs_dir: './logs'
scheduler_gamma: 0.9886
test_batch_size: 2
train_sum_dir: './logs/training_log'
val_batch_size: 2
272 changes: 272 additions & 0 deletions ml3d/datasets/pandaset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import os
from os.path import join
import numpy as np
import pandas as pd
from pathlib import Path
import logging

from .base_dataset import BaseDataset, BaseDatasetSplit
from ..utils import make_dir, DATASET

log = logging.getLogger(__name__)

class Pandaset(BaseDataset):
""" This class is used to create a dataset based on the Pandaset autonomous
driving dataset.

https://pandaset.org/

The dataset includes 42 semantic classes and covers more than 100 scenes,
each of which is 8 seconds long.

"""
def __init__(self,
dataset_path,
name="Pandaset",
cache_dir="./logs/cache",
use_cache=False,
ignored_label_inds=[],
test_result_folder='./logs/test_log',
test_split=['115', '116', '117', '119', '120', '124', '139', '149', '158'],
training_split=[
'001', '002', '003', '005', '011', '013', '015', '016',
'017', '019', '021', '023', '024', '027', '028', '029',
'030', '032', '033', '034', '035', '037', '038', '039',
'040', '041', '042', '043', '044', '046', '052', '053',
'054', '056', '057', '058', '064', '065', '066', '067',
'070', '071', '072', '073', '077', '078', '080', '084',
'088', '089', '090', '094', '095', '097', '098', '101',
'102', '103', '105', '106', '109', '110', '112', '113'
],
validation_split=['122', '123'],
all_split=['001', '002', '003', '005', '011', '013', '015', '016',
'017', '019', '021', '023', '024', '027', '028', '029',
'030', '032', '033', '034', '035', '037', '038', '039',
'040', '041', '042', '043', '044', '046', '052', '053',
'054', '056', '057', '058', '064', '065', '066', '067',
'069', '070', '071', '072', '073', '077', '078', '080',
'084', '088', '089', '090', '094', '095', '097', '098',
'101', '102', '103', '105', '106', '109', '110', '112',
'113', '115', '116', '117', '119', '120', '122', '123',
'124', '139', '149', '158'],
**kwargs):

"""Initialize the function by passing the dataset and other details.

Args:
dataset_path: The path to the dataset to use.
name: The name of the dataset.
cache_dir: The directory where the cache is stored.
use_cache: Indicates if the dataset should be cached.
ignored_label_inds: A list of labels that should be ignored in the dataset.
Returns:
class: The corresponding class.
"""
super().__init__(dataset_path=dataset_path,
name=name,
cache_dir=cache_dir,
use_cache=use_cache,
ignored_label_inds=ignored_label_inds,
test_result_folder=test_result_folder,
test_split=test_split,
training_split=training_split,
validation_split=validation_split,
all_split=all_split,
**kwargs)

cfg = self.cfg

self.label_to_names = self.get_label_to_names()
self.num_classes = len(self.label_to_names)
self.label_values = np.sort([k for k, v in self.label_to_names.items()])

@staticmethod
def get_label_to_names():
"""Returns a label to names dictionary object.

Returns:
A dict where keys are label numbers and
values are the corresponding names.
"""
label_to_names = {
1: "Reflection",
2: "Vegetation",
3: "Ground",
4: "Road",
5: "Lane Line Marking",
6: "Stop Line Marking",
7: "Other Road Marking",
8: "Sidewalk",
9: "Driveway",
10: "Car",
11: "Pickup Truck",
12: "Medium-sized Truck",
13: "Semi-truck",
14: "Towed Object",
15: "Motorcycle",
16: "Other Vehicle - Construction Vehicle",
17: "Other Vehicle - Uncommon",
18: "Other Vehicle - Pedicab",
19: "Emergency Vehicle",
20: "Bus",
21: "Personal Mobility Device",
22: "Motorized Scooter",
23: "Bicycle",
24: "Train",
25: "Trolley",
26: "Tram / Subway",
27: "Pedestrian",
28: "Pedestrian with Object",
29: "Animals - Bird",
30: "Animals - Other",
31: "Pylons",
32: "Road Barriers",
33: "Signs",
34: "Cones",
35: "Construction Signs",
36: "Temporary Construction Barriers",
37: "Rolling Containers",
38: "Building",
39: "Other Static Object"
}
return label_to_names

def get_split(self, split):
"""Returns a dataset split.

Args:
split: A string identifying the dataset split that is usually one of
'training', 'test', or 'all'.

Returns:
A dataset split object providing the requested subset of the data.
"""
return PandasetSplit(self, split=split)

def get_split_list(self, split):
"""Returns the list of data splits available.

Args:
split: A string identifying the dataset split that is usually one of
'training', 'test', 'validation', or 'all'.

Returns:
A dataset split object providing the requested subset of the data.

Raises:
ValueError: Indicates that the split name passed is incorrect. The split name should be one of
'training', 'test', 'validation', or 'all'.
"""
cfg = self.cfg
dataset_path = cfg.dataset_path
file_list = []

if split in ['train', 'training']:
seq_list = cfg.training_split
elif split in ['test', 'testing']:
seq_list = cfg.test_split
elif split in ['val', 'validation']:
seq_list = cfg.validation_split
elif split in ['all']:
seq_list = cfg.all_split
else:
raise ValueError("Invalid split {}".format(split))

for seq_id in seq_list:
pc_path = join(dataset_path, seq_id, 'lidar')
for f in np.sort(os.listdir(pc_path)):
if f.split('.')[-1] == 'gz':
file_list.append(join(pc_path, f))

return file_list

def is_tested(self, attr):
"""Checks if a datum in the dataset has been tested.

Args:
dataset: The current dataset to which the datum belongs to.
attr: The attribute that needs to be checked.

Returns:
If the dataum attribute is tested, then return the path where the
attribute is stored; else, returns false.
"""
pass

def save_test_result(self, results, attr):
"""Saves the output of a model.

Args:
results: The output of a model for the datum associated with the
attribute passed.
attrs: The attributes that correspond to the outputs passed in
results.
"""
cfg = self.cfg
pred = results['predict_labels']
name = attr['name']

test_path = join(cfg.test_result_folder, 'sequences')
make_dir(test_path)
save_path = join(test_path, name, 'predictions')
make_dir(save_path)
pred = results['predict_labels']

for ign in cfg.ignored_label_inds:
pred[pred >= ign] += 1

store_path = join(save_path, name + '.label')

pred = pred.astype(np.uint32)
pred.tofile(store_path)


class PandasetSplit(BaseDatasetSplit):
"""This class is used to create a split for Pandaset dataset.

Args:
dataset: The dataset to split.
split: A string identifying the dataset split that is usually one of
'training', 'test', 'validation', or 'all'.
**kwargs: The configuration of the model as keyword arguments.

Returns:
A dataset split object providing the requested subset of the data.
"""
def __init__(self, dataset, split='train'):
super().__init__(dataset, split=split)
log.info("Found {} pointclouds for {}".format(len(self.path_list),
split))

def __len__(self):
return len(self.path_list)

def get_data(self, idx):
pc_path = self.path_list[idx]
label_path = pc_path.replace('lidar', 'annotations/semseg')

points = pd.read_pickle(pc_path)
labels = pd.read_pickle(label_path)

intensity = points['i'].to_numpy().astype(np.float32)
points = points.drop(columns=['i', 't', 'd']).to_numpy().astype(np.float32)
labels = labels.to_numpy().astype(np.int32)

data = {
'point': points,
'intensity': intensity,
'label': labels
}

return data

def get_attr(self, idx):
pc_path = self.path_list[idx]
value = (pc_path).split('/')[9]
name = Path(pc_path).name.split('.')[0]
name = value + '_' + name

attr = {'name': name, 'path': pc_path, 'split': self.split}
return attr

DATASET._register_module(Pandaset)
6 changes: 2 additions & 4 deletions ml3d/torch/models/kpconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def inference_preprocess(self):

return inputs

def update_probs(self, inputs, results, test_probs, test_labels):
def update_probs(self, inputs, results, test_probs):
self.test_smooth = 0.95
stk_probs = torch.nn.functional.softmax(results, dim=-1)
stk_probs = stk_probs.cpu().data.numpy()
Expand All @@ -577,16 +577,14 @@ def update_probs(self, inputs, results, test_probs, test_labels):
for b_i, length in enumerate(lengths):
# Get prediction
probs = stk_probs[i0:i0 + length]
labels = np.argmax(probs, 1)

proj_inds = r_inds_list[b_i]
proj_mask = r_mask_list[b_i]
test_probs[proj_mask] = self.test_smooth * test_probs[proj_mask] + (
1 - self.test_smooth) * probs
test_labels[proj_mask] = labels
i0 += length

return test_probs, test_labels
return test_probs

def inference_end(self, inputs, results):
m_softmax = torch.nn.Softmax(dim=-1)
Expand Down
5 changes: 2 additions & 3 deletions ml3d/torch/models/point_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,13 @@ def transform(self, data, attr):

return data

def update_probs(self, inputs, results, test_probs, test_labels):
def update_probs(self, inputs, results, test_probs):
result = results.reshape(-1, self.cfg.num_classes)
probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy()
labels = np.argmax(probs, 1)

self.trans_point_sampler(patchwise=False)

return probs, labels
return probs

def inference_begin(self):
data = self.preprocess(data, {'split': 'test'})
Expand Down
5 changes: 2 additions & 3 deletions ml3d/torch/models/pvcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,13 @@ def transform(self, data, attr):

return data

def update_probs(self, inputs, results, test_probs, test_labels):
def update_probs(self, inputs, results, test_probs):
result = results.reshape(-1, self.cfg.num_classes)
probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy()
labels = np.argmax(probs, 1)

self.trans_point_sampler(patchwise=False)

return probs, labels
return probs

def inference_begin(self, data):
data = self.preprocess(data, {'split': 'test'})
Expand Down
Loading