From 7948aaea4c2105a355bbae6b2204d987e58b0a70 Mon Sep 17 00:00:00 2001 From: Philip May Date: Fri, 22 Dec 2023 21:30:53 +0100 Subject: [PATCH] Add TextDistance class for calculating distance between texts. (#113) * Add TextDistance class for calculating cosine distance between texts * Refactor TextDistance class to use Manhattan distance and normalize character counters * Add pytest markers to test functions * Skip loading prostate and leukemia big datasets due to issue #118 * Skip leukemia big test due to issue #118 * Add counted_char_set to TextDistance class * Normalize counter and calculate Manhatten distance * Add additional tests for TextDistance class This commit adds three new test cases to the `test_text.py` file. The new tests verify the functionality of the `TextDistance` class by checking the distance calculation for different input texts. The tests cover scenarios such as orthogonal texts, extended texts, and an exception case. These new tests enhance the test coverage and ensure the accuracy of the `TextDistance` class. * add validation for max_dimensions parameter in TextDistance class * Add test cases for TextDistance class * Update TextDistance class documentation and add type hints --- mltb2/text.py | 112 ++++++++++++++++++++++++++++++++++++++++++++- tests/test_data.py | 5 ++ tests/test_text.py | 74 ++++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 1 deletion(-) diff --git a/mltb2/text.py b/mltb2/text.py index d721e87..9858d6e 100644 --- a/mltb2/text.py +++ b/mltb2/text.py @@ -12,7 +12,12 @@ """ import re -from typing import Dict, Final, Pattern, Tuple +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from typing import Dict, Final, Iterable, Optional, Pattern, Set, Tuple, Union + +from scipy.spatial.distance import cityblock +from tqdm import tqdm INVISIBLE_CHARACTERS: Final[Tuple[str, ...]] = ( "\u200b", # Zero Width Space (ZWSP) https://www.compart.com/en/unicode/U+200b @@ -138,3 +143,108 @@ def clean_all_invisible_chars_and_whitespaces(text: str) -> str: text = replace_multiple_whitespaces(text) text = text.strip() return text + + +def _normalize_counter_to_defaultdict(counter: Counter, max_dimensions: int) -> defaultdict: + """Normalize a counter to to ``max_dimensions``. + + The number of dimensions is limited to ``max_dimensions`` + of the most commen characters. + The counter values are normalized by deviding them by the total count. + + Args: + counter: The counter to normalize. + max_dimensions: The maximum number of dimensions to use for the normalization. + Must be greater than 0. + Returns: + The normalized counter with a maximum of ``max_dimensions`` dimensions. + """ + total_count = sum(counter.values()) + normalized_counter = defaultdict(float) + for char, count in counter.most_common(max_dimensions): + normalized_counter[char] = count / total_count + return normalized_counter + + +@dataclass +class TextDistance: + """Calculate the distance between two texts. + + One text (or multiple texts) must first be fitted with :func:`~TextDistance.fit`. + After that the distance to other given texts can be calculated with :func:`~TextDistance.distance`. + After the distance was calculated the first time, the class can + not be fitted again. + + Args: + show_progress_bar: Show a progressbar during processing. + max_dimensions: The maximum number of dimensions to use for the distance calculation. + Must be greater than 0. + Raises: + ValueError: If ``max_dimensions`` is not greater than 0. + """ + + show_progress_bar: bool = False + max_dimensions: int = 100 + + # counter for the text we fit + _char_counter: Optional[Counter] = field(default_factory=Counter, init=False) + + # normalized counter for the text we fit - see _normalize_char_counter + _normalized_char_counts: Optional[defaultdict] = field(default=None, init=False) + + # set of all counted characters - see _normalize_char_counter + _counted_char_set: Optional[Set[str]] = field(default=None, init=False) + + def __post_init__(self) -> None: + """Do post init.""" + if not self.max_dimensions > 0: + raise ValueError("'max_dimensions' must be > 0!") + + def fit(self, text: Union[str, Iterable[str]]) -> None: + """Fit the text. + + Args: + text: The text to fit. + Raises: + ValueError: If :func:`~TextDistance.fit` is called after + :func:`~TextDistance.distance`. + """ + if self._char_counter is None: + raise ValueError("Fit mut not be called after distance calculation!") + + if isinstance(text, str): + self._char_counter.update(text) + else: + for t in tqdm(text, disable=not self.show_progress_bar): + self._char_counter.update(t) + + def _normalize_char_counter(self) -> None: + """Normalize the char counter to a defaultdict. + + This supports lazy postprocessing of the char counter. + """ + if self._char_counter is not None: + self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions) + self._char_counter = None + self._counted_char_set = set(self._normalized_char_counts) + + def distance(self, text) -> float: + """Calculate the distance between the fitted text and the given text. + + This implementation uses the Manhattan distance (:func:`scipy.spatial.distance.cityblock`). + The distance is only calculated for ``max_dimensions`` most commen characters. + + Args: + text: The text to calculate the Manhattan distance to. + """ + self._normalize_char_counter() + all_vector = [] + text_vector = [] + text_count = Counter(text) + text_count_defaultdict = _normalize_counter_to_defaultdict(text_count, self.max_dimensions) + for c in self._counted_char_set.union(text_count_defaultdict): # type: ignore + all_vector.append( + self._normalized_char_counts[c] # type: ignore + ) # if c is not in defaultdict, it will return 0 + text_vector.append(text_count_defaultdict[c]) # if c is not in defaultdict, it will return 0 + return cityblock(all_vector, text_vector) diff --git a/tests/test_data.py b/tests/test_data.py index 3ea9165..dd8db57 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ # which is available at https://opensource.org/licenses/MIT import pandas as pd +import pytest from numpy.testing import assert_almost_equal from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate @@ -44,6 +45,7 @@ def test_load_colon_compare_original(): assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_prostate(): result = load_prostate() assert result is not None @@ -55,6 +57,7 @@ def test_load_prostate(): assert result[1].shape == (102, 6033) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_prostate_compare_original(): result = load_prostate() ori_result = load_prostate_data() @@ -64,6 +67,7 @@ def test_load_prostate_compare_original(): assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_leukemia_big(): result = load_leukemia_big() assert result is not None @@ -75,6 +79,7 @@ def test_load_leukemia_big(): assert result[1].shape == (72, 7128) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_leukemia_big_compare_original(): result = load_leukemia_big() ori_result = load_leukemia_data() diff --git a/tests/test_text.py b/tests/test_text.py index 708e987..d3e2b49 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -2,11 +2,16 @@ # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT +from collections import Counter, defaultdict +from math import isclose + import pytest from mltb2.text import ( INVISIBLE_CHARACTERS, SPECIAL_WHITESPACES, + TextDistance, + _normalize_counter_to_defaultdict, clean_all_invisible_chars_and_whitespaces, has_invisible_characters, has_special_whitespaces, @@ -112,3 +117,72 @@ def test_clean_all_invisible_chars_and_whitespaces_empty_result(): text = " \u200b\u00ad\u2007 " result = clean_all_invisible_chars_and_whitespaces(text) assert result == "" + + +def test_text_distance_distance_same(): + text = "Hello World!" + td = TextDistance() + td.fit(text) + assert len(td._char_counter) == 9 + assert td._normalized_char_counts is None + assert td._counted_char_set is None + distance = td.distance(text) + assert td._char_counter is None # none after fit + assert td._normalized_char_counts is not None + assert td._counted_char_set is not None + + assert isclose(distance, 0.0), distance + + +def test_text_distance_orthogonal(): + text = "ab" + td = TextDistance() + td.fit(text) + distance = td.distance("xy") + assert distance > 0.0, distance + assert isclose(distance, 2.0), distance + + +def test_text_distance_extended(): + text = "aabbbb" # a:1/3, b:2/3 + td = TextDistance() + td.fit(text) + distance = td.distance("bbcccc") # b:1/3, c:2/3 + assert distance > 0.0, distance + assert isclose(distance, 1 / 3 + 1 / 3 + 2 / 3), distance + + +def test_text_distance_fit_not_allowed_after_distance(): + text = "Hello World!" + td = TextDistance() + td.fit(text) + _ = td.distance(text) + with pytest.raises(ValueError): + td.fit("Hello World") + + +def test_text_distance_max_dimensions_must_be_greater_zero(): + with pytest.raises(ValueError): + _ = TextDistance(max_dimensions=0) + + +def test_normalize_counter_to_defaultdict(): + counter = Counter("aaaabbbcc") + max_dimensions = 2 + normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions) + + assert isinstance(normalized_counter, defaultdict) + assert len(normalized_counter) == max_dimensions + assert isclose(normalized_counter["a"], 4 / 9) + assert isclose(normalized_counter["b"], 3 / 9) + assert "c" not in normalized_counter + assert len(normalized_counter) == max_dimensions + + +def test_normalize_counter_to_defaultdict_empty_counter(): + counter = Counter() + max_dimensions = 2 + normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions) + + assert isinstance(normalized_counter, defaultdict) + assert len(normalized_counter) == 0