Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/better dictionaries #135

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions simplemma/strategies/dictionaries/dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from os import listdir, path
from pathlib import Path
from typing import ByteString, Dict, Protocol
from typing import Dict, Mapping, Protocol

DATA_FOLDER = str(Path(__file__).parent / "data")
SUPPORTED_LANGUAGES = [
Expand All @@ -24,7 +24,7 @@
]


def _load_dictionary_from_disk(langcode: str) -> Dict[ByteString, ByteString]:
def _load_dictionary_from_disk(langcode: str) -> Dict[bytes, bytes]:
"""
Load a dictionary from disk.

Expand Down Expand Up @@ -62,22 +62,41 @@ class DictionaryFactory(Protocol):
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.

Args:
lang (str): The language code.

Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.

Raises:
ValueError: If the specified language is not supported.
"""
raise NotImplementedError


class MappingStrToByteString(Mapping[str, str]):
"""Wrapper around ByString dict to make them behave like str dict."""

__slots__ = ["_dict"]

def __init__(self, dict: Dict[bytes, bytes]):
self._dict = dict

def __getitem__(self, item: str):
return self._dict[item.encode()].decode()

def __iter__(self):
for key in self._dict.iterkeys():
yield key.decode()

def __len__(self):
return len(self._dict)


class DefaultDictionaryFactory(DictionaryFactory):
"""
Default Dictionary Factory.
Expand All @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory):
It provides functionality for loading and caching dictionaries from disk that are included in Simplemma.
"""

__slots__ = ["_data", "_load_dictionary_from_disk"]
__slots__ = ["_load_dictionary_from_disk"]

def __init__(self, cache_max_size: int = 8):
"""
Expand All @@ -96,27 +115,26 @@ def __init__(self, cache_max_size: int = 8):
cache_max_size (int): The maximum size of the cache for loaded dictionaries.
Defaults to `8`.
"""
self._data: Dict[str, Dict[ByteString, ByteString]] = {}
self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)(
_load_dictionary_from_disk
)

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.

Args:
lang (str): The language code.

Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.

Raises:
ValueError: If the specified language is not supported.
"""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")
return self._load_dictionary_from_disk(lang)
return MappingStrToByteString(self._load_dictionary_from_disk(lang))
120 changes: 120 additions & 0 deletions simplemma/strategies/dictionaries/trie_dictionary_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from functools import lru_cache
from pathlib import Path
import tempfile
from typing import ByteString, Dict, Mapping, Optional

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ByteString' is not used.
Import of 'Dict' is not used.

from marisa_trie import RecordTrie, HUGE_CACHE # type: ignore[import-not-found]

from simplemma import __version__ as SIMPLEMMA_VERSION
from simplemma.strategies.dictionaries.dictionary_factory import (
SUPPORTED_LANGUAGES,
DictionaryFactory,
DefaultDictionaryFactory,
)


class TrieWrapDict(Mapping[str, str]):
"""Wrapper around RecordTrie to make them behave like dicts."""

__slots__ = ["_trie"]

def __init__(self, trie: RecordTrie):
self._trie = trie

def __getitem__(self, item: str) -> str:
return self._trie[item][0]

def __iter__(self):
for key in self._trie.iterkeys():
yield key

def __len__(self):
return len(self._trie)


class TrieDictionaryFactory(DictionaryFactory):
"""Memory optimized DictionaryFactory backed by MARISA-tries.

This dictionary factory creates dictionaries, which are backed by a
MARISA-trie instead of a dict, to make them consume very little
memory compared to the DefaultDictionaryFactory. Trade-offs are that
lookup performance isn't as good as with dicts.
"""

__slots__ = ["_cache_dir", "_defaultDictionaryFactory", "_get_dictionary"]

def __init__(
self,
cache_max_size: int = 8,
use_disk_cache: bool = True,
disk_cache_dir: Optional[str] = None,
) -> None:
"""Initialize the TrieDictionaryFactory.

Args:
cache_max_size (int): The maximum number dictionaries to
keep in memory. Defaults to `8`.
use_disk_cache (bool): Whether to cache the tries on disk to
speed up loading time. Defaults to `True`.
disk_cache_dir (Optional[str]): Path where the generated
tries should be stored in. Defaults to a Simplemma-
specific subdirectory of the user's cache directory.
"""

self._defaultDictionaryFactory = DefaultDictionaryFactory(cache_max_size=0)
if use_disk_cache:
self._cache_dir: Optional[Path] = (
Path(disk_cache_dir)
if disk_cache_dir is not None
else (Path(tempfile.gettempdir()) / "simplemma" / SIMPLEMMA_VERSION)
)
else:
self._cache_dir = None

self._get_dictionary = lru_cache(maxsize=cache_max_size)(
self._get_dictionary_uncached
)

def _try_read_trie_from_disk(self, lang: str) -> bool:
"""Check if a trie for the given language is available on disk."""
if self._cache_dir is None:
return False
try:
return RecordTrie().load(str(self._cache_dir / f"{lang}.pkl"))
except FileNotFoundError:
return False

def _write_trie_to_disk(self, lang: str, trie: RecordTrie) -> None:
"""Persist the trie to disk for later usage.

The persisted trie can be loaded by subsequent runs to speed up
loading times.
"""
if self._cache_dir is None:
return

trie.save(self._cache_dir / f"{lang}.pkl")

def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]:
"""Get the dictionary for the given language."""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")

if self._cache_dir:
trie = RecordTrie().load(str(self._cache_dir / f"{lang}.pkl"))

if trie is None:
trie = RecordTrie(
self._defaultDictionaryFactory.get_dictionary(lang).items(),
cache_size=HUGE_CACHE,
)
if self._cache_dir:
self._write_trie_to_disk(lang, trie)

return TrieWrapDict(trie)

def get_dictionary(
self,
lang: str,
) -> Mapping[str, str]:
return self._get_dictionary(lang)
16 changes: 4 additions & 12 deletions simplemma/strategies/dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It provides lemmatization using dictionary lookup.
"""

from typing import ByteString, Dict, Optional
from typing import Optional

from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .lemmatization_strategy import LemmatizationStrategy
Expand All @@ -26,13 +26,6 @@ def __init__(
"""
self._dictionary_factory = dictionary_factory

def _get(
self, token: str, dictionary: Dict[ByteString, ByteString]
) -> Optional[str]:
"Convenience function to handle bytestring to string conversion."
result = dictionary.get(token.encode("utf-8"))
return result.decode("utf-8") if result else None # type: ignore[union-attr]

def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
Get Lemma using Dictionary Lookup
Expand All @@ -50,9 +43,8 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
# Search the language data, reverse case to extend coverage.
dictionary = self._dictionary_factory.get_dictionary(lang)
result = self._get(token, dictionary)
if result:
return result
if token in dictionary:
return dictionary[token]
# Try upper or lowercase.
token = token.lower() if token[0].isupper() else token.capitalize()
return self._get(token, dictionary)
return dictionary.get(token)
4 changes: 2 additions & 2 deletions simplemma/strategies/greedy_dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str:
return token

dictionary = self._dictionary_factory.get_dictionary(lang)
candidate = token.encode("utf-8")
candidate = token
for _ in range(self._steps):
if candidate not in dictionary:
break
Expand All @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str:

candidate = new_candidate

return candidate.decode("utf-8")
return candidate
8 changes: 2 additions & 6 deletions simplemma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple.
"""

from typing import ByteString, Tuple, Union
from typing import Tuple, Union


def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
Expand All @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
return lang # type: ignore[return-value]


def levenshtein_dist(
first: Union[ByteString, str], second: Union[ByteString, str]
) -> int:
def levenshtein_dist(str1: str, str2: str) -> int:
"""
Calculate the Levenshtein distance between two strings.

Expand All @@ -49,8 +47,6 @@ def levenshtein_dist(
int: The Levenshtein distance between the two strings.

"""
str1 = first.encode("utf-8") if isinstance(first, str) else first
str2 = second.encode("utf-8") if isinstance(second, str) else second
# inspired by this noticeably faster code:
# https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b
if str1 == str2:
Expand Down
5 changes: 2 additions & 3 deletions tests/test_dictionary_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_logic() -> None:
# different order
mydict = dictionary_pickler._read_dict(testfile, "es", silent=True)
assert len(mydict) == 5
assert mydict[b"closeones"] == b"closeone"
assert mydict["closeones"] == "closeone"
item = sorted(mydict.keys(), reverse=True)[0]
assert item == b"valid-word"
assert item == "valid-word"

# file I/O
assert dictionary_pickler._determine_path("lists", "de").endswith("de.txt")
Expand All @@ -37,4 +37,3 @@ def test_logic() -> None:
listpath = os.path.join(TEST_DIR, "data")
os_handle, temp_outputfile = tempfile.mkstemp(suffix=".pkl", text=True)
dictionary_pickler._pickle_dict("zz", listpath, temp_outputfile)
dictionary_pickler._pickle_dict("zz", listpath, in_place=True)
10 changes: 4 additions & 6 deletions tests/test_lemmatizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for `simplemma` package."""

from typing import ByteString, Dict
from typing import Dict

import pytest

Expand All @@ -17,8 +17,8 @@ class CustomDictionaryFactory(DictionaryFactory):
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
return {b"testing": b"the test works!!"}
) -> Dict[str, str]:
return {"testing": "the test works!!"}

assert (
Lemmatizer(
Expand Down Expand Up @@ -113,9 +113,7 @@ def test_readme() -> None:
".",
]
# error
assert Lemmatizer().lemmatize("スパゲッティ", lang="pt") == lemmatize(
"スパゲッティ", lang="pt"
)
assert Lemmatizer().lemmatize("スパゲッティ", lang="pt") == lemmatize("スパゲッティ", lang="pt")
assert lemmatize("スパゲッティ", lang="pt") == "スパゲッティ"

with pytest.raises(ValueError):
Expand Down
Loading
Loading