Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove bias in augmented loss + fix perplexity #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions model/augmented_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding, Dense, TimeDistributed, LSTM, Activation, Dropout, Lambda
from keras import backend as K
import tensorflow as tf
from keras.losses import kullback_leibler_divergence
from model.lang_model_sgd import LangModelSGD
from model.one_hot_model import OneHotModel
from model.utils import Bias
from model.utils import perplexity
from model.utils import log_perplexity


class AugmentedModel(OneHotModel):
Expand All @@ -27,42 +28,29 @@ def __init__(self,
self.model.pop() # remove activation
self.model.pop() # remove projection (use self embedding)
self.model.add(Lambda(lambda x: K.dot(x, K.transpose(self.embedding.embeddings))))
self.model.add(Bias())
self.model.add(Activation("softmax"))

def augmented_loss(self, y_true, y_pred):
_y_pred = Activation("softmax")(y_pred)
loss = K.categorical_crossentropy(_y_pred, y_true)
loss = K.categorical_crossentropy(y_true, y_pred)

# y is (batch x seq x vocab)
y_indexes = K.argmax(y_true, axis=2) # turn one hot to index. (batch x seq)
y_vectors = self.embedding(y_indexes) # lookup the vector (batch x seq x vector_length)

#v_length = self.setting.vector_length
#y_vectors = K.reshape(y_vectors, (-1, v_length))
#y_t = K.map_fn(lambda v: K.dot(self.embedding.embeddings, K.reshape(v, (-1, 1))), y_vectors)
#y_t = K.squeeze(y_t, axis=2) # unknown but necessary operation
#y_t = K.reshape(y_t, (-1, self.sequence_size, self.vocab_size))

# vector x embedding dot products (batch x seq x vocab)
y_t = tf.tensordot(y_vectors, K.transpose(self.embedding.embeddings), 1)
y_t = K.reshape(y_t, (-1, self.sequence_size, self.vocab_size)) # explicitly set shape
y_t = K.dot(y_vectors, K.transpose(self.embedding.embeddings))
y_t = K.softmax(y_t / self.temperature)
_y_pred_t = Activation("softmax")(y_pred / self.temperature)
aug_loss = kullback_leibler_divergence(y_t, _y_pred_t)
y_pred_t = K.softmax((K.log(y_pred) - self.model.layers[-2].bias) / self.temperature)
aug_loss = kullback_leibler_divergence(y_t, y_pred_t)
loss += (self.gamma * self.temperature) * aug_loss
return loss

@classmethod
def perplexity(cls, y_true, y_pred):
_y_pred = Activation("softmax")(y_pred)
return super(AugmentedModel, cls).perplexity(y_true, _y_pred)

def compile(self):
self.model.pop() # remove activation (to calculate aug loss)
self.model.compile(
loss=self.augmented_loss,
optimizer=LangModelSGD(self.setting),
metrics=["accuracy", self.perplexity]
metrics=["accuracy", log_perplexity, perplexity]
)

def get_name(self):
Expand Down
3 changes: 0 additions & 3 deletions model/lang_model_sgd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import copy
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.optimizers import Optimizer
from keras.callbacks import LearningRateScheduler
Expand Down
16 changes: 6 additions & 10 deletions model/one_hot_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Embedding, Dense, TimeDistributed, LSTM, Activation, Dropout
from keras import losses
from keras import backend as K
from keras.callbacks import ModelCheckpoint, TensorBoard
from model.lang_model_sgd import LangModelSGD
from model.setting import Setting
from model.utils import PerplexityLogger
from model.utils import perplexity
from model.utils import log_perplexity


class OneHotModel():
Expand All @@ -31,7 +33,7 @@ def __init__(self,
self.embedding = Embedding(self.vocab_size, vector_length, input_length=sequence_size)
layer1 = LSTM(vector_length, return_sequences=True, dropout=dropout, recurrent_dropout=dropout)
layer2 = LSTM(vector_length, return_sequences=True, dropout=dropout, recurrent_dropout=dropout)
projection = TimeDistributed(Dense(self.vocab_size))
projection = Dense(self.vocab_size)
self.model = Sequential()
self.model.add(self.embedding)
self.model.add(layer1)
Expand All @@ -43,14 +45,8 @@ def compile(self):
self.model.compile(
loss=losses.categorical_crossentropy,
optimizer=LangModelSGD(self.setting),
metrics=["accuracy", self.perplexity]
metrics=["accuracy", log_perplexity, perplexity]
)

@classmethod
def perplexity(cls, y_true, y_pred):
cross_entropy = K.mean(K.categorical_crossentropy(y_pred, y_true), axis=-1)
perplexity = K.exp(cross_entropy)
return perplexity

def fit(self, x_train, y_train, x_test, y_test, batch_size=20, epochs=20):
self.model.fit(
Expand All @@ -72,7 +68,7 @@ def fit_generator(self, generator, steps_per_epoch, test_generator, test_steps_p
)

def _get_callbacks(self):
callbacks = [self.model.optimizer.get_lr_scheduler()]
callbacks = [PerplexityLogger(), self.model.optimizer.get_lr_scheduler()]
folder_name = self.get_name()
self_path = os.path.join(self.checkpoint_path, folder_name)
if self.checkpoint_path:
Expand Down
76 changes: 76 additions & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np
from keras import backend as K
from keras.engine import InputSpec
from keras.engine import Layer
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.callbacks import Callback


class Bias(Layer):
def __init__(self, bias_initializer='zeros',
bias_regularizer=None,
bias_constraint=None,
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(Bias, self).__init__(**kwargs)
self.bias_initializer = initializers.get(bias_initializer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True

def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]

self.bias = self.add_weight(shape=(input_dim,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)

self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True

def call(self, inputs):
return K.bias_add(inputs, self.bias)

def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) >= 2
assert input_shape[-1]
output_shape = list(input_shape)
return tuple(output_shape)

def get_config(self):
config = {
'bias_initializer': initializers.serialize(self.bias_initializer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))



def log_perplexity(y_true, y_pred):
cross_entropy = K.mean(K.categorical_crossentropy(y_true, y_pred), axis=-1)
return cross_entropy


def perplexity(y_true, y_pred):
# will be calculated by perplexity logger
return K.mean(K.zeros_like(y_pred), axis=-1)


class PerplexityLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
logs['perplexity'] = np.exp(logs['log_perplexity'])

def on_epoch_end(self, epoch, logs=None):
if logs is not None:
logs['perplexity'] = np.exp(logs['log_perplexity'])
logs['val_perplexity'] = np.exp(logs['val_log_perplexity'])
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train_augmented(network_size, dataset_kind, tying=False, epochs=40, stride=0
train_steps, train_generator = dp.make_batch_iter(dataset, sequence_size=sequence_size, stride=stride)
valid_steps, valid_generator = dp.make_batch_iter(dataset, kind="valid", sequence_size=sequence_size, stride=stride)

# make one hot model
# make augmented model
model = AugmentedModel(vocab_size, sequence_size, setting, tying=tying, checkpoint_path=LOG_ROOT)
model.compile()
model.fit_generator(train_generator, train_steps, valid_generator, valid_steps, epochs=epochs)
Expand Down