diff --git a/gensim/models/phrases.py b/gensim/models/phrases.py index c10aeed0e7..55424904f1 100644 --- a/gensim/models/phrases.py +++ b/gensim/models/phrases.py @@ -208,6 +208,13 @@ def load(cls, *args, **kwargs): """ model = super(PhrasesTransformation, cls).load(*args, **kwargs) # update older models + # if value in phrasegrams dict is a tuple, load only the scores. + + for component, score in getattr(model, "phrasegrams", {}).items(): + if isinstance(score, tuple): + frequency, score_val = score + model.phrasegrams[component] = score_val + # if no scoring parameter, use default scoring if not hasattr(model, 'scoring'): logger.info('older version of %s loaded without scoring function', cls.__name__) @@ -815,7 +822,7 @@ def __init__(self, phrases_model): for bigram, score in phrases_model.export_phrases(corpus, self.delimiter, as_tuples=True): if bigram in self.phrasegrams: logger.info('Phraser repeat %s', bigram) - self.phrasegrams[bigram] = (phrases_model.vocab[self.delimiter.join(bigram)], score) + self.phrasegrams[bigram] = score count += 1 if not count % 50000: logger.info('Phraser added %i phrasegrams', count) @@ -858,7 +865,7 @@ def score_item(self, worda, wordb, components, scorer): """ try: - return self.phrasegrams[tuple(components)][1] + return self.phrasegrams[tuple(components)] except KeyError: return -1 diff --git a/gensim/test/test_data/phraser-3.6.0.model b/gensim/test/test_data/phraser-3.6.0.model new file mode 100644 index 0000000000..4416b13867 Binary files /dev/null and b/gensim/test/test_data/phraser-3.6.0.model differ diff --git a/gensim/test/test_data/phrases-3.6.0.model b/gensim/test/test_data/phrases-3.6.0.model new file mode 100644 index 0000000000..65b831439f Binary files /dev/null and b/gensim/test/test_data/phrases-3.6.0.model differ diff --git a/gensim/test/test_phrases.py b/gensim/test/test_phrases.py index e83bf5b2b9..1ec3a87f6a 100644 --- a/gensim/test/test_phrases.py +++ b/gensim/test/test_phrases.py @@ -12,7 +12,6 @@ import unittest import six - import numpy as np from gensim.utils import to_unicode @@ -646,6 +645,22 @@ def testEncoding(self): self.assertTrue(isinstance(transformed, six.text_type)) +class TestPhraserModelCompatibilty(unittest.TestCase): + + def testCompatibilty(self): + phr = Phraser.load(datapath("phraser-3.6.0.model")) + model = Phrases.load(datapath("phrases-3.6.0.model")) + + test_sentences = ['trees', 'graph', 'minors'] + expected_res = ['trees', 'graph_minors'] + + phr_out = phr[test_sentences] + model_out = model[test_sentences] + + self.assertEqual(phr_out, expected_res) + self.assertEqual(model_out, expected_res) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main()