Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add support for space delimited multi-label targets (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 15, 2021
1 parent ba72af6 commit 6f3dc08
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for multi-label, space delimited, targets ([#1076](https://github.com/PyTorchLightning/lightning-flash/pull/1076))

### Changed

### Deprecated
Expand Down
27 changes: 23 additions & 4 deletions flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ def _as_list(x: Union[List, torch.Tensor, np.ndarray]) -> List:
return x


def _strip(x: str) -> str:
return x.strip(", ")


class TargetMode(Enum):
"""The ``TargetMode`` Enum describes the different supported formats for targets in Flash."""

MULTI_TOKEN = auto()
MULTI_NUMERIC = auto()
MUTLI_COMMA_DELIMITED = auto()
MUTLI_SPACE_DELIMITED = auto()
MULTI_BINARY = auto()

SINGLE_TOKEN = auto()
Expand All @@ -67,9 +72,12 @@ def from_target(cls, target: Any) -> "TargetMode":
target: A target that is one of: a single target, a list of targets, a comma delimited string.
"""
if isinstance(target, str):
# TODO: This could be a dangerous assumption if people happen to have a label that contains a comma
target = _strip(target)
# TODO: This could be a dangerous assumption if people happen to have a label that contains a comma or space
if "," in target:
return TargetMode.MUTLI_COMMA_DELIMITED
elif " " in target:
return TargetMode.MUTLI_SPACE_DELIMITED
else:
return TargetMode.SINGLE_TOKEN
elif _is_list_like(target):
Expand All @@ -88,6 +96,7 @@ def multi_label(self) -> bool:
return any(
[
self is TargetMode.MUTLI_COMMA_DELIMITED,
self is TargetMode.MUTLI_SPACE_DELIMITED,
self is TargetMode.MULTI_NUMERIC,
self is TargetMode.MULTI_TOKEN,
self is TargetMode.MULTI_BINARY,
Expand Down Expand Up @@ -116,7 +125,7 @@ def binary(self) -> bool:
_RESOLUTION_MAPPING = {
TargetMode.MULTI_BINARY: [TargetMode.MULTI_NUMERIC],
TargetMode.SINGLE_BINARY: [TargetMode.MULTI_BINARY, TargetMode.MULTI_NUMERIC],
TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED],
TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED, TargetMode.MUTLI_SPACE_DELIMITED],
TargetMode.SINGLE_NUMERIC: [TargetMode.MULTI_NUMERIC],
}

Expand Down Expand Up @@ -179,7 +188,7 @@ def __init__(self, labels: List[Any]):
self.label_to_idx = {label: idx for idx, label in enumerate(labels)}

def format(self, target: Any) -> Any:
return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()]
return self.label_to_idx[_strip(target[0] if not isinstance(target, str) else target)]


class MultiLabelTargetFormatter(SingleLabelTargetFormatter):
Expand All @@ -201,6 +210,11 @@ def format(self, target: Any) -> Any:
return super().format(target.split(","))


class SpaceDelimitedTargetFormatter(MultiLabelTargetFormatter):
def format(self, target: Any) -> Any:
return super().format(target.split(" "))


class MultiNumericTargetFormatter(TargetFormatter):
def __init__(self, num_classes: int):
self.num_classes = num_classes
Expand Down Expand Up @@ -245,6 +259,8 @@ def get_target_formatter(
return SingleLabelTargetFormatter(labels)
elif target_mode is TargetMode.MUTLI_COMMA_DELIMITED:
return CommaDelimitedTargetFormatter(labels)
elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED:
return SpaceDelimitedTargetFormatter(labels)
return MultiLabelTargetFormatter(labels)


Expand Down Expand Up @@ -289,13 +305,16 @@ def get_target_details(targets: List[Any], target_mode: TargetMode) -> Tuple[Opt
if target_mode is TargetMode.MUTLI_COMMA_DELIMITED:
for target in targets:
tokens.extend(target.split(","))
elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED:
for target in targets:
tokens.extend(target.split(" "))
elif target_mode is TargetMode.MULTI_TOKEN:
for target in targets:
tokens.extend(target)
else:
tokens = targets

tokens = [token.strip() for token in tokens]
tokens = [_strip(token) for token in tokens]
labels = list(sorted_alphanumeric(set(tokens)))
num_classes = len(labels)
return labels, num_classes
7 changes: 4 additions & 3 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ def from_voc(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "ObjectDetectionData":
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
and annotation files in the `PASCAL VOC (Visual Obect Challenge)
<http://host.robots.ox.ac.uk/pascal/VOC/>`_ XML format.
""".. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/
Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format <PASCAL>`_.
Args:
train_folder: The folder containing the train data.
Expand Down
29 changes: 25 additions & 4 deletions tests/core/data/utilities/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue green", "green red", "red blue"],
[[1, 1, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_SPACE_DELIMITED,
["blue", "green", "red"],
3,
),
# Ambiguous
Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_NUMERIC, None, 3),
Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_BINARY, None, 3),
Expand All @@ -67,6 +74,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue", "green red", "red blue"],
[[1, 0, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_SPACE_DELIMITED,
["blue", "green", "red"],
3,
),
# Special cases
Case(["blue ", " green", "red"], [0, 1, 2], TargetMode.SINGLE_TOKEN, ["blue", "green", "red"], 3),
Case(
Expand All @@ -76,6 +90,13 @@
["blue", "green", "red"],
3,
),
Case(
["blue", "green ,red", "red ,blue"],
[[1, 0, 0], [0, 1, 1], [1, 0, 1]],
TargetMode.MUTLI_COMMA_DELIMITED,
["blue", "green", "red"],
3,
),
Case(
[f"class_{i}" for i in range(10000)],
list(range(10000)),
Expand Down Expand Up @@ -115,16 +136,16 @@ def test_speed(case):
else:
targets = case.target * repeats

start = time.time()
start = time.perf_counter()
target_mode = get_target_mode(targets)
labels, num_classes = get_target_details(targets, target_mode)
formatter = get_target_formatter(target_mode, labels, num_classes)
end = time.time()
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target

start = time.time()
start = time.perf_counter()
_ = [formatter(t) for t in targets]
end = time.time()
end = time.perf_counter()

assert (end - start) / len(targets) < 1e-5 # 0.01ms per target

0 comments on commit 6f3dc08

Please sign in to comment.