Skip to content

Commit

Permalink
autoregressive refactor: compatibility with beam_search_decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd authored and jlibovicky committed Mar 18, 2019
1 parent dab4ccb commit 60164f9
Show file tree
Hide file tree
Showing 17 changed files with 71 additions and 42 deletions.
12 changes: 9 additions & 3 deletions neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class DecoderHistories(NamedTuple(
This should only record decoding history and the decoding should not be
dependent on these values.
Attributes defined here (and in the `other`) substructure should always
be time-major (e.g., shape(time, batch, ...)).
Attributes:
logits: A tensor of shape ``(time, batch, vocabulary)`` which contains
the unnormalized output scores of words in a vocabulary.
Expand Down Expand Up @@ -93,6 +96,9 @@ class DecoderFeedables(NamedTuple(
The decoder should be able to decode in each step only using this.
Attributes defined here (and in the `other`) substructure should always
be batch-major (e.g., shape(batch, ...)).
Attributes:
step: A scalar int tensor, stores the number of the current time step.
finished: A boolean tensor of shape ``(batch)``, which says whether
Expand Down Expand Up @@ -276,7 +282,7 @@ def train_logits(self) -> tf.Tensor:
@tensor
def train_output_states(self) -> tf.Tensor:
train_result = LoopState(*self.train_loop_result)
return train_result.histories.decoder_outputs
return train_result.histories.output_states

@tensor
def train_logprobs(self) -> tf.Tensor:
Expand Down Expand Up @@ -324,12 +330,12 @@ def runtime_logits(self) -> tf.Tensor:
@tensor
def runtime_output_states(self) -> tf.Tensor:
runtime_result = LoopState(*self.runtime_loop_result)
return runtime_result.histories.decoder_outputs
return runtime_result.histories.output_states

@tensor
def runtime_mask(self) -> tf.Tensor:
runtime_result = LoopState(*self.runtime_loop_result)
return runtime_result.histories.mask
return runtime_result.histories.output_mask

@tensor
def decoded(self) -> tf.Tensor:
Expand Down
15 changes: 11 additions & 4 deletions neuralmonkey/decoders/beam_search_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,14 @@ def get_initial_loop_state(self) -> BeamSearchLoopState:
# time, the placeholder replacement is done on the whole structures, as
# you can see below.

logits = dec_next_ls.histories.logits[-1, :, :]
search_state = SearchState(
logprob_sum=tf.tile(
tf.expand_dims([0.0] + [-INF] * (self.beam_size - 1), 0),
[self.batch_size, 1],
name="bs_logprob_sum"),
prev_logprobs=tf.reshape(
tf.nn.log_softmax(dec_next_ls.feedables.prev_logits),
tf.nn.log_softmax(logits),
[self.batch_size, self.beam_size, len(self.vocabulary)]),
lengths=tf.zeros(
[self.batch_size, self.beam_size], dtype=tf.int32,
Expand All @@ -296,13 +297,14 @@ def get_initial_loop_state(self) -> BeamSearchLoopState:
# We add the input_symbol to token_ids during search_results
# initialization for simpler beam_body implementation

input_symbols = dec_next_ls.histories.output_symbols[-1, :]
search_results = SearchResults(
scores=tf.zeros(
shape=[self.batch_size, self.beam_size],
dtype=tf.float32,
name="beam_scores"),
token_ids=tf.reshape(
feedables.input_symbol,
input_symbols,
[1, self.batch_size, self.beam_size],
name="beam_tokens"))

Expand Down Expand Up @@ -505,11 +507,15 @@ def body(*args: Any) -> BeamSearchLoopState:
dec_loop_state.feedables)

next_feedables = next_feedables._replace(
input_symbol=tf.reshape(next_word_ids, [-1]),
embedded_input=self.parent_decoder.embed_input_symbols(
tf.reshape(next_word_ids, [-1])),
finished=tf.reshape(next_finished, [-1]))

# histories have shape [len, batch, ...]
def gather_fn(x):
if len(x.shape.dims) < 2:
return x

return partial_transpose(
gather_flat(
partial_transpose(x, [1, 0]),
Expand All @@ -528,10 +534,11 @@ def gather_fn(x):
# CALL THE DECODER BODY FUNCTION
next_loop_state = decoder_body(*dec_loop_state)

logits = next_loop_state.histories.logits[-1, :, :]
next_search_state = SearchState(
logprob_sum=next_beam_logprob_sum,
prev_logprobs=tf.reshape(
tf.nn.log_softmax(next_loop_state.feedables.prev_logits),
tf.nn.log_softmax(logits),
[self.batch_size, self.beam_size, len(self.vocabulary)]),
lengths=next_beam_lengths,
finished=next_finished)
Expand Down
6 changes: 6 additions & 0 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
cell_output, self.dropout_keep_prob, self.train_mode)

with tf.name_scope("rnn_output_projection"):
if self.embedding_size != self.output_dimension:
raise ValueError(
"The dimension ({}) of the output projection must be "
"same as the dimension of the input embedding "
"({})".format(self.output_dimension,
self.embedding_size))
# pylint: disable=not-callable
output = self.output_projection(
cell_output, loop_state.feedables.embedded_input,
Expand Down
32 changes: 19 additions & 13 deletions neuralmonkey/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ class TransformerFeedables(NamedTuple(
("input_mask", tf.Tensor)])):
"""Additional feedables used only by the Transformer-based decoder.
Follows the shape pattern of having batch_sized first dimension
shape(batch_size, ...)
Attributes:
input_sequence: The whole input sequence (embedded) that is fed into
the decoder in each decoding step.
input_mask: Mask for masking finished sequences.
shape(batch, len, emb)
input_mask: Mask for masking finished sequences. The last dimension
is required for compatibility with the beam_search_decoder.
shape(batch, len, 1)
"""


Expand Down Expand Up @@ -392,14 +398,14 @@ def train_loop_result(self) -> LoopState:
decoder_ls = AutoregressiveDecoder.get_initial_loop_state(self)

input_sequence = self.embed_input_symbols(self.train_input_symbols)
input_mask = tf.transpose(self.train_mask)

last_layer = self.layer(
self.depth, input_sequence, tf.transpose(self.train_mask))
self.depth, input_sequence, input_mask)

# We transpose input sequence and mask only to convey to
# the defined shapes
tr_feedables = TransformerFeedables(
input_sequence=tf.transpose(input_sequence),
input_mask=self.train_mask)
input_sequence=input_sequence,
input_mask=tf.expand_dims(input_mask, -1))

# t_states shape: (batch, time, channels)
# dec_w shape: (channels, vocab)
Expand Down Expand Up @@ -453,11 +459,11 @@ def get_initial_loop_state(self) -> LoopState:

tr_feedables = TransformerFeedables(
input_sequence=tf.zeros(
shape=[0, self.batch_size, self.dimension],
shape=[self.batch_size, 0, self.dimension],
dtype=tf.float32,
name="input_sequence"),
input_mask=tf.zeros(
shape=[0, self.batch_size],
shape=[self.batch_size, 0, 1],
dtype=tf.float32,
name="input_mask"))

Expand Down Expand Up @@ -486,16 +492,16 @@ def next_state(self, loop_state: LoopState) -> Tuple[tf.Tensor, Any, Any]:
with tf.variable_scope(self._variable_scope, reuse=tf.AUTO_REUSE):
# shape (time, batch)
input_sequence = append_tensor(
tr_feedables.input_sequence, feedables.embedded_input)
tr_feedables.input_sequence, feedables.embedded_input, 1)

unfinished_mask = tf.to_float(tf.logical_not(feedables.finished))
input_mask = append_tensor(
tr_feedables.input_mask, unfinished_mask)
tr_feedables.input_mask,
tf.expand_dims(unfinished_mask, -1),
axis=1)

last_layer = self.layer(
self.depth,
tf.transpose(input_sequence, [1, 0, 2]),
tf.transpose(input_mask))
self.depth, input_sequence, tf.squeeze(input_mask, -1))

# (batch, state_size)
output_state = last_layer.temporal_states[:, -1, :]
Expand Down
7 changes: 5 additions & 2 deletions neuralmonkey/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ def layer_norm(x: tf.Tensor, epsilon: float = 1e-6) -> tf.Tensor:
return norm_x * gamma + beta


def append_tensor(tensor: tf.Tensor, appendval: tf.Tensor) -> tf.Tensor:
def append_tensor(tensor: tf.Tensor,
appendval: tf.Tensor,
axis: int = 0) -> tf.Tensor:
"""Append an ``N``-D Tensor to an ``(N+1)``-D Tensor.
Arguments:
tensor: The original Tensor
appendval: The Tensor to add
axis: Which axis should we use
Returns:
An ``(N+1)``-D Tensor with ``appendval`` on the last position.
"""
return tf.concat([tensor, tf.expand_dims(appendval, 0)], 0)
return tf.concat([tensor, tf.expand_dims(appendval, axis)], axis)
2 changes: 1 addition & 1 deletion neuralmonkey/trainers/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _score_with_reward_function(references: np.array,
sample_loop_result = self.decoder.decoding_loop(
train_mode=False, sample=True, temperature=self.temperature)
sample_logits = sample_loop_result.histories.logits
sample_decoded = sample_loop_result.histories.outputs
sample_decoded = sample_loop_result.histories.output_symbols

# rewards, shape (batch)
# simulate from reference
Expand Down
2 changes: 1 addition & 1 deletion tests/bahdanau.ini
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ supress_unk=True

[dec_maxout_output]
class=decoders.output_projection.maxout_output
maxout_size=7
maxout_size=9

[trainer1]
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
Expand Down
2 changes: 1 addition & 1 deletion tests/bpe.ini
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ encoder=<encoder>
class=decoders.decoder.Decoder
name="decoder"
encoders=[<encoder>]
embedding_size=9
embedding_size=10
attentions=[<attention>]
dropout_keep_prob=0.5
data_id="target_bpe"
Expand Down
2 changes: 1 addition & 1 deletion tests/factored.ini
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ name="decoder"
encoders=[<encoder>]
attentions=[<attention>]
rnn_size=32
embedding_size=20
embedding_size=32
dropout_keep_prob=0.5
data_id="target"
max_output_len=10
Expand Down
8 changes: 4 additions & 4 deletions tests/flat-multiattention.ini
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class=decoders.decoder.Decoder
name="decoder_flat_noshare_nosentinel"
attentions=[<flat_noshare_nosentinel>]
encoders=[<encoder>, <imagenet>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -92,7 +92,7 @@ class=decoders.decoder.Decoder
name="decoder_flat_share_nosentinel"
attentions=[<flat_share_nosentinel>]
encoders=[<encoder>, <imagenet>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -112,7 +112,7 @@ class=decoders.decoder.Decoder
name="decoder_flat_share_sentinel"
attentions=[<flat_share_sentinel>]
encoders=[<encoder>, <imagenet>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -132,7 +132,7 @@ class=decoders.decoder.Decoder
name="decoder_flat_noshare_sentinel"
attentions=[<flat_noshare_sentinel>]
encoders=[<encoder>, <imagenet>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand Down
9 changes: 5 additions & 4 deletions tests/hier-multiattention.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ tf_manager=<tf_manager>
output="tests/outputs/hier-multiattention"
overwrite_output_dir=True
epochs=1
batch_size=1
train_dataset=<train_data>
val_dataset=<val_data>
trainer=<trainer>
Expand Down Expand Up @@ -99,7 +100,7 @@ class=decoders.decoder.Decoder
name="decoder_hier_noshare_nosentinel"
encoders=[<encoder>, <imagenet>]
attentions=[<hier_noshare_nosentinel>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -119,7 +120,7 @@ class=decoders.decoder.Decoder
name="decoder_hier_share_nosentinel"
encoders=[<encoder>, <imagenet>]
attentions=[<hier_share_nosentinel>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -139,7 +140,7 @@ class=decoders.decoder.Decoder
name="decoder_hier_share_sentinel"
encoders=[<encoder>, <imagenet>]
attentions=[<hier_share_sentinel>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand All @@ -159,7 +160,7 @@ class=decoders.decoder.Decoder
name="decoder_hier_noshare_sentinel"
encoders=[<encoder>, <imagenet>]
attentions=[<hier_noshare_sentinel>]
rnn_size=2
rnn_size=3
embedding_size=3
dropout_keep_prob=0.5
data_id="target"
Expand Down
2 changes: 1 addition & 1 deletion tests/language-model.ini
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ w2v=<word2vec>
class=decoders.decoder.Decoder
name="decoder"
encoders=[]
rnn_size=8
rnn_size=5
embedding_size=5
dropout_keep_prob=0.5
data_id="target"
Expand Down
6 changes: 3 additions & 3 deletions tests/post-edit.ini
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ path="tests/data/postedit_target_vocab.tsv"
class=encoders.SentenceEncoder
rnn_size=15
max_input_len=5
embedding_size=10
embedding_size=30
dropout_keep_prob=0.8
data_id="source"
name="src_encoder"
Expand Down Expand Up @@ -60,7 +60,7 @@ values_encoder=<trans_encoder>
class=model.sequence.EmbeddedSequence
name="trans_encoder_input_sequence"
data_id="translated"
embedding_size=10
embedding_size=30
max_length=5
vocabulary=<target_vocabulary>

Expand Down Expand Up @@ -113,4 +113,4 @@ validation_period=2
logging_period=1
visualize_embeddings=[<trans_embedded_input>]
postprocess=[("target", <postprocess>)]
overwrite_output_dir=True
overwrite_output_dir=True
2 changes: 1 addition & 1 deletion tests/rl.ini
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ path="tests/data/decoder_vocab.tsv"
class=decoders.decoder.Decoder
name="decoder"
encoders=[<encoder>]
rnn_size=8
rnn_size=9
embedding_size=9
attentions=[<attention>]
dropout_keep_prob=0.5
Expand Down
2 changes: 1 addition & 1 deletion tests/self-critical.ini
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ path="tests/data/decoder_vocab.tsv"
class=decoders.decoder.Decoder
name="decoder"
encoders=[<encoder>]
rnn_size=8
rnn_size=9
embedding_size=9
attentions=[<attention>]
dropout_keep_prob=0.5
Expand Down
2 changes: 1 addition & 1 deletion tests/small_sent_cnn.ini
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ path="tests/data/str/vocab.tsv"
class=decoders.Decoder
name="decoder"
encoders=[<encoder>]
rnn_size=8
rnn_size=9
embedding_size=9
attentions=[<attention>]
dropout_keep_prob=0.5
Expand Down
2 changes: 1 addition & 1 deletion tests/str.ini
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class=decoders.decoder.Decoder
name="decoder"
encoders=[<encoder>]
attentions=[<attention>]
rnn_size=8
rnn_size=9
embedding_size=9
dropout_keep_prob=0.5
data_id="target_chars"
Expand Down

0 comments on commit 60164f9

Please sign in to comment.