diff --git a/tests/test_general.py b/tests/test_general.py index d338077..dd10224 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -10,8 +10,6 @@ StopwordRemover, ) -import sys - class StemmerTest(unittest.TestCase): def setUp(self): @@ -230,14 +228,13 @@ def test_remove_stopwords(self): ) def test_dynamic_stopwords(self): - py_version = int(sys.version.split('.')[1]) dsw = self.stopword_remover.dynamically_detect_stop_words( "ben bugün gidip aşı olacağım sonra da eve gelip telefon açacağım aşı nasıl etkiledi eve gelip anlatırım aşı olmak bu dönemde çok ama ama ama ama çok önemli".split() ) - expected = ["ama", "aşı", "çok", "eve"] - if py_version <= 8: #Sorting algorithm returns different results from python 3.8+ on - expected = ["ama", "aşı", "gelip", "eve"] - self.assertEqual(dsw, expected) + expected = ['ama', 'aşı', 'çok', 'eve', 'gelip'] + + # Converted to set since order is not stable + self.assertEqual(set(dsw), set(expected)) self.stopword_remover.add_to_stop_words(dsw) self.assertEqual( self.stopword_remover.drop_stop_words( diff --git a/vnlp/stopword_remover/stopword_remover.py b/vnlp/stopword_remover/stopword_remover.py index dc131ab..7d1c4e2 100644 --- a/vnlp/stopword_remover/stopword_remover.py +++ b/vnlp/stopword_remover/stopword_remover.py @@ -1,4 +1,5 @@ from typing import List + from pathlib import Path import numpy as np @@ -59,6 +60,9 @@ def dynamically_detect_stop_words( ['ama', 'aşı', 'gelip', 'eve'] """ unq, cnts = np.unique(list_of_tokens, return_counts=True) + # Edgecase: Every word used once + if len(unq) == list_of_tokens: + return [] sorted_indices = cnts.argsort()[ ::-1 ] # I need them in descending order @@ -83,8 +87,12 @@ def dynamically_detect_stop_words( ] # removing nan argmax_second_der = np.argmax(pct_change_two) + # Correction term since argmax finds first occurence + amount_of_max = np.sum(cnts == cnts[argmax_second_der]) + # +2 is due to shifting twice due to np.diff() - detected_stop_words = unq[: argmax_second_der + 2].tolist() + # -1 is added to correctly find all values + detected_stop_words = unq[: argmax_second_der + amount_of_max].tolist() # Determine rare_words according to given rare_words_freq value # Add them to dynamic_stop_words list