diff --git a/src/gluonnlp/model/attention_cell.py b/src/gluonnlp/model/attention_cell.py index f5d044b4ea..cec069da36 100644 --- a/src/gluonnlp/model/attention_cell.py +++ b/src/gluonnlp/model/attention_cell.py @@ -25,6 +25,38 @@ from mxnet.gluon import nn from .block import L2Normalization + +def _apply_mask(F, att_score, mask, dtype): + """Fill in the masked scores with a very small value + + Parameters + ---------- + F : symbol or ndarray + att_score : Symbol or NDArray + Shape (batch_size, query_length, memory_length) + mask : Symbol or NDArray or None + Shape (batch_size, query_length, memory_length) + Returns + ------- + att_score : Symbol or NDArray + Shape (batch_size, query_length, memory_length) + """ + # Fill in the masked scores with a very small value + neg = -1e18 + if np.dtype(dtype) == np.float16: + neg = -1e4 + else: + try: + # if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN. + from mxnet.contrib import amp + if amp.amp._amp_initialized: + neg = -1e4 + except ImportError: + pass + att_score = F.where(mask, att_score, neg * F.ones_like(att_score)) + return att_score + + # TODO(sxjscience) Add mask flag to softmax operator. Think about how to accelerate the kernel def _masked_softmax(F, att_score, mask, dtype): """Ignore the masked elements when calculating the softmax @@ -38,23 +70,12 @@ def _masked_softmax(F, att_score, mask, dtype): Shape (batch_size, query_length, memory_length) Returns ------- - att_weights : Symborl or NDArray + att_weights : Symbol or NDArray Shape (batch_size, query_length, memory_length) """ if mask is not None: # Fill in the masked scores with a very small value - neg = -1e18 - if np.dtype(dtype) == np.float16: - neg = -1e4 - else: - try: - # if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN. - from mxnet.contrib import amp - if amp.amp._amp_initialized: - neg = -1e4 - except ImportError: - pass - att_score = F.where(mask, att_score, neg * F.ones_like(att_score)) + att_score = _apply_mask(F, att_score, mask, dtype) att_weights = F.softmax(att_score, axis=-1) * mask else: att_weights = F.softmax(att_score, axis=-1) @@ -353,14 +374,23 @@ def hybrid_forward(self, F, x, g, v): # pylint: disable=arguments-differ weight_initializer=weight_initializer, prefix='score_') - def _compute_weight(self, F, query, key, mask=None): + def _compute_score(self, F, query, key, mask=None): mapped_query = self._query_mid_layer(query) mapped_key = self._key_mid_layer(key) mid_feat = F.broadcast_add(F.expand_dims(mapped_query, axis=2), F.expand_dims(mapped_key, axis=1)) mid_feat = self._act(mid_feat) att_score = self._attention_score(mid_feat).reshape(shape=(0, 0, 0)) - att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype)) + if mask is not None: + att_score = _apply_mask(F, att_score, mask, self._dtype) + return att_score + + def _compute_weight(self, F, query, key, mask=None): + att_score = self._compute_score(F, query, key, mask) + att_weights = F.softmax(att_score, axis=-1) + if mask is not None: + att_weights = att_weights * mask + att_weights = self._dropout_layer(att_weights) return att_weights @@ -449,7 +479,7 @@ def __init__(self, units=None, luong_style=False, scaled=True, normalized=False, with self.name_scope(): self._l2_norm = L2Normalization(axis=-1) - def _compute_weight(self, F, query, key, mask=None): + def _compute_score(self, F, query, key, mask=None): if self._units is not None: query = self._proj_query(query) if not self._luong_style: @@ -466,6 +496,14 @@ def _compute_weight(self, F, query, key, mask=None): query = F.contrib.div_sqrt_dim(query) att_score = F.batch_dot(query, key, transpose_b=True) + if mask is not None: + att_score = _apply_mask(F, att_score, mask, self._dtype) + return att_score - att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype)) + def _compute_weight(self, F, query, key, mask=None): + att_score = self._compute_score(F, query, key, mask) + att_weights = F.softmax(att_score, axis=-1) + if mask is not None: + att_weights = att_weights * mask + att_weights = self._dropout_layer(att_weights) return att_weights