Skip to content

Commit

Permalink
Make word2vec2tensor script compatible with python3 (#2147)
Browse files Browse the repository at this point in the history
* encoded strings to unicode

* added test scripts for word2vec2tensor

* added acknowledgement

* removed windows CRs

* changed filenotfound to exception to appease flake8

* addressed comments, added key-wise assert

* forgot to check flake8 again

* added dec param to pass test
  • Loading branch information
vsocrates authored and menshikh-iv committed Aug 14, 2018
1 parent f9beeaa commit 27c524d
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 114 deletions.
13 changes: 7 additions & 6 deletions gensim/scripts/word2vec2tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Vimig Socrates <[email protected]>
# Copyright (C) 2016 Loreto Parisi <[email protected]>
# Copyright (C) 2016 Silvio Olivastri <[email protected]>
# Copyright (C) 2016 Radim Rehurek <[email protected]>
Expand Down Expand Up @@ -43,6 +44,7 @@
import logging
import argparse

from smart_open import smart_open
import gensim

logger = logging.getLogger(__name__)
Expand All @@ -67,12 +69,11 @@ def word2vec2tensor(word2vec_model_path, tensor_filename, binary=False):
outfiletsv = tensor_filename + '_tensor.tsv'
outfiletsvmeta = tensor_filename + '_metadata.tsv'

with open(outfiletsv, 'w+') as file_vector:
with open(outfiletsvmeta, 'w+') as file_metadata:
for word in model.index2word:
file_metadata.write(gensim.utils.to_utf8(word) + gensim.utils.to_utf8('\n'))
vector_row = '\t'.join(str(x) for x in model[word])
file_vector.write(vector_row + '\n')
with smart_open(outfiletsv, 'wb') as file_vector, smart_open(outfiletsvmeta, 'wb') as file_metadata:
for word in model.index2word:
file_metadata.write(gensim.utils.to_utf8(word) + gensim.utils.to_utf8('\n'))
vector_row = '\t'.join(str(x) for x in model[word])
file_vector.write(gensim.utils.to_utf8(vector_row) + gensim.utils.to_utf8('\n'))

logger.info("2D tensor file saved to %s", outfiletsv)
logger.info("Tensor metadata file saved to %s", outfiletsvmeta)
Expand Down
263 changes: 155 additions & 108 deletions gensim/test/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,155 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Manos Stergiadis <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking the output of gensim.scripts.
"""

from __future__ import unicode_literals

import json
import logging
import os.path
import unittest

from gensim.scripts.segment_wiki import segment_all_articles, segment_and_write_all_articles
from smart_open import smart_open
from gensim.test.utils import datapath, get_tmpfile


class TestSegmentWiki(unittest.TestCase):

def setUp(self):
self.fname = datapath('enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2')
self.expected_title = 'Anarchism'
self.expected_section_titles = [
'Introduction',
'Etymology and terminology',
'History',
'Anarchist schools of thought',
'Internal issues and debates',
'Topics of interest',
'Criticisms',
'References',
'Further reading',
'External links'
]

def tearDown(self):
# remove all temporary test files
fname = get_tmpfile('script.tst')
extensions = ['', '.json']
for ext in extensions:
try:
os.remove(fname + ext)
except OSError:
pass

def test_segment_all_articles(self):
title, sections, interlinks = next(segment_all_articles(self.fname, include_interlinks=True))

# Check title
self.assertEqual(title, self.expected_title)

# Check section titles
section_titles = [s[0] for s in sections]
self.assertEqual(section_titles, self.expected_section_titles)

# Check text
first_section_text = sections[0][1]
first_sentence = "'''Anarchism''' is a political philosophy that advocates self-governed societies"
self.assertTrue(first_sentence in first_section_text)

# Check interlinks
self.assertTrue(interlinks['self-governance'] == 'self-governed')
self.assertTrue(interlinks['Hierarchy'] == 'hierarchical')
self.assertTrue(interlinks['Pierre-Joseph Proudhon'] == 'Proudhon')

def test_generator_len(self):
expected_num_articles = 106
num_articles = sum(1 for x in segment_all_articles(self.fname))

self.assertEqual(num_articles, expected_num_articles)

def test_json_len(self):
tmpf = get_tmpfile('script.tst.json')
segment_and_write_all_articles(self.fname, tmpf, workers=1)

expected_num_articles = 106
num_articles = sum(1 for line in smart_open(tmpf))
self.assertEqual(num_articles, expected_num_articles)

def test_segment_and_write_all_articles(self):
tmpf = get_tmpfile('script.tst.json')
segment_and_write_all_articles(self.fname, tmpf, workers=1, include_interlinks=True)

# Get the first line from the text file we created.
with open(tmpf) as f:
first = next(f)

# decode JSON line into a Python dictionary object
article = json.loads(first)
title, section_titles, interlinks = article['title'], article['section_titles'], article['interlinks']

self.assertEqual(title, self.expected_title)
self.assertEqual(section_titles, self.expected_section_titles)

# Check interlinks
self.assertTrue(interlinks['self-governance'] == 'self-governed')
self.assertTrue(interlinks['Hierarchy'] == 'hierarchical')
self.assertTrue(interlinks['Pierre-Joseph Proudhon'] == 'Proudhon')


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Vimig Socrates <[email protected]> heavily influenced from @AakaashRao
# Copyright (C) 2018 Manos Stergiadis <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking the output of gensim.scripts.
"""

from __future__ import unicode_literals

import json
import logging
import os.path
import unittest

from smart_open import smart_open
import numpy as np

from gensim.scripts.segment_wiki import segment_all_articles, segment_and_write_all_articles
from gensim.test.utils import datapath, get_tmpfile

from gensim.scripts.word2vec2tensor import word2vec2tensor
from gensim.models import KeyedVectors


class TestSegmentWiki(unittest.TestCase):

def setUp(self):
self.fname = datapath('enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2')
self.expected_title = 'Anarchism'
self.expected_section_titles = [
'Introduction',
'Etymology and terminology',
'History',
'Anarchist schools of thought',
'Internal issues and debates',
'Topics of interest',
'Criticisms',
'References',
'Further reading',
'External links'
]

def tearDown(self):
# remove all temporary test files
fname = get_tmpfile('script.tst')
extensions = ['', '.json']
for ext in extensions:
try:
os.remove(fname + ext)
except OSError:
pass

def test_segment_all_articles(self):
title, sections, interlinks = next(segment_all_articles(self.fname, include_interlinks=True))

# Check title
self.assertEqual(title, self.expected_title)

# Check section titles
section_titles = [s[0] for s in sections]
self.assertEqual(section_titles, self.expected_section_titles)

# Check text
first_section_text = sections[0][1]
first_sentence = "'''Anarchism''' is a political philosophy that advocates self-governed societies"
self.assertTrue(first_sentence in first_section_text)

# Check interlinks
self.assertTrue(interlinks['self-governance'] == 'self-governed')
self.assertTrue(interlinks['Hierarchy'] == 'hierarchical')
self.assertTrue(interlinks['Pierre-Joseph Proudhon'] == 'Proudhon')

def test_generator_len(self):
expected_num_articles = 106
num_articles = sum(1 for x in segment_all_articles(self.fname))

self.assertEqual(num_articles, expected_num_articles)

def test_json_len(self):
tmpf = get_tmpfile('script.tst.json')
segment_and_write_all_articles(self.fname, tmpf, workers=1)

expected_num_articles = 106
num_articles = sum(1 for line in smart_open(tmpf))
self.assertEqual(num_articles, expected_num_articles)

def test_segment_and_write_all_articles(self):
tmpf = get_tmpfile('script.tst.json')
segment_and_write_all_articles(self.fname, tmpf, workers=1, include_interlinks=True)

# Get the first line from the text file we created.
with open(tmpf) as f:
first = next(f)

# decode JSON line into a Python dictionary object
article = json.loads(first)
title, section_titles, interlinks = article['title'], article['section_titles'], article['interlinks']

self.assertEqual(title, self.expected_title)
self.assertEqual(section_titles, self.expected_section_titles)

# Check interlinks
self.assertTrue(interlinks['self-governance'] == 'self-governed')
self.assertTrue(interlinks['Hierarchy'] == 'hierarchical')
self.assertTrue(interlinks['Pierre-Joseph Proudhon'] == 'Proudhon')


class TestWord2Vec2Tensor(unittest.TestCase):
def setUp(self):
self.datapath = datapath('word2vec_pre_kv_c')
self.output_folder = get_tmpfile('w2v2t_test')
self.metadata_file = self.output_folder + '_metadata.tsv'
self.tensor_file = self.output_folder + '_tensor.tsv'
self.vector_file = self.output_folder + '_vector.tsv'

def testConversion(self):
word2vec2tensor(word2vec_model_path=self.datapath, tensor_filename=self.output_folder)

with smart_open(self.metadata_file, 'rb') as f:
metadata = f.readlines()

with smart_open(self.tensor_file, 'rb') as f:
vectors = f.readlines()

# check if number of words and vector size in tensor file line up with word2vec
with smart_open(self.datapath, 'rb') as f:
first_line = f.readline().strip()

number_words, vector_size = map(int, first_line.split(b' '))
self.assertTrue(len(metadata) == len(vectors) == number_words,
('Metadata file %s and tensor file %s imply different number of rows.'
% (self.metadata_file, self.tensor_file)))

# grab metadata and vectors from written file
metadata = [word.strip() for word in metadata]
vectors = [vector.replace(b'\t', b' ') for vector in vectors]

# get the originaly vector KV model
orig_model = KeyedVectors.load_word2vec_format(self.datapath, binary=False)

# check that the KV model and tensor files have the same values key-wise
for word, vector in zip(metadata, vectors):
word_string = word.decode("utf8")
vector_string = vector.decode("utf8")
vector_array = np.array(list(map(float, vector_string.split())))
np.testing.assert_almost_equal(orig_model[word_string], vector_array, decimal=5)


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit 27c524d

Please sign in to comment.