Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added save method for doc2vec #1256

Merged
merged 9 commits into from
Apr 19, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.models.word2vec import Word2Vec, train_cbow_pair, train_sg_pair, train_batch_sg
from six.moves import xrange, zip
from six import string_types, integer_types, itervalues
from six import string_types, integer_types, itervalues, iteritems

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -808,6 +808,45 @@ def delete_temporary_training_data(self, keep_doctags_vectors=True, keep_inferen
if self.docvecs and hasattr(self.docvecs, 'doctag_syn0_lockf'):
del self.docvecs.doctag_syn0_lockf

def save_doc2vec_format(self, fname, word_vec=False, binary=False):
"""
Store the input-hidden weight matrix.

`fname` is the file used to save the vectors in
`word_vec` is an optional boolean indicating whether to store word vectors
in the same file as document vectors
`binary` is an optional boolean indicating whether the data is to be saved
in binary word2vec format (default: False)

"""
total_vec = len(self.docvecs)
if word_vec:
total_vec = total_vec + len(self.wv.vocab)
logger.info("storing %sx%s projection weights into %s" % (total_vec, self.vector_size, fname))

# save document vectors
with utils.smart_open(fname, 'wb') as fout:
fout.write(utils.to_utf8("%s %s\n" % (total_vec, self.vector_size)))
# store as in input order
for i in range(len(self.docvecs)):
doctag = self.docvecs.index_to_doctag(i)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will work in case user's model.docvecs.doctags is empty, and will assign the vector index as doctag

row = self.docvecs.doctag_syn0[i]
if binary:
fout.write(utils.to_utf8(doctag) + b" " + row.tostring())
else:
fout.write(utils.to_utf8("%s %s\n" % (doctag, ' '.join("%f" % val for val in row))))

# save word vectors
if word_vec:
with utils.smart_open(fname, 'ab') as fout:
for word, vocab in sorted(iteritems(self.wv.vocab), key=lambda item: -item[1].count):
row = self.wv.syn0[vocab.index]
if binary:
fout.write(utils.to_utf8(word) + b" " + row.tostring())
else:
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join("%f" % val for val in row))))


class TaggedBrownCorpus(object):
"""Iterate over documents from the Brown corpus (part of NLTK data), yielding
each document out as a TaggedDocument object."""
Expand Down