Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for space delimited multi-label targets #1076

Merged
merged 9 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for multi-label, space delimited, targets ([#1076](https://github.com/PyTorchLightning/lightning-flash/pull/1076))

### Changed

### Deprecated
Expand Down
27 changes: 23 additions & 4 deletions flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ def _as_list(x: Union[List, torch.Tensor, np.ndarray]) -> List:
return x


def _strip(x: str) -> str:
return x.strip(", ")


class TargetMode(Enum):
"""The ``TargetMode`` Enum describes the different supported formats for targets in Flash."""

MULTI_TOKEN = auto()
MULTI_NUMERIC = auto()
MUTLI_COMMA_DELIMITED = auto()
MUTLI_SPACE_DELIMITED = auto()
MULTI_BINARY = auto()

SINGLE_TOKEN = auto()
Expand All @@ -67,9 +72,12 @@ def from_target(cls, target: Any) -> "TargetMode":
target: A target that is one of: a single target, a list of targets, a comma delimited string.
"""
if isinstance(target, str):
# TODO: This could be a dangerous assumption if people happen to have a label that contains a comma
target = _strip(target)
# TODO: This could be a dangerous assumption if people happen to have a label that contains a comma or space
if "," in target:
return TargetMode.MUTLI_COMMA_DELIMITED
elif " " in target:
return TargetMode.MUTLI_SPACE_DELIMITED
else:
return TargetMode.SINGLE_TOKEN
elif _is_list_like(target):
Expand All @@ -88,6 +96,7 @@ def multi_label(self) -> bool:
return any(
[
self is TargetMode.MUTLI_COMMA_DELIMITED,
self is TargetMode.MUTLI_SPACE_DELIMITED,
self is TargetMode.MULTI_NUMERIC,
self is TargetMode.MULTI_TOKEN,
self is TargetMode.MULTI_BINARY,
Expand Down Expand Up @@ -116,7 +125,7 @@ def binary(self) -> bool:
_RESOLUTION_MAPPING = {
TargetMode.MULTI_BINARY: [TargetMode.MULTI_NUMERIC],
TargetMode.SINGLE_BINARY: [TargetMode.MULTI_BINARY, TargetMode.MULTI_NUMERIC],
TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED],
TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED, TargetMode.MUTLI_SPACE_DELIMITED],
TargetMode.SINGLE_NUMERIC: [TargetMode.MULTI_NUMERIC],
}

Expand Down Expand Up @@ -179,7 +188,7 @@ def __init__(self, labels: List[Any]):
self.label_to_idx = {label: idx for idx, label in enumerate(labels)}

def format(self, target: Any) -> Any:
return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()]
return self.label_to_idx[_strip(target[0] if not isinstance(target, str) else target)]


class MultiLabelTargetFormatter(SingleLabelTargetFormatter):
Expand All @@ -201,6 +210,11 @@ def format(self, target: Any) -> Any:
return super().format(target.split(","))


class SpaceDelimitedTargetFormatter(MultiLabelTargetFormatter):
def format(self, target: Any) -> Any:
return super().format(target.split(" "))


class MultiNumericTargetFormatter(TargetFormatter):
def __init__(self, num_classes: int):
self.num_classes = num_classes
Expand Down Expand Up @@ -245,6 +259,8 @@ def get_target_formatter(
return SingleLabelTargetFormatter(labels)
elif target_mode is TargetMode.MUTLI_COMMA_DELIMITED:
return CommaDelimitedTargetFormatter(labels)
elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED:
return SpaceDelimitedTargetFormatter(labels)
return MultiLabelTargetFormatter(labels)


Expand Down Expand Up @@ -289,13 +305,16 @@ def get_target_details(targets: List[Any], target_mode: TargetMode) -> Tuple[Opt
if target_mode is TargetMode.MUTLI_COMMA_DELIMITED:
for target in targets:
tokens.extend(target.split(","))
elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED:
for target in targets:
tokens.extend(target.split(" "))
elif target_mode is TargetMode.MULTI_TOKEN:
for target in targets:
tokens.extend(target)
else:
tokens = targets

tokens = [token.strip() for token in tokens]
tokens = [_strip(token) for token in tokens]
labels = list(sorted_alphanumeric(set(tokens)))
num_classes = len(labels)
return labels, num_classes
7 changes: 4 additions & 3 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ def from_voc(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ObjectDetectionData":
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
and annotation files in the `PASCAL VOC (Visual Obect Challenge)
<http://host.robots.ox.ac.uk/pascal/VOC/>`_ XML format.
""".. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/

Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format <PASCAL>`_.

Args:
train_folder: The folder containing the train data.
Expand Down
29 changes: 25 additions & 4 deletions tests/core/data/utilities/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue green", "green red", "red blue"],
[[1, 1, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_SPACE_DELIMITED,
["blue", "green", "red"],
3,
),
# Ambiguous
Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_NUMERIC, None, 3),
Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_BINARY, None, 3),
Expand All @@ -67,6 +74,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue", "green red", "red blue"],
[[1, 0, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_SPACE_DELIMITED,
["blue", "green", "red"],
3,
),
# Special cases
Case(["blue ", " green", "red"], [0, 1, 2], TargetMode.SINGLE_TOKEN, ["blue", "green", "red"], 3),
Case(
Expand All @@ -76,6 +90,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue", "green ,red", "red ,blue"],
[[1, 0, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_COMMA_DELIMITED,
["blue", "green", "red"],
3,
),
Case(
[f"class_{i}" for i in range(10000)],
list(range(10000)),
Expand Down Expand Up @@ -115,16 +136,16 @@ def test_speed(case):
else:
targets = case.target * repeats

start = time.time()
start = time.perf_counter()
target_mode = get_target_mode(targets)
labels, num_classes = get_target_details(targets, target_mode)
formatter = get_target_formatter(target_mode, labels, num_classes)
end = time.time()
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target

start = time.time()
start = time.perf_counter()
_ = [formatter(t) for t in targets]
end = time.time()
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target