Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fit and distance flags to TextDistance class for call order checks. #122

Merged
merged 2 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions mltb2/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class TextDistance:
# set of all counted characters - see _normalize_char_counter
_counted_char_set: Optional[Set[str]] = field(default=None, init=False)

# flag if fit was called
_fit_called: bool = field(default=False, init=False)

# flag if distance was called
_distance_called: bool = field(default=False, init=False)

def __post_init__(self) -> None:
"""Do post init."""
if not self.max_dimensions > 0:
Expand All @@ -209,24 +215,27 @@ def fit(self, text: Union[str, Iterable[str]]) -> None:
ValueError: If :func:`~TextDistance.fit` is called after
:func:`~TextDistance.distance`.
"""
if self._char_counter is None:
raise ValueError("Fit mut not be called after distance calculation!")
if self._distance_called:
raise ValueError("fit must not be called after distance calculation!")

if isinstance(text, str):
self._char_counter.update(text)
self._char_counter.update(text) # type: ignore
else:
for t in tqdm(text, disable=not self.show_progress_bar):
self._char_counter.update(t)
self._char_counter.update(t) # type: ignore

self._fit_called = True

def _normalize_char_counter(self) -> None:
"""Normalize the char counter to a defaultdict.

This supports lazy postprocessing of the char counter.
"""
if self._char_counter is not None:
self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions)
if not self._distance_called:
self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions) # type: ignore
self._char_counter = None
self._counted_char_set = set(self._normalized_char_counts)
self._distance_called = True

def distance(self, text) -> float:
"""Calculate the distance between the fitted text and the given text.
Expand All @@ -237,6 +246,8 @@ def distance(self, text) -> float:
Args:
text: The text to calculate the Manhattan distance to.
"""
if not self._fit_called:
raise ValueError("fit must not be called before distance!")
self._normalize_char_counter()
all_vector = []
text_vector = []
Expand Down
7 changes: 7 additions & 0 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def test_text_distance_fit_not_allowed_after_distance():
td.fit("Hello World")


def test_text_distance_distance_not_allowed_before_fit():
text = "Hello World!"
td = TextDistance()
with pytest.raises(ValueError):
_ = td.distance(text)


def test_text_distance_max_dimensions_must_be_greater_zero():
with pytest.raises(ValueError):
_ = TextDistance(max_dimensions=0)
Expand Down