Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online word2vec #700

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def reset_from(self, other_model):
self.docvecs.borrow_from(other_model.docvecs)
super(Doc2Vec, self).reset_from(other_model)

def scan_vocab(self, documents, progress_per=10000, trim_rule=None):
def scan_vocab(self, documents, update=False, progress_per=10000, trim_rule=None):
logger.info("collecting all words and their counts")
document_no = -1
total_words = 0
Expand Down
124 changes: 92 additions & 32 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,11 @@ def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, progress_
Each sentence must be a list of unicode strings.

"""
self.scan_vocab(sentences, progress_per=progress_per, trim_rule=trim_rule) # initial survey
self.scale_vocab(keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling
self.finalize_vocab() # build tables & arrays
self.scan_vocab(sentences, update, progress_per=progress_per, trim_rule=trim_rule) # initial survey
self.scale_vocab(update, keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling
self.finalize_vocab(update) # build tables & arrays

def scan_vocab(self, sentences, progress_per=10000, trim_rule=None):
def scan_vocab(self, sentences, update, progress_per=10000, trim_rule=None):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to put a default in here, so that the change is backward compatible (some users call these functions manually in their app, we don't want to break that just because of optional upgrade).

"""Do an initial scan of all words appearing in sentences."""
logger.info("collecting all words and their counts")
sentence_no = -1
Expand All @@ -527,17 +527,18 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None):
for word in sentence:
vocab[word] += 1

if self.max_vocab_size and len(vocab) > self.max_vocab_size:
total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule)
min_reduce += 1
if not update:
if self.max_vocab_size and len(vocab) > self.max_vocab_size:
total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule)
min_reduce += 1

total_words += sum(itervalues(vocab))
logger.info("collected %i word types from a corpus of %i raw words and %i sentences",
len(vocab), total_words, sentence_no + 1)
self.corpus_count = sentence_no + 1
self.raw_vocab = vocab

def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None):
def scale_vocab(self, update, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None):
"""
Apply vocabulary settings for `min_count` (discarding less-frequent words)
and `sample` (controlling the downsampling of more-frequent words).
Expand All @@ -553,28 +554,51 @@ def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab
"""
min_count = min_count or self.min_count
sample = sample or self.sample

# Discard words less-frequent than min_count
if not dry_run:
self.index2word = []
# make stored settings match these applied settings
self.min_count = min_count
self.sample = sample
self.vocab = {}
drop_unique, drop_total, retain_total, original_total = 0, 0, 0, 0
retain_words = []
for word, v in iteritems(self.raw_vocab):
if keep_vocab_item(word, v, min_count, trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v, index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v

if not update:
logger.info("Loading a fresh vocabulary")
# Discard words less-frequent than min_count
if not dry_run:
self.index2word = []
# make stored settings match these applied settings
self.min_count = min_count
self.sample = sample
self.vocab = {}

for word, v in iteritems(self.raw_vocab):
if keep_vocab_item(word, v, min_count, trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v,
index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v
else:
logger.info("Updating model with new vocabulary")
for word, v in iteritems(self.raw_vocab):
if not word in self.vocab:
# the word does not already exist in vocab
if keep_vocab_item(word, v, min_count,
trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v,
index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v

logger.info("min_count=%d retains %i unique words (drops %i)",
min_count, len(retain_words), drop_unique)
logger.info("min_count leaves %i word corpus (%i%% of original %i)",
Expand Down Expand Up @@ -621,11 +645,11 @@ def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab

return report_values

def finalize_vocab(self):
def finalize_vocab(self, update):
"""Build tables and model weights based on final vocabulary settings."""
if not self.index2word:
self.scale_vocab()
if self.sorted_vocab:
self.scale_vocab(False)
if self.sorted_vocab and not update:
self.sort_vocab()
if self.hs:
# add info about each word's Huffman encoding
Expand All @@ -641,7 +665,10 @@ def finalize_vocab(self):
self.index2word.append(word)
self.vocab[word] = v
# set initial input/projection and hidden weights
self.reset_weights()
if not update:
self.reset_weights()
else:
self.update_weights()

def sort_vocab(self):
"""Sort the vocabulary so the most frequent words have the lowest indexes."""
Expand Down Expand Up @@ -983,6 +1010,39 @@ def worker_loop():
def clear_sims(self):
self.syn0norm = None

def update_weights(self):
"""
Copy all the existing weights, and reset the weights for the newly
added vocabulary.
"""
logger.info("updating layer weights")
newsyn0 = empty((len(self.vocab), self.vector_size), dtype=REAL)

# copy the weights that are already learned
for i in xrange(0, len(self.syn0)):
newsyn0[i] = deepcopy(self.syn0[i])

# randomize the remaining words
for i in xrange(len(self.vocab), len(newsyn0)):
# construct deterministic seed from word AND seed argument
self.syn0[i] = self.seeded_vector(self.index2word[i] + str(self.seed))
self.syn0 = deepcopy(newsyn0)

if self.hs:
oldsyn1 = deepcopy(self.syn1)
self.syn1 = zeros((len(self.vocab), self.layer1_size), dtype=REAL)
self.syn1[i] = deepcopy(oldsyn1[i])

if self.negative:
oldneg = deepcopy(self.syn1neg)
self.syn1neg = zeros((len(self.vocab), self.layer1_size), dtype=REAL)
self.syn1neg[i] = deepcopy(oldneg[i])

self.syn0norm = None

# do not suppress learning for already learned words
self.syn0_lockf = ones(len(self.vocab), dtype=REAL) # zeros suppress learning

def reset_weights(self):
"""Reset all projection weights to an initial (untrained) state, but keep the existing vocabulary."""
logger.info("resetting layer weights")
Expand Down
26 changes: 26 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def __iter__(self):
['graph', 'minors', 'survey']
]

new_sentences = [
['computer', 'artificial', 'intelligence'],
['artificial', 'trees']
]

def testfile():
# temporary data will be stored to this file
return os.path.join(tempfile.gettempdir(), 'gensim_word2vec.tst')
Expand Down Expand Up @@ -91,6 +96,27 @@ def testLambdaRule(self):
model = word2vec.Word2Vec(sentences, min_count=1, trim_rule=rule)
self.assertTrue("human" not in model.vocab)

class TestWord2VecModel(unittest.TestCase):
def testOnlineLearning(self):
"""Test that the algorithm is able to add new words to the
vocabulary and to a trained model"""
model = word2vec.Word2Vec(sentences, min_count=1, sorted_vocab=0)
model.build_vocab(new_sentences, update=True)
model.train(new_sentences)
self.assertEqual(len(model.vocab), 14)
self.assertEqual(model.syn0.shape[0], 14)
self.assertEqual(model.syn0.shape[1], 100)

def testOnlineLearning(self):
"""Test that the algorithm is able to add new words to the
vocabulary and to a trained model when using a sorted vocabulary"""
model = word2vec.Word2Vec(sentences, min_count=0, sorted_vocab=0)
model.build_vocab(new_sentences, update=True)
model.train(new_sentences)
self.assertEqual(len(model.vocab), 14)
self.assertEqual(model.syn0.shape[0], 14)
self.assertEqual(model.syn0.shape[1], 100)

def testPersistenceWord2VecFormat(self):
"""Test storing/loading the entire model in word2vec format."""
model = word2vec.Word2Vec(sentences, min_count=1)
Expand Down