-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading fastText models using only bin file #1341
Changes from 7 commits
7759a95
c12b4fa
8025710
7ee83d9
041a6e9
22c6710
61be613
e11ac44
a63a3bc
f80410f
454d74e
e6b0d8b
9b03ea3
c496be9
2c4a8dd
d2ab903
82507d1
c44b958
0fc1159
f421b05
68ec73b
f7b372e
5f7fe02
8bd56cf
b916187
1a0bfc0
98e0287
f3d2032
bd7e7f6
800cd01
a15233a
431aebf
e52fee4
cebb3fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,11 +35,15 @@ | |
import numpy as np | ||
from numpy import float32 as REAL, sqrt, newaxis | ||
from gensim import utils | ||
from gensim.models.keyedvectors import KeyedVectors | ||
from gensim.models.keyedvectors import KeyedVectors, Vocab | ||
from gensim.models.word2vec import Word2Vec | ||
|
||
from six import string_types | ||
|
||
from numpy import exp, log, dot, zeros, outer, random, dtype, float32 as REAL,\ | ||
double, uint32, seterr, array, uint8, vstack, fromstring, sqrt, newaxis,\ | ||
ndarray, empty, sum as np_sum, prod, ones, ascontiguousarray | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
FASTTEXT_FILEFORMAT_MAGIC = 793712314 | ||
|
@@ -224,7 +228,7 @@ def load_word2vec_format(cls, *args, **kwargs): | |
return FastTextKeyedVectors.load_word2vec_format(*args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that a load using this method only learns the full-word vectors as in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this method is not used now for loading using bin only. I removed this unused code, but got a strange flake8 error for python 3+, therefore re-added this for this PR. I'll try removing these unused codes later maybe in a different PR. @gojomo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is an odd error! I suspect it's not really the presence/absence of that method that triggered it, but something else either random or hidden in the whitespace. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gojomo ok, test passed this time after removing this code 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference, this was a bug in the flake8 script, fixed in cebb3fc |
||
|
||
@classmethod | ||
def load_fasttext_format(cls, model_file, encoding='utf8'): | ||
def load_fasttext_format(cls, model_file, bin_only = False, encoding='utf8'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
""" | ||
Load the input-hidden weight matrix from the fast text output files. | ||
|
||
|
@@ -237,8 +241,11 @@ def load_fasttext_format(cls, model_file, encoding='utf8'): | |
|
||
""" | ||
model = cls() | ||
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding) | ||
model.load_binary_data('%s.bin' % model_file, encoding=encoding) | ||
if bin_only: | ||
model.load_binary_data('%s.bin' % model_file, bin_only, encoding=encoding) | ||
else: | ||
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding) | ||
model.load_binary_data('%s.bin' % model_file, encoding=encoding) | ||
return model | ||
|
||
@classmethod | ||
|
@@ -251,11 +258,11 @@ def delete_training_files(cls, model_file): | |
logger.debug('Training files %s not found when attempting to delete', model_file) | ||
pass | ||
|
||
def load_binary_data(self, model_binary_file, encoding='utf8'): | ||
def load_binary_data(self, model_binary_file, bin_only = False, encoding='utf8'): | ||
"""Loads data from the output binary file created by FastText training""" | ||
with utils.smart_open(model_binary_file, 'rb') as f: | ||
self.load_model_params(f) | ||
self.load_dict(f, encoding=encoding) | ||
self.load_dict(f, bin_only, encoding=encoding) | ||
self.load_vectors(f) | ||
|
||
def load_model_params(self, file_handle): | ||
|
@@ -281,15 +288,22 @@ def load_model_params(self, file_handle): | |
self.wv.max_n = maxn | ||
self.sample = t | ||
|
||
def load_dict(self, file_handle, encoding='utf8'): | ||
def load_dict(self, file_handle, bin_only = False, encoding='utf8'): | ||
vocab_size, nwords, _ = self.struct_unpack(file_handle, '@3i') | ||
# Vocab stored by [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc) | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
assert len(self.wv.vocab) == vocab_size, 'mismatch between vocab sizes' | ||
self.struct_unpack(file_handle, '@1q') # number of tokens | ||
if not bin_only: | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also log |
||
if len(self.wv.vocab) != vocab_size: | ||
logger.warnings("If you are loading any model other than pretrained vector wiki.fr, ") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
logger.warnings("Please report to gensim or fastText.") | ||
else: | ||
self.wv.syn0 = zeros((vocab_size, self.vector_size), dtype=REAL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't belong here, we need to be loading the vectors for the in-vocab words after loading the ngram weight matrix ( |
||
logger.info("here?") | ||
# TO-DO : how to update this | ||
ntokens= self.struct_unpack(file_handle, '@1q') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please take care of code style: space before |
||
if self.new_format: | ||
pruneidx_size, = self.struct_unpack(file_handle, '@q') | ||
for i in range(nwords): | ||
for i in range(vocab_size): | ||
word_bytes = b'' | ||
char_byte = file_handle.read(1) | ||
# Read vocab word | ||
|
@@ -298,8 +312,24 @@ def load_dict(self, file_handle, encoding='utf8'): | |
char_byte = file_handle.read(1) | ||
word = word_bytes.decode(encoding) | ||
count, _ = self.struct_unpack(file_handle, '@qb') | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
if bin_only and word not in self.wv.vocab: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unnecessarily convoluted - can the word be in |
||
self.wv.vocab[word] = Vocab(index=i, count=count) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? The word There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, you were right about it, this is the last word, but we can't skip reading it otherwise there will be error in further bytes reading. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I understand we have to ready the bytes, agreed. if i == nwords and i < vocab_size:
assert word == "__label__"
continue # don't add word to vocab
|
||
elif not bin_only: | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
|
||
if bin_only: | ||
#self.wv.syn0[i] = weight # How to get weight vector for each word ? | ||
self.wv.index2word.append(word) | ||
|
||
if bin_only: | ||
if self.wv.syn0.shape[0] != len(self.wv.vocab): | ||
logger.info( | ||
"duplicate words detected, shrinking matrix size from %i to %i", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand the need for this. Can duplicate words exist? |
||
self.wv.syn0.shape[0], len(self.wv.vocab) | ||
) | ||
self.wv.syn0 = ascontiguousarray(result.syn0[: len(self.wv.vocab)]) | ||
assert (len(self.wv.vocab), self.vector_size) == self.wv.syn0.shape | ||
|
||
if self.new_format: | ||
for j in range(pruneidx_size): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,37 @@ def testLoadFastTextNewFormat(self): | |
self.assertEquals(self.test_new_model.wv.min_n, 3) | ||
self.model_sanity(new_model) | ||
|
||
def testLoadBinOnly(self): | ||
""" Test model succesfully loaded from fastText (new format) .bin files only """ | ||
new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file, bin_only = True) | ||
vocab_size, model_size = 1763, 10 | ||
self.assertEqual(self.test_new_model.wv.syn0.shape, (vocab_size, model_size)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we be testing |
||
self.assertEqual(len(self.test_new_model.wv.vocab), vocab_size, model_size) | ||
self.assertEqual(self.test_new_model.wv.syn0_all.shape, (self.test_new_model.num_ngram_vectors, model_size)) | ||
|
||
expected_vec_new = [-0.025627, | ||
-0.11448, | ||
0.18116, | ||
-0.96779, | ||
0.2532, | ||
-0.93224, | ||
0.3929, | ||
0.12679, | ||
-0.19685, | ||
-0.13179] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin < queries.txt | ||
|
||
self.assertTrue(numpy.allclose(self.test_new_model["hundred"], expected_vec_new, 0.001)) | ||
self.assertEquals(self.test_new_model.min_count, 5) | ||
self.assertEquals(self.test_new_model.window, 5) | ||
self.assertEquals(self.test_new_model.iter, 5) | ||
self.assertEquals(self.test_new_model.negative, 5) | ||
self.assertEquals(self.test_new_model.sample, 0.0001) | ||
self.assertEquals(self.test_new_model.bucket, 1000) | ||
self.assertEquals(self.test_new_model.wv.max_n, 6) | ||
self.assertEquals(self.test_new_model.wv.min_n, 3) | ||
self.model_sanity(new_model) | ||
|
||
|
||
def testLoadModelWithNonAsciiVocab(self): | ||
"""Test loading model with non-ascii words in vocab""" | ||
model = fasttext.FastText.load_fasttext_format(datapath('non_ascii_fasttext')) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need all these imports?