diff --git a/gensim/models/wrappers/fasttext.py b/gensim/models/wrappers/fasttext.py index 8e7a966cee..e32fd9e747 100644 --- a/gensim/models/wrappers/fasttext.py +++ b/gensim/models/wrappers/fasttext.py @@ -1,8 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Author: Jayant Jain -# Copyright (C) 2017 Radim Rehurek +# Copyright (C) 2017 Radim Rehurek +# Copyright (C) 2017 Carl Saroufim # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html @@ -390,10 +390,7 @@ def struct_unpack(self, file_handle, fmt): def init_ngrams(self): """ - Computes ngrams of all words present in vocabulary and stores vectors for only those ngrams. - Vectors for other ngrams are initialized with a random uniform distribution in FastText. These - vectors are discarded here to save space. - + Computes ngrams of all words present in vocabulary and stores vectors for those ngrams. """ self.wv.ngrams = {} all_ngrams = [] @@ -403,32 +400,11 @@ def init_ngrams(self): all_ngrams += compute_ngrams(w, self.wv.min_n, self.wv.max_n) self.wv.syn0[vocab.index] += np.array(self.wv.syn0_ngrams[vocab.index]) - all_ngrams = set(all_ngrams) - self.num_ngram_vectors = len(all_ngrams) - ngram_indices = [] - for i, ngram in enumerate(all_ngrams): - ngram_hash = ft_hash(ngram) - ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket) - self.wv.ngrams[ngram] = i - self.wv.syn0_ngrams = self.wv.syn0_ngrams.take(ngram_indices, axis=0) - - ngram_weights = self.wv.syn0_ngrams - - logger.info( - "loading weights for %s words for fastText model from %s", - len(self.wv.vocab), self.file_name - ) + self.wv.syn0_ngrams = self.wv.syn0_ngrams[len(self.wv.vocab):] - for w, vocab in self.wv.vocab.items(): - word_ngrams = compute_ngrams(w, self.wv.min_n, self.wv.max_n) - for word_ngram in word_ngrams: - self.wv.syn0[vocab.index] += np.array(ngram_weights[self.wv.ngrams[word_ngram]]) - - self.wv.syn0[vocab.index] /= (len(word_ngrams) + 1) - logger.info( - "loaded %s weight matrix for fastText model from %s", - self.wv.syn0.shape, self.file_name - ) + all_ngrams = set(all_ngrams) + for ngram in all_ngrams: + self.wv.ngrams[ngram] = ft_hash(ngram) % self.bucket def compute_ngrams(word, min_n, max_n):