This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Add semantic segmentation task #239
Merged
Merged
Changes from 57 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
e2f5f20
semantic segmentation skeleton
edgarriba f3ce4c7
expose and add smoke tests for preproces and datamodule
edgarriba 1ef1b40
data module connections working
edgarriba 7f17fb2
preprocess not crashing(wip)
edgarriba 7d9d46c
implement segmentation sequential
edgarriba 498e278
implement torchvision backbone model
edgarriba 56fa4d5
model working
edgarriba 950252e
implement labels mapping
edgarriba 6a75245
add map labels tests
edgarriba 7a7f855
from filepaths training test not crashing
edgarriba def1ea0
non working visualiser
edgarriba ed17eb0
fix visualiser
edgarriba 3eb6417
training working
edgarriba d529d9e
training not crashing
edgarriba 13095e6
cleanup example and move serializer to core
edgarriba 2f9ede5
cleanup model code, tests and docs
edgarriba dc9b2b8
move transforms apart
edgarriba e767a53
implement ApplytransformsToKey augmentations
edgarriba f268b62
relative path
edgarriba 99b99f0
fix load from pretrained and add resnet 101
edgarriba d1a91fd
create segmentation keys enum
edgarriba 7343887
sync with master and fix val_split
edgarriba febe7f0
move apart segmentation backbones
edgarriba 3891920
Merge branch 'master' into feat/segmentation
edgarriba 248145b
fix tests
edgarriba 6d635db
fix tests
edgarriba ca97034
fix tests
edgarriba da83251
fix memory leak issues
edgarriba f1e76f9
Merge branch 'master' into feat/segmentation
edgarriba 2ef8c88
undo function filtering
edgarriba 87a92f3
fix import
edgarriba 73d462b
more fixes for memory leaks
edgarriba 8b971a4
add segmentation to docs
edgarriba 69358a6
add inference example
edgarriba caabfb6
add image to docs and update with AdamW
edgarriba 78301ff
Merge branch 'master' into feat/segmentation
ethanwharris e8e92d1
Make pretrained arg kwarg
ethanwharris cf430f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 74ce6dc
Data sources initial commit
ethanwharris df2b989
Update transforms
ethanwharris bb95b8f
Updates
ethanwharris 3a8842d
Fixes
ethanwharris 3596e16
Fix tests
ethanwharris 859a0ef
Fixes
ethanwharris 268b4c8
Fixes
ethanwharris 0b67769
Merge branch 'master' into feat/segmentation
ethanwharris 091e50d
Merge branch 'master' into feat/segmentation
ethanwharris 4aa3716
Add tests
ethanwharris f50c200
Update docs/source/reference/semantic_segmentation.rst
ethanwharris 11ed7c5
Update docs/source/reference/semantic_segmentation.rst
ethanwharris 0fc3581
Add a check
ethanwharris 1875967
Move KorniaParallelTransforms and add docstring
ethanwharris e9dee30
implement quick test for segmentation labels
edgarriba 0049197
add small assertion tests
edgarriba 4c75774
Rename test_serialisation.py to test_serialization.py
ethanwharris 2d76f37
Switch to exception
ethanwharris 5f254b6
Fix
ethanwharris 8745191
Fixes
ethanwharris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,3 +153,5 @@ wmt_en_ro | |
action_youtube_naudio | ||
kinetics | ||
movie_posters | ||
CameraRGB | ||
CameraSeg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
|
||
.. _semantinc_segmentation: | ||
|
||
###################### | ||
Semantinc Segmentation | ||
###################### | ||
|
||
******** | ||
The task | ||
******** | ||
Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``. | ||
|
||
See more: https://paperswithcode.com/task/semantic-segmentation | ||
|
||
.. raw:: html | ||
|
||
<p> | ||
<a href="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png" > | ||
<img src="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png"/> | ||
</a> | ||
</p> | ||
|
||
------ | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA <http://carla.org/>`_ simulator is provided for the inference example. | ||
|
||
|
||
Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`: | ||
|
||
.. code-block:: python | ||
|
||
# import our libraries | ||
from flash.data.utils import download_data | ||
from flash.vision import SemanticSegmentation | ||
from flash.vision.segmentation.serialization import SegmentationLabels | ||
|
||
# 1. Download the data | ||
download_data( | ||
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", | ||
"data/" | ||
) | ||
|
||
# 2. Load the model from a checkpoint | ||
model = SemanticSegmentation.load_from_checkpoint( | ||
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" | ||
) | ||
model.serializer = SegmentationLabels(visualize=True) | ||
|
||
# 3. Predict what's on a few images and visualize! | ||
predictions = model.predict([ | ||
'data/CameraRGB/F61-1.png', | ||
'data/CameraRGB/F62-1.png', | ||
'data/CameraRGB/F63-1.png', | ||
]) | ||
|
||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
------ | ||
|
||
********** | ||
Finetuning | ||
********** | ||
|
||
you now want to customise your model with new data using the same dataset. | ||
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`. | ||
|
||
.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename. | ||
|
||
.. code-block:: | ||
|
||
data | ||
├── CameraRGB | ||
│ ├── F61-1.png | ||
│ ├── F61-2.png | ||
│ ... | ||
└── CameraSeg | ||
├── F61-1.png | ||
├── F61-2.png | ||
... | ||
|
||
|
||
Now all we need is three lines of code to build to train our task! | ||
|
||
.. code-block:: python | ||
|
||
import flash | ||
from flash.data.utils import download_data | ||
from flash.vision import SemanticSegmentation, SemanticSegmentationData | ||
from flash.vision.segmentation.serialization import SegmentationLabels | ||
|
||
# 1. Download the data | ||
download_data( | ||
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", | ||
"data/" | ||
) | ||
|
||
# 2.1 Load the data | ||
datamodule = SemanticSegmentationData.from_folders( | ||
train_folder="data/CameraRGB", | ||
train_target_folder="data/CameraSeg", | ||
batch_size=4, | ||
val_split=0.3, | ||
image_size=(200, 200), # (600, 800) | ||
) | ||
|
||
# 2.2 Visualise the samples | ||
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) | ||
datamodule.set_labels_map(labels_map) | ||
datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) | ||
|
||
# 3. Build the model | ||
model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21) | ||
|
||
# 4. Create the trainer. | ||
trainer = flash.Trainer(max_epochs=1) | ||
|
||
# 5. Train the model | ||
trainer.finetune(model, datamodule=datamodule, strategy='freeze') | ||
|
||
# 7. Save it! | ||
trainer.save_checkpoint("semantic_segmentation_model.pt") | ||
|
||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
.. _segmentation: | ||
|
||
SemanticSegmentation | ||
-------------------- | ||
|
||
.. autoclass:: flash.vision.SemanticSegmentation | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _segmentation_data: | ||
|
||
SemanticSegmentationData | ||
------------------------ | ||
|
||
.. autoclass:: flash.vision.SemanticSegmentationData | ||
|
||
.. automethod:: flash.vision.SemanticSegmentationData.from_folders | ||
|
||
.. autoclass:: flash.vision.SemanticSegmentationPreprocess |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess | ||
from flash.vision.segmentation.model import SemanticSegmentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch.nn as nn | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.utils.imports import _TORCHVISION_AVAILABLE | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
import torchvision | ||
|
||
SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") | ||
|
||
|
||
@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50") | ||
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module: | ||
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) | ||
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) | ||
return model | ||
|
||
|
||
@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101") | ||
def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module: | ||
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) | ||
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) | ||
return model |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any library should we integrate there ? Like IceVision ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torchvision should be good enough for now.
We already have heavy dependencies.