Skip to content

Commit

Permalink
update summarization tutorial image
Browse files Browse the repository at this point in the history
resolving merge conflicts

resolving merge conflicts

resolving merge conflicts

resolving merge conflicts

resolving merge conflicts

resolving merge conflicts
  • Loading branch information
piskvorky authored and rutum committed Aug 24, 2015
1 parent 87340d9 commit e6fcbb3
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 41 deletions.
3 changes: 2 additions & 1 deletion gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def reset_from(self, other_model):
self.docvecs.borrow_from(other_model.docvecs)
super(Doc2Vec, self).reset_from(other_model)

def scan_vocab(self, documents, progress_per=10000, trim_rule=None):

def scan_vocab(self, documents, update=False, progress_per=10000, trim_rule=None):
logger.info("collecting all words and their counts")
document_no = -1
total_words = 0
Expand Down
145 changes: 105 additions & 40 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def train_cbow_pair(model, word, input_word_indices, l1, alpha, learn_vectors=Tr
if learn_hidden:
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

if learn_vectors:
# learn input -> hidden, here for all words in the window separately
if not model.cbow_mean and input_word_indices:
Expand Down Expand Up @@ -485,17 +484,20 @@ def create_binary_tree(self):

logger.info("built huffman tree with maximum node depth %i", max_depth)

def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None):

def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, update=False):

"""
Build vocabulary from a sequence of sentences (can be a once-only generator stream).
Each sentence must be a list of unicode strings.
"""
self.scan_vocab(sentences, trim_rule=trim_rule) # initial survey
self.scale_vocab(keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling
self.finalize_vocab() # build tables & arrays
self.scan_vocab(sentences, update, trim_rule=trim_rule) # initial survey
self.scale_vocab(update, keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling
self.finalize_vocab(update) # build tables & arrays

def scan_vocab(self, sentences, progress_per=10000, trim_rule=None):
def scan_vocab(self, sentences, update, progress_per=10000,
trim_rule=None):
"""Do an initial scan of all words appearing in sentences."""
logger.info("collecting all words and their counts")
sentence_no = -1
Expand All @@ -509,54 +511,79 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None):
for word in sentence:
vocab[word] += 1

if self.max_vocab_size and len(vocab) > self.max_vocab_size:
total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule)
min_reduce += 1
if not update:
if self.max_vocab_size and len(vocab) > self.max_vocab_size:
total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule)
min_reduce += 1

total_words += sum(itervalues(vocab))
logger.info("collected %i word types from a corpus of %i raw words and %i sentences",
len(vocab), total_words, sentence_no + 1)
self.corpus_count = sentence_no + 1
self.raw_vocab = vocab

def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None):
def scale_vocab(self, update, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None):
"""
Apply vocabulary settings for `min_count` (discarding less-frequent words)
and `sample` (controlling the downsampling of more-frequent words).
Apply vocabulary settings for `min_count` (discarding
less-frequent words) and `sample` (controlling the downsampling of
more-frequent words).
Calling with `dry_run=True` will only simulate the provided settings and
report the size of the retained vocabulary, effective corpus length, and
estimated memory requirements. Results are both printed via logging and
returned as a dict.
Calling with `dry_run=True` will only simulate the provided settings
and report the size of the retained vocabulary, effective corpus
length, and estimated memory requirements. Results are both printed
via logging and returned as a dict.
Delete the raw vocabulary after the scaling is done to free up RAM,
unless `keep_raw_vocab` is set.
"""
min_count = min_count or self.min_count
sample = sample or self.sample

# Discard words less-frequent than min_count
if not dry_run:
self.index2word = []
# make stored settings match these applied settings
self.min_count = min_count
self.sample = sample
self.vocab = {}
drop_unique, drop_total, retain_total, original_total = 0, 0, 0, 0
retain_words = []
for word, v in iteritems(self.raw_vocab):
if keep_vocab_item(word, v, min_count, trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v, index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v

if not update:
logger.info("Loading a fresh vocabulary")
# Discard words less-frequent than min_count
if not dry_run:
self.index2word = []
# make stored settings match these applied settings
self.min_count = min_count
self.sample = sample
self.vocab = {}

for word, v in iteritems(self.raw_vocab):
if keep_vocab_item(word, v, min_count, trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v,
index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v
elif update:
logger.info("Updating model with new vocabulary")
for word, v in iteritems(self.raw_vocab):
if not word in self.vocab:
# the word does not already exist in vocab
if keep_vocab_item(word, v, min_count,
trim_rule=trim_rule):
retain_words.append(word)
retain_total += v
original_total += v
if not dry_run:
self.vocab[word] = Vocab(count=v,
index=len(self.index2word))
self.index2word.append(word)
else:
drop_unique += 1
drop_total += v
original_total += v

logger.info("min_count=%d retains %i unique words (drops %i)",
min_count, len(retain_words), drop_unique)
logger.info("min_count leaves %i word corpus (%i%% of original %i)",
Expand Down Expand Up @@ -603,10 +630,10 @@ def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab

return report_values

def finalize_vocab(self):
def finalize_vocab(self, update):
"""Build tables and model weights based on final vocabulary settings."""
if not self.index2word:
self.scale_vocab()
self.scale_vocab(False)
if self.hs:
# add info about each word's Huffman encoding
self.create_binary_tree()
Expand All @@ -621,7 +648,10 @@ def finalize_vocab(self):
self.index2word.append(word)
self.vocab[word] = v
# set initial input/projection and hidden weights
self.reset_weights()
if not update:
self.reset_weights()
else:
self.update_weights()

def reset_from(self, other_model):
"""
Expand Down Expand Up @@ -921,6 +951,41 @@ def worker_loop():
def clear_sims(self):
self.syn0norm = None


def update_weights(self):
"""
Copy all the existing weights, and reset the weights for the newly
added vocabulary.
"""
logger.info("updating layer weights")
newsyn0 = empty((len(self.vocab), self.layer1_size), dtype=REAL)

# copy the weights that are already learned
for i in xrange(0, len(self.syn0)):
newsyn0[i] = self.syn0[i]

# randomize the remaining words
for i in xrange(len(self.vocab), len(newsyn0)):
# construct deterministic seed from word AND seed argument
self.syn0[i] = self.seeded_vector(self.index2word[i] + str(self.seed))
self.syn0 = newsyn0

if self.hs:
oldsyn1 = self.syn1
self.syn1 = zeros((len(self.vocab), self.layer1_size), dtype=REAL)
for i in xrange(0, len(oldsyn1)):
self.syn1[i] = oldsyn1[i]
if self.negative:
oldneg = self.syn1neg
self.syn1neg = zeros((len(self.vocab), self.layer1_size), dtype=REAL)
for i in xrange(0, len(oldneg)):
self.syn1neg[i] = oldneg[i]
self.syn0norm = None

# do not suppress learning for already learned words
self.syn0_lockf = ones(len(self.vocab), dtype=REAL) # zeros suppress learning


def reset_weights(self):
"""Reset all projection weights to an initial (untrained) state, but keep the existing vocabulary."""
logger.info("resetting layer weights")
Expand Down Expand Up @@ -1409,6 +1474,7 @@ def save(self, *args, **kwargs):

save.__doc__ = utils.SaveLoad.save.__doc__


@classmethod
def load(cls, *args, **kwargs):
model = super(Word2Vec, cls).load(*args, **kwargs)
Expand Down Expand Up @@ -1444,7 +1510,6 @@ def __init__(self, init_fn, job_fn):
def put(self, job):
self.job_fn(job, self.inits)


class BrownCorpus(object):
"""Iterate over sentences from the Brown corpus (part of NLTK data)."""
def __init__(self, dirname):
Expand Down
18 changes: 18 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,24 @@ def __iter__(self):
['graph', 'minors', 'survey']
]

new_sentences = [
['computer', 'artificial', 'intelligence'],
['artificial', 'trees']
]

def testfile():
# temporary data will be stored to this file
return os.path.join(tempfile.gettempdir(), 'gensim_word2vec.tst')


def rule_for_testing(word, count, min_count):
if word == "human":
return utils.RULE_DISCARD # throw out
else:
return utils.RULE_DEFAULT # apply default rule, i.e. min_count

class TestWord2VecModel(unittest.TestCase):

def testPersistence(self):
"""Test storing/loading the entire model."""
model = word2vec.Word2Vec(sentences, min_count=1)
Expand Down Expand Up @@ -88,6 +94,18 @@ def testLambdaRule(self):
model = word2vec.Word2Vec(sentences, min_count=1, trim_rule=rule)
self.assertTrue("human" not in model.vocab)

class TestWord2VecModel(unittest.TestCase):
def testOnlineLearning(self):
"""Test that the algorithm is able to add new words to the
vocabulary and to a trained model"""
model = word2vec.Word2Vec(sentences, min_count=1)
model.build_vocab(new_sentences, update=True)
model.train(new_sentences)
self.assertEqual(len(model.vocab), 14)
self.assertEqual(model.syn0.shape[0], 14)
self.assertEqual(model.syn0.shape[1], 100)


def testPersistenceWord2VecFormat(self):
"""Test storing/loading the entire model in word2vec format."""
model = word2vec.Word2Vec(sentences, min_count=1)
Expand Down

0 comments on commit e6fcbb3

Please sign in to comment.