Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Onboard text classification inputs to new object #1022

Merged
merged 19 commits into from
Dec 6, 2021
12 changes: 11 additions & 1 deletion flash/core/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset

import flash
from flash.core.data.properties import Properties


Expand All @@ -26,7 +27,11 @@ class SplitDataset(Properties, Dataset):
def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None:
kwargs = {}
if isinstance(dataset, Properties):
kwargs = {"running_stage": dataset._running_stage, "state": dataset._state}
kwargs = dict(
running_stage=dataset._running_stage,
data_pipeline_state=dataset._data_pipeline_state,
state=dataset._state,
)
super().__init__(**kwargs)

if indices is None:
Expand All @@ -45,6 +50,11 @@ def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indic
self.dataset = dataset
self.indices = indices

def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"):
super().attach_data_pipeline_state(data_pipeline_state)
if isinstance(self.dataset, Properties):
self.dataset.attach_data_pipeline_state(data_pipeline_state)

def __getattr__(self, key: str):
if key != "dataset":
return getattr(self.dataset, key)
Expand Down
35 changes: 19 additions & 16 deletions flash/core/integrations/labelstudio/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Type

Expand All @@ -18,6 +19,7 @@
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.text.classification.model import TextClassificationBackboneState

if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.data.clip_sampling import make_clip_sampler
Expand Down Expand Up @@ -277,31 +279,32 @@ class LabelStudioTextClassificationInput(LabelStudioInput):
Export data should point to text data
"""

def __init__(self, *args, backbone=None, max_length=128, **kwargs):
if backbone:
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length
def __init__(self, *args, max_length=128, **kwargs):
self.max_length = max_length
super().__init__(*args, **kwargs)

@property
@lru_cache(maxsize=None)
def tokenizer(self):
backbone_state = self.get_state(TextClassificationBackboneState)
return AutoTokenizer.from_pretrained(backbone_state.backbone, use_fast=True)

def load_sample(self, sample: Mapping[str, Any] = None) -> Any:
"""Load 1 sample from dataset."""
if not self.state:
self.state = self.get_state(LabelStudioState)

assert self.state

if self.backbone:
data = ""
for key in sample.get("data"):
data += sample.get("data").get(key)
tokenized_data = self.tokenizer(data, max_length=self.max_length, truncation=True, padding="max_length")
for key in tokenized_data:
tokenized_data[key] = torch.tensor(tokenized_data[key])
tokenized_data["labels"] = _get_labels_from_sample(sample["label"], self.state.classes)
# separate text data type block
result = tokenized_data
return result
data = ""
for key in sample.get("data"):
data += sample.get("data").get(key)
tokenized_data = self.tokenizer(data, max_length=self.max_length, truncation=True, padding="max_length")
for key in tokenized_data:
tokenized_data[key] = torch.tensor(tokenized_data[key])
tokenized_data["labels"] = _get_labels_from_sample(sample["label"], self.state.classes)
# separate text data type block
return tokenized_data


class LabelStudioVideoClassificationInput(LabelStudioIterableInput):
Expand Down
8 changes: 4 additions & 4 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,10 @@ def from_data_frame(
predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver)

return cls(
ImageClassificationCSVInput(RunningStage.TRAINING, *train_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.VALIDATING, *val_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.TESTING, *test_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.PREDICTING, *predict_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.TRAINING, *train_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.VALIDATING, *val_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.TESTING, *test_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.PREDICTING, *predict_data, **dataset_kwargs),
input_transform=cls.input_transform_cls(
train_transform,
val_transform,
Expand Down
Loading