Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Adding integration with Label Studio #554

Merged
merged 64 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
6a887e9
Adding label studio integration
KonstantinKorotaev Jul 7, 2021
f9987d1
Import fixes
KonstantinKorotaev Jul 7, 2021
81d2de3
Moving project.json to test data in S3
KonstantinKorotaev Jul 7, 2021
ae1cbe7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
c51dd64
Moving from dataset to datasource
KonstantinKorotaev Aug 13, 2021
79f3d70
Merge fixes
KonstantinKorotaev Aug 13, 2021
718b152
Add video example
KonstantinKorotaev Aug 17, 2021
02c4cea
Fix merge conflict and Analysis checks
KonstantinKorotaev Aug 18, 2021
0c3f2fb
Updating imports to solve conflict
KonstantinKorotaev Aug 18, 2021
92672ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2021
6049fae
Moving to separate classes for image, video and text
KonstantinKorotaev Aug 23, 2021
e995d8c
Fixing merge conflict
KonstantinKorotaev Aug 23, 2021
3e8056d
Merging changes from PyTorchLightning master
KonstantinKorotaev Aug 23, 2021
33958d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
9a0d919
Merge pull request #2 from KonstantinKorotaev/PyTorchLightning-master
KonstantinKorotaev Aug 23, 2021
0a3c906
Merge pull request #3 from KonstantinKorotaev/PyTorchLightning-master
KonstantinKorotaev Aug 23, 2021
34e7a75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
595f266
Fixing DeepSource analysis issues
KonstantinKorotaev Aug 23, 2021
8ba9f4c
Fix last DeepSource analysis issues
KonstantinKorotaev Aug 23, 2021
c9af6f6
Merge pull request #4 from KonstantinKorotaev/PyTorchLightning-master
KonstantinKorotaev Aug 23, 2021
fb3103f
Delete useless init
KonstantinKorotaev Aug 23, 2021
bf82fd9
Move label studio datasource to separate file
KonstantinKorotaev Aug 27, 2021
7f67ce1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
5ecd88e
Add test for LabelStudioDataSource._load_json_data
KonstantinKorotaev Aug 30, 2021
ba560c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2021
19fe4a8
Fix DeepSource analysis issues
KonstantinKorotaev Aug 30, 2021
7b44ea1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2021
baf4fa1
Merge branch 'master' into master
tchaton Sep 3, 2021
3c117f0
Adding test for each Label Studio datasource
KonstantinKorotaev Sep 8, 2021
23b6d7e
Fixing typo and grouping import
KonstantinKorotaev Sep 8, 2021
d711724
Merge remote-tracking branch 'upstream/master'
KonstantinKorotaev Sep 8, 2021
4964a6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
55b1345
Fix import for DefaultDataSources
KonstantinKorotaev Sep 8, 2021
e41151d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
c5c11d7
Fix import for ImageClassificationData
KonstantinKorotaev Sep 8, 2021
ea2e474
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
a940f6f
Fix tests conditions
KonstantinKorotaev Sep 9, 2021
0cc7bd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
a3066d8
Fix data sources prerequisite
KonstantinKorotaev Sep 9, 2021
f59a1fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
119593e
Separating tests for Datamodule
KonstantinKorotaev Sep 9, 2021
20c0c48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
28c51ee
Fixing link to file for image data sets
KonstantinKorotaev Sep 9, 2021
a46140b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
466bd69
Fix LabelStudioImageClassificationDataSource test
KonstantinKorotaev Sep 9, 2021
18ae793
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
77f1338
Adding App for visualizing predictions
KonstantinKorotaev Sep 9, 2021
5bf56b2
Fix import for label studio launch_app
KonstantinKorotaev Sep 9, 2021
50ebe0f
Fix strings in tests
KonstantinKorotaev Sep 9, 2021
c172b01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
81fd7a7
Rename visualizer module
KonstantinKorotaev Sep 9, 2021
f512e0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
5590a2a
Fix test for video and image multilabel
KonstantinKorotaev Sep 9, 2021
a19712f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
0d9dd95
Fixing docstring and test condition
KonstantinKorotaev Sep 9, 2021
54186b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
593e51e
Fix docstring and CODEOWNERS
KonstantinKorotaev Sep 13, 2021
022e1ab
Adding converter to tasks
KonstantinKorotaev Sep 14, 2021
0e1080a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2021
85bf6b6
Merge branch 'master' into master
tchaton Sep 24, 2021
b81f53f
update
tchaton Sep 28, 2021
46b3f6b
Merge branch 'master' into master
tchaton Sep 28, 2021
b89d6bb
update changelog
tchaton Sep 28, 2021
5b5be24
Merge commit 'refs/pull/554/head' of https://github.com/PyTorchLightn…
tchaton Sep 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 95 additions & 2 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import platform
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
Expand All @@ -21,14 +22,14 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.dataset import IterableDataset, random_split, Subset
from torch.utils.data.sampler import Sampler

from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset
from flash.core.data.base_viz import BaseVisualization
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.data_source import DataSource, DefaultDataSources, LabelStudioDataset
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
Expand Down Expand Up @@ -1152,3 +1153,95 @@ def from_fiftyone(
num_workers=num_workers,
**preprocess_kwargs,
)

@classmethod
def from_labelstudio(
cls,
export_json: str = None,
img_folder: str = None,
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object
from the given export file and data directory using the
:class:`~flash.core.data.data_source.DataSource` of name
:attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS`
from the passed or constructed :class:`~flash.core.data.process.Preprocess`.

Args:
export_json: path to label studio export file
img_folder: path to label studio data folder
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Returns:
The constructed data module.

Examples::

data_module = DataModule.from_labelstudio(
export_json='project.json',
img_folder='label-studio/media/upload',
val_split=0.8,
)
"""
# loading export data
with open(export_json) as f:
js = json.load(f)
# loading data sets
full_dataset = LabelStudioDataset(js, img_folder)
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
val_dataset = LabelStudioDataset(js, img_folder, val=True)
# creating splitting params
l = len(full_dataset)
prop = int(l * val_split)
# splitting full data set
train_dataset, test_dataset = random_split(full_dataset, [prop, l - prop])

preprocess = preprocess or cls.preprocess_cls(
train_transform,
val_transform,
test_transform,
predict_transform,
**preprocess_kwargs,
)
data_source = preprocess.data_source_of_name(DefaultDataSources.FOLDERS)
data = cls(
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
train_dataset,
val_dataset,
test_dataset,
None,
data_source=data_source,
preprocess=preprocess,
data_fetcher=data_fetcher,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
)
return data
64 changes: 64 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytorch_lightning.utilities.enums import LightningEnum
from torch.nn import Module
from torch.utils.data.dataset import Dataset
from torchvision.datasets.folder import default_loader

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.properties import ProcessState, Properties
Expand Down Expand Up @@ -156,6 +157,7 @@ class DefaultDataSources(LightningEnum):
JSON = "json"
DATASET = "dataset"
FIFTYONE = "fiftyone"
LABELSTUDIO = "labelstudio"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
Expand Down Expand Up @@ -543,3 +545,65 @@ def _get_classes(self, data):
classes = data.distinct(label_path)

return classes


class LabelStudioDataset(Dataset):
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
r"""Dataset wrapping label studio annotations.

Each sample will be retrieved by checking result field of the annotation.

Args:
*js: json of export file
*img_folder: path to image folder of label studio
"""

def __init__(self, js, img_folder, val=False):
self._raw_data = js
self._img_folder = img_folder
self.results = list()
self.classes = set()
# iterate through all tasks in exported data
for task in self._raw_data:
KonstantinKorotaev marked this conversation as resolved.
Show resolved Hide resolved
for annotation in task['annotations']:
# Adding ground_truth annotation to separate dataset
result = annotation['result']
for res in result:
t = res['type']
for label in res['value'][t]:
# check if labeling result is a list of labels
if isinstance(label, list):
for sublabel in label:
self.classes.add(sublabel)
temp = dict()
temp['file_upload'] = task['file_upload']
temp['label'] = sublabel
if annotation['ground_truth'] & val:
self.results.append(temp)
elif not annotation['ground_truth'] or not val:
self.results.append(temp)
else:
self.classes.add(label)
temp = dict()
temp['file_upload'] = task['file_upload']
temp['label'] = label
if annotation['ground_truth'] & val:
self.results.append(temp)
elif not annotation['ground_truth'] or not val:
self.results.append(temp)
self.num_classes = len(self.classes)

def __getitem__(self, idx):
r = self.results[idx]
# extracting path to file
p = os.path.join(self._img_folder, r['file_upload'])
# loading image
sample = default_loader(p)
# casting to list and sorting classes
sorted_labels = sorted(list(self.classes))
# checking index of class
label = sorted_labels.index(r['label'])
result = {DefaultDataKeys.INPUT: sample, DefaultDataKeys.TARGET: label}
return result

def __len__(self):
return len(self.results)
40 changes: 40 additions & 0 deletions flash_examples/integrations/labelstudio/image_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from itertools import chain

import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassificationData, ImageClassifier

# 1 Download data
download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip")

# 1. Load export data
datamodule = ImageClassificationData.from_labelstudio(
export_json='data/project.json',
img_folder='data/upload/',
val_split=0.8,
)

# 2. Fine tune a model
model = ImageClassifier(
backbone="resnet18",
num_classes=datamodule.num_classes,
)
trainer = flash.Trainer(max_epochs=3)

trainer.finetune(
model,
datamodule=datamodule,
strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 3. Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
model.serializer = Labels()

predictions = model.predict([
"data/test/1.jpg",
"data/test/2.jpg",
])