Skip to content

Commit

Permalink
Merge pull request #7 from chinmayapancholi13/word2vec_keras_integration
Browse files Browse the repository at this point in the history
[WIP] Updating Gensim's Word2vec-Keras integration
  • Loading branch information
stephenhky authored Jun 27, 2017
2 parents 189a57d + 6400793 commit 311d41e
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 32 deletions.
57 changes: 47 additions & 10 deletions shorttext/classifiers/embed/nnlib/VarNNEmbedVecClassification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

import shorttext.utils.kerasmodel_io as kerasio
import shorttext.utils.classification_exceptions as e
Expand Down Expand Up @@ -61,7 +63,7 @@ class VarNNEmbeddedVecClassifier:
>>> classifier.score('artificial intelligence')
{'mathematics': 0.57749695, 'physics': 0.33749574, 'theology': 0.085007325}
"""
def __init__(self, wvmodel, vecsize=300, maxlen=15):
def __init__(self, wvmodel, vecsize=300, maxlen=15, with_gensim=True):
""" Initialize the classifier.
:param wvmodel: Word2Vec model
Expand All @@ -74,6 +76,7 @@ def __init__(self, wvmodel, vecsize=300, maxlen=15):
self.wvmodel = wvmodel
self.vecsize = vecsize
self.maxlen = maxlen
self.with_gensim = with_gensim
self.trained = False

def convert_trainingdata_matrix(self, classdict):
Expand All @@ -99,7 +102,13 @@ def convert_trainingdata_matrix(self, classdict):
category_bucket = [0]*len(classlabels)
category_bucket[lblidx_dict[label]] = 1
indices.append(category_bucket)
phrases.append(tokenize(shorttext))
if self.with_gensim == True:
phrases.append(shorttext)
else:
phrases.append(tokenize(shorttext))

if self.with_gensim == True:
return classlabels, phrases, indices

# store embedded vectors
train_embedvec = np.zeros(shape=(len(phrases), self.maxlen, self.vecsize))
Expand All @@ -110,7 +119,6 @@ def convert_trainingdata_matrix(self, classdict):

return classlabels, train_embedvec, indices


def train(self, classdict, kerasmodel, nb_epoch=10):
""" Train the classifier.
Expand All @@ -127,11 +135,23 @@ def train(self, classdict, kerasmodel, nb_epoch=10):
:type kerasmodel: keras.models.Sequential
:type nb_epoch: int
"""
# convert classdict to training input vectors
self.classlabels, train_embedvec, indices = self.convert_trainingdata_matrix(classdict)
if self.with_gensim:
# convert classdict to training input vectors
self.classlabels, x_train, y_train = self.convert_trainingdata_matrix(classdict)

tokenizer = Tokenizer()
tokenizer.fit_on_texts(x_train)
x_train = tokenizer.texts_to_sequences(x_train)
x_train = pad_sequences(x_train, maxlen=self.maxlen)

# train the model
kerasmodel.fit(train_embedvec, indices, epochs=nb_epoch)
# train the model
kerasmodel.fit(x_train, y_train, epochs=nb_epoch)
else:
# convert classdict to training input vectors
self.classlabels, train_embedvec, indices = self.convert_trainingdata_matrix(classdict)

# train the model
kerasmodel.fit(train_embedvec, indices, epochs=nb_epoch)

# flag switch
self.model = kerasmodel
Expand Down Expand Up @@ -210,6 +230,18 @@ def shorttext_to_matrix(self, shorttext):
matrix[i] = self.word_to_embedvec(tokens[i])
return matrix

def process_text(self, shorttext):
"""Process the input text by tokenizing and padding it.
:param shorttext: a short sentence
"""
tokenizer = Tokenizer()
tokenizer.fit_on_texts(shorttext)
x_train = tokenizer.texts_to_sequences(shorttext)

x_train = pad_sequences(x_train, maxlen=self.maxlen)
return x_train

def score(self, shorttext):
""" Calculate the scores for all the class labels for the given short sentence.
Expand All @@ -227,8 +259,12 @@ def score(self, shorttext):
if not self.trained:
raise e.ModelNotTrainedException()

# retrieve vector
matrix = np.array([self.shorttext_to_matrix(shorttext)])
if self.with_gensim == True:
# tokenize and pad input text
matrix = self.process_text(shorttext)
else:
# retrieve vector
matrix = np.array([self.shorttext_to_matrix(shorttext)])

# classification using the neural network
predictions = self.model.predict(matrix)
Expand All @@ -237,6 +273,7 @@ def score(self, shorttext):
scoredict = {}
for idx, classlabel in zip(range(len(self.classlabels)), self.classlabels):
scoredict[classlabel] = predictions[0][idx]

return scoredict

def load_varnnlibvec_classifier(wvmodel, name, compact=True):
Expand All @@ -256,4 +293,4 @@ def load_varnnlibvec_classifier(wvmodel, name, compact=True):
classifier.load_compact_model(name)
else:
classifier.loadmodel(name)
return classifier
return classifier
69 changes: 47 additions & 22 deletions shorttext/classifiers/embed/nnlib/frameworks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout, LSTM
from keras.models import Sequential
from keras.models import Sequential, Model
from keras.regularizers import l2
from keras.engine import Input

# Codes were changed because of Keras.
# Keras 1 --> Keras 2: https://github.com/fchollet/keras/wiki/Keras-2.0-release-notes
Expand All @@ -9,6 +10,7 @@
# ref: https://gist.github.com/entron/b9bc61a74e7cadeb1fec
# ref: http://cs231n.github.io/convolutional-networks/
def CNNWordEmbed(nb_labels,
wvmodel=None,
nb_filters=1200,
n_gram=2,
maxlen=15,
Expand All @@ -17,14 +19,16 @@ def CNNWordEmbed(nb_labels,
final_activation='softmax',
dense_wl2reg=0.0,
dense_bl2reg=0.0,
optimizer='adam'):
optimizer='adam',
with_gensim=True):
""" Returns the convolutional neural network (CNN/ConvNet) for word-embedded vectors.
Reference: Yoon Kim, "Convolutional Neural Networks for Sentence Classification,"
*EMNLP* 2014, 1746-1751 (arXiv:1408.5882). [`arXiv
<https://arxiv.org/abs/1408.5882>`_]
:param nb_labels: number of class labels
:param wvmodel: pre-trained Gensim word2vec model
:param nb_filters: number of filters (Default: 1200)
:param n_gram: n-gram, or window size of CNN/ConvNet (Default: 2)
:param maxlen: maximum number of words in a sentence (Default: 15)
Expand All @@ -34,7 +38,8 @@ def CNNWordEmbed(nb_labels,
:param dense_wl2reg: L2 regularization coefficient (Default: 0.0)
:param dense_bl2reg: L2 regularization coefficient for bias (Default: 0.0)
:param optimizer: optimizer for gradient descent. Options: sgd, rmsprop, adagrad, adadelta, adam, adamax, nadam. (Default: adam)
:return: keras sequantial model for CNN/ConvNet for Word-Embeddings
:param with_gensim: boolean variable to indicate if the word-embeddings being used derived from a Gensim's Word2Vec model. (Default: True)
:return: keras model (`Sequential` or`Model`) for CNN/ConvNet for Word-Embeddings
:type nb_labels: int
:type nb_filters: int
:type n_gram: int
Expand All @@ -45,24 +50,46 @@ def CNNWordEmbed(nb_labels,
:type dense_wl2reg: float
:type dense_bl2reg: float
:type optimizer: str
:rtype: keras.model.Sequential
:rtype: keras.models.Sequential or keras.models.Model
"""
model = Sequential()
model.add(Conv1D(filters=nb_filters,
kernel_size=n_gram,
padding='valid',
activation='relu',
input_shape=(maxlen, vecsize)))
if cnn_dropout > 0.0:
model.add(Dropout(cnn_dropout))
model.add(MaxPooling1D(pool_size=maxlen - n_gram + 1))
model.add(Flatten())
model.add(Dense(nb_labels,
activation=final_activation,
kernel_regularizer=l2(dense_wl2reg),
bias_regularizer=l2(dense_bl2reg))
)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
if with_gensim == True:
embedding_layer = wvmodel.get_embedding_layer()
sequence_input = Input(shape=(maxlen,),
dtype='int32')
x = embedding_layer(sequence_input)
x = Conv1D(filters=nb_filters,
kernel_size=n_gram,
padding='valid',
activation='relu',
input_shape=(maxlen, vecsize))(x)
if cnn_dropout > 0.0:
x = Dropout(cnn_dropout)(x)
x = MaxPooling1D(pool_size=maxlen - n_gram + 1)(x)
x = Flatten()(x)
x = Dense(nb_labels,
activation=final_activation,
kernel_regularizer=l2(dense_wl2reg),
bias_regularizer=l2(dense_bl2reg))(x)

model = Model(sequence_input, x)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
else:
model = Sequential()
model.add(Conv1D(filters=nb_filters,
kernel_size=n_gram,
padding='valid',
activation='relu',
input_shape=(maxlen, vecsize)))
if cnn_dropout > 0.0:
model.add(Dropout(cnn_dropout))
model.add(MaxPooling1D(pool_size=maxlen - n_gram + 1))
model.add(Flatten())
model.add(Dense(nb_labels,
activation=final_activation,
kernel_regularizer=l2(dense_wl2reg),
bias_regularizer=l2(dense_bl2reg))
)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

return model

Expand Down Expand Up @@ -207,5 +234,3 @@ def CLSTMWordEmbed(nb_labels,
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

return model


0 comments on commit 311d41e

Please sign in to comment.