Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Temporary fix for RTD build (#605)
Browse files Browse the repository at this point in the history
* Try something

* Try something

* Try something

* Try something

* Try something

* Try something

* Try something

* Try something

* Add few more paths

* Test

* Drop

* Add back, remove requires

* Remove

* task

* temp

* test

* test

* test

* ttempt

* Format code with autopep8

* attempt

* attempt

* temp

* Format code with autopep8

* Fix a few

* Format code with autopep8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Try fix

* Try fix

* Try fix

* Try something

* Try something

* Try something

* Try something

* Cleaning

* Fixes

* Remove CI addition

Co-authored-by: SeanNaren <[email protected]>
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 23, 2021
1 parent 0f6bb7e commit ffe31b5
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 256 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ __________________
:nosignatures:
:template: classtemplate.rst

~speech_recognition.model.SpeechRecognition
~speech_recognition.data.SpeechRecognitionData
~speech_recognition.model.SpeechRecognition

speech_recognition.data.SpeechRecognitionPreprocess
speech_recognition.data.SpeechRecognitionBackboneState
Expand Down
16 changes: 0 additions & 16 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,6 @@ flash.pointcloud

.. currentmodule:: flash.pointcloud

Segmentation
____________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~segmentation.model.PointCloudSegmentation
~segmentation.data.PointCloudSegmentationData

segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource


Object Detection
________________

Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/pointcloud_segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ Here's the structure:
Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task.
We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference.
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the ``PointCloudSegmentationData``.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our ``PointCloudSegmentation`` task.
We then use the trained ``PointCloudSegmentation`` for inference.
Finally, we save the model.
Here's the full example:

Expand Down
2 changes: 1 addition & 1 deletion flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import torch
import torch.nn as nn

from flash import Task
from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState
from flash.core.data.process import Serializer
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE

Expand Down
2 changes: 1 addition & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _compare_version(package: str, op, version) -> bool:
_SEGMENTATION_MODELS_AVAILABLE,
])
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE
_AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

Expand Down
3 changes: 1 addition & 2 deletions flash/pointcloud/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from flash.core.data.base_viz import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import Deserializer
from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE

if _POINTCLOUD_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions flash/pointcloud/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Serializer
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
Expand All @@ -37,7 +37,7 @@ class PointCloudObjectDetectorSerializer(Serializer):
pass


class PointCloudObjectDetector(flash.Task):
class PointCloudObjectDetector(Task):
"""The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies
pointcloud data.
Expand Down
10 changes: 5 additions & 5 deletions flash/pointcloud/detection/open3d_ml/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import basename, dirname, exists, isdir, isfile, join
from posix import listdir
from typing import Any, Dict, List, Optional, Union

import yaml
Expand Down Expand Up @@ -69,7 +69,7 @@ def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]):
dataset.color_map = self.meta["color_map"]

def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]):
sub_directories = listdir(folder)
sub_directories = os.listdir(folder)
if len(sub_directories) != 3:
raise MisconfigurationException(
f"Using KITTI Format, the {folder} should contains 3 directories "
Expand All @@ -84,9 +84,9 @@ def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]):
labels_dir = join(folder, self.labels_folder_name)
calibrations_dir = join(folder, self.calibrations_folder_name)

scan_paths = [join(scans_dir, f) for f in listdir(scans_dir)]
label_paths = [join(labels_dir, f) for f in listdir(labels_dir)]
calibration_paths = [join(calibrations_dir, f) for f in listdir(calibrations_dir)]
scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)]
label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)]
calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)]

assert len(scan_paths) == len(label_paths) == len(calibration_paths)

Expand Down
9 changes: 3 additions & 6 deletions flash/pointcloud/segmentation/data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Any, Callable, Dict, Optional, Tuple

from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import Deserializer
from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras

if _POINTCLOUD_AVAILABLE:
from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import requires_extras
from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset


class PointCloudSegmentationDatasetDataSource(DataSource):
Expand Down
2 changes: 1 addition & 1 deletion flash/pointcloud/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch.utils.data import DataLoader, Sampler
from torchmetrics import IoU

import flash
from flash.core.classification import ClassificationTask
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -112,6 +111,7 @@ def __init__(
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(),
):
import flash
if metrics is None:
metrics = IoU(num_classes=num_classes)

Expand Down
139 changes: 73 additions & 66 deletions flash/pointcloud/segmentation/open3d_ml/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,94 @@
# limitations under the License.
import torch

from flash import DataModule
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE

if _POINTCLOUD_AVAILABLE:

from open3d._ml3d.torch.dataloaders import TorchDataloader
from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer
from open3d._ml3d.vis.visualizer import LabelLUT
from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer

class Visualizer(Visualizer):
else:

def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768):
"""Visualize a dataset.
Open3dVisualizer = object

Example:
Minimal example for visualizing a dataset::
import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d

dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/')
vis = ml3d.vis.Visualizer()
vis.visualize_dataset(dataset, 'all', indices=range(100))
class Visualizer(Open3dVisualizer):

Args:
dataset: The dataset to use for visualization.
split: The dataset split to be used, such as 'training'
indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
width: The width of the visualization window.
height: The height of the visualization window.
"""
# Setup the labels
def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768):
"""Visualize a dataset.
Example:
Minimal example for visualizing a dataset::
import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d
dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/')
vis = ml3d.vis.Visualizer()
vis.visualize_dataset(dataset, 'all', indices=range(100))
Args:
dataset: The dataset to use for visualization.
split: The dataset split to be used, such as 'training'
indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
width: The width of the visualization window.
height: The height of the visualization window.
"""
# Setup the labels
lut = LabelLUT()
color_map = dataset.color_map
for id, val in dataset.label_to_names.items():
lut.add_label(val, id, color=color_map[id])
self.set_lut("labels", lut)

self._consolidate_bounding_boxes = True
self._init_dataset(dataset, split, indices)
self._visualize("Open3D - " + dataset.name, width, height)


class App:

def __init__(self, datamodule: DataModule):
self.datamodule = datamodule
self._enabled = True # not flash._IS_TESTING

def get_dataset(self, stage: str = "train"):
dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
dataset = dataloader.dataset.dataset
if isinstance(dataset, TorchDataloader):
return dataset.dataset
return dataset

def show_train_dataset(self, indices=None):
if self._enabled:
dataset = self.get_dataset("train")
viz = Visualizer()
viz.visualize_dataset(dataset, 'all', indices=indices)

def show_predictions(self, predictions):
if self._enabled:
dataset = self.get_dataset("train")
color_map = dataset.color_map

predictions_visualizations = []
for pred in predictions:
predictions_visualizations.append({
"points": torch.stack(pred[DefaultDataKeys.INPUT]),
"labels": torch.stack(pred[DefaultDataKeys.TARGET]),
"predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1,
"name": pred[DefaultDataKeys.METADATA]["name"],
})

viz = Visualizer()
lut = LabelLUT()
color_map = dataset.color_map
for id, val in dataset.label_to_names.items():
lut.add_label(val, id, color=color_map[id])
self.set_lut("labels", lut)

self._consolidate_bounding_boxes = True
self._init_dataset(dataset, split, indices)
self._visualize("Open3D - " + dataset.name, width, height)

class App:

def __init__(self, datamodule: DataModule):
self.datamodule = datamodule
self._enabled = True # not flash._IS_TESTING

def get_dataset(self, stage: str = "train"):
dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
dataset = dataloader.dataset.dataset
if isinstance(dataset, TorchDataloader):
return dataset.dataset
return dataset

def show_train_dataset(self, indices=None):
if self._enabled:
dataset = self.get_dataset("train")
viz = Visualizer()
viz.visualize_dataset(dataset, 'all', indices=indices)

def show_predictions(self, predictions):
if self._enabled:
dataset = self.get_dataset("train")
color_map = dataset.color_map

predictions_visualizations = []
for pred in predictions:
predictions_visualizations.append({
"points": torch.stack(pred[DefaultDataKeys.INPUT]),
"labels": torch.stack(pred[DefaultDataKeys.TARGET]),
"predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1,
"name": pred[DefaultDataKeys.METADATA]["name"],
})

viz = Visualizer()
lut = LabelLUT()
color_map = dataset.color_map
for id, val in dataset.label_to_names.items():
lut.add_label(val, id, color=color_map[id])
viz.set_lut("labels", lut)
viz.set_lut("predictions", lut)
viz.visualize(predictions_visualizations)
viz.set_lut("labels", lut)
viz.set_lut("predictions", lut)
viz.visualize(predictions_visualizations)


def launch_app(datamodule: DataModule) -> 'App':
Expand Down
Loading

0 comments on commit ffe31b5

Please sign in to comment.