Skip to content

Commit

Permalink
maintenance: better syntax and simplified code (#10)
Browse files Browse the repository at this point in the history
* better syntax and simplified code

* simplify syntax
  • Loading branch information
adbar authored Jun 18, 2024
1 parent 1817a4c commit 1aa0c67
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions py3langid/langid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from base64 import b64decode
from collections import Counter
from operator import itemgetter
from pathlib import Path
from urllib.parse import parse_qs

Expand All @@ -33,6 +34,9 @@
# affect the relative ordering of the predicted classes. It can be
# re-enabled at runtime - see the readme.

# quantization: faster but less precise
DATATYPE = "uint16"


def load_model(path=None):
"""
Expand Down Expand Up @@ -60,7 +64,7 @@ def set_languages(langs=None):
return IDENTIFIER.set_languages(langs)


def classify(instance, datatype='uint16'):
def classify(instance, datatype=DATATYPE):
"""
Convenience method using a global identifier instance with the default
model included in langid.py. Identifies the language that a string is
Expand Down Expand Up @@ -198,9 +202,7 @@ def set_languages(self, langs=None):
nb_ptc, nb_pc, nb_classes = self.__full_model

if langs is None:
self.nb_classes = nb_classes
self.nb_ptc = nb_ptc
self.nb_pc = nb_pc
self.nb_classes, self.nb_ptc, self.nb_pc = nb_classes, nb_ptc, nb_pc

else:
# We were passed a restricted set of languages. Trim the arrays accordingly
Expand All @@ -209,12 +211,12 @@ def set_languages(self, langs=None):
if lang not in nb_classes:
raise ValueError(f"Unknown language code {lang}")

subset_mask = np.fromiter((l in langs for l in nb_classes), dtype=bool)
subset_mask = np.isin(nb_classes, langs)
self.nb_classes = [c for c in nb_classes if c in langs]
self.nb_ptc = nb_ptc[:, subset_mask]
self.nb_pc = nb_pc[subset_mask]

def instance2fv(self, text, datatype='uint16'):
def instance2fv(self, text, datatype=DATATYPE):
"""
Map an instance into the feature space of the trained model.
Expand All @@ -227,11 +229,12 @@ def instance2fv(self, text, datatype='uint16'):

# Convert the text to a sequence of ascii values and
# Count the number of times we enter each state
state = 0
indexes = []
for letter in list(text):
state, indexes = 0, []
extend = indexes.extend

for letter in text:
state = self.tk_nextmove[(state << 8) + letter]
indexes.extend(self.tk_output.get(state, []))
extend(self.tk_output.get(state, []))

# datatype: consider that less feature counts are going to be needed
arr = np.zeros(self.nb_numfeats, dtype=datatype)
Expand All @@ -247,7 +250,7 @@ def nb_classprobs(self, fv):
# compute the partial log-probability of the document in each class
return pdc + self.nb_pc

def classify(self, text, datatype='uint16'):
def classify(self, text, datatype=DATATYPE):
"""
Classify an instance.
"""
Expand All @@ -262,7 +265,7 @@ def rank(self, text):
"""
fv = self.instance2fv(text)
probs = self.norm_probs(self.nb_classprobs(fv))
return [(str(k), float(v)) for (v, k) in sorted(zip(probs, self.nb_classes), reverse=True)]
return sorted(zip(self.nb_classes, probs), key=itemgetter(1), reverse=True)

def cl_path(self, path):
"""
Expand Down

0 comments on commit 1aa0c67

Please sign in to comment.