Skip to content

Commit

Permalink
update setup process with doc2vec + docs cleanup
Browse files Browse the repository at this point in the history
* renamed LabeledText to LabeledSentence &c
  • Loading branch information
piskvorky committed Oct 5, 2014
1 parent a7520e6 commit 1aa11bb
Show file tree
Hide file tree
Showing 5 changed files with 13,072 additions and 102 deletions.
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ include ez_setup.py
include gensim/models/voidptr.h
include gensim/models/word2vec_inner.c
include gensim/models/word2vec_inner.pyx
include gensim/models/doc2vec_inner.c
include gensim/models/doc2vec_inner.pyx
1 change: 1 addition & 0 deletions gensim/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .rpmodel import RpModel
from .logentropy_model import LogEntropyModel
from .word2vec import Word2Vec
from .doc2vec import Doc2Vec
from .ldamulticore import LdaMulticore
from .dtmmodel import DtmModel

Expand Down
216 changes: 115 additions & 101 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,136 +41,147 @@
except ImportError:
from Queue import Queue

from numpy import zeros, random, get_include, sum as np_sum
from numpy import zeros, random, sum as np_sum

logger = logging.getLogger("gensim.models.doc2vec")
logger = logging.getLogger(__name__)

from gensim import utils # utility fnc for pickling, common scipy operations etc
from word2vec import Word2Vec, Vocab, train_cbow_pair, train_sg_pair


try:
raise ImportError
from gensim_addons.models.doc2vec_inner import train_sentence_dbow, train_sentence_dm, FAST_VERSION
except ImportError:
try:
# try to compile and use the faster cython version
import pyximport
models_dir = os.path.dirname(__file__) or os.getcwd()
pyximport.install(setup_args={"include_dirs": [models_dir, get_include()]})
from doc2vec_inner import train_sentence_dbow, train_sentence_dm, FAST_VERSION
except:
# failed... fall back to plain numpy (20-80x slower training than the above)
FAST_VERSION = -1

def train_sentence_dbow(model, sentence, lbls, alpha, work=None, train_words=True, train_lbls=True):
"""
Update distributed bag of words model by training on a single sentence.
The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Doc2Vec.train()`.
This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from doc2vec_inner instead.
"""
neg_labels = []
if model.negative:
# precompute negative labels
neg_labels = zeros(model.negative + 1)
neg_labels[0] = 1.0

for label in lbls:
if label is None:
continue # OOV word in the input sentence => skip
for word in sentence:
if word is None:
continue # OOV word in the input sentence => skip
train_sg_pair(model, word, label, alpha, neg_labels, train_words, train_lbls)

return len([word for word in sentence if word is not None])

def train_sentence_dm(model, sentence, lbls, alpha, work=None, neu1=None, train_words=True, train_lbls=True):
"""
Update distributed memory model by training on a single sentence.
The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Doc2Vec.train()`.
This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from doc2vec_inner instead.
"""
lbl_indices = [lbl.index for lbl in lbls if lbl is not None]
lbl_sum = np_sum(model.syn0[lbl_indices], axis=0)
lbl_len = len(lbl_indices)
neg_labels = []
if model.negative:
# precompute negative labels
neg_labels = zeros(model.negative + 1)
neg_labels[0] = 1.

for pos, word in enumerate(sentence):
from gensim.models.doc2vec_inner import train_sentence_dbow, train_sentence_dm, FAST_VERSION
except:
# failed... fall back to plain numpy (20-80x slower training than the above)
FAST_VERSION = -1

def train_sentence_dbow(model, sentence, lbls, alpha, work=None, train_words=True, train_lbls=True):
"""
Update distributed bag of words model by training on a single sentence.
The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Doc2Vec.train()`.
This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from doc2vec_inner instead.
"""
neg_labels = []
if model.negative:
# precompute negative labels
neg_labels = zeros(model.negative + 1)
neg_labels[0] = 1.0

for label in lbls:
if label is None:
continue # OOV word in the input sentence => skip
for word in sentence:
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original doc2vec code
start = max(0, pos - model.window + reduced_window)
window_pos = enumerate(sentence[start : pos + model.window + 1 - reduced_window], start)
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)]
l1 = np_sum(model.syn0[word2_indices], axis=0) + lbl_sum # 1 x layer1_size
if word2_indices and model.cbow_mean:
l1 /= (len(word2_indices) + lbl_len)
neu1e = train_cbow_pair(model, word, word2_indices, l1, alpha, neg_labels, train_words, train_words)
if train_lbls:
model.syn0[lbl_indices] += neu1e

return len([word for word in sentence if word is not None])


class LabeledText(object):
"""A single labeled text item. Replaces list of words for each sentence from Word2Vec."""
def __init__(self, text, labels):
self.text = text
train_sg_pair(model, word, label, alpha, neg_labels, train_words, train_lbls)

return len([word for word in sentence if word is not None])

def train_sentence_dm(model, sentence, lbls, alpha, work=None, neu1=None, train_words=True, train_lbls=True):
"""
Update distributed memory model by training on a single sentence.
The sentence is a list of Vocab objects (or None, where the corresponding
word is not in the vocabulary. Called internally from `Doc2Vec.train()`.
This is the non-optimized, Python version. If you have cython installed, gensim
will use the optimized version from doc2vec_inner instead.
"""
lbl_indices = [lbl.index for lbl in lbls if lbl is not None]
lbl_sum = np_sum(model.syn0[lbl_indices], axis=0)
lbl_len = len(lbl_indices)
neg_labels = []
if model.negative:
# precompute negative labels
neg_labels = zeros(model.negative + 1)
neg_labels[0] = 1.

for pos, word in enumerate(sentence):
if word is None:
continue # OOV word in the input sentence => skip
reduced_window = random.randint(model.window) # `b` in the original doc2vec code
start = max(0, pos - model.window + reduced_window)
window_pos = enumerate(sentence[start : pos + model.window + 1 - reduced_window], start)
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)]
l1 = np_sum(model.syn0[word2_indices], axis=0) + lbl_sum # 1 x layer1_size
if word2_indices and model.cbow_mean:
l1 /= (len(word2_indices) + lbl_len)
neu1e = train_cbow_pair(model, word, word2_indices, l1, alpha, neg_labels, train_words, train_words)
if train_lbls:
model.syn0[lbl_indices] += neu1e

return len([word for word in sentence if word is not None])


class LabeledSentence(object):
"""
A single labeled sentence = text item.
Replaces "sentence as a list of words" from Word2Vec.
"""
def __init__(self, words, labels):
"""
`words` is a list of tokens (unicode strings), `labels` a
list of text labels associated with this text.
"""
self.words = words
self.labels = labels

def __str__(self):
return 'LabeledText(' + str(self.text) + ', ' + str(self.labels) + ')'
return '%s(%s, %s)' % (self.__class__.__name__, self.words, self.labels)


class Doc2Vec(Word2Vec):
"""Class for training, using and evaluating neural networks described in http://arxiv.org/pdf/1405.4053v2.pdf"""
def __init__(self, sentences=None, size=300, alpha=0.025, window=8, min_count=5,
sample=0, seed=1, workers=1, min_alpha=0.0001, dm=1, hs=1, negative=0,
dm_mean=0, train_words=True, train_lbls=True):
dm_mean=0, train_words=True, train_lbls=True, **kwargs):
"""
Initialize the model from an iterable of `sentences`. Each sentence is a
list of LabeledText objects that will be used for training.
LabeledSentence object that will be used for training.
The `sentences` iterable can be simply a list of LabeledText elements, but for larger corpora,
The `sentences` iterable can be simply a list of LabeledSentence elements, but for larger corpora,
consider an iterable that streams the sentences directly from disk/network.
If you don't supply `sentences`, the model is left uninitialized -- use if
you plan to initialize it in some other way.
`dm` defines the training algorithm. By default (`dm=1`), distributed memory is used.
Otherwise, `dbow` is employed.
Otherwise, `dbow` is employed.
`size` is the dimensionality of the feature vectors.
`window` is the maximum distance between the current and predicted word within a sentence.
`alpha` is the initial learning rate (will linearly drop to zero as training progresses).
`seed` = for the random number generator.
`min_count` = ignore all words with total frequency lower than this.
`sample` = threshold for configuring which higher-frequency words are randomly downsampled;
default is 0 (off), useful value is 1e-5.
`workers` = use this many worker threads to train the model (=faster training with multicore machines)
`hs` = if 1 (default), hierarchical sampling will be used for model training (else set to 0)
`workers` = use this many worker threads to train the model (=faster training with multicore machines).
`hs` = if 1 (default), hierarchical sampling will be used for model training (else set to 0).
`negative` = if > 0, negative sampling will be used, the int for negative
specifies how many "noise words" should be drawn (usually between 5-20)
specifies how many "noise words" should be drawn (usually between 5-20).
`dm_mean` = if 0 (default), use the sum of the context word vectors. If 1, use the mean.
Only applies when dm is used.
Only applies when dm is used.
"""
Word2Vec.__init__(self, size=size, alpha=alpha, window=window, min_count=min_count,
sample=sample, seed=seed, workers=workers, min_alpha=min_alpha,
sg=(1+dm) % 2, hs=hs, negative=negative, cbow_mean=dm_mean)
sg=(1+dm) % 2, hs=hs, negative=negative, cbow_mean=dm_mean, **kwargs)
self.train_words = train_words
self.train_lbls = train_lbls
if sentences is not None:
Expand All @@ -185,14 +196,14 @@ def _vocab_from(sentences):
if sentence_no % 10000 == 0:
logger.info("PROGRESS: at item #%i, processed %i words and %i word types" %
(sentence_no, total_words, len(vocab)))
sentence_length = len(sentence.text)
sentence_length = len(sentence.words)
for label in sentence.labels:
total_words += 1
if label in vocab:
vocab[label].count += sentence_length
else:
vocab[label] = Vocab(count=sentence_length)
for word in sentence.text:
for word in sentence.words:
total_words += 1
if word in vocab:
vocab[word].count += 1
Expand All @@ -205,7 +216,7 @@ def _vocab_from(sentences):
def _prepare_sentences(self, sentences):
for sentence in sentences:
# avoid calling random_sample() where prob >= 1, to speed things up a little:
sampled = [self.vocab[word] for word in sentence.text
sampled = [self.vocab[word] for word in sentence.words
if word in self.vocab and (self.vocab[word].sample_probability >= 1.0 or
self.vocab[word].sample_probability >= random.random_sample())]
yield (sampled, [self.vocab[word] for word in sentence.labels if word in self.vocab])
Expand All @@ -225,7 +236,8 @@ def save(self, *args, **kwargs):


class LabeledBrownCorpus(object):
"""Iterate over sentences from the Brown corpus (part of NLTK data)."""
"""Iterate over sentences from the Brown corpus (part of NLTK data), yielding
each sentence out as a LabeledSentence object."""
def __init__(self, dirname):
self.dirname = dirname

Expand All @@ -243,11 +255,14 @@ def __iter__(self):
words = ["%s/%s" % (token.lower(), tag[:2]) for token, tag in token_tags if tag[:2].isalpha()]
if not words: # don't bother sending out empty sentences
continue
yield LabeledText(words, [fname+'_SENT_'+str(item_no)])
yield LabeledSentence(words, ['%s_SENT_%s' % (fname, item_no)])


class LabeledLineSentence(object):
"""Simple format: one sentence = one line; words already preprocessed and separated by whitespace."""
"""Simple format: one sentence = one line = one LabeledSentence object.
Words are expected to be already preprocessed and separated by whitespace,
labels are constructed automatically from the sentence line number."""
def __init__(self, source):
"""
`source` can be either a string or a file object.
Expand All @@ -271,10 +286,9 @@ def __iter__(self):
# Things that don't have seek will trigger an exception
self.source.seek(0)
for item_no, line in enumerate(self.source):
yield LabeledText(utils.to_unicode(line).split(), ['SENT_'+str(item_no)])
yield LabeledSentence(utils.to_unicode(line).split(), ['SENT_%s' % item_no])
except AttributeError:
# If it didn't work like a file, use it as a string filename
with utils.smart_open(self.source) as fin:
for item_no, line in enumerate(fin):
yield LabeledText(utils.to_unicode(line).split(), ['SENT_'+str(item_no)])

yield LabeledSentence(utils.to_unicode(line).split(), ['SENT_%s' % item_no])
Loading

0 comments on commit 1aa11bb

Please sign in to comment.