Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate training and runtime attention #174

Merged
merged 12 commits into from
Dec 12, 2016
10 changes: 5 additions & 5 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,13 @@ def _get_rnn_cell(self):
return OrthoGRUCell(self.rnn_size)


def _collect_attention_objects(self):
def _collect_attention_objects(self, runtime_mode):
"""Collect attention objects from encoders."""
if not self.use_attention:
return []
return [e.attention_object for e in self.encoders if e.attention_object]

return [e.get_attention_object(runtime_mode)
for e in self.encoders]

def _embed_inputs(self, inputs):
"""Embed inputs using the decoder"s word embedding matrix
Expand Down Expand Up @@ -323,7 +324,7 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False,
scope: The variable scope to use with this function.
"""
cell = self._get_rnn_cell()
att_objects = self._collect_attention_objects()
att_objects = self._collect_attention_objects(runtime_mode)

## Broadcast the initial state to the whole batch if needed
if len(initial_state.get_shape()) == 1:
Expand Down Expand Up @@ -366,9 +367,8 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False,

if runtime_mode:
for i, a in enumerate(att_objects):
attentions = a.attentions_in_time[-len(inputs):]
alignments = tf.expand_dims(tf.transpose(
tf.pack(attentions), perm=[1, 2, 0]), -1)
tf.pack(a.attentions_in_time), perm=[1, 2, 0]), -1)

tf.image_summary("attention_{}".format(i), alignments,
collections=["summary_val_plots"],
Expand Down
24 changes: 14 additions & 10 deletions neuralmonkey/decoding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ class Attention(object):
# pylint: disable=unused-argument,too-many-instance-attributes,too-many-arguments
# For maintaining the same API as in CoverageAttention

def __init__(self, attention_states, scope, input_weights=None,
max_fertility=None):
def __init__(self, attention_states, scope,
input_weights=None, attention_fertility=None,
runtime_mode=False):
"""Create the attention object.

Args:
Expand All @@ -24,8 +25,10 @@ def __init__(self, attention_states, scope, input_weights=None,
scope: The name of the variable scope in the graph used by this
attention object.
input_weights: (Optional) The padding weights on the input.
max_fertility: (Optional) For the Coverage attention compatibilty,
maximum fertility of one word.
attention_fertility: (Optional) For the Coverage attention
compatibilty, maximum fertility of one word.
runtime_mode: (Optional) Indicates whether the object will be used
for runtime decoding.
"""
self.scope = scope
self.attentions_in_time = []
Expand Down Expand Up @@ -107,19 +110,20 @@ class CoverageAttention(Attention):
# pylint: disable=too-many-arguments
# Great objects require great number of parameters
def __init__(self, attention_states, scope,
input_weights=None, max_fertility=5):
input_weights=None, attention_fertility=5):

super(CoverageAttention, self).__init__(attention_states, scope,
input_weights=input_weights,
max_fertility=max_fertility)
super(CoverageAttention, self).__init__(
attention_states, scope,
input_weights=input_weights,
attention_fertility=attention_fertility)

self.coverage_weights = tf.get_variable("coverage_matrix",
[1, 1, 1, self.attn_size])
self.fertility_weights = tf.get_variable("fertility_matrix",
[1, 1, self.attn_size])
self.max_fertility = max_fertility
self.attention_fertility = attention_fertility

self.fertility = 1e-8 + self.max_fertility * tf.sigmoid(
self.fertility = 1e-8 + self.attention_fertility * tf.sigmoid(
tf.reduce_sum(self.fertility_weights * self.attention_states, [2]))

def get_logits(self, y):
Expand Down
41 changes: 41 additions & 0 deletions neuralmonkey/encoders/attentive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABCMeta, abstractproperty
import tensorflow as tf

# pylint: disable=too-few-public-methods
class Attentive(metaclass=ABCMeta):
"""A base class fro an attentive part of graph (typically encoder).

Objects inheriting this class are able to generate an attention object that
allows a decoder to perform attention over an attention_object provided by
the encoder (e.g., input word representations in case of MT or
convolutional maps in case of image captioning).
"""
def __init__(self, attention_type, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chtělo by to ještě doplnit docstringy, ale už nechci zdržovat merge.

self._attention_type = attention_type
self._attention_kwargs = kwargs

def get_attention_object(self, runtime: bool=False):
"""Attention object that can be used in decoder."""
# pylint: disable=no-member
if hasattr(self, "name") and self.name:
name = self.name
else:
name = str(self)

return self._attention_type(
self._attention_tensor,
scope="attention_{}".format(name),
input_weights=self._attention_mask,
runtime_mode=runtime,
**self._attention_kwargs) if self._attention_type else None

@abstractproperty
def _attention_tensor(self):
"""Tensor over which the attention is done."""
raise NotImplementedError(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proč nepoužít @abstractmethod?

"Attentive object is missing attention_tensor.")

@property
def _attention_mask(self):
"""Zero/one masking the attention logits."""
return tf.ones(tf.shape(self._attention_tensor))
58 changes: 31 additions & 27 deletions neuralmonkey/encoders/cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import numpy as np
import tensorflow as tf
from neuralmonkey.encoders.attentive import Attentive
from neuralmonkey.decoding_function import Attention

# tests: lint, mypy

# pylint: disable=too-many-instance-attributes, too-few-public-methods
class CNNEncoder(object):
class CNNEncoder(Attentive):
"""

An image encoder. It projects the input image through a serie of
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self, data_id, convolutions, rnn_layers,
bidirectional=True,
batch_normalization=True,
local_response_normalization=True,
dropout_keep_prob=0.5):
dropout_keep_prob=0.5,
attention_type=Attention):
"""
Initilizes and configures the computational graph creator.

Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, data_id, convolutions, rnn_layers,
dropout keeping probability

"""
super().__init__(attention_type)

self.convolutions = convolutions
self.data_id = data_id
Expand Down Expand Up @@ -123,15 +126,16 @@ def __init__(self, data_id, convolutions, rnn_layers,
self.image_processing_layers = []

with tf.variable_scope("convolutions"):
for i, (filter_size, n_filters, pool_size) \
in enumerate(convolutions):
for i, (filter_size,
n_filters,
pool_size) in enumerate(convolutions):
with tf.variable_scope("cnn_layer_{}".format(i)):
conv_w = tf.get_variable(
"wieghts",
shape=[filter_size, filter_size,
last_n_channels, n_filters],
initializer= \
tf.truncated_normal_initializer(stddev=.1))
initializer=tf.truncated_normal_initializer(
stddev=.1))
conv_b = tf.get_variable(
"biases",
shape=[n_filters],
Expand Down Expand Up @@ -169,11 +173,9 @@ def __init__(self, data_id, convolutions, rnn_layers,
last_layer_size = last_n_channels * image_height * image_width

with tf.variable_scope("rnn_inputs"):
encoder_ins = [tf.reshape(x,
[-1, last_n_channels * image_height])
for x in tf.split(2, image_width,
last_layer,
name='split_input')]
encoder_ins = [
tf.reshape(x, [-1, last_n_channels * image_height]) for x in
tf.split(2, image_width, last_layer, name='split_input')]

def rnn_encoder(inputs, last_layer_size, scope):
with tf.variable_scope(scope):
Expand Down Expand Up @@ -202,28 +204,30 @@ def rnn_encoder(inputs, last_layer_size, scope):
encoder_state = rnn_encoder(
encoder_ins, last_layer_size, "encoder-forward")

# pylint: disable=redefined-variable-type
if bidirectional:
backward_encoder_state = rnn_encoder(
list(reversed(encoder_ins)),
last_layer_size,
"encoder-backward")
# pylint: disable=redefined-variable-type
last_layer_size, "encoder-backward")
encoder_state = tf.concat(
1, [encoder_state, backward_encoder_state])

self.encoded = encoder_state

self.attention_tensor = \
tf.reshape(last_layer, [-1, image_width,
last_n_channels * image_height])
self.__attention_tensor = tf.reshape(
last_layer, [-1, image_width,
last_n_channels * image_height])

att_in_weights = tf.squeeze(
self.__attention_mask = tf.squeeze(
tf.reduce_prod(last_padding_masks, [1]), [2])

self.attention_object = Attention(self.attention_tensor,
scope="attention_{}".format(
name),
input_weights=att_in_weights)
@property
def _attention_tensor(self):
return self.__attention_tensor

@property
def _attention_mask(self):
return self.__attention_mask

def feed_dict(self, dataset, train=False):
# if it is from the pickled file, it is list, not numpy tensor,
Expand All @@ -247,7 +251,9 @@ def feed_dict(self, dataset, train=False):
# pylint: disable=too-many-locals
def batch_norm(tensor, n_out, phase_train, scope='bn', scale_after_norm=True):
"""
Batch normalization on convolutional maps. Taken from
Batch normalization on convolutional maps.

Taken from
http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow

Arguments:
Expand Down Expand Up @@ -282,8 +288,6 @@ def mean_var_with_update():
mean_var_with_update,
lambda: (ema_mean, ema_var))

normed = \
tf.nn.batch_norm_with_global_normalization(tensor, mean, var,
beta, gamma, 1e-3,
scale_after_norm)
normed = tf.nn.batch_norm_with_global_normalization(
tensor, mean, var, beta, gamma, 1e-3, scale_after_norm)
return normed
Loading