From ffe31b5730fc25ad3b94a475608d0cc628d67a98 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 23 Jul 2021 12:05:52 +0100 Subject: [PATCH] Temporary fix for RTD build (#605) * Try something * Try something * Try something * Try something * Try something * Try something * Try something * Try something * Add few more paths * Test * Drop * Add back, remove requires * Remove * task * temp * test * test * test * ttempt * Format code with autopep8 * attempt * attempt * temp * Format code with autopep8 * Fix a few * Format code with autopep8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Try fix * Try fix * Try fix * Try something * Try something * Try something * Try something * Cleaning * Fixes * Remove CI addition Co-authored-by: SeanNaren Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/api/audio.rst | 2 +- docs/source/api/pointcloud.rst | 16 - .../reference/pointcloud_segmentation.rst | 6 +- flash/audio/speech_recognition/model.py | 2 +- flash/core/utilities/imports.py | 2 +- flash/pointcloud/detection/data.py | 3 +- flash/pointcloud/detection/model.py | 4 +- .../detection/open3d_ml/data_sources.py | 10 +- flash/pointcloud/segmentation/data.py | 9 +- flash/pointcloud/segmentation/model.py | 2 +- .../pointcloud/segmentation/open3d_ml/app.py | 139 ++++---- .../open3d_ml/sequences_dataset.py | 303 +++++++++--------- requirements/datatype_pointcloud.txt | 2 +- 13 files changed, 244 insertions(+), 256 deletions(-) diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index 706a364372..ae6455c6d8 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -28,8 +28,8 @@ __________________ :nosignatures: :template: classtemplate.rst - ~speech_recognition.model.SpeechRecognition ~speech_recognition.data.SpeechRecognitionData + ~speech_recognition.model.SpeechRecognition speech_recognition.data.SpeechRecognitionPreprocess speech_recognition.data.SpeechRecognitionBackboneState diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst index a98c6124f0..b71b335445 100644 --- a/docs/source/api/pointcloud.rst +++ b/docs/source/api/pointcloud.rst @@ -9,22 +9,6 @@ flash.pointcloud .. currentmodule:: flash.pointcloud -Segmentation -____________ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - ~segmentation.model.PointCloudSegmentation - ~segmentation.data.PointCloudSegmentationData - - segmentation.data.PointCloudSegmentationPreprocess - segmentation.data.PointCloudSegmentationFoldersDataSource - segmentation.data.PointCloudSegmentationDatasetDataSource - - Object Detection ________________ diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst index eec2fbf2b6..a44b67d396 100644 --- a/docs/source/reference/pointcloud_segmentation.rst +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -57,9 +57,9 @@ Here's the structure: Learn more: http://www.semantic-kitti.org/dataset.html -Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`. -We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task. -We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference. +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the ``PointCloudSegmentationData``. +We select a pre-trained ``randlanet_semantic_kitti`` backbone for our ``PointCloudSegmentation`` task. +We then use the trained ``PointCloudSegmentation`` for inference. Finally, we save the model. Here's the full example: diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index 588f4f89b2..d62767a8d8 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn -from flash import Task from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState from flash.core.data.process import Serializer from flash.core.data.states import CollateFn +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index d1ba3388b6..fc6c017bed 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -116,7 +116,7 @@ def _compare_version(package: str, op, version) -> bool: _SEGMENTATION_MODELS_AVAILABLE, ]) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE -_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE +_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE _AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 30c877e70d..59f6f893f9 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -4,9 +4,8 @@ from flash.core.data.base_viz import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import Deserializer from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess +from flash.core.data.process import Deserializer, Preprocess from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index ff1e718484..d1abee600a 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -20,11 +20,11 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Sampler -import flash from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer from flash.core.data.states import CollateFn +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE @@ -37,7 +37,7 @@ class PointCloudObjectDetectorSerializer(Serializer): pass -class PointCloudObjectDetector(flash.Task): +class PointCloudObjectDetector(Task): """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies pointcloud data. diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py index bd594ebe2f..f88a0c1ed3 100644 --- a/flash/pointcloud/detection/open3d_ml/data_sources.py +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from os.path import basename, dirname, exists, isdir, isfile, join -from posix import listdir from typing import Any, Dict, List, Optional, Union import yaml @@ -69,7 +69,7 @@ def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]): dataset.color_map = self.meta["color_map"] def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]): - sub_directories = listdir(folder) + sub_directories = os.listdir(folder) if len(sub_directories) != 3: raise MisconfigurationException( f"Using KITTI Format, the {folder} should contains 3 directories " @@ -84,9 +84,9 @@ def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]): labels_dir = join(folder, self.labels_folder_name) calibrations_dir = join(folder, self.calibrations_folder_name) - scan_paths = [join(scans_dir, f) for f in listdir(scans_dir)] - label_paths = [join(labels_dir, f) for f in listdir(labels_dir)] - calibration_paths = [join(calibrations_dir, f) for f in listdir(calibrations_dir)] + scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)] + label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)] + calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)] assert len(scan_paths) == len(label_paths) == len(calibration_paths) diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 4ef0f4c596..18d63ce265 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -1,13 +1,10 @@ from typing import Any, Callable, Dict, Optional, Tuple from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import Deserializer from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras - -if _POINTCLOUD_AVAILABLE: - from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import requires_extras +from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset class PointCloudSegmentationDatasetDataSource(DataSource): diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index b3936acc21..f0b5fdcc29 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -23,7 +23,6 @@ from torch.utils.data import DataLoader, Sampler from torchmetrics import IoU -import flash from flash.core.classification import ClassificationTask from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_source import DefaultDataKeys @@ -112,6 +111,7 @@ def __init__( multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(), ): + import flash if metrics is None: metrics = IoU(num_classes=num_classes) diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index 879f45570e..f525ef64c9 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -13,87 +13,94 @@ # limitations under the License. import torch -from flash import DataModule +from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: from open3d._ml3d.torch.dataloaders import TorchDataloader - from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer + from open3d._ml3d.vis.visualizer import LabelLUT + from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer - class Visualizer(Visualizer): +else: - def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): - """Visualize a dataset. + Open3dVisualizer = object - Example: - Minimal example for visualizing a dataset:: - import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d - dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') - vis = ml3d.vis.Visualizer() - vis.visualize_dataset(dataset, 'all', indices=range(100)) +class Visualizer(Open3dVisualizer): - Args: - dataset: The dataset to use for visualization. - split: The dataset split to be used, such as 'training' - indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. - width: The width of the visualization window. - height: The height of the visualization window. - """ - # Setup the labels + def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): + """Visualize a dataset. + + Example: + Minimal example for visualizing a dataset:: + import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d + + dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') + vis = ml3d.vis.Visualizer() + vis.visualize_dataset(dataset, 'all', indices=range(100)) + + Args: + dataset: The dataset to use for visualization. + split: The dataset split to be used, such as 'training' + indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. + width: The width of the visualization window. + height: The height of the visualization window. + """ + # Setup the labels + lut = LabelLUT() + color_map = dataset.color_map + for id, val in dataset.label_to_names.items(): + lut.add_label(val, id, color=color_map[id]) + self.set_lut("labels", lut) + + self._consolidate_bounding_boxes = True + self._init_dataset(dataset, split, indices) + self._visualize("Open3D - " + dataset.name, width, height) + + +class App: + + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + self._enabled = True # not flash._IS_TESTING + + def get_dataset(self, stage: str = "train"): + dataloader = getattr(self.datamodule, f"{stage}_dataloader")() + dataset = dataloader.dataset.dataset + if isinstance(dataset, TorchDataloader): + return dataset.dataset + return dataset + + def show_train_dataset(self, indices=None): + if self._enabled: + dataset = self.get_dataset("train") + viz = Visualizer() + viz.visualize_dataset(dataset, 'all', indices=indices) + + def show_predictions(self, predictions): + if self._enabled: + dataset = self.get_dataset("train") + color_map = dataset.color_map + + predictions_visualizations = [] + for pred in predictions: + predictions_visualizations.append({ + "points": torch.stack(pred[DefaultDataKeys.INPUT]), + "labels": torch.stack(pred[DefaultDataKeys.TARGET]), + "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, + "name": pred[DefaultDataKeys.METADATA]["name"], + }) + + viz = Visualizer() lut = LabelLUT() color_map = dataset.color_map for id, val in dataset.label_to_names.items(): lut.add_label(val, id, color=color_map[id]) - self.set_lut("labels", lut) - - self._consolidate_bounding_boxes = True - self._init_dataset(dataset, split, indices) - self._visualize("Open3D - " + dataset.name, width, height) - - class App: - - def __init__(self, datamodule: DataModule): - self.datamodule = datamodule - self._enabled = True # not flash._IS_TESTING - - def get_dataset(self, stage: str = "train"): - dataloader = getattr(self.datamodule, f"{stage}_dataloader")() - dataset = dataloader.dataset.dataset - if isinstance(dataset, TorchDataloader): - return dataset.dataset - return dataset - - def show_train_dataset(self, indices=None): - if self._enabled: - dataset = self.get_dataset("train") - viz = Visualizer() - viz.visualize_dataset(dataset, 'all', indices=indices) - - def show_predictions(self, predictions): - if self._enabled: - dataset = self.get_dataset("train") - color_map = dataset.color_map - - predictions_visualizations = [] - for pred in predictions: - predictions_visualizations.append({ - "points": torch.stack(pred[DefaultDataKeys.INPUT]), - "labels": torch.stack(pred[DefaultDataKeys.TARGET]), - "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, - "name": pred[DefaultDataKeys.METADATA]["name"], - }) - - viz = Visualizer() - lut = LabelLUT() - color_map = dataset.color_map - for id, val in dataset.label_to_names.items(): - lut.add_label(val, id, color=color_map[id]) - viz.set_lut("labels", lut) - viz.set_lut("predictions", lut) - viz.visualize(predictions_visualizations) + viz.set_lut("labels", lut) + viz.set_lut("predictions", lut) + viz.visualize(predictions_visualizations) def launch_app(datamodule: DataModule) -> 'App': diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index 0609e2e098..1ad0608e87 100644 --- a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -26,156 +26,157 @@ from open3d._ml3d.datasets.utils import DataProcessing from open3d._ml3d.utils.config import Config - class SequencesDataset(Dataset): - - def __init__( - self, - data, - cache_dir='./logs/cache', - use_cache=False, - num_points=65536, - ignored_label_inds=[0], - predicting=False, - **kwargs - ): - - super().__init__() - - self.name = "Dataset" - self.ignored_label_inds = ignored_label_inds - - kwargs["cache_dir"] = cache_dir - kwargs["use_cache"] = use_cache - kwargs["num_points"] = num_points - kwargs["ignored_label_inds"] = ignored_label_inds - - self.cfg = Config(kwargs) - self.predicting = predicting - - if not predicting: - self.on_fit(data) - else: - self.on_predict(data) - - @property - def color_map(self): - return self.meta["color_map"] - - def on_fit(self, dataset_path): - self.split = basename(dataset_path) - - self.load_meta(dirname(dataset_path)) - self.dataset_path = dataset_path - self.label_to_names = self.get_label_to_names() - self.num_classes = len(self.label_to_names) - len(self.ignored_label_inds) - self.make_datasets() - - def load_meta(self, root_dir): - meta_file = join(root_dir, "meta.yaml") - if not exists(meta_file): - raise MisconfigurationException( - f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences." - ) - - with open(meta_file, 'r') as f: - self.meta = yaml.safe_load(f) - - self.label_to_names = self.get_label_to_names() - self.num_classes = len(self.label_to_names) - - with open(meta_file, 'r') as f: - self.meta = yaml.safe_load(f) - - remap_dict_val = self.meta["learning_map"] - max_key = max(remap_dict_val.keys()) - remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) - remap_lut_val[list(remap_dict_val.keys())] = list(remap_dict_val.values()) - - self.remap_lut_val = remap_lut_val - - def make_datasets(self): - self.path_list = [] - for seq in os.listdir(self.dataset_path): - sequence_path = join(self.dataset_path, seq) - directories = [f for f in os.listdir(sequence_path) if isdir(join(sequence_path, f)) and f != "labels"] - assert len(directories) == 1 - scan_dir = join(sequence_path, directories[0]) - for scan_name in os.listdir(scan_dir): - self.path_list.append(join(scan_dir, scan_name)) - - def on_predict(self, data): - if isinstance(data, list): - if not all(isfile(p) for p in data): - raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") - root_dir = split(data[0])[0] - elif isinstance(data, str): - if not isdir(data) and not isfile(data): - raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") - if isdir(data): - root_dir = data - data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if ".bin" in f] - elif isfile(data): - root_dir = dirname(data) - data = [data] - else: - raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") - else: - raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") - - self.path_list = data - self.split = "predict" - self.load_meta(root_dir) - - def get_label_to_names(self): - """Returns a label to names dictonary object. - Returns: - A dict where keys are label numbers and - values are the corresponding names. - """ - return self.meta["label_to_names"] - - def __getitem__(self, index): - data = self.get_data(index) - data['attr'] = self.get_attr(index) - return data - - def get_data(self, idx): - pc_path = self.path_list[idx] - points = DataProcessing.load_pc_kitti(pc_path) - - dir, file = split(pc_path) - if self.predicting: - label_path = join(dir, file[:-4] + '.label') - else: - label_path = join(dir, '../labels', file[:-4] + '.label') - if not exists(label_path): - labels = np.zeros(np.shape(points)[0], dtype=np.int32) - if self.split not in ['test', 'all']: - raise FileNotFoundError(f' Label file {label_path} not found') +class SequencesDataset(Dataset): + + def __init__( + self, + data, + cache_dir='./logs/cache', + use_cache=False, + num_points=65536, + ignored_label_inds=[0], + predicting=False, + **kwargs + ): + + super().__init__() + + self.name = "Dataset" + self.ignored_label_inds = ignored_label_inds + + kwargs["cache_dir"] = cache_dir + kwargs["use_cache"] = use_cache + kwargs["num_points"] = num_points + kwargs["ignored_label_inds"] = ignored_label_inds + + self.cfg = Config(kwargs) + self.predicting = predicting + + if not predicting: + self.on_fit(data) + else: + self.on_predict(data) + + @property + def color_map(self): + return self.meta["color_map"] + + def on_fit(self, dataset_path): + self.split = basename(dataset_path) + + self.load_meta(dirname(dataset_path)) + self.dataset_path = dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) - len(self.ignored_label_inds) + self.make_datasets() + + def load_meta(self, root_dir): + meta_file = join(root_dir, "meta.yaml") + if not exists(meta_file): + raise MisconfigurationException( + f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences." + ) + + with open(meta_file, 'r') as f: + self.meta = yaml.safe_load(f) + + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + with open(meta_file, 'r') as f: + self.meta = yaml.safe_load(f) + + remap_dict_val = self.meta["learning_map"] + max_key = max(remap_dict_val.keys()) + remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) + remap_lut_val[list(remap_dict_val.keys())] = list(remap_dict_val.values()) + + self.remap_lut_val = remap_lut_val + + def make_datasets(self): + self.path_list = [] + for seq in os.listdir(self.dataset_path): + sequence_path = join(self.dataset_path, seq) + directories = [f for f in os.listdir(sequence_path) if isdir(join(sequence_path, f)) and f != "labels"] + assert len(directories) == 1 + scan_dir = join(sequence_path, directories[0]) + for scan_name in os.listdir(scan_dir): + self.path_list.append(join(scan_dir, scan_name)) + + def on_predict(self, data): + if isinstance(data, list): + if not all(isfile(p) for p in data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + root_dir = split(data[0])[0] + elif isinstance(data, str): + if not isdir(data) and not isfile(data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + if isdir(data): + root_dir = data + data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if ".bin" in f] + elif isfile(data): + root_dir = dirname(data) + data = [data] else: - labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) - - data = { - 'point': points[:, 0:3], - 'feat': None, - 'label': labels, - } - - return data - - def get_attr(self, idx): - pc_path = self.path_list[idx] - dir, file = split(pc_path) - _, seq = split(split(dir)[0]) - name = '{}_{}'.format(seq, file[:-4]) - - pc_path = str(pc_path) - attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': self.split} - return attr - - def __len__(self): - return len(self.path_list) - - def get_split(self, *_): - return self + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + else: + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + + self.path_list = data + self.split = "predict" + self.load_meta(root_dir) + + def get_label_to_names(self): + """Returns a label to names dictonary object. + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + return self.meta["label_to_names"] + + def __getitem__(self, index): + data = self.get_data(index) + data['attr'] = self.get_attr(index) + return data + + def get_data(self, idx): + pc_path = self.path_list[idx] + points = DataProcessing.load_pc_kitti(pc_path) + + dir, file = split(pc_path) + if self.predicting: + label_path = join(dir, file[:-4] + '.label') + else: + label_path = join(dir, '../labels', file[:-4] + '.label') + if not exists(label_path): + labels = np.zeros(np.shape(points)[0], dtype=np.int32) + if self.split not in ['test', 'all']: + raise FileNotFoundError(f' Label file {label_path} not found') + + else: + labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) + + data = { + 'point': points[:, 0:3], + 'feat': None, + 'label': labels, + } + + return data + + def get_attr(self, idx): + pc_path = self.path_list[idx] + dir, file = split(pc_path) + _, seq = split(split(dir)[0]) + name = '{}_{}'.format(seq, file[:-4]) + + pc_path = str(pc_path) + attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': self.split} + return attr + + def __len__(self): + return len(self.path_list) + + def get_split(self, *_): + return self diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt index 544ab6061b..cc6437f44c 100644 --- a/requirements/datatype_pointcloud.txt +++ b/requirements/datatype_pointcloud.txt @@ -1,4 +1,4 @@ -open3d +open3d==0.13 torch==1.7.1 torchvision tensorboard