-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New feature: wordrank wrapper #1066
Merged
Merged
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
420e426
added wordrank wrapper
parulsethi d2f5607
update example
parulsethi c175851
add comparison ipynb
parulsethi 227546c
update graph output
parulsethi 9b100e4
try diff interpreter
parulsethi 7df527d
Merge branch 'develop' into wordrank_wrapper
parulsethi 8e9a6cb
use unittest2
parulsethi 5c42762
use unittest2
parulsethi 17b2b75
add backport_collections for py2.6
parulsethi f569120
added tutorial ipynb
parulsethi 11ac3e8
remove subprocess32
parulsethi f75c9ff
Merge branch 'develop' into wordrank_wrapper
parulsethi 7f541a2
made requested changes
parulsethi fb75890
replace with vocab for comparison
parulsethi 14d6f90
added conclusions
parulsethi 4b9271e
added some comments
parulsethi 207dd8b
changed test data loc and update check_output
parulsethi 09f4617
remove extra comment in check_output
parulsethi b3ecdd4
update check_output
parulsethi 4256252
update wordrank_wrapper's check_output call
parulsethi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,6 @@ | |
import random | ||
import tempfile | ||
import os | ||
import subprocess | ||
|
||
import numpy | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# Copyright (C) 2017 Parul Sethi <[email protected]> | ||
# Copyright (C) 2017 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Python wrapper around word representation learning from Wordrank. | ||
The wrapped model can NOT be updated with new documents for online training -- use gensim's | ||
|
@@ -30,10 +34,6 @@ | |
from smart_open import smart_open | ||
from shutil import copyfile, rmtree | ||
|
||
if sys.version_info[:2] == (2, 6): | ||
from backport_collections import Counter | ||
else: | ||
from collections import Counter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -48,15 +48,15 @@ class Wordrank(Word2Vec): | |
@classmethod | ||
def train(cls, wr_path, corpus_file, out_path, size=100, window=15, symmetric=1, min_count=5, max_vocab_size=0, | ||
sgd_num=100, lrate=0.001, period=10, iter=91, epsilon=0.75, dump_period=10, reg=0, alpha=100, | ||
beta=99, loss='hinge', memory=4.0, cleanup_files=True, sorted_vocab=1, ensemble=1): | ||
beta=99, loss='hinge', memory=4.0, cleanup_files=True, sorted_vocab=1, ensemble=0): | ||
""" | ||
`wr_path` is the path to the Wordrank directory. | ||
`corpus_file` is the filename of the text file to be used for training the Wordrank model. | ||
Expects file to contain space-separated tokens in a single line | ||
`out_path` is the path to directory which will be created to save embeddings and training data. | ||
`size` is the dimensionality of the feature vectors. | ||
`window` is the number of context words to the left (and to the right, if symmetric = 1). | ||
symmetric` if 0, only use left context words, else use left and right both. | ||
`symmetric` if 0, only use left context words, else use left and right both. | ||
`min_count` = ignore all words with total frequency lower than this. | ||
`max_vocab_size` upper bound on vocabulary size, i.e. keep the <int> most frequent words. Default is 0 for no limit. | ||
`sgd_num` number of SGD taken for each data point. | ||
|
@@ -90,37 +90,30 @@ def train(cls, wr_path, corpus_file, out_path, size=100, window=15, symmetric=1, | |
copyfile(corpus_file, os.path.join(meta_dir, corpus_file.split('/')[-1])) | ||
os.chdir(meta_dir) | ||
|
||
cmd0 = ['../../glove/vocab_count', '-min-count', str(min_count), '-max-vocab', str(max_vocab_size)] | ||
cmd1 = ['../../glove/cooccur', '-memory', str(memory), '-vocab-file', temp_vocab_file, '-window-size', str(window), '-symmetric', str(symmetric)] | ||
cmd2 = ['../../glove/shuffle', '-memory', str(memory)] | ||
cmd3 = ['cut', '-d', " ", '-f', '1', temp_vocab_file] | ||
cmds = [cmd0, cmd1, cmd2, cmd3] | ||
logger.info("Preparing training data using glove code '%s'", cmds) | ||
o0 = smart_open(temp_vocab_file, 'w') | ||
o1 = smart_open(cooccurrence_file, 'w') | ||
o2 = smart_open(cooccurrence_shuf_file, 'w') | ||
o3 = smart_open(vocab_file, 'w') | ||
i0 = smart_open(corpus_file.split('/')[-1]) | ||
i1 = smart_open(corpus_file.split('/')[-1]) | ||
i2 = smart_open(cooccurrence_file) | ||
i3 = None | ||
outputs = [o0, o1, o2, o3] | ||
inputs = [i0, i1, i2, i3] | ||
prepare_train_data = [utils.check_output(cmd, stdin=inp, stdout=out) for cmd, inp, out in zip(cmds, inputs, outputs)] | ||
o0.close() | ||
o1.close() | ||
o2.close() | ||
o3.close() | ||
i0.close() | ||
i1.close() | ||
i2.close() | ||
|
||
with smart_open(vocab_file) as f: | ||
cmd_vocab_count = ['../../glove/vocab_count', '-min-count', str(min_count), '-max-vocab', str(max_vocab_size)] | ||
cmd_cooccurence_count = ['../../glove/cooccur', '-memory', str(memory), '-vocab-file', temp_vocab_file, '-window-size', str(window), '-symmetric', str(symmetric)] | ||
cmd_shuffle_cooccurences = ['../../glove/shuffle', '-memory', str(memory)] | ||
cmd_del_vocab_freq = ['cut', '-d', " ", '-f', '1', temp_vocab_file] | ||
|
||
commands = [cmd_vocab_count, cmd_cooccurence_count, cmd_shuffle_cooccurences] | ||
logger.info("Prepare training data using glove code '%s'", commands) | ||
input_fnames = [corpus_file.split('/')[-1], corpus_file.split('/')[-1], cooccurrence_file] | ||
output_fnames = [temp_vocab_file, cooccurrence_file, cooccurrence_shuf_file] | ||
|
||
for command, input_fname, output_fname in zip(commands, input_fnames, output_fnames): | ||
with smart_open(input_fname, 'rb') as r: | ||
with smart_open(output_fname, 'wb') as w: | ||
utils.check_output(command, stdin=r, stdout=w) | ||
with smart_open(vocab_file, 'wb') as w: | ||
utils.check_output(cmd_del_vocab_freq, stdout=w) | ||
|
||
with smart_open(vocab_file, 'rb') as f: | ||
numwords = sum(1 for line in f) | ||
with smart_open(cooccurrence_shuf_file) as f: | ||
with smart_open(cooccurrence_shuf_file, 'rb') as f: | ||
numlines = sum(1 for line in f) | ||
with smart_open(meta_file, 'w') as f: | ||
f.write("{0} {1}\n{2} {3}\n{4} {5}".format(numwords, numwords, numlines, cooccurrence_shuf_file, numwords, vocab_file)) | ||
with smart_open(meta_file, 'wb') as f: | ||
meta_info = "{0} {1}\n{2} {3}\n{4} {5}".format(numwords, numwords, numlines, cooccurrence_shuf_file, numwords, vocab_file) | ||
f.write(meta_info.encode('utf-8')) | ||
|
||
wr_args = { | ||
'path': 'meta', | ||
|
@@ -189,11 +182,11 @@ def sort_embeddings(self, vocab_file): | |
self.wv.vocab[word].count = counts[word] | ||
|
||
def ensemble_embedding(self, word_embedding, context_embedding): | ||
"""Addition of two embeddings.""" | ||
"""Replace syn0 with the sum of context and word embeddings.""" | ||
glove2word2vec(context_embedding, context_embedding+'.w2vformat') | ||
w_emb = Word2Vec.load_word2vec_format('%s.w2vformat' % word_embedding) | ||
c_emb = Word2Vec.load_word2vec_format('%s.w2vformat' % context_embedding) | ||
assert Counter(w_emb.wv.index2word) == Counter(c_emb.wv.index2word), 'Vocabs are not same for both embeddings' | ||
assert set(w_emb.wv.index2word) == set(c_emb.wv.index2word), 'Vocabs are not same for both embeddings' | ||
|
||
prev_c_emb = copy.deepcopy(c_emb.wv.syn0) | ||
for word_id, word in enumerate(w_emb.wv.index2word): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to compare
wv.vocab
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, similarly using set(wv.vocab)
correcting in next commit