Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Adding integration with Label Studio (#554)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
KonstantinKorotaev and tchaton authored Sep 28, 2021
1 parent f2c9b2a commit 0a28672
Show file tree
Hide file tree
Showing 14 changed files with 950 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
/.github/*.md @edenlightning @ethanwharris @ananyahjha93
/.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris @ananyahjha93
/docs/source/conf.py @borda @ethanwharris @ananyahjha93
/flash/core/integrations/labelstudio @KonstantinKorotaev @niklub
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `LabelStudio` integration ([#554](https://github.com/PyTorchLightning/lightning-flash/pull/554))

- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737))

- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682))
Expand Down
125 changes: 125 additions & 0 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,3 +1246,128 @@ def from_fiftyone(
num_workers=num_workers,
**preprocess_kwargs,
)

@classmethod
def from_labelstudio(
cls,
export_json: str = None,
train_export_json: str = None,
val_export_json: str = None,
test_export_json: str = None,
predict_export_json: str = None,
data_folder: str = None,
train_data_folder: str = None,
val_data_folder: str = None,
test_data_folder: str = None,
predict_data_folder: str = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object
from the given export file and data directory using the
:class:`~flash.core.data.data_source.DataSource` of name
:attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS`
from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
Args:
export_json: path to label studio export file
train_export_json: path to label studio export file for train set,
overrides export_json if specified
val_export_json: path to label studio export file for validation
test_export_json: path to label studio export file for test
predict_export_json: path to label studio export file for predict
data_folder: path to label studio data folder
train_data_folder: path to label studio data folder for train data set,
overrides data_folder if specified
val_data_folder: path to label studio data folder for validation data
test_data_folder: path to label studio data folder for test data
predict_data_folder: path to label studio data folder for predict data
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
Examples::
data_module = DataModule.from_labelstudio(
export_json='project.json',
data_folder='label-studio/media/upload',
val_split=0.8,
)
"""
data = {
"data_folder": data_folder,
"export_json": export_json,
"split": val_split,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
train_data = None
val_data = None
test_data = None
predict_data = None
if (train_data_folder or data_folder) and train_export_json:
train_data = {
"data_folder": train_data_folder or data_folder,
"export_json": train_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (val_data_folder or data_folder) and val_export_json:
val_data = {
"data_folder": val_data_folder or data_folder,
"export_json": val_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (test_data_folder or data_folder) and test_export_json:
test_data = {
"data_folder": test_data_folder or data_folder,
"export_json": test_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (predict_data_folder or data_folder) and predict_export_json:
predict_data = {
"data_folder": predict_data_folder or data_folder,
"export_json": predict_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
return cls.from_data_source(
DefaultDataSources.LABELSTUDIO,
train_data=train_data if train_data else data,
val_data=val_data,
test_data=test_data,
predict_data=predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)
1 change: 1 addition & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class DefaultDataSources(LightningEnum):
JSON = "json"
DATASETS = "datasets"
FIFTYONE = "fiftyone"
LABELSTUDIO = "labelstudio"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
Expand Down
Loading

0 comments on commit 0a28672

Please sign in to comment.