diff --git a/neuralmonkey/decoders/decoder.py b/neuralmonkey/decoders/decoder.py index e8927cc4a..15a07a26e 100644 --- a/neuralmonkey/decoders/decoder.py +++ b/neuralmonkey/decoders/decoder.py @@ -245,12 +245,13 @@ def _get_rnn_cell(self): return OrthoGRUCell(self.rnn_size) - def _collect_attention_objects(self): + def _collect_attention_objects(self, runtime_mode): """Collect attention objects from encoders.""" if not self.use_attention: return [] - return [e.attention_object for e in self.encoders if e.attention_object] + return [e.get_attention_object(runtime_mode) + for e in self.encoders] def _embed_inputs(self, inputs): """Embed inputs using the decoder"s word embedding matrix @@ -323,7 +324,7 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False, scope: The variable scope to use with this function. """ cell = self._get_rnn_cell() - att_objects = self._collect_attention_objects() + att_objects = self._collect_attention_objects(runtime_mode) ## Broadcast the initial state to the whole batch if needed if len(initial_state.get_shape()) == 1: @@ -366,9 +367,8 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False, if runtime_mode: for i, a in enumerate(att_objects): - attentions = a.attentions_in_time[-len(inputs):] alignments = tf.expand_dims(tf.transpose( - tf.pack(attentions), perm=[1, 2, 0]), -1) + tf.pack(a.attentions_in_time), perm=[1, 2, 0]), -1) tf.image_summary("attention_{}".format(i), alignments, collections=["summary_val_plots"], diff --git a/neuralmonkey/decoding_function.py b/neuralmonkey/decoding_function.py index 0e7ec10a5..4fb4b7f10 100644 --- a/neuralmonkey/decoding_function.py +++ b/neuralmonkey/decoding_function.py @@ -14,8 +14,9 @@ class Attention(object): # pylint: disable=unused-argument,too-many-instance-attributes,too-many-arguments # For maintaining the same API as in CoverageAttention - def __init__(self, attention_states, scope, input_weights=None, - max_fertility=None): + def __init__(self, attention_states, scope, + input_weights=None, attention_fertility=None, + runtime_mode=False): """Create the attention object. Args: @@ -24,8 +25,10 @@ def __init__(self, attention_states, scope, input_weights=None, scope: The name of the variable scope in the graph used by this attention object. input_weights: (Optional) The padding weights on the input. - max_fertility: (Optional) For the Coverage attention compatibilty, - maximum fertility of one word. + attention_fertility: (Optional) For the Coverage attention + compatibilty, maximum fertility of one word. + runtime_mode: (Optional) Indicates whether the object will be used + for runtime decoding. """ self.scope = scope self.attentions_in_time = [] @@ -107,19 +110,20 @@ class CoverageAttention(Attention): # pylint: disable=too-many-arguments # Great objects require great number of parameters def __init__(self, attention_states, scope, - input_weights=None, max_fertility=5): + input_weights=None, attention_fertility=5): - super(CoverageAttention, self).__init__(attention_states, scope, - input_weights=input_weights, - max_fertility=max_fertility) + super(CoverageAttention, self).__init__( + attention_states, scope, + input_weights=input_weights, + attention_fertility=attention_fertility) self.coverage_weights = tf.get_variable("coverage_matrix", [1, 1, 1, self.attn_size]) self.fertility_weights = tf.get_variable("fertility_matrix", [1, 1, self.attn_size]) - self.max_fertility = max_fertility + self.attention_fertility = attention_fertility - self.fertility = 1e-8 + self.max_fertility * tf.sigmoid( + self.fertility = 1e-8 + self.attention_fertility * tf.sigmoid( tf.reduce_sum(self.fertility_weights * self.attention_states, [2])) def get_logits(self, y): diff --git a/neuralmonkey/encoders/attentive.py b/neuralmonkey/encoders/attentive.py new file mode 100644 index 000000000..4c8739e37 --- /dev/null +++ b/neuralmonkey/encoders/attentive.py @@ -0,0 +1,41 @@ +from abc import ABCMeta, abstractproperty +import tensorflow as tf + +# pylint: disable=too-few-public-methods +class Attentive(metaclass=ABCMeta): + """A base class fro an attentive part of graph (typically encoder). + + Objects inheriting this class are able to generate an attention object that + allows a decoder to perform attention over an attention_object provided by + the encoder (e.g., input word representations in case of MT or + convolutional maps in case of image captioning). + """ + def __init__(self, attention_type, **kwargs): + self._attention_type = attention_type + self._attention_kwargs = kwargs + + def get_attention_object(self, runtime: bool=False): + """Attention object that can be used in decoder.""" + # pylint: disable=no-member + if hasattr(self, "name") and self.name: + name = self.name + else: + name = str(self) + + return self._attention_type( + self._attention_tensor, + scope="attention_{}".format(name), + input_weights=self._attention_mask, + runtime_mode=runtime, + **self._attention_kwargs) if self._attention_type else None + + @abstractproperty + def _attention_tensor(self): + """Tensor over which the attention is done.""" + raise NotImplementedError( + "Attentive object is missing attention_tensor.") + + @property + def _attention_mask(self): + """Zero/one masking the attention logits.""" + return tf.ones(tf.shape(self._attention_tensor)) diff --git a/neuralmonkey/encoders/cnn_encoder.py b/neuralmonkey/encoders/cnn_encoder.py index 598a57be7..178c87063 100644 --- a/neuralmonkey/encoders/cnn_encoder.py +++ b/neuralmonkey/encoders/cnn_encoder.py @@ -5,12 +5,13 @@ import numpy as np import tensorflow as tf +from neuralmonkey.encoders.attentive import Attentive from neuralmonkey.decoding_function import Attention # tests: lint, mypy # pylint: disable=too-many-instance-attributes, too-few-public-methods -class CNNEncoder(object): +class CNNEncoder(Attentive): """ An image encoder. It projects the input image through a serie of @@ -45,7 +46,8 @@ def __init__(self, data_id, convolutions, rnn_layers, bidirectional=True, batch_normalization=True, local_response_normalization=True, - dropout_keep_prob=0.5): + dropout_keep_prob=0.5, + attention_type=Attention): """ Initilizes and configures the computational graph creator. @@ -85,6 +87,7 @@ def __init__(self, data_id, convolutions, rnn_layers, dropout keeping probability """ + super().__init__(attention_type) self.convolutions = convolutions self.data_id = data_id @@ -123,15 +126,16 @@ def __init__(self, data_id, convolutions, rnn_layers, self.image_processing_layers = [] with tf.variable_scope("convolutions"): - for i, (filter_size, n_filters, pool_size) \ - in enumerate(convolutions): + for i, (filter_size, + n_filters, + pool_size) in enumerate(convolutions): with tf.variable_scope("cnn_layer_{}".format(i)): conv_w = tf.get_variable( "wieghts", shape=[filter_size, filter_size, last_n_channels, n_filters], - initializer= \ - tf.truncated_normal_initializer(stddev=.1)) + initializer=tf.truncated_normal_initializer( + stddev=.1)) conv_b = tf.get_variable( "biases", shape=[n_filters], @@ -169,11 +173,9 @@ def __init__(self, data_id, convolutions, rnn_layers, last_layer_size = last_n_channels * image_height * image_width with tf.variable_scope("rnn_inputs"): - encoder_ins = [tf.reshape(x, - [-1, last_n_channels * image_height]) - for x in tf.split(2, image_width, - last_layer, - name='split_input')] + encoder_ins = [ + tf.reshape(x, [-1, last_n_channels * image_height]) for x in + tf.split(2, image_width, last_layer, name='split_input')] def rnn_encoder(inputs, last_layer_size, scope): with tf.variable_scope(scope): @@ -202,28 +204,30 @@ def rnn_encoder(inputs, last_layer_size, scope): encoder_state = rnn_encoder( encoder_ins, last_layer_size, "encoder-forward") + # pylint: disable=redefined-variable-type if bidirectional: backward_encoder_state = rnn_encoder( list(reversed(encoder_ins)), - last_layer_size, - "encoder-backward") - # pylint: disable=redefined-variable-type + last_layer_size, "encoder-backward") encoder_state = tf.concat( 1, [encoder_state, backward_encoder_state]) self.encoded = encoder_state - self.attention_tensor = \ - tf.reshape(last_layer, [-1, image_width, - last_n_channels * image_height]) + self.__attention_tensor = tf.reshape( + last_layer, [-1, image_width, + last_n_channels * image_height]) - att_in_weights = tf.squeeze( + self.__attention_mask = tf.squeeze( tf.reduce_prod(last_padding_masks, [1]), [2]) - self.attention_object = Attention(self.attention_tensor, - scope="attention_{}".format( - name), - input_weights=att_in_weights) + @property + def _attention_tensor(self): + return self.__attention_tensor + + @property + def _attention_mask(self): + return self.__attention_mask def feed_dict(self, dataset, train=False): # if it is from the pickled file, it is list, not numpy tensor, @@ -247,7 +251,9 @@ def feed_dict(self, dataset, train=False): # pylint: disable=too-many-locals def batch_norm(tensor, n_out, phase_train, scope='bn', scale_after_norm=True): """ - Batch normalization on convolutional maps. Taken from + Batch normalization on convolutional maps. + + Taken from http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow Arguments: @@ -282,8 +288,6 @@ def mean_var_with_update(): mean_var_with_update, lambda: (ema_mean, ema_var)) - normed = \ - tf.nn.batch_norm_with_global_normalization(tensor, mean, var, - beta, gamma, 1e-3, - scale_after_norm) + normed = tf.nn.batch_norm_with_global_normalization( + tensor, mean, var, beta, gamma, 1e-3, scale_after_norm) return normed diff --git a/neuralmonkey/encoders/factored_encoder.py b/neuralmonkey/encoders/factored_encoder.py index 802276077..131577ea2 100644 --- a/neuralmonkey/encoders/factored_encoder.py +++ b/neuralmonkey/encoders/factored_encoder.py @@ -1,6 +1,7 @@ import tensorflow as tf import numpy as np +from neuralmonkey.encoders.attentive import Attentive from neuralmonkey.logging import log from neuralmonkey.nn.bidirectional_rnn_layer import BidirectionalRNNLayer from neuralmonkey.nn.noisy_gru_cell import NoisyGRUCell @@ -10,8 +11,8 @@ # tests: lint, mypy -# pylint: disable=too-many-instance-attributes, too-few-public-methods -class FactoredEncoder(object): +# pylint: disable=too-many-instance-attributes +class FactoredEncoder(Attentive): """Implementation of a generic encoder that processes an arbitrary number of input sequences. """ @@ -43,6 +44,10 @@ def __init__(self, max_input_len, vocabularies, data_ids, embedding_sizes, name: The name for this encoder. [sentence_encoder] dropout_keep_prob: 1 - Dropout probability [1] """ + attention_type = kwargs.get("attention_type", None) + attention_fertility = kwargs.get("attention_fertility", 3) + super().__init__( + attention_type, attention_fertility=attention_fertility) for vocabulary in vocabularies: assert_type(self, 'vocabulary', vocabulary, Vocabulary) @@ -59,27 +64,22 @@ def __init__(self, max_input_len, vocabularies, data_ids, embedding_sizes, self.use_noisy_activations = kwargs.get("use_noisy_activations", False) self.use_pervasive_dropout = kwargs.get("use_pervasive_dropout", False) - attention_type = kwargs.get("attention_type", None) - attention_fertility = kwargs.get("attention_fertility", 3) log("Building encoder graph, name: '{}'.".format(self.name)) with tf.variable_scope(self.name): self._create_encoder_graph() - ## Attention mechanism - if attention_type is not None: - weight_tensor = tf.concat( - 1, [tf.expand_dims(w, 1) for w in self.padding_weights]) + # Attention mechanism - self.attention_object = attention_type( - self.attention_tensor, - scope="attention_{}".format(self.name), - dropout_placeholder=self.dropout_placeholder, - input_weights=weight_tensor, - max_fertility=attention_fertility) + log("Encoder graph constructed.") - log("Encoder graph constructed.") + @property + def _attention_mask(self): + return self.__attention_mask + @property + def _attention_tensor(self): + return self.__attention_tensor def _get_rnn_cell(self): """Return the RNN cell for the encoder""" @@ -93,7 +93,7 @@ def _get_rnn_cell(self): #pylint: disable=no-member, undefined-variable # TODO fix this shape = tf.concat(0, [tf.shape(self.inputs[0]), [rnn_size]]) - ## TODO shape needs recomputing + # TODO shape needs recomputing dropout_mask = tf.floor(tf.random_uniform(shape, 0.0, 1.0) + self.dropout_placeholder) @@ -103,7 +103,6 @@ def _get_rnn_cell(self): return cell - def _get_birnn_cells(self): """Return forward and backward RNN cells for the encoder""" forward = self._get_rnn_cell() @@ -111,7 +110,6 @@ def _get_birnn_cells(self): return forward, backward - # pylint: disable=too-many-locals def _create_encoder_graph(self): self.dropout_placeholder = tf.placeholder(tf.float32, name="dropout") @@ -128,8 +126,8 @@ def _create_encoder_graph(self): for data_id, vocabulary, embedding_size in zip( self.data_ids, self.vocabularies, self.embedding_sizes): - ## Create data placehoders. The tensors' length is max_input_len+2 - ## because we add explicit start and end symbols. + # Create data placehoders. The tensors' length is max_input_len+2 + # because we add explicit start and end symbols. prefix = "" if len(self.data_ids) > 1: prefix = "{}_".format(data_id) @@ -140,8 +138,8 @@ def _create_encoder_graph(self): inputs = [tf.placeholder(tf.int32, shape=[None], name=n) for n in names] - ## Create embeddings for this factor and embed the placeholders - ## NOTE the initialization + # Create embeddings for this factor and embed the placeholders + # NOTE the initialization embeddings = tf.get_variable( "word_embeddings", shape=[len(vocabulary), embedding_size], initializer=tf.random_normal_initializer(stddev=0.01)) @@ -153,18 +151,18 @@ def _create_encoder_graph(self): tf.nn.dropout(i, self.dropout_placeholder) for i in embedded_inputs] - ## Resulting shape is batch x embedding_size + # Resulting shape is batch x embedding_size factors.append(dropped_embedded_inputs) - ## Add inputs and weights to self to be able to feed them + # Add inputs and weights to self to be able to feed them self.factor_inputs[data_id] = inputs - ## Concatenate all embedded factors into one tensor - ## Resulting shape is batch x sum(embedding_size) + # Concatenate all embedded factors into one tensor + # Resulting shape is batch x sum(embedding_size) - ## factors is a 2D list of embeddings of dims [factor-type, time-step] - ## by doing zip(*factors), we get a list of (factor-type) embedding - ## tuples indexed by the time step + # factors is a 2D list of embeddings of dims [factor-type, time-step] + # by doing zip(*factors), we get a list of (factor-type) embedding + # tuples indexed by the time step concatenated_factors = [tf.concat(1, related_factors) for related_factors in zip(*factors)] forward_gru, backward_gru = self._get_birnn_cells() @@ -176,18 +174,20 @@ def _create_encoder_graph(self): self.outputs_bidi = bidi_layer.outputs_bidi self.encoded = bidi_layer.encoded - self.attention_tensor = tf.concat(1, [tf.expand_dims(o, 1) - for o in self.outputs_bidi]) - self.attention_tensor = tf.nn.dropout(self.attention_tensor, - self.dropout_placeholder) + self.__attention_tensor = tf.concat(1, [tf.expand_dims(o, 1) + for o in self.outputs_bidi]) + self.__attention_tensor = tf.nn.dropout(self.__attention_tensor, + self.dropout_placeholder) + self.__attention_mask = tf.concat( + 1, [tf.expand_dims(w, 1) for w in self.padding_weights]) # pylint: disable=too-many-locals def feed_dict(self, dataset, train=False): factors = {data_id: dataset.get_series(data_id) for data_id in self.data_ids} - ## this method should be responsible for checking if the factored - ## sentences are of the same length + # this method should be responsible for checking if the factored + # sentences are of the same length res = {} # we asume that all factors have equal word counts @@ -206,7 +206,7 @@ def feed_dict(self, dataset, train=False): train=train) for data_id, vocabulary in zip(self.data_ids, self.vocabularies)} - ## check input lengths + # check input lengths lengths = [] paddings = None diff --git a/neuralmonkey/encoders/image_encoder.py b/neuralmonkey/encoders/image_encoder.py index 20129fd5e..480aa11e2 100644 --- a/neuralmonkey/encoders/image_encoder.py +++ b/neuralmonkey/encoders/image_encoder.py @@ -1,5 +1,7 @@ import tensorflow as tf +from neuralmonkey.encoders.attentive import Attentive + # tests: lint, mypy # pylint: disable=too-few-public-methods @@ -23,19 +25,17 @@ def __init__(self, dimension, output_shape, data_id): self.encoded = tf.tanh(tf.matmul(self.flat, project_w) + project_b) - self.attention_tensor = None - self.attention_object = None - # pylint: disable=unused-argument def feed_dict(self, dataset, train=False): return {self.image_features: dataset.get_series(self.data_id)} -class PostCNNImageEncoder(object): +class PostCNNImageEncoder(Attentive): def __init__(self, input_shape, output_shape, data_id, name, dropout_keep_prob=1.0, attention_type=None): assert len(input_shape) == 3 + super().__init__(attention_type) self.input_shape = input_shape self.output_shape = output_shape @@ -63,17 +63,15 @@ def __init__(self, input_shape, output_shape, data_id, name, self.encoded = tf.tanh(tf.matmul(self.flat, project_w) + project_b) - self.attention_tensor = \ - tf.reshape(self.image_features, - [-1, input_shape[0] * input_shape[1], - input_shape[2]], - name="flatten_image") - - self.attention_object = \ - attention_type(self.attention_tensor, - scope="attention_img", - dropout_placeholder=self.dropout_placeholder) \ - if attention_type else None + self.__attention_tensor = tf.reshape( + self.image_features, + [-1, input_shape[0] * input_shape[1], + input_shape[2]], + name="flatten_image") + + @property + def _attention_tensor(self): + return self.__attention_tensor def feed_dict(self, dataset, train=False): res = {self.image_features: dataset.get_series(self.data_id)} diff --git a/neuralmonkey/encoders/sentence_encoder.py b/neuralmonkey/encoders/sentence_encoder.py index deeffebd2..959b6e8ee 100644 --- a/neuralmonkey/encoders/sentence_encoder.py +++ b/neuralmonkey/encoders/sentence_encoder.py @@ -4,6 +4,7 @@ import tensorflow as tf +from neuralmonkey.encoders.attentive import Attentive from neuralmonkey.logging import log from neuralmonkey.nn.noisy_gru_cell import NoisyGRUCell from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell @@ -18,7 +19,7 @@ # pylint: disable=too-many-instance-attributes -class SentenceEncoder(object): +class SentenceEncoder(Attentive): """A class that manages parts of the computation graph that are used for encoding of input sentences. It uses a bidirectional RNN. @@ -61,6 +62,9 @@ def __init__(self, attention_fertility: Fertility parameter used with CoverageAttention (default 3). """ + super().__init__( + attention_type, attention_fertility=attention_fertility) + self.vocabulary = vocabulary self.data_id = data_id self.name = name @@ -86,18 +90,21 @@ def __init__(self, fw_cell, bw_cell, embedded_inputs, self.sentence_lengths, dtype=tf.float32) - self.attention_tensor = tf.concat(2, outputs_bidi_tup) - self.attention_tensor = self._dropout(self.attention_tensor) + self.__attention_tensor = tf.concat(2, outputs_bidi_tup) + self.__attention_tensor = self._dropout(self.__attention_tensor) self.encoded = tf.concat(1, encoded_tup) - self.attention_object = attention_type( - self.attention_tensor, scope="attention_{}".format(name), - input_weights=self.padding, - max_fertility=attention_fertility) if attention_type else None log("Sentence encoder initialized") + @property + def _attention_tensor(self): + return self.__attention_tensor + + @property + def _attention_mask(self): + return self._input_mask @property def vocabulary_size(self): @@ -112,16 +119,16 @@ def _create_input_placeholders(self): self.inputs = tf.placeholder(tf.int32, shape=[None, self.max_input_len], name="encoder_input") - self.padding = tf.placeholder( + self._input_mask = tf.placeholder( tf.float32, shape=[None, self.max_input_len], name="encoder_padding") - self.sentence_lengths = tf.to_int32(tf.reduce_sum(self.padding, 1)) + self.sentence_lengths = tf.to_int32( + tf.reduce_sum(self._input_mask, 1)) def _create_embedding_matrix(self): - """Create variables and operations for embedding - the input words + """Create variables and operations for embedding the input words. If parent encoder is specified, we reuse its embedding matrix """ @@ -206,6 +213,6 @@ def feed_dict(self, dataset: Dataset, train: bool=False) -> FeedDict: # as sentences_to_tensor returns lists of shape (time, batch), # we need to transpose fd[self.inputs] = list(zip(*vectors)) - fd[self.padding] = list(zip(*paddings)) + fd[self._input_mask] = list(zip(*paddings)) return fd