Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RuntimeError in export_phrases (change defaultdict to dict) #3041

Merged
merged 9 commits into from
Feb 13, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions gensim/models/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"""

import logging
from collections import defaultdict
import itertools
from math import log
import pickle
Expand Down Expand Up @@ -412,7 +411,7 @@ def load(cls, *args, **kwargs):
if not isinstance(word, str):
logger.info("old version of %s loaded, upgrading %i words in memory", cls.__name__, len(model.vocab))
logger.info("re-save the loaded model to avoid this upgrade in the future")
vocab = defaultdict(int)
vocab = {}
for key, value in model.vocab.items(): # needs lots of extra RAM temporarily!
vocab[str(key, encoding='utf8')] = value
model.vocab = vocab
Expand Down Expand Up @@ -554,7 +553,7 @@ def __init__(
self.min_count = min_count
self.threshold = threshold
self.max_vocab_size = max_vocab_size
self.vocab = defaultdict(int) # mapping between token => its count
self.vocab = {} # mapping between token => its count
self.min_reduce = 1 # ignore any tokens with count smaller than this
self.delimiter = delimiter
self.progress_per = progress_per
Expand All @@ -579,7 +578,7 @@ def __str__(self):
def _learn_vocab(sentences, max_vocab_size, delimiter, connector_words, progress_per):
"""Collect unigram and bigram counts from the `sentences` iterable."""
sentence_no, total_words, min_reduce = -1, 0, 1
vocab = defaultdict(int)
vocab = {}
logger.info("collecting all words and their counts")
for sentence_no, sentence in enumerate(sentences):
if sentence_no % progress_per == 0:
Expand All @@ -590,10 +589,11 @@ def _learn_vocab(sentences, max_vocab_size, delimiter, connector_words, progress
start_token, in_between = None, []
for word in sentence:
if word not in connector_words:
vocab[word] += 1
vocab[word] = vocab.get(word, 0) + 1
if start_token is not None:
phrase_tokens = itertools.chain([start_token], in_between, [word])
vocab[delimiter.join(phrase_tokens)] += 1
joined_phrase_token = delimiter.join(phrase_tokens)
vocab[joined_phrase_token] = vocab.get(joined_phrase_token, 0) + 1
start_token, in_between = word, [] # treat word as both end of a phrase AND beginning of another
elif start_token is not None:
in_between.append(word)
Expand Down Expand Up @@ -654,7 +654,7 @@ def add_vocab(self, sentences):
logger.info("merging %i counts into %s", len(vocab), self)
self.min_reduce = max(self.min_reduce, min_reduce)
for word, count in vocab.items():
self.vocab[word] += count
self.vocab[word] = self.vocab.get(word, 0) + count
if len(self.vocab) > self.max_vocab_size:
utils.prune_vocab(self.vocab, self.min_reduce)
self.min_reduce += 1
Expand All @@ -666,17 +666,17 @@ def add_vocab(self, sentences):

def score_candidate(self, word_a, word_b, in_between):
# Micro optimization: check for quick early-out conditions, before the actual scoring.
word_a_cnt = self.vocab[word_a]
word_a_cnt = self.vocab.get(word_a, 0)
if word_a_cnt <= 0:
return None, None

word_b_cnt = self.vocab[word_b]
word_b_cnt = self.vocab.get(word_b, 0)
if word_b_cnt <= 0:
return None, None

phrase = self.delimiter.join([word_a] + in_between + [word_b])
# XXX: Why do we care about *all* phrase tokens? Why not just score the start+end bigram?
phrase_cnt = self.vocab[phrase]
phrase_cnt = self.vocab.get(phrase, 0)
if phrase_cnt <= 0:
return None, None

Expand Down
46 changes: 40 additions & 6 deletions gensim/test/test_phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,34 @@ def dumb_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, co
class TestPhrasesModel(PhrasesCommon, unittest.TestCase):

def test_export_phrases(self):
"""Test Phrases bigram export phrases."""
"""Test Phrases bigram and trigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
trigram = Phrases(bigram[self.sentences], min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.export_phrases().keys())
seen_trigrams = set(trigram.export_phrases().keys())

assert seen_bigrams == set([
'human interface',
'response time',
'graph minors',
'minors survey',
])

assert seen_trigrams == set([
'human interface',
'graph minors survey',
])

def test_find_phrases(self):
"""Test Phrases bigram find phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.find_phrases(self.sentences).keys())

assert seen_bigrams == {
assert seen_bigrams == set([
'response time',
'graph minors',
'human interface',
}
])

def test_multiple_bigrams_single_entry(self):
"""Test a single entry produces multiple bigrams."""
Expand Down Expand Up @@ -441,7 +460,7 @@ def test_multiple_bigrams_single_entry(self):
'human interface',
])

def test_export_phrases(self):
def test_find_phrases(self):
"""Test Phrases bigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, connector_words=self.connector_words, delimiter=' ')
seen_bigrams = set(bigram.find_phrases(self.sentences).keys())
Expand All @@ -453,6 +472,21 @@ def test_export_phrases(self):
'lack of interest',
])

def test_export_phrases(self):
"""Test Phrases bigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.export_phrases().keys())
assert seen_bigrams == set([
'and graph',
'data and',
'graph of',
'graph survey',
'human interface',
'lack of',
'of interest',
'of trees',
])

def test_scoring_default(self):
""" test the default scoring, from the mikolov word2vec paper """
bigram = Phrases(self.sentences, min_count=1, threshold=1, connector_words=self.connector_words)
Expand Down Expand Up @@ -510,9 +544,9 @@ def test__getitem__(self):
assert phrased_sentence == ['data_and_graph', 'survey', 'for', 'human_interface']


class TestFrozenPhrasesModelCompatibilty(unittest.TestCase):
class TestFrozenPhrasesModelCompatibility(unittest.TestCase):

def test_compatibilty(self):
def test_compatibility(self):
phrases = Phrases.load(datapath("phrases-3.6.0.model"))
phraser = FrozenPhrases.load(datapath("phraser-3.6.0.model"))
test_sentences = ['trees', 'graph', 'minors']
Expand Down