From 6f3dc08913c07ab49b0b5abab131050aee3bdb5c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 15 Dec 2021 13:18:03 +0000 Subject: [PATCH] Add support for space delimited multi-label targets (#1076) --- CHANGELOG.md | 2 ++ flash/core/data/utilities/classification.py | 27 ++++++++++++++--- flash/image/detection/data.py | 7 +++-- .../data/utilities/test_classification.py | 29 ++++++++++++++++--- 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20efe5a897..10fec32a22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for multi-label, space delimited, targets ([#1076](https://github.com/PyTorchLightning/lightning-flash/pull/1076)) + ### Changed ### Deprecated diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py index c1a2916b25..ce3e0f07a0 100644 --- a/flash/core/data/utilities/classification.py +++ b/flash/core/data/utilities/classification.py @@ -36,12 +36,17 @@ def _as_list(x: Union[List, torch.Tensor, np.ndarray]) -> List: return x +def _strip(x: str) -> str: + return x.strip(", ") + + class TargetMode(Enum): """The ``TargetMode`` Enum describes the different supported formats for targets in Flash.""" MULTI_TOKEN = auto() MULTI_NUMERIC = auto() MUTLI_COMMA_DELIMITED = auto() + MUTLI_SPACE_DELIMITED = auto() MULTI_BINARY = auto() SINGLE_TOKEN = auto() @@ -67,9 +72,12 @@ def from_target(cls, target: Any) -> "TargetMode": target: A target that is one of: a single target, a list of targets, a comma delimited string. """ if isinstance(target, str): - # TODO: This could be a dangerous assumption if people happen to have a label that contains a comma + target = _strip(target) + # TODO: This could be a dangerous assumption if people happen to have a label that contains a comma or space if "," in target: return TargetMode.MUTLI_COMMA_DELIMITED + elif " " in target: + return TargetMode.MUTLI_SPACE_DELIMITED else: return TargetMode.SINGLE_TOKEN elif _is_list_like(target): @@ -88,6 +96,7 @@ def multi_label(self) -> bool: return any( [ self is TargetMode.MUTLI_COMMA_DELIMITED, + self is TargetMode.MUTLI_SPACE_DELIMITED, self is TargetMode.MULTI_NUMERIC, self is TargetMode.MULTI_TOKEN, self is TargetMode.MULTI_BINARY, @@ -116,7 +125,7 @@ def binary(self) -> bool: _RESOLUTION_MAPPING = { TargetMode.MULTI_BINARY: [TargetMode.MULTI_NUMERIC], TargetMode.SINGLE_BINARY: [TargetMode.MULTI_BINARY, TargetMode.MULTI_NUMERIC], - TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED], + TargetMode.SINGLE_TOKEN: [TargetMode.MUTLI_COMMA_DELIMITED, TargetMode.MUTLI_SPACE_DELIMITED], TargetMode.SINGLE_NUMERIC: [TargetMode.MULTI_NUMERIC], } @@ -179,7 +188,7 @@ def __init__(self, labels: List[Any]): self.label_to_idx = {label: idx for idx, label in enumerate(labels)} def format(self, target: Any) -> Any: - return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()] + return self.label_to_idx[_strip(target[0] if not isinstance(target, str) else target)] class MultiLabelTargetFormatter(SingleLabelTargetFormatter): @@ -201,6 +210,11 @@ def format(self, target: Any) -> Any: return super().format(target.split(",")) +class SpaceDelimitedTargetFormatter(MultiLabelTargetFormatter): + def format(self, target: Any) -> Any: + return super().format(target.split(" ")) + + class MultiNumericTargetFormatter(TargetFormatter): def __init__(self, num_classes: int): self.num_classes = num_classes @@ -245,6 +259,8 @@ def get_target_formatter( return SingleLabelTargetFormatter(labels) elif target_mode is TargetMode.MUTLI_COMMA_DELIMITED: return CommaDelimitedTargetFormatter(labels) + elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED: + return SpaceDelimitedTargetFormatter(labels) return MultiLabelTargetFormatter(labels) @@ -289,13 +305,16 @@ def get_target_details(targets: List[Any], target_mode: TargetMode) -> Tuple[Opt if target_mode is TargetMode.MUTLI_COMMA_DELIMITED: for target in targets: tokens.extend(target.split(",")) + elif target_mode is TargetMode.MUTLI_SPACE_DELIMITED: + for target in targets: + tokens.extend(target.split(" ")) elif target_mode is TargetMode.MULTI_TOKEN: for target in targets: tokens.extend(target) else: tokens = targets - tokens = [token.strip() for token in tokens] + tokens = [_strip(token) for token in tokens] labels = list(sorted_alphanumeric(set(tokens))) num_classes = len(labels) return labels, num_classes diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 8cd9e890d0..b8e940ae2c 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -151,9 +151,10 @@ def from_voc( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - and annotation files in the `PASCAL VOC (Visual Obect Challenge) - `_ XML format. + """.. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ + + Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format `_. Args: train_folder: The folder containing the train data. diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py index 3f57198e2d..0674740f90 100644 --- a/tests/core/data/utilities/test_classification.py +++ b/tests/core/data/utilities/test_classification.py @@ -50,6 +50,13 @@ ["blue", "green", "red"], 3, ), + Case( + ["blue green", "green red", "red blue"], + [[1, 1, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_SPACE_DELIMITED, + ["blue", "green", "red"], + 3, + ), # Ambiguous Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_NUMERIC, None, 3), Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], TargetMode.MULTI_BINARY, None, 3), @@ -67,6 +74,13 @@ ["blue", "green", "red"], 3, ), + Case( + ["blue", "green red", "red blue"], + [[1, 0, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_SPACE_DELIMITED, + ["blue", "green", "red"], + 3, + ), # Special cases Case(["blue ", " green", "red"], [0, 1, 2], TargetMode.SINGLE_TOKEN, ["blue", "green", "red"], 3), Case( @@ -76,6 +90,13 @@ ["blue", "green", "red"], 3, ), + Case( + ["blue", "green ,red", "red ,blue"], + [[1, 0, 0], [0, 1, 1], [1, 0, 1]], + TargetMode.MUTLI_COMMA_DELIMITED, + ["blue", "green", "red"], + 3, + ), Case( [f"class_{i}" for i in range(10000)], list(range(10000)), @@ -115,16 +136,16 @@ def test_speed(case): else: targets = case.target * repeats - start = time.time() + start = time.perf_counter() target_mode = get_target_mode(targets) labels, num_classes = get_target_details(targets, target_mode) formatter = get_target_formatter(target_mode, labels, num_classes) - end = time.time() + end = time.perf_counter() assert (end - start) / len(targets) < 1e-5 # 0.01ms per target - start = time.time() + start = time.perf_counter() _ = [formatter(t) for t in targets] - end = time.time() + end = time.perf_counter() assert (end - start) / len(targets) < 1e-5 # 0.01ms per target