diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index b0c31da313..d4c13474a7 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -609,7 +609,8 @@ def reset_from(self, other_model): self.docvecs.borrow_from(other_model.docvecs) super(Doc2Vec, self).reset_from(other_model) - def scan_vocab(self, documents, progress_per=10000, trim_rule=None): + + def scan_vocab(self, documents, update=False, progress_per=10000, trim_rule=None): logger.info("collecting all words and their counts") document_no = -1 total_words = 0 diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index fb665f4655..e0c6a56816 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -220,7 +220,6 @@ def train_cbow_pair(model, word, input_word_indices, l1, alpha, learn_vectors=Tr if learn_hidden: model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output neu1e += dot(gb, l2b) # save error - if learn_vectors: # learn input -> hidden, here for all words in the window separately if not model.cbow_mean and input_word_indices: @@ -485,17 +484,20 @@ def create_binary_tree(self): logger.info("built huffman tree with maximum node depth %i", max_depth) - def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None): + + def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, update=False): + """ Build vocabulary from a sequence of sentences (can be a once-only generator stream). Each sentence must be a list of unicode strings. """ - self.scan_vocab(sentences, trim_rule=trim_rule) # initial survey - self.scale_vocab(keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling - self.finalize_vocab() # build tables & arrays + self.scan_vocab(sentences, update, trim_rule=trim_rule) # initial survey + self.scale_vocab(update, keep_raw_vocab, trim_rule=trim_rule) # trim by min_count & precalculate downsampling + self.finalize_vocab(update) # build tables & arrays - def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): + def scan_vocab(self, sentences, update, progress_per=10000, + trim_rule=None): """Do an initial scan of all words appearing in sentences.""" logger.info("collecting all words and their counts") sentence_no = -1 @@ -509,9 +511,10 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): for word in sentence: vocab[word] += 1 - if self.max_vocab_size and len(vocab) > self.max_vocab_size: - total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) - min_reduce += 1 + if not update: + if self.max_vocab_size and len(vocab) > self.max_vocab_size: + total_words += utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) + min_reduce += 1 total_words += sum(itervalues(vocab)) logger.info("collected %i word types from a corpus of %i raw words and %i sentences", @@ -519,15 +522,16 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): self.corpus_count = sentence_no + 1 self.raw_vocab = vocab - def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None): + def scale_vocab(self, update, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None): """ - Apply vocabulary settings for `min_count` (discarding less-frequent words) - and `sample` (controlling the downsampling of more-frequent words). + Apply vocabulary settings for `min_count` (discarding + less-frequent words) and `sample` (controlling the downsampling of + more-frequent words). - Calling with `dry_run=True` will only simulate the provided settings and - report the size of the retained vocabulary, effective corpus length, and - estimated memory requirements. Results are both printed via logging and - returned as a dict. + Calling with `dry_run=True` will only simulate the provided settings + and report the size of the retained vocabulary, effective corpus + length, and estimated memory requirements. Results are both printed + via logging and returned as a dict. Delete the raw vocabulary after the scaling is done to free up RAM, unless `keep_raw_vocab` is set. @@ -535,28 +539,51 @@ def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab """ min_count = min_count or self.min_count sample = sample or self.sample - - # Discard words less-frequent than min_count - if not dry_run: - self.index2word = [] - # make stored settings match these applied settings - self.min_count = min_count - self.sample = sample - self.vocab = {} drop_unique, drop_total, retain_total, original_total = 0, 0, 0, 0 retain_words = [] - for word, v in iteritems(self.raw_vocab): - if keep_vocab_item(word, v, min_count, trim_rule=trim_rule): - retain_words.append(word) - retain_total += v - original_total += v - if not dry_run: - self.vocab[word] = Vocab(count=v, index=len(self.index2word)) - self.index2word.append(word) - else: - drop_unique += 1 - drop_total += v - original_total += v + + if not update: + logger.info("Loading a fresh vocabulary") + # Discard words less-frequent than min_count + if not dry_run: + self.index2word = [] + # make stored settings match these applied settings + self.min_count = min_count + self.sample = sample + self.vocab = {} + + for word, v in iteritems(self.raw_vocab): + if keep_vocab_item(word, v, min_count, trim_rule=trim_rule): + retain_words.append(word) + retain_total += v + original_total += v + if not dry_run: + self.vocab[word] = Vocab(count=v, + index=len(self.index2word)) + self.index2word.append(word) + else: + drop_unique += 1 + drop_total += v + original_total += v + else: + logger.info("Updating model with new vocabulary") + for word, v in iteritems(self.raw_vocab): + if not word in self.vocab: + # the word does not already exist in vocab + if keep_vocab_item(word, v, min_count, + trim_rule=trim_rule): + retain_words.append(word) + retain_total += v + original_total += v + if not dry_run: + self.vocab[word] = Vocab(count=v, + index=len(self.index2word)) + self.index2word.append(word) + else: + drop_unique += 1 + drop_total += v + original_total += v + logger.info("min_count=%d retains %i unique words (drops %i)", min_count, len(retain_words), drop_unique) logger.info("min_count leaves %i word corpus (%i%% of original %i)", @@ -603,10 +630,10 @@ def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab return report_values - def finalize_vocab(self): + def finalize_vocab(self, update): """Build tables and model weights based on final vocabulary settings.""" if not self.index2word: - self.scale_vocab() + self.scale_vocab(False) if self.hs: # add info about each word's Huffman encoding self.create_binary_tree() @@ -621,7 +648,10 @@ def finalize_vocab(self): self.index2word.append(word) self.vocab[word] = v # set initial input/projection and hidden weights - self.reset_weights() + if not update: + self.reset_weights() + else: + self.update_weights() def reset_from(self, other_model): """ @@ -921,6 +951,39 @@ def worker_loop(): def clear_sims(self): self.syn0norm = None + def update_weights(self): + """ + Copy all the existing weights, and reset the weights for the newly + added vocabulary. + """ + logger.info("updating layer weights") + newsyn0 = empty((len(self.vocab), self.vector_size), dtype=REAL) + + # copy the weights that are already learned + for i in xrange(0, len(self.syn0)): + newsyn0[i] = deepcopy(self.syn0[i]) + + # randomize the remaining words + for i in xrange(len(self.vocab), len(newsyn0)): + # construct deterministic seed from word AND seed argument + self.syn0[i] = self.seeded_vector(self.index2word[i] + str(self.seed)) + self.syn0 = deepcopy(newsyn0) + + if self.hs: + oldsyn1 = deepcopy(self.syn1) + self.syn1 = zeros((len(self.vocab), self.layer1_size), dtype=REAL) + self.syn1[i] = deepcopy(oldsyn1[i]) + + if self.negative: + oldneg = deepcopy(self.syn1neg) + self.syn1neg = zeros((len(self.vocab), self.layer1_size), dtype=REAL) + self.syn1neg[i] = deepcopy(oldneg[i]) + + self.syn0norm = None + + # do not suppress learning for already learned words + self.syn0_lockf = ones(len(self.vocab), dtype=REAL) # zeros suppress learning + def reset_weights(self): """Reset all projection weights to an initial (untrained) state, but keep the existing vocabulary.""" logger.info("resetting layer weights") @@ -1409,6 +1472,7 @@ def save(self, *args, **kwargs): save.__doc__ = utils.SaveLoad.save.__doc__ + @classmethod def load(cls, *args, **kwargs): model = super(Word2Vec, cls).load(*args, **kwargs) @@ -1444,7 +1508,6 @@ def __init__(self, init_fn, job_fn): def put(self, job): self.job_fn(job, self.inits) - class BrownCorpus(object): """Iterate over sentences from the Brown corpus (part of NLTK data).""" def __init__(self, dirname): diff --git a/gensim/scripts/make_wiki.py b/gensim/scripts/make_wiki.py deleted file mode 120000 index 85ddf6cc4f..0000000000 --- a/gensim/scripts/make_wiki.py +++ /dev/null @@ -1 +0,0 @@ -make_wikicorpus.py \ No newline at end of file diff --git a/gensim/scripts/make_wiki.py b/gensim/scripts/make_wiki.py new file mode 100755 index 0000000000..f1bee1b79b --- /dev/null +++ b/gensim/scripts/make_wiki.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Copyright (C) 2012 Lars Buitinck +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + +""" +USAGE: %(program)s WIKI_XML_DUMP OUTPUT_PREFIX [VOCABULARY_SIZE] + +Convert articles from a Wikipedia dump to (sparse) vectors. The input is a +bz2-compressed dump of Wikipedia articles, in XML format. + +This actually creates three files: + +* `OUTPUT_PREFIX_wordids.txt`: mapping between words and their integer ids +* `OUTPUT_PREFIX_bow.mm`: bag-of-words (word counts) representation, in + Matrix Matrix format +* `OUTPUT_PREFIX_tfidf.mm`: TF-IDF representation +* `OUTPUT_PREFIX.tfidf_model`: TF-IDF model dump + +The output Matrix Market files can then be compressed (e.g., by bzip2) to save +disk space; gensim's corpus iterators can work with compressed input, too. + +`VOCABULARY_SIZE` controls how many of the most frequent words to keep (after +removing tokens that appear in more than 10% of all documents). Defaults to +100,000. + +If you have the `pattern` package installed, this script will use a fancy +lemmatization to get a lemma of each token (instead of plain alphabetic +tokenizer). The package is available at https://github.com/clips/pattern . + +Example: python -m gensim.scripts.make_wikicorpus ~/gensim/results/enwiki-latest-pages-articles.xml.bz2 ~/gensim/results/wiki_en +""" + + +import logging +import os.path +import sys + +from gensim.corpora import Dictionary, HashDictionary, MmCorpus, WikiCorpus +from gensim.models import TfidfModel + + +# Wiki is first scanned for all distinct word types (~7M). The types that +# appear in more than 10% of articles are removed and from the rest, the +# DEFAULT_DICT_SIZE most frequent types are kept. +DEFAULT_DICT_SIZE = 100000 + + +if __name__ == '__main__': + program = os.path.basename(sys.argv[0]) + logger = logging.getLogger(program) + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s') + logging.root.setLevel(level=logging.INFO) + logger.info("running %s" % ' '.join(sys.argv)) + + # check and process input arguments + if len(sys.argv) < 3: + print(globals()['__doc__'] % locals()) + sys.exit(1) + inp, outp = sys.argv[1:3] + if len(sys.argv) > 3: + keep_words = int(sys.argv[3]) + else: + keep_words = DEFAULT_DICT_SIZE + online = 'online' in program + lemmatize = 'lemma' in program + debug = 'nodebug' not in program + + if online: + dictionary = HashDictionary(id_range=keep_words, debug=debug) + dictionary.allow_update = True # start collecting document frequencies + wiki = WikiCorpus(inp, lemmatize=lemmatize, dictionary=dictionary) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # ~4h on my macbook pro without lemmatization, 3.1m articles (august 2012) + # with HashDictionary, the token->id mapping is only fully instantiated now, after `serialize` + dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + dictionary.save_as_text(outp + '_wordids.txt.bz2') + wiki.save(outp + '_corpus.pkl.bz2') + dictionary.allow_update = False + else: + wiki = WikiCorpus(inp, lemmatize=lemmatize) # takes about 9h on a macbook pro, for 3.5m articles (june 2011) + # only keep the most frequent words (out of total ~8.2m unique tokens) + wiki.dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + # save dictionary and bag-of-words (term-document frequency matrix) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # another ~9h + wiki.dictionary.save_as_text(outp + '_wordids.txt.bz2') + # load back the id->word mapping directly from file + # this seems to save more memory, compared to keeping the wiki.dictionary object from above + dictionary = Dictionary.load_from_text(outp + '_wordids.txt.bz2') + del wiki + + # initialize corpus reader and word->id mapping + mm = MmCorpus(outp + '_bow.mm') + + # build tfidf, ~50min + tfidf = TfidfModel(mm, id2word=dictionary, normalize=True) + tfidf.save(outp + '.tfidf_model') + + # save tfidf vectors in matrix market format + # ~4h; result file is 15GB! bzip2'ed down to 4.5GB + MmCorpus.serialize(outp + '_tfidf.mm', tfidf[mm], progress_cnt=10000) + + logger.info("finished running %s" % program) diff --git a/gensim/scripts/make_wiki_lemma.py b/gensim/scripts/make_wiki_lemma.py deleted file mode 120000 index 85ddf6cc4f..0000000000 --- a/gensim/scripts/make_wiki_lemma.py +++ /dev/null @@ -1 +0,0 @@ -make_wikicorpus.py \ No newline at end of file diff --git a/gensim/scripts/make_wiki_lemma.py b/gensim/scripts/make_wiki_lemma.py new file mode 100755 index 0000000000..f1bee1b79b --- /dev/null +++ b/gensim/scripts/make_wiki_lemma.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Copyright (C) 2012 Lars Buitinck +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + +""" +USAGE: %(program)s WIKI_XML_DUMP OUTPUT_PREFIX [VOCABULARY_SIZE] + +Convert articles from a Wikipedia dump to (sparse) vectors. The input is a +bz2-compressed dump of Wikipedia articles, in XML format. + +This actually creates three files: + +* `OUTPUT_PREFIX_wordids.txt`: mapping between words and their integer ids +* `OUTPUT_PREFIX_bow.mm`: bag-of-words (word counts) representation, in + Matrix Matrix format +* `OUTPUT_PREFIX_tfidf.mm`: TF-IDF representation +* `OUTPUT_PREFIX.tfidf_model`: TF-IDF model dump + +The output Matrix Market files can then be compressed (e.g., by bzip2) to save +disk space; gensim's corpus iterators can work with compressed input, too. + +`VOCABULARY_SIZE` controls how many of the most frequent words to keep (after +removing tokens that appear in more than 10% of all documents). Defaults to +100,000. + +If you have the `pattern` package installed, this script will use a fancy +lemmatization to get a lemma of each token (instead of plain alphabetic +tokenizer). The package is available at https://github.com/clips/pattern . + +Example: python -m gensim.scripts.make_wikicorpus ~/gensim/results/enwiki-latest-pages-articles.xml.bz2 ~/gensim/results/wiki_en +""" + + +import logging +import os.path +import sys + +from gensim.corpora import Dictionary, HashDictionary, MmCorpus, WikiCorpus +from gensim.models import TfidfModel + + +# Wiki is first scanned for all distinct word types (~7M). The types that +# appear in more than 10% of articles are removed and from the rest, the +# DEFAULT_DICT_SIZE most frequent types are kept. +DEFAULT_DICT_SIZE = 100000 + + +if __name__ == '__main__': + program = os.path.basename(sys.argv[0]) + logger = logging.getLogger(program) + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s') + logging.root.setLevel(level=logging.INFO) + logger.info("running %s" % ' '.join(sys.argv)) + + # check and process input arguments + if len(sys.argv) < 3: + print(globals()['__doc__'] % locals()) + sys.exit(1) + inp, outp = sys.argv[1:3] + if len(sys.argv) > 3: + keep_words = int(sys.argv[3]) + else: + keep_words = DEFAULT_DICT_SIZE + online = 'online' in program + lemmatize = 'lemma' in program + debug = 'nodebug' not in program + + if online: + dictionary = HashDictionary(id_range=keep_words, debug=debug) + dictionary.allow_update = True # start collecting document frequencies + wiki = WikiCorpus(inp, lemmatize=lemmatize, dictionary=dictionary) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # ~4h on my macbook pro without lemmatization, 3.1m articles (august 2012) + # with HashDictionary, the token->id mapping is only fully instantiated now, after `serialize` + dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + dictionary.save_as_text(outp + '_wordids.txt.bz2') + wiki.save(outp + '_corpus.pkl.bz2') + dictionary.allow_update = False + else: + wiki = WikiCorpus(inp, lemmatize=lemmatize) # takes about 9h on a macbook pro, for 3.5m articles (june 2011) + # only keep the most frequent words (out of total ~8.2m unique tokens) + wiki.dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + # save dictionary and bag-of-words (term-document frequency matrix) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # another ~9h + wiki.dictionary.save_as_text(outp + '_wordids.txt.bz2') + # load back the id->word mapping directly from file + # this seems to save more memory, compared to keeping the wiki.dictionary object from above + dictionary = Dictionary.load_from_text(outp + '_wordids.txt.bz2') + del wiki + + # initialize corpus reader and word->id mapping + mm = MmCorpus(outp + '_bow.mm') + + # build tfidf, ~50min + tfidf = TfidfModel(mm, id2word=dictionary, normalize=True) + tfidf.save(outp + '.tfidf_model') + + # save tfidf vectors in matrix market format + # ~4h; result file is 15GB! bzip2'ed down to 4.5GB + MmCorpus.serialize(outp + '_tfidf.mm', tfidf[mm], progress_cnt=10000) + + logger.info("finished running %s" % program) diff --git a/gensim/scripts/make_wiki_online.py b/gensim/scripts/make_wiki_online.py deleted file mode 120000 index 85ddf6cc4f..0000000000 --- a/gensim/scripts/make_wiki_online.py +++ /dev/null @@ -1 +0,0 @@ -make_wikicorpus.py \ No newline at end of file diff --git a/gensim/scripts/make_wiki_online.py b/gensim/scripts/make_wiki_online.py new file mode 100755 index 0000000000..f1bee1b79b --- /dev/null +++ b/gensim/scripts/make_wiki_online.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Copyright (C) 2012 Lars Buitinck +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + +""" +USAGE: %(program)s WIKI_XML_DUMP OUTPUT_PREFIX [VOCABULARY_SIZE] + +Convert articles from a Wikipedia dump to (sparse) vectors. The input is a +bz2-compressed dump of Wikipedia articles, in XML format. + +This actually creates three files: + +* `OUTPUT_PREFIX_wordids.txt`: mapping between words and their integer ids +* `OUTPUT_PREFIX_bow.mm`: bag-of-words (word counts) representation, in + Matrix Matrix format +* `OUTPUT_PREFIX_tfidf.mm`: TF-IDF representation +* `OUTPUT_PREFIX.tfidf_model`: TF-IDF model dump + +The output Matrix Market files can then be compressed (e.g., by bzip2) to save +disk space; gensim's corpus iterators can work with compressed input, too. + +`VOCABULARY_SIZE` controls how many of the most frequent words to keep (after +removing tokens that appear in more than 10% of all documents). Defaults to +100,000. + +If you have the `pattern` package installed, this script will use a fancy +lemmatization to get a lemma of each token (instead of plain alphabetic +tokenizer). The package is available at https://github.com/clips/pattern . + +Example: python -m gensim.scripts.make_wikicorpus ~/gensim/results/enwiki-latest-pages-articles.xml.bz2 ~/gensim/results/wiki_en +""" + + +import logging +import os.path +import sys + +from gensim.corpora import Dictionary, HashDictionary, MmCorpus, WikiCorpus +from gensim.models import TfidfModel + + +# Wiki is first scanned for all distinct word types (~7M). The types that +# appear in more than 10% of articles are removed and from the rest, the +# DEFAULT_DICT_SIZE most frequent types are kept. +DEFAULT_DICT_SIZE = 100000 + + +if __name__ == '__main__': + program = os.path.basename(sys.argv[0]) + logger = logging.getLogger(program) + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s') + logging.root.setLevel(level=logging.INFO) + logger.info("running %s" % ' '.join(sys.argv)) + + # check and process input arguments + if len(sys.argv) < 3: + print(globals()['__doc__'] % locals()) + sys.exit(1) + inp, outp = sys.argv[1:3] + if len(sys.argv) > 3: + keep_words = int(sys.argv[3]) + else: + keep_words = DEFAULT_DICT_SIZE + online = 'online' in program + lemmatize = 'lemma' in program + debug = 'nodebug' not in program + + if online: + dictionary = HashDictionary(id_range=keep_words, debug=debug) + dictionary.allow_update = True # start collecting document frequencies + wiki = WikiCorpus(inp, lemmatize=lemmatize, dictionary=dictionary) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # ~4h on my macbook pro without lemmatization, 3.1m articles (august 2012) + # with HashDictionary, the token->id mapping is only fully instantiated now, after `serialize` + dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + dictionary.save_as_text(outp + '_wordids.txt.bz2') + wiki.save(outp + '_corpus.pkl.bz2') + dictionary.allow_update = False + else: + wiki = WikiCorpus(inp, lemmatize=lemmatize) # takes about 9h on a macbook pro, for 3.5m articles (june 2011) + # only keep the most frequent words (out of total ~8.2m unique tokens) + wiki.dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + # save dictionary and bag-of-words (term-document frequency matrix) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # another ~9h + wiki.dictionary.save_as_text(outp + '_wordids.txt.bz2') + # load back the id->word mapping directly from file + # this seems to save more memory, compared to keeping the wiki.dictionary object from above + dictionary = Dictionary.load_from_text(outp + '_wordids.txt.bz2') + del wiki + + # initialize corpus reader and word->id mapping + mm = MmCorpus(outp + '_bow.mm') + + # build tfidf, ~50min + tfidf = TfidfModel(mm, id2word=dictionary, normalize=True) + tfidf.save(outp + '.tfidf_model') + + # save tfidf vectors in matrix market format + # ~4h; result file is 15GB! bzip2'ed down to 4.5GB + MmCorpus.serialize(outp + '_tfidf.mm', tfidf[mm], progress_cnt=10000) + + logger.info("finished running %s" % program) diff --git a/gensim/scripts/make_wiki_online_lemma.py b/gensim/scripts/make_wiki_online_lemma.py deleted file mode 120000 index 85ddf6cc4f..0000000000 --- a/gensim/scripts/make_wiki_online_lemma.py +++ /dev/null @@ -1 +0,0 @@ -make_wikicorpus.py \ No newline at end of file diff --git a/gensim/scripts/make_wiki_online_lemma.py b/gensim/scripts/make_wiki_online_lemma.py new file mode 100755 index 0000000000..f1bee1b79b --- /dev/null +++ b/gensim/scripts/make_wiki_online_lemma.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Copyright (C) 2012 Lars Buitinck +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + +""" +USAGE: %(program)s WIKI_XML_DUMP OUTPUT_PREFIX [VOCABULARY_SIZE] + +Convert articles from a Wikipedia dump to (sparse) vectors. The input is a +bz2-compressed dump of Wikipedia articles, in XML format. + +This actually creates three files: + +* `OUTPUT_PREFIX_wordids.txt`: mapping between words and their integer ids +* `OUTPUT_PREFIX_bow.mm`: bag-of-words (word counts) representation, in + Matrix Matrix format +* `OUTPUT_PREFIX_tfidf.mm`: TF-IDF representation +* `OUTPUT_PREFIX.tfidf_model`: TF-IDF model dump + +The output Matrix Market files can then be compressed (e.g., by bzip2) to save +disk space; gensim's corpus iterators can work with compressed input, too. + +`VOCABULARY_SIZE` controls how many of the most frequent words to keep (after +removing tokens that appear in more than 10% of all documents). Defaults to +100,000. + +If you have the `pattern` package installed, this script will use a fancy +lemmatization to get a lemma of each token (instead of plain alphabetic +tokenizer). The package is available at https://github.com/clips/pattern . + +Example: python -m gensim.scripts.make_wikicorpus ~/gensim/results/enwiki-latest-pages-articles.xml.bz2 ~/gensim/results/wiki_en +""" + + +import logging +import os.path +import sys + +from gensim.corpora import Dictionary, HashDictionary, MmCorpus, WikiCorpus +from gensim.models import TfidfModel + + +# Wiki is first scanned for all distinct word types (~7M). The types that +# appear in more than 10% of articles are removed and from the rest, the +# DEFAULT_DICT_SIZE most frequent types are kept. +DEFAULT_DICT_SIZE = 100000 + + +if __name__ == '__main__': + program = os.path.basename(sys.argv[0]) + logger = logging.getLogger(program) + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s') + logging.root.setLevel(level=logging.INFO) + logger.info("running %s" % ' '.join(sys.argv)) + + # check and process input arguments + if len(sys.argv) < 3: + print(globals()['__doc__'] % locals()) + sys.exit(1) + inp, outp = sys.argv[1:3] + if len(sys.argv) > 3: + keep_words = int(sys.argv[3]) + else: + keep_words = DEFAULT_DICT_SIZE + online = 'online' in program + lemmatize = 'lemma' in program + debug = 'nodebug' not in program + + if online: + dictionary = HashDictionary(id_range=keep_words, debug=debug) + dictionary.allow_update = True # start collecting document frequencies + wiki = WikiCorpus(inp, lemmatize=lemmatize, dictionary=dictionary) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # ~4h on my macbook pro without lemmatization, 3.1m articles (august 2012) + # with HashDictionary, the token->id mapping is only fully instantiated now, after `serialize` + dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + dictionary.save_as_text(outp + '_wordids.txt.bz2') + wiki.save(outp + '_corpus.pkl.bz2') + dictionary.allow_update = False + else: + wiki = WikiCorpus(inp, lemmatize=lemmatize) # takes about 9h on a macbook pro, for 3.5m articles (june 2011) + # only keep the most frequent words (out of total ~8.2m unique tokens) + wiki.dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + # save dictionary and bag-of-words (term-document frequency matrix) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # another ~9h + wiki.dictionary.save_as_text(outp + '_wordids.txt.bz2') + # load back the id->word mapping directly from file + # this seems to save more memory, compared to keeping the wiki.dictionary object from above + dictionary = Dictionary.load_from_text(outp + '_wordids.txt.bz2') + del wiki + + # initialize corpus reader and word->id mapping + mm = MmCorpus(outp + '_bow.mm') + + # build tfidf, ~50min + tfidf = TfidfModel(mm, id2word=dictionary, normalize=True) + tfidf.save(outp + '.tfidf_model') + + # save tfidf vectors in matrix market format + # ~4h; result file is 15GB! bzip2'ed down to 4.5GB + MmCorpus.serialize(outp + '_tfidf.mm', tfidf[mm], progress_cnt=10000) + + logger.info("finished running %s" % program) diff --git a/gensim/scripts/make_wiki_online_nodebug.py b/gensim/scripts/make_wiki_online_nodebug.py deleted file mode 120000 index 85ddf6cc4f..0000000000 --- a/gensim/scripts/make_wiki_online_nodebug.py +++ /dev/null @@ -1 +0,0 @@ -make_wikicorpus.py \ No newline at end of file diff --git a/gensim/scripts/make_wiki_online_nodebug.py b/gensim/scripts/make_wiki_online_nodebug.py new file mode 100755 index 0000000000..f1bee1b79b --- /dev/null +++ b/gensim/scripts/make_wiki_online_nodebug.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Copyright (C) 2012 Lars Buitinck +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + + +""" +USAGE: %(program)s WIKI_XML_DUMP OUTPUT_PREFIX [VOCABULARY_SIZE] + +Convert articles from a Wikipedia dump to (sparse) vectors. The input is a +bz2-compressed dump of Wikipedia articles, in XML format. + +This actually creates three files: + +* `OUTPUT_PREFIX_wordids.txt`: mapping between words and their integer ids +* `OUTPUT_PREFIX_bow.mm`: bag-of-words (word counts) representation, in + Matrix Matrix format +* `OUTPUT_PREFIX_tfidf.mm`: TF-IDF representation +* `OUTPUT_PREFIX.tfidf_model`: TF-IDF model dump + +The output Matrix Market files can then be compressed (e.g., by bzip2) to save +disk space; gensim's corpus iterators can work with compressed input, too. + +`VOCABULARY_SIZE` controls how many of the most frequent words to keep (after +removing tokens that appear in more than 10% of all documents). Defaults to +100,000. + +If you have the `pattern` package installed, this script will use a fancy +lemmatization to get a lemma of each token (instead of plain alphabetic +tokenizer). The package is available at https://github.com/clips/pattern . + +Example: python -m gensim.scripts.make_wikicorpus ~/gensim/results/enwiki-latest-pages-articles.xml.bz2 ~/gensim/results/wiki_en +""" + + +import logging +import os.path +import sys + +from gensim.corpora import Dictionary, HashDictionary, MmCorpus, WikiCorpus +from gensim.models import TfidfModel + + +# Wiki is first scanned for all distinct word types (~7M). The types that +# appear in more than 10% of articles are removed and from the rest, the +# DEFAULT_DICT_SIZE most frequent types are kept. +DEFAULT_DICT_SIZE = 100000 + + +if __name__ == '__main__': + program = os.path.basename(sys.argv[0]) + logger = logging.getLogger(program) + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s') + logging.root.setLevel(level=logging.INFO) + logger.info("running %s" % ' '.join(sys.argv)) + + # check and process input arguments + if len(sys.argv) < 3: + print(globals()['__doc__'] % locals()) + sys.exit(1) + inp, outp = sys.argv[1:3] + if len(sys.argv) > 3: + keep_words = int(sys.argv[3]) + else: + keep_words = DEFAULT_DICT_SIZE + online = 'online' in program + lemmatize = 'lemma' in program + debug = 'nodebug' not in program + + if online: + dictionary = HashDictionary(id_range=keep_words, debug=debug) + dictionary.allow_update = True # start collecting document frequencies + wiki = WikiCorpus(inp, lemmatize=lemmatize, dictionary=dictionary) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # ~4h on my macbook pro without lemmatization, 3.1m articles (august 2012) + # with HashDictionary, the token->id mapping is only fully instantiated now, after `serialize` + dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + dictionary.save_as_text(outp + '_wordids.txt.bz2') + wiki.save(outp + '_corpus.pkl.bz2') + dictionary.allow_update = False + else: + wiki = WikiCorpus(inp, lemmatize=lemmatize) # takes about 9h on a macbook pro, for 3.5m articles (june 2011) + # only keep the most frequent words (out of total ~8.2m unique tokens) + wiki.dictionary.filter_extremes(no_below=20, no_above=0.1, keep_n=DEFAULT_DICT_SIZE) + # save dictionary and bag-of-words (term-document frequency matrix) + MmCorpus.serialize(outp + '_bow.mm', wiki, progress_cnt=10000) # another ~9h + wiki.dictionary.save_as_text(outp + '_wordids.txt.bz2') + # load back the id->word mapping directly from file + # this seems to save more memory, compared to keeping the wiki.dictionary object from above + dictionary = Dictionary.load_from_text(outp + '_wordids.txt.bz2') + del wiki + + # initialize corpus reader and word->id mapping + mm = MmCorpus(outp + '_bow.mm') + + # build tfidf, ~50min + tfidf = TfidfModel(mm, id2word=dictionary, normalize=True) + tfidf.save(outp + '.tfidf_model') + + # save tfidf vectors in matrix market format + # ~4h; result file is 15GB! bzip2'ed down to 4.5GB + MmCorpus.serialize(outp + '_tfidf.mm', tfidf[mm], progress_cnt=10000) + + logger.info("finished running %s" % program) diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index a1c6664c70..a46cbd2231 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -45,11 +45,16 @@ def __iter__(self): ['graph', 'minors', 'survey'] ] +new_sentences = [ + ['computer', 'artificial', 'intelligence'], + ['artificial', 'trees'] +] def testfile(): # temporary data will be stored to this file return os.path.join(tempfile.gettempdir(), 'gensim_word2vec.tst') + def rule_for_testing(word, count, min_count): if word == "human": return utils.RULE_DISCARD # throw out @@ -57,6 +62,7 @@ def rule_for_testing(word, count, min_count): return utils.RULE_DEFAULT # apply default rule, i.e. min_count class TestWord2VecModel(unittest.TestCase): + def testPersistence(self): """Test storing/loading the entire model.""" model = word2vec.Word2Vec(sentences, min_count=1) @@ -88,6 +94,17 @@ def testLambdaRule(self): model = word2vec.Word2Vec(sentences, min_count=1, trim_rule=rule) self.assertTrue("human" not in model.vocab) +class TestWord2VecModel(unittest.TestCase): + def testOnlineLearning(self): + """Test that the algorithm is able to add new words to the + vocabulary and to a trained model""" + model = word2vec.Word2Vec(sentences, min_count=1) + model.build_vocab(new_sentences, update=True) + model.train(new_sentences) + self.assertEqual(len(model.vocab), 14) + self.assertEqual(model.syn0.shape[0], 14) + + def testPersistenceWord2VecFormat(self): """Test storing/loading the entire model in word2vec format.""" model = word2vec.Word2Vec(sentences, min_count=1)