Skip to content

Commit

Permalink
Add TextDistance class for calculating distance between texts. (#113)
Browse files Browse the repository at this point in the history
* Add TextDistance class for calculating cosine distance between texts

* Refactor TextDistance class to use Manhattan distance and normalize character counters

* Add pytest markers to test functions

* Skip loading prostate and leukemia big datasets due to issue #118

* Skip leukemia big test due to issue #118

* Add counted_char_set to TextDistance class

* Normalize counter and calculate Manhatten distance

* Add additional tests for TextDistance class

This commit adds three new test cases to the `test_text.py` file. The new tests verify the functionality of the `TextDistance` class by checking the distance calculation for different input texts. The tests cover scenarios such as orthogonal texts, extended texts, and an exception case. These new tests enhance the test coverage and ensure the accuracy of the `TextDistance` class.

* add validation for max_dimensions parameter in TextDistance class

* Add test cases for TextDistance class

* Update TextDistance class documentation and add type hints
  • Loading branch information
PhilipMay authored Dec 22, 2023
1 parent 5e788f0 commit 7948aae
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 1 deletion.
112 changes: 111 additions & 1 deletion mltb2/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
"""

import re
from typing import Dict, Final, Pattern, Tuple
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from typing import Dict, Final, Iterable, Optional, Pattern, Set, Tuple, Union

from scipy.spatial.distance import cityblock
from tqdm import tqdm

INVISIBLE_CHARACTERS: Final[Tuple[str, ...]] = (
"\u200b", # Zero Width Space (ZWSP) https://www.compart.com/en/unicode/U+200b
Expand Down Expand Up @@ -138,3 +143,108 @@ def clean_all_invisible_chars_and_whitespaces(text: str) -> str:
text = replace_multiple_whitespaces(text)
text = text.strip()
return text


def _normalize_counter_to_defaultdict(counter: Counter, max_dimensions: int) -> defaultdict:
"""Normalize a counter to to ``max_dimensions``.
The number of dimensions is limited to ``max_dimensions``
of the most commen characters.
The counter values are normalized by deviding them by the total count.
Args:
counter: The counter to normalize.
max_dimensions: The maximum number of dimensions to use for the normalization.
Must be greater than 0.
Returns:
The normalized counter with a maximum of ``max_dimensions`` dimensions.
"""
total_count = sum(counter.values())
normalized_counter = defaultdict(float)
for char, count in counter.most_common(max_dimensions):
normalized_counter[char] = count / total_count
return normalized_counter


@dataclass
class TextDistance:
"""Calculate the distance between two texts.
One text (or multiple texts) must first be fitted with :func:`~TextDistance.fit`.
After that the distance to other given texts can be calculated with :func:`~TextDistance.distance`.
After the distance was calculated the first time, the class can
not be fitted again.
Args:
show_progress_bar: Show a progressbar during processing.
max_dimensions: The maximum number of dimensions to use for the distance calculation.
Must be greater than 0.
Raises:
ValueError: If ``max_dimensions`` is not greater than 0.
"""

show_progress_bar: bool = False
max_dimensions: int = 100

# counter for the text we fit
_char_counter: Optional[Counter] = field(default_factory=Counter, init=False)

# normalized counter for the text we fit - see _normalize_char_counter
_normalized_char_counts: Optional[defaultdict] = field(default=None, init=False)

# set of all counted characters - see _normalize_char_counter
_counted_char_set: Optional[Set[str]] = field(default=None, init=False)

def __post_init__(self) -> None:
"""Do post init."""
if not self.max_dimensions > 0:
raise ValueError("'max_dimensions' must be > 0!")

def fit(self, text: Union[str, Iterable[str]]) -> None:
"""Fit the text.
Args:
text: The text to fit.
Raises:
ValueError: If :func:`~TextDistance.fit` is called after
:func:`~TextDistance.distance`.
"""
if self._char_counter is None:
raise ValueError("Fit mut not be called after distance calculation!")

if isinstance(text, str):
self._char_counter.update(text)
else:
for t in tqdm(text, disable=not self.show_progress_bar):
self._char_counter.update(t)

def _normalize_char_counter(self) -> None:
"""Normalize the char counter to a defaultdict.
This supports lazy postprocessing of the char counter.
"""
if self._char_counter is not None:
self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions)
self._char_counter = None
self._counted_char_set = set(self._normalized_char_counts)

def distance(self, text) -> float:
"""Calculate the distance between the fitted text and the given text.
This implementation uses the Manhattan distance (:func:`scipy.spatial.distance.cityblock`).
The distance is only calculated for ``max_dimensions`` most commen characters.
Args:
text: The text to calculate the Manhattan distance to.
"""
self._normalize_char_counter()
all_vector = []
text_vector = []
text_count = Counter(text)
text_count_defaultdict = _normalize_counter_to_defaultdict(text_count, self.max_dimensions)
for c in self._counted_char_set.union(text_count_defaultdict): # type: ignore
all_vector.append(
self._normalized_char_counts[c] # type: ignore
) # if c is not in defaultdict, it will return 0
text_vector.append(text_count_defaultdict[c]) # if c is not in defaultdict, it will return 0
return cityblock(all_vector, text_vector)
5 changes: 5 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# which is available at https://opensource.org/licenses/MIT

import pandas as pd
import pytest
from numpy.testing import assert_almost_equal

from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate
Expand Down Expand Up @@ -44,6 +45,7 @@ def test_load_colon_compare_original():
assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())


@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118")
def test_load_prostate():
result = load_prostate()
assert result is not None
Expand All @@ -55,6 +57,7 @@ def test_load_prostate():
assert result[1].shape == (102, 6033)


@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118")
def test_load_prostate_compare_original():
result = load_prostate()
ori_result = load_prostate_data()
Expand All @@ -64,6 +67,7 @@ def test_load_prostate_compare_original():
assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())


@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118")
def test_load_leukemia_big():
result = load_leukemia_big()
assert result is not None
Expand All @@ -75,6 +79,7 @@ def test_load_leukemia_big():
assert result[1].shape == (72, 7128)


@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118")
def test_load_leukemia_big_compare_original():
result = load_leukemia_big()
ori_result = load_leukemia_data()
Expand Down
74 changes: 74 additions & 0 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

from collections import Counter, defaultdict
from math import isclose

import pytest

from mltb2.text import (
INVISIBLE_CHARACTERS,
SPECIAL_WHITESPACES,
TextDistance,
_normalize_counter_to_defaultdict,
clean_all_invisible_chars_and_whitespaces,
has_invisible_characters,
has_special_whitespaces,
Expand Down Expand Up @@ -112,3 +117,72 @@ def test_clean_all_invisible_chars_and_whitespaces_empty_result():
text = " \u200b\u00ad\u2007 "
result = clean_all_invisible_chars_and_whitespaces(text)
assert result == ""


def test_text_distance_distance_same():
text = "Hello World!"
td = TextDistance()
td.fit(text)
assert len(td._char_counter) == 9
assert td._normalized_char_counts is None
assert td._counted_char_set is None
distance = td.distance(text)
assert td._char_counter is None # none after fit
assert td._normalized_char_counts is not None
assert td._counted_char_set is not None

assert isclose(distance, 0.0), distance


def test_text_distance_orthogonal():
text = "ab"
td = TextDistance()
td.fit(text)
distance = td.distance("xy")
assert distance > 0.0, distance
assert isclose(distance, 2.0), distance


def test_text_distance_extended():
text = "aabbbb" # a:1/3, b:2/3
td = TextDistance()
td.fit(text)
distance = td.distance("bbcccc") # b:1/3, c:2/3
assert distance > 0.0, distance
assert isclose(distance, 1 / 3 + 1 / 3 + 2 / 3), distance


def test_text_distance_fit_not_allowed_after_distance():
text = "Hello World!"
td = TextDistance()
td.fit(text)
_ = td.distance(text)
with pytest.raises(ValueError):
td.fit("Hello World")


def test_text_distance_max_dimensions_must_be_greater_zero():
with pytest.raises(ValueError):
_ = TextDistance(max_dimensions=0)


def test_normalize_counter_to_defaultdict():
counter = Counter("aaaabbbcc")
max_dimensions = 2
normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions)

assert isinstance(normalized_counter, defaultdict)
assert len(normalized_counter) == max_dimensions
assert isclose(normalized_counter["a"], 4 / 9)
assert isclose(normalized_counter["b"], 3 / 9)
assert "c" not in normalized_counter
assert len(normalized_counter) == max_dimensions


def test_normalize_counter_to_defaultdict_empty_counter():
counter = Counter()
max_dimensions = 2
normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions)

assert isinstance(normalized_counter, defaultdict)
assert len(normalized_counter) == 0

0 comments on commit 7948aae

Please sign in to comment.