Skip to content

Commit

Permalink
Correct text typos and fix NTM's addressing bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinghai-sun committed May 15, 2017
1 parent cdbf36d commit aefd266
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 59 deletions.
4 changes: 2 additions & 2 deletions mt_with_external_memory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ TBD
## References

1. Alex Graves, Greg Wayne, Ivo Danihelka, [Neural Turing Machines](https://arxiv.org/abs/1410.5401). arXiv preprint arXiv:1410.5401, 2014.
2. Mingxuan Wang, Zhengdong Lu, Hang Li, Qun Liu[Memory-enhanced Decoder Neural Machine Translation](https://arxiv.org/abs/1606.02003). arXiv preprint arXiv:1606.02003, 2016.
3. Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio, [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473). arXiv preprint arXiv:1409.0473, 2014.
2. Mingxuan Wang, Zhengdong Lu, Hang Li, Qun Liu, [Memory-enhanced Decoder Neural Machine Translation](https://arxiv.org/abs/1606.02003). arXiv preprint arXiv:1606.02003, 2016.
3. Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio, [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473). arXiv preprint arXiv:1409.0473, 2014.
115 changes: 58 additions & 57 deletions mt_with_external_memory/mt_with_external_memory.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This python script is a example model configuration for neural machine
This python script is an example model configuration for neural machine
translation with external memory, based on PaddlePaddle V2 APIs.
The "external memory" refers to two types of memories.
Expand All @@ -21,7 +8,7 @@
Both types of external memories are exploited to enhance the vanilla
Seq2Seq neural machine translation.
The implementation largely followers the paper
The implementation primarily follows the paper
`Memory-enhanced Decoder for Neural Machine Translation
<https://arxiv.org/abs/1606.02003>`_,
with some minor differences (will be listed in README.md).
Expand All @@ -39,7 +26,7 @@
hidden_size = 1024
batch_size = 5
memory_slot_num = 8
beam_size = 40
beam_size = 3
infer_data_num = 3


Expand Down Expand Up @@ -67,24 +54,40 @@ class ExternalMemory(object):
:type name: basestring
:param mem_slot_size: Size of memory slot/vector.
:type mem_slot_size: int
:param boot_layer: Boot layer for initializing memory. Sequence layer
with sequence length indicating the number of memory
slots, and size as mem_slot_size.
:param boot_layer: Boot layer for initializing the external memory. The
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
:param enable_interpolation: If set true, the read/write addressing weights
will be interpolated with the weights in the
last step, with the affine coefficients being
a learnable gate function.
:type enable_interpolation: bool
"""

def __init__(self, name, mem_slot_size, boot_layer, readonly=False):
def __init__(self,
name,
mem_slot_size,
boot_layer,
readonly=False,
enable_interpolation=True):
self.name = name
self.mem_slot_size = mem_slot_size
self.readonly = readonly
self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory(
name=self.name,
size=self.mem_slot_size,
is_seq=True,
boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
# set memory to constant when readonly=True
if self.readonly:
self.updated_external_memory = paddle.layer.mixed(
Expand All @@ -111,18 +114,18 @@ def __content_addressing__(self, key_vector):
size=self.mem_slot_size,
act=paddle.activation.Linear(),
bias_attr=False)
merged = paddle.layer.addto(
merged_projection = paddle.layer.addto(
input=[key_proj_expanded, memory_projection],
act=paddle.activation.Tanh())
# softmax addressing weight: w=softmax(v^T a)
addressing_weight = paddle.layer.fc(
input=merged,
input=merged_projection,
size=1,
act=paddle.activation.SequenceSoftmax(),
bias_attr=False)
return addressing_weight

def __interpolation__(self, key_vector, addressing_weight):
def __interpolation__(self, head_name, key_vector, addressing_weight):
"""
Interpolate between previous and current addressing weights.
"""
Expand All @@ -134,34 +137,33 @@ def __interpolation__(self, key_vector, addressing_weight):
bias_attr=False)
# interpolation: w_t = g*w_t+(1-g)*w_{t-1}
last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight", size=1, is_seq=True)
gated_addressing_weight = paddle.layer.addto(
name=self.name + "_addressing_weight",
input=[
last_addressing_weight,
paddle.layer.scaling(weight=gate, input=addressing_weight),
paddle.layer.mixed(
input=paddle.layer.dotmul_operator(
a=gate, b=last_addressing_weight, scale=-1.0),
size=1)
],
act=paddle.activation.Tanh())
return gated_addressing_weight

def __get_addressing_weight__(self, key_vector):
name=self.name + "_addressing_weight_" + head_name,
size=1,
is_seq=True,
boot_layer=self.zero_addressing_init)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight

def __get_addressing_weight__(self, head_name, key_vector):
"""
Get final addressing weights for read/write heads, including content
addressing and interpolation.
"""
# current content-based addressing
addressing_weight = self.__content_addressing__(key_vector)
return addressing_weight
# interpolation with previous addresing weight
return self.__interpolation__(key_vector, addressing_weight)
if self.enable_interpolation:
return self.__interpolation__(head_name, key_vector,
addressing_weight)
else:
return addressing_weight

def write(self, write_key):
"""
Write head for external memory.
Write onto the external memory.
It cannot be called if "readonly" set True.
:param write_key: Key vector for write heads to generate writing
Expand All @@ -172,7 +174,7 @@ def write(self, write_key):
if self.readonly:
raise ValueError("ExternalMemory with readonly=True cannot write.")
# get addressing weight for write head
write_weight = self.__get_addressing_weight__(write_key)
write_weight = self.__get_addressing_weight__("write_head", write_key)
# prepare add_vector and erase_vector
erase_vector = paddle.layer.fc(
input=write_key,
Expand Down Expand Up @@ -205,7 +207,7 @@ def write(self, write_key):

def read(self, read_key):
"""
Read head for external memory.
Read from the external memory.
:param write_key: Key vector for read head to generate addressing
signals.
Expand All @@ -214,7 +216,7 @@ def read(self, read_key):
:rtype: LayerOutput
"""
# get addressing weight for write head
read_weight = self.__get_addressing_weight__(read_key)
read_weight = self.__get_addressing_weight__("read_head", read_key)
# read content from external memory
scaled = paddle.layer.scaling(
weight=read_weight, input=self.updated_external_memory)
Expand All @@ -227,19 +229,16 @@ def bidirectional_gru_encoder(input, size, word_vec_dim):
Bidirectional GRU encoder.
"""
# token embedding
embeddings = paddle.layer.embedding(
input=input,
size=word_vec_dim,
param_attr=paddle.attr.ParamAttr(name='_encoder_word_embedding'))
embeddings = paddle.layer.embedding(input=input, size=word_vec_dim)
# token-level forward and backard encoding for attentions
forward = paddle.networks.simple_gru(
input=embeddings, size=size, reverse=False)
backward = paddle.networks.simple_gru(
input=embeddings, size=size, reverse=True)
merged = paddle.layer.concat(input=[forward, backward])
forward_backward = paddle.layer.concat(input=[forward, backward])
# sequence-level encoding
backward_first = paddle.layer.first_seq(input=backward)
return merged, backward_first
return forward_backward, backward_first


def memory_enhanced_decoder(input, target, initial_state, source_context, size,
Expand All @@ -256,9 +255,9 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
The vanilla RNN/LSTM/GRU also has a narrow memory mechanism, namely the
hidden state vector (or cell state in LSTM) carrying information through
a span of sequence time, which is a successful design enriching the model
with capability to "remember" things in the long run. However, such a vector
state is somewhat limited to a very narrow memory bandwidth. External memory
introduced here could easily increase the memory capacity with linear
with the capability to "remember" things in the long run. However, such a
vector state is somewhat limited to a very narrow memory bandwidth. External
memory introduced here could easily increase the memory capacity with linear
complexity cost (rather than quadratic for vector state).
This enhanced decoder expands its "memory passage" through two
Expand All @@ -268,7 +267,7 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
- Unbounded memory for handling source language's token-wise information.
Exactly the attention mechanism over Seq2Seq.
Notice that we take the attention mechanism as a special form of external
Notice that we take the attention mechanism as a particular form of external
memory, with read-only memory bank initialized with encoder states, and a
read head with content-based addressing (attention). From this view point,
we arrive at a better understanding of attention mechanism itself and other
Expand Down Expand Up @@ -306,12 +305,14 @@ def recurrent_decoder_step(cur_embedding):
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
readonly=False)
readonly=False,
enable_interpolation=True)
unbounded_memory = ExternalMemory(
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
readonly=True)
readonly=True,
enable_interpolation=False)
# write bounded memory
bounded_memory.write(state)
# read bounded memory
Expand Down Expand Up @@ -566,7 +567,7 @@ def infer():


def main():
paddle.init(use_gpu=False, trainer_count=8)
paddle.init(use_gpu=False, trainer_count=1)
train(num_passes=1)
infer()

Expand Down

0 comments on commit aefd266

Please sign in to comment.