Skip to content

Commit

Permalink
Adding writing vocabulary, vectors, output layer for FB format (piskv…
Browse files Browse the repository at this point in the history
  • Loading branch information
lopusz committed Dec 22, 2019
1 parent fb5923e commit 41afcfa
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 29 deletions.
185 changes: 159 additions & 26 deletions gensim/models/_fasttext_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@

_END_OF_WORD_MARKER = b'\x00'

# FastText dictionary data structure holds elements of type `entry` which can have `entry_type`
# either `word` (0 :: int8) or `label` (1 :: int8). Here we deal with unsupervised case only
# so we want `word` type.
# See https://github.com/facebookresearch/fastText/blob/master/src/dictionary.h

_DICT_WORD_ENTRY_TYPE_MARKER = b'\x00'


logger = logging.getLogger(__name__)

# Constants for FastText vesrion and FastText file format magic (both int32)
Expand Down Expand Up @@ -101,6 +109,7 @@ def _yield_field_names():
yield 'nwords'
yield 'vectors_ngrams'
yield 'hidden_output'
yield 'ntokens'


_FIELD_NAMES = sorted(set(_yield_field_names()))
Expand Down Expand Up @@ -184,7 +193,8 @@ def _load_vocab(fin, new_format, encoding='utf-8'):
raise NotImplementedError("Supervised fastText models are not supported")
logger.info("loading %s words for fastText model from %s", vocab_size, fin.name)

_struct_unpack(fin, '@1q') # number of tokens
ntokens, = _struct_unpack(fin, '@q') # number of tokens

if new_format:
pruneidx_size, = _struct_unpack(fin, '@q')

Expand Down Expand Up @@ -213,7 +223,7 @@ def _load_vocab(fin, new_format, encoding='utf-8'):
for j in range(pruneidx_size):
_struct_unpack(fin, '@2i')

return raw_vocab, vocab_size, nwords
return raw_vocab, vocab_size, nwords, ntokens


def _load_matrix(fin, new_format=True):
Expand Down Expand Up @@ -319,16 +329,17 @@ def load(fin, encoding='utf-8', full_model=True):
fin = open(fin, 'rb')

magic, version = _struct_unpack(fin, '@2i')
print(version)
new_format = magic == _FASTTEXT_FILEFORMAT_MAGIC

header_spec = _NEW_HEADER_FORMAT if new_format else _OLD_HEADER_FORMAT
model = {name: _struct_unpack(fin, fmt)[0] for (name, fmt) in header_spec}


if not new_format:
model.update(dim=magic, ws=version)

raw_vocab, vocab_size, nwords = _load_vocab(fin, new_format, encoding=encoding)
model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords)
raw_vocab, vocab_size, nwords, ntokens = _load_vocab(fin, new_format, encoding=encoding)
model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords, ntokens=ntokens)

vectors_ngrams = _load_matrix(fin, new_format=new_format)

Expand Down Expand Up @@ -376,8 +387,7 @@ def _backslashreplace_backport(ex):


def _sign_model(fout):
# Reimplementation of the FastText::signModel function, see
# https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc
# Reimplementation of the [FastText::signModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc)
fout.write(_FASTTEXT_FILEFORMAT_MAGIC.tobytes())
fout.write(_FASTTEXT_VERSION.tobytes())

Expand Down Expand Up @@ -423,7 +433,7 @@ def _get_field(model, field, field_type):
# cbow = continous bag of words (default)
# sg = skip-gram
# sup = supervised
res = 1 if model.sg == 1 else 2
res = 2 if model.sg == 1 else 1
elif field == 'neg':
res = model.negative
elif field == 't':
Expand All @@ -444,34 +454,88 @@ def _get_field(model, field, field_type):

def _args_save(fout, model):

# Reimplementation of the Args::save method, see
# https://github.com/facebookresearch/fastText/blob/master/src/args.cc
# Reimplementation of the [Args::save](https://github.com/facebookresearch/fastText/blob/master/src/args.cc)

for field, field_type in _NEW_HEADER_FORMAT:
fout.write(_get_field(model, field, field_type))


def _dict_save(fout, model):
pass
def _dict_save(fout, model, encoding):
# Reimplementation of the [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc)

# out.write((char*)&size_, sizeof(int32_t));
# out.write((char*)&nwords_, sizeof(int32_t));
# out.write((char*)&nlabels_, sizeof(int32_t));
# out.write((char*)&ntokens_, sizeof(int64_t));
# out.write((char*)&pruneidx_size_, sizeof(int64_t));
# for (int32_t i = 0; i < size_; i++) {
# entry e = words_[i];
# out.write(e.word.data(), e.word.size() * sizeof(char));
# out.put(0);
# out.write((char*)&(e.count), sizeof(int64_t));
# out.write((char*)&(e.type), sizeof(entry_type));
# }
# for (const auto pair : pruneidx_) {
# out.write((char*)&(pair.first), sizeof(int32_t));
# out.write((char*)&(pair.second), sizeof(int32_t));
# }

# TODO Check what is the difference between `size` and `nwords`

fout.write(np.int32(len(model.wv.vocab)).tobytes())

fout.write(np.int32(len(model.wv.vocab)).tobytes())

# nlabels=0 <- no labels we are in unsupervised mode
fout.write(np.int32(0).tobytes())

fout.write(np.int64(model.corpus_total_words).tobytes())

def _save_vocab(fout, model):
pass
# prunedidx_size_=-1, -1 value denotes no prunning index (prunning is only supported in supervised mode)
fout.write(np.int64(-1))

for word, vocab_entry in model.wv.vocab.items():
fout.write(word.encode(encoding))
fout.write(_END_OF_WORD_MARKER)
fout.write(np.int64(vocab_entry.count).tobytes())
fout.write(_DICT_WORD_ENTRY_TYPE_MARKER)

def _save_vector_ngrams(fout, model):
pass
# We are in unsupervised case, therefore pruned_idx is empty. so we do not need to write anything else


def _save_hidden_outputs(fout, model):
pass
def _input_save(fout, model):
vocab_n, vocab_dim = model.wv.vectors_vocab.shape
ngrams_n, ngrams_dim = model.wv.vectors_ngrams.shape

assert vocab_dim == ngrams_dim
assert vocab_n == len(model.wv.vocab)
assert ngrams_n == model.wv.bucket

def _save(fout, model):
fout.write(struct.pack('@2q', vocab_n + ngrams_n, vocab_dim))
fout.write(model.wv.vectors_vocab.tobytes())
fout.write(model.wv.vectors_ngrams.tobytes())


def _output_save(fout, model):

# TODO Can model.hs and model.negative be both False?
# TODO Can model.hs and model.negative be both True?

if model.hs:
hidden_output = model.trainables.syn1
if model.negative:
hidden_output = model.trainables.syn1neg

hidden_n, hidden_dim = hidden_output.shape
fout.write(struct.pack('@2q', hidden_n, hidden_dim))
fout.write(hidden_output.tobytes())


def _save(fout, model, encoding):

# Unfortunatelly there is no documentation of the FB binary format
# This is just reimplementation of FastText::saveModel method
# See https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc
# This is just reimplementation of
# [FastText::saveModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc)

# As of writing this (12.2019) the C++ code looks as follows
#
Expand All @@ -487,18 +551,87 @@ def _save(fout, model):

_sign_model(fout)
_args_save(fout, model)
_dict_save(fout, model)
_save_vector_ngrams(fout, model)
_save_hidden_outputs(fout, model)
_dict_save(fout, model, encoding)
fout.write(struct.pack('@?', False)) # TODO Check if quantization works for unsupervised models

# Save words and ngrams vectors
_input_save(fout, model)
fout.write(struct.pack('@?', False)) # TODO Check if quantization works for unsupervised models

# Save output layers of the model
_output_save(fout, model)

def save(fout, model):

def save(fout, model, encoding='utf-8'):
if isinstance(fout, str):
with open(fout, "wb") as fout_stream:
_save(fout_stream, model)
_save(fout_stream, model, encoding)
else:
_save(fout, model)


# COMPARING FUNCTIONALITY


def _sign_load(fin):
keys = ['fileformat_magic', 'version']
vals = _struct_unpack(fin, '@2i')
return dict(zip(keys, vals))


def _load_key_fmt_list_to_dict(fin, key_fmt_list):
res = {}
for key, fmt in key_fmt_list:
res[key] = _struct_unpack(fin, fmt)[0]
return res


def _args_load(fin):
return _load_key_fmt_list_to_dict(fin, _NEW_HEADER_FORMAT)


def _dict_header_load(fin):
DICT_HEADER_FORMAT = [('size', 'i'),
('nwords', 'i'),
('nlabels', 'i'),
('ntokens', 'i'),
('pruneidx_size', '@q')]
return _load_key_fmt_list_to_dict(fin, DICT_HEADER_FORMAT)


def _yield_differing_keys(d1, d2):
assert set(d1.keys()) == set(d2.keys())

for k in d1.keys():
v1, v2 = d1[k], d2[k]

if v1 != v2:
yield k, v1, v2

return None


def _print_differences_between_dicts(d1, d2):
if d1 != d2:
for k, v1, v2 in _yield_differing_keys(d1, d2):
print('Key "%s" differs -> %s != %s' % (k, str(v1), str(v2)))


def compare_fasttext_files(fname1, fname2):

with open(fname1, 'rb') as fin1, open(fname2, 'rb') as fin2:
sign1 = _sign_load(fin1)
sign2 = _sign_load(fin2)
_print_differences_between_dicts(sign1, sign2)

args1 = _args_load(fin1)
args2 = _args_load(fin2)
_print_differences_between_dicts(args1, args2)

dict_header1 = _dict_header_load(fin1)
dict_header2 = _dict_header_load(fin2)
_print_differences_between_dicts(dict_header1, dict_header2)


if six.PY2:
codecs.register_error('backslashreplace', _backslashreplace_backport)
6 changes: 3 additions & 3 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,15 +1216,15 @@ def _load_fasttext_format(model_file, encoding='utf-8', full_model=True):
window=m.ws,
iter=m.epoch,
negative=m.neg,
hs=(m.loss == 1),
sg=(m.model == 2),
hs=int(m.loss == 2),
sg=int(m.model == 2),
bucket=m.bucket,
min_count=m.min_count,
sample=m.t,
min_n=m.minn,
max_n=m.maxn,
)

model.corpus_total_words = m.ntokens
model.vocabulary.raw_vocab = m.raw_vocab
model.vocabulary.nwords = m.nwords
model.vocabulary.vocab_size = m.vocab_size
Expand Down

0 comments on commit 41afcfa

Please sign in to comment.