diff --git a/docs/notebooks/FastText_Tutorial.ipynb b/docs/notebooks/FastText_Tutorial.ipynb index 31144c3b96..96a977ab0e 100644 --- a/docs/notebooks/FastText_Tutorial.ipynb +++ b/docs/notebooks/FastText_Tutorial.ipynb @@ -45,9 +45,7 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -99,9 +97,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -140,9 +136,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -176,9 +170,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -224,9 +216,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "ename": "KeyError", @@ -258,9 +248,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -295,9 +283,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -314,8 +300,8 @@ ] }, "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -336,9 +322,7 @@ { "cell_type": "code", "execution_count": 8, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -356,8 +340,8 @@ ] }, "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -368,9 +352,7 @@ { "cell_type": "code", "execution_count": 9, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -379,8 +361,8 @@ ] }, "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -390,9 +372,7 @@ { "cell_type": "code", "execution_count": 10, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -401,8 +381,8 @@ ] }, "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -412,9 +392,7 @@ { "cell_type": "code", "execution_count": 11, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -432,8 +410,8 @@ ] }, "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -443,9 +421,7 @@ { "cell_type": "code", "execution_count": 13, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -539,8 +515,8 @@ ] }, "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -550,9 +526,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -561,8 +535,8 @@ ] }, "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" + "output_type": "execute_result", + "metadata": {} } ], "source": [ @@ -592,7 +566,7 @@ "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 2.0 }, "file_extension": ".py", "mimetype": "text/x-python", @@ -604,4 +578,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index 00712f0685..2d78305c71 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -53,7 +53,7 @@ from gensim.utils import call_on_class_only from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc -from gensim.models.word2vec import Word2Vec, Vocab, train_cbow_pair, train_sg_pair, train_batch_sg +from gensim.models.word2vec import Word2Vec, train_cbow_pair, train_sg_pair, train_batch_sg from six.moves import xrange, zip from six import string_types, integer_types, itervalues diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index d97a5ae497..9a66f56287 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -22,7 +22,7 @@ from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc from gensim.corpora.dictionary import Dictionary -from six import string_types +from six import string_types, iteritems from six.moves import xrange from scipy import stats @@ -30,6 +30,24 @@ logger = logging.getLogger(__name__) +class Vocab(object): + """ + A single vocabulary item, used internally for collecting per-word frequency/sampling info, + and for constructing binary trees (incl. both word leaves and inner nodes). + + """ + def __init__(self, **kwargs): + self.count = 0 + self.__dict__.update(kwargs) + + def __lt__(self, other): # used for sorting in a priority queue + return self.count < other.count + + def __str__(self): + vals = ['%s:%r' % (key, self.__dict__[key]) for key in sorted(self.__dict__) if not key.startswith('_')] + return "%s(%s)" % (self.__class__.__name__, ', '.join(vals)) + + class KeyedVectors(utils.SaveLoad): """ Class to contain vectors and vocab for the Word2Vec training class and other w2v methods not directly @@ -46,14 +64,150 @@ def save(self, *args, **kwargs): kwargs['ignore'] = kwargs.get('ignore', ['syn0norm']) super(KeyedVectors, self).save(*args, **kwargs) + def save_word2vec_format(self, fname, fvocab=None, binary=False): + """ + Store the input-hidden weight matrix in the same format used by the original + C word2vec-tool, for compatibility. + + `fname` is the file used to save the vectors in + `fvocab` is an optional file used to save the vocabulary + `binary` is an optional boolean indicating whether the data is to be saved + in binary word2vec format (default: False) + + """ + vector_size = self.syn0.shape[1] + if fvocab is not None: + logger.info("storing vocabulary in %s" % (fvocab)) + with utils.smart_open(fvocab, 'wb') as vout: + for word, vocab in sorted(iteritems(self.vocab), key=lambda item: -item[1].count): + vout.write(utils.to_utf8("%s %s\n" % (word, vocab.count))) + logger.info("storing %sx%s projection weights into %s" % (len(self.vocab), vector_size, fname)) + assert (len(self.vocab), vector_size) == self.syn0.shape + with utils.smart_open(fname, 'wb') as fout: + fout.write(utils.to_utf8("%s %s\n" % self.syn0.shape)) + # store in sorted order: most frequent words at the top + for word, vocab in sorted(iteritems(self.vocab), key=lambda item: -item[1].count): + row = self.syn0[vocab.index] + if binary: + fout.write(utils.to_utf8(word) + b" " + row.tostring()) + else: + fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join("%f" % val for val in row)))) + + + @classmethod + def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict', + limit=None, datatype=REAL): + """ + Load the input-hidden weight matrix from the original C word2vec-tool format. + + Note that the information stored in the file is incomplete (the binary tree is missing), + so while you can query for word similarity etc., you cannot continue training + with a model loaded this way. + + `binary` is a boolean indicating whether the data is in binary word2vec format. + `norm_only` is a boolean indicating whether to only store normalised word2vec vectors in memory. + Word counts are read from `fvocab` filename, if set (this is the file generated + by `-save-vocab` flag of the original C tool). + + If you trained the C model using non-utf8 encoding for words, specify that + encoding in `encoding`. + + `unicode_errors`, default 'strict', is a string suitable to be passed as the `errors` + argument to the unicode() (Python 2.x) or str() (Python 3.x) function. If your source + file may include word tokens truncated in the middle of a multibyte unicode character + (as is common from the original word2vec.c tool), 'ignore' or 'replace' may help. + + `limit` sets a maximum number of word-vectors to read from the file. The default, + None, means read all. + + `datatype` (experimental) can coerce dimensions to a non-default float type (such + as np.float16) to save memory. (Such types may result in much slower bulk operations + or incompatibility with optimized routines.) + + """ + counts = None + if fvocab is not None: + logger.info("loading word counts from %s", fvocab) + counts = {} + with utils.smart_open(fvocab) as fin: + for line in fin: + word, count = utils.to_unicode(line).strip().split() + counts[word] = int(count) + + logger.info("loading projection weights from %s", fname) + with utils.smart_open(fname) as fin: + header = utils.to_unicode(fin.readline(), encoding=encoding) + vocab_size, vector_size = map(int, header.split()) # throws for invalid file format + if limit: + vocab_size = min(vocab_size, limit) + result = cls() + result.syn0 = zeros((vocab_size, vector_size), dtype=datatype) + + def add_word(word, weights): + word_id = len(result.vocab) + if word in result.vocab: + logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname) + return + if counts is None: + # most common scenario: no vocab file given. just make up some bogus counts, in descending order + result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id) + elif word in counts: + # use count from the vocab file + result.vocab[word] = Vocab(index=word_id, count=counts[word]) + else: + # vocab file given, but word is missing -- set count to None (TODO: or raise?) + logger.warning("vocabulary file is incomplete: '%s' is missing", word) + result.vocab[word] = Vocab(index=word_id, count=None) + result.syn0[word_id] = weights + result.index2word.append(word) + + if binary: + binary_len = dtype(REAL).itemsize * vector_size + for line_no in xrange(vocab_size): + # mixed text and binary: read text first, then binary + word = [] + while True: + ch = fin.read(1) + if ch == b' ': + break + if ch == b'': + raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") + if ch != b'\n': # ignore newlines in front of words (some binary files have) + word.append(ch) + word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors) + weights = fromstring(fin.read(binary_len), dtype=REAL) + add_word(word, weights) + else: + for line_no in xrange(vocab_size): + line = fin.readline() + if line == b'': + raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") + parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ") + if len(parts) != vector_size + 1: + raise ValueError("invalid vector on line %s (is this really the text format?)" % (line_no)) + word, weights = parts[0], list(map(REAL, parts[1:])) + add_word(word, weights) + if result.syn0.shape[0] != len(result.vocab): + logger.info( + "duplicate words detected, shrinking matrix size from %i to %i", + result.syn0.shape[0], len(result.vocab) + ) + result.syn0 = ascontiguousarray(result.syn0[: len(result.vocab)]) + assert (len(result.vocab), vector_size) == result.syn0.shape + + logger.info("loaded %s matrix from %s" % (result.syn0.shape, fname)) + return result + def word_vec(self, word, use_norm=False): """ Accept a single word as input. Returns the word's representations in vector space, as a 1D numpy array. + If `use_norm` is True, returns the normalized word vector. + Example:: - >>> trained_model.word_vec('office', use_norm=True) + >>> trained_model['office'] array([ -1.40128313e-02, ...]) """ diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index ccdf93c8a1..0a59224e9b 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -80,7 +80,7 @@ from gensim.utils import keep_vocab_item, call_on_class_only from gensim.utils import keep_vocab_item -from gensim.models.keyedvectors import KeyedVectors +from gensim.models.keyedvectors import KeyedVectors, Vocab try: from queue import Queue, Empty @@ -322,23 +322,6 @@ def score_cbow_pair(model, word, word2_indices, l1): return sum(lprob) -class Vocab(object): - """ - A single vocabulary item, used internally for collecting per-word frequency/sampling info, - and for constructing binary trees (incl. both word leaves and inner nodes). - - """ - def __init__(self, **kwargs): - self.count = 0 - self.__dict__.update(kwargs) - - def __lt__(self, other): # used for sorting in a priority queue - return self.count < other.count - - def __str__(self): - vals = ['%s:%r' % (key, self.__dict__[key]) for key in sorted(self.__dict__) if not key.startswith('_')] - return "%s(%s)" % (self.__class__.__name__, ', '.join(vals)) - class Word2Vec(utils.SaveLoad): """ @@ -1109,136 +1092,19 @@ def seeded_vector(self, seed_string): once = random.RandomState(self.hashfxn(seed_string) & 0xffffffff) return (once.rand(self.vector_size) - 0.5) / self.vector_size - def save_word2vec_format(self, fname, fvocab=None, binary=False): - """ - Store the input-hidden weight matrix in the same format used by the original - C word2vec-tool, for compatibility. - - `fname` is the file used to save the vectors in - `fvocab` is an optional file used to save the vocabulary - `binary` is an optional boolean indicating whether the data is to be saved - in binary word2vec format (default: False) - - """ - if fvocab is not None: - logger.info("storing vocabulary in %s" % (fvocab)) - with utils.smart_open(fvocab, 'wb') as vout: - for word, vocab in sorted(iteritems(self.wv.vocab), key=lambda item: -item[1].count): - vout.write(utils.to_utf8("%s %s\n" % (word, vocab.count))) - logger.info("storing %sx%s projection weights into %s" % (len(self.wv.vocab), self.vector_size, fname)) - assert (len(self.wv.vocab), self.vector_size) == self.wv.syn0.shape - with utils.smart_open(fname, 'wb') as fout: - fout.write(utils.to_utf8("%s %s\n" % self.wv.syn0.shape)) - # store in sorted order: most frequent words at the top - for word, vocab in sorted(iteritems(self.wv.vocab), key=lambda item: -item[1].count): - row = self.wv.syn0[vocab.index] - if binary: - fout.write(utils.to_utf8(word) + b" " + row.tostring()) - else: - fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join("%f" % val for val in row)))) + def save_word2vec_format(self, *args, **kwargs): + if Word2Vec.keyed_vector_warnings: + logger.warning('word2vec.save_word2vec_format will be deprected in future gensim releases. Please use model.wv.save_word2vec_format') + return self.wv.save_word2vec_format(*args, **kwargs) @classmethod - def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict', - limit=None, datatype=REAL): - """ - Load the input-hidden weight matrix from the original C word2vec-tool format. - - Note that the information stored in the file is incomplete (the binary tree is missing), - so while you can query for word similarity etc., you cannot continue training - with a model loaded this way. - - `binary` is a boolean indicating whether the data is in binary word2vec format. - `norm_only` is a boolean indicating whether to only store normalised word2vec vectors in memory. - Word counts are read from `fvocab` filename, if set (this is the file generated - by `-save-vocab` flag of the original C tool). - - If you trained the C model using non-utf8 encoding for words, specify that - encoding in `encoding`. - - `unicode_errors`, default 'strict', is a string suitable to be passed as the `errors` - argument to the unicode() (Python 2.x) or str() (Python 3.x) function. If your source - file may include word tokens truncated in the middle of a multibyte unicode character - (as is common from the original word2vec.c tool), 'ignore' or 'replace' may help. - - `limit` sets a maximum number of word-vectors to read from the file. The default, - None, means read all. - - `datatype` (experimental) can coerce dimensions to a non-default float type (such - as np.float16) to save memory. (Such types may result in much slower bulk operations - or incompatibility with optimized routines.) - - """ - counts = None - if fvocab is not None: - logger.info("loading word counts from %s", fvocab) - counts = {} - with utils.smart_open(fvocab) as fin: - for line in fin: - word, count = utils.to_unicode(line).strip().split() - counts[word] = int(count) - - logger.info("loading projection weights from %s", fname) - with utils.smart_open(fname) as fin: - header = utils.to_unicode(fin.readline(), encoding=encoding) - vocab_size, vector_size = map(int, header.split()) # throws for invalid file format - if limit: - vocab_size = min(vocab_size, limit) - result = cls(size=vector_size) - result.wv.syn0 = zeros((vocab_size, vector_size), dtype=datatype) - - def add_word(word, weights): - word_id = len(result.wv.vocab) - if word in result.wv.vocab: - logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname) - return - if counts is None: - # most common scenario: no vocab file given. just make up some bogus counts, in descending order - result.wv.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id) - elif word in counts: - # use count from the vocab file - result.wv.vocab[word] = Vocab(index=word_id, count=counts[word]) - else: - # vocab file given, but word is missing -- set count to None (TODO: or raise?) - logger.warning("vocabulary file is incomplete: '%s' is missing", word) - result.wv.vocab[word] = Vocab(index=word_id, count=None) - result.wv.syn0[word_id] = weights - result.wv.index2word.append(word) + def load_word2vec_format(cls, *args, **kwargs): + if Word2Vec.keyed_vector_warnings: + logger.warning('Word2vec.load_word2vec_format will be deprected in future gensim releases. Please use KeyedVectors.load_word2vec_format') - if binary: - binary_len = dtype(REAL).itemsize * vector_size - for line_no in xrange(vocab_size): - # mixed text and binary: read text first, then binary - word = [] - while True: - ch = fin.read(1) - if ch == b' ': - break - if ch == b'': - raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") - if ch != b'\n': # ignore newlines in front of words (some binary files have) - word.append(ch) - word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors) - weights = fromstring(fin.read(binary_len), dtype=REAL) - add_word(word, weights) - else: - for line_no in xrange(vocab_size): - line = fin.readline() - if line == b'': - raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") - parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ") - if len(parts) != vector_size + 1: - raise ValueError("invalid vector on line %s (is this really the text format?)" % (line_no)) - word, weights = parts[0], list(map(REAL, parts[1:])) - add_word(word, weights) - if result.wv.syn0.shape[0] != len(result.wv.vocab): - logger.info( - "duplicate words detected, shrinking matrix size from %i to %i", - result.wv.syn0.shape[0], len(result.wv.vocab) - ) - result.wv.syn0 = ascontiguousarray(result.wv.syn0[: len(result.wv.vocab)]) - assert (len(result.wv.vocab), result.vector_size) == result.wv.syn0.shape - - logger.info("loaded %s matrix from %s" % (result.wv.syn0.shape, fname)) + wv = KeyedVectors.load_word2vec_format(*args, **kwargs) + result = cls(size=wv.syn0.shape[1]) + result.wv = wv return result def intersect_word2vec_format(self, fname, lockf=0.0, binary=False, encoding='utf8', unicode_errors='strict'): diff --git a/gensim/models/wrappers/__init__.py b/gensim/models/wrappers/__init__.py index 8933171250..b954e679f3 100644 --- a/gensim/models/wrappers/__init__.py +++ b/gensim/models/wrappers/__init__.py @@ -6,4 +6,4 @@ from .dtmmodel import DtmModel from .ldavowpalwabbit import LdaVowpalWabbit from .fasttext import FastText -from .wordrank import Wordrank \ No newline at end of file +from .wordrank import Wordrank diff --git a/gensim/models/wrappers/fasttext.py b/gensim/models/wrappers/fasttext.py index 81bb6ccce4..d38b736987 100644 --- a/gensim/models/wrappers/fasttext.py +++ b/gensim/models/wrappers/fasttext.py @@ -216,6 +216,10 @@ def save(self, *args, **kwargs): kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_all_norm']) super(FastText, self).save(*args, **kwargs) + @classmethod + def load_word2vec_format(cls, *args, **kwargs): + return FastTextKeyedVectors.load_word2vec_format(*args, **kwargs) + @classmethod def load_fasttext_format(cls, model_file): """ @@ -229,7 +233,8 @@ def load_fasttext_format(cls, model_file): Expected value for this example: `/path/to/train` """ - model = cls.load_word2vec_format('%s.vec' % model_file) + model = cls() + model.wv = cls.load_word2vec_format('%s.vec' % model_file) model.load_binary_data('%s.bin' % model_file) return model diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index 6c96652c28..035765e8a0 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -280,6 +280,16 @@ def testPersistenceWord2VecFormatWithVocab(self): binary_model_with_vocab = word2vec.Word2Vec.load_word2vec_format(testfile(), testvocab, binary=True) self.assertEqual(model.wv.vocab['human'].count, binary_model_with_vocab.wv.vocab['human'].count) + def testPersistenceKeyedVectorsFormatWithVocab(self): + """Test storing/loading the entire model and vocabulary in word2vec format.""" + model = word2vec.Word2Vec(sentences, min_count=1) + model.init_sims() + testvocab = os.path.join(tempfile.gettempdir(), 'gensim_word2vec.vocab') + model.wv.save_word2vec_format(testfile(), testvocab, binary=True) + kv_binary_model_with_vocab = keyedvectors.KeyedVectors.load_word2vec_format(testfile(), testvocab, binary=True) + self.assertEqual(model.wv.vocab['human'].count, kv_binary_model_with_vocab.vocab['human'].count) + + def testPersistenceWord2VecFormatCombinationWithStandardPersistence(self): """Test storing/loading the entire model and vocabulary in word2vec format chained with saving and loading via `save` and `load` methods`.""" @@ -292,6 +302,7 @@ def testPersistenceWord2VecFormatCombinationWithStandardPersistence(self): binary_model_with_vocab = word2vec.Word2Vec.load(testfile()) self.assertEqual(model.wv.vocab['human'].count, binary_model_with_vocab.wv.vocab['human'].count) + def testLargeMmap(self): """Test storing/loading the entire model.""" model = word2vec.Word2Vec(sentences, min_count=1) @@ -579,8 +590,8 @@ def testDeleteTemporaryTrainingData(self): def testNormalizeAfterTrainingData(self): model = word2vec.Word2Vec(sentences, min_count=1) - model.save_word2vec_format(testfile(), binary=True) - norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) + model.save(testfile()) + norm_only_model = word2vec.Word2Vec.load(testfile()) norm_only_model.delete_temporary_training_data(replace_word_vectors_with_normalized=True) self.assertFalse(np.allclose(model['human'], norm_only_model['human']))