Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Hybridize convolutional encoder (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu authored and szha committed Jul 10, 2018
1 parent fa9dfc1 commit ca071a6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 27 deletions.
34 changes: 15 additions & 19 deletions gluonnlp/model/convolutional_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

__all__ = ['ConvolutionalEncoder']

from mxnet import gluon, nd
from mxnet import gluon
from mxnet.gluon import nn
from gluonnlp.initializer import HighwayBias

from .highway import Highway


class ConvolutionalEncoder(gluon.Block):
class ConvolutionalEncoder(gluon.HybridBlock):
r"""Convolutional encoder.
We implement the convolutional encoder proposed in the following work::
Expand Down Expand Up @@ -104,18 +104,20 @@ def __init__(self,
self._output_size = output_size

with self.name_scope():
self._convs = nn.HybridSequential()
self._convs = gluon.contrib.nn.HybridConcurrent()
maxpool_output_size = 0
with self._convs.name_scope():
for num_filter, ngram_size in zip(self._num_filters, self._ngram_filter_sizes):
self._convs.add(nn.Conv1D(in_channels=self._embed_size,
channels=num_filter,
kernel_size=ngram_size,
use_bias=True))
seq = nn.HybridSequential()
seq.add(nn.Conv1D(in_channels=self._embed_size,
channels=num_filter,
kernel_size=ngram_size,
use_bias=True))
seq.add(gluon.nn.HybridLambda(lambda F, x: F.max(x, axis=2)))
seq.add(nn.Activation(conv_layer_activation))
self._convs.add(seq)
maxpool_output_size += num_filter

self._activation = nn.Activation(conv_layer_activation)

if self._num_highway:
self._highways = Highway(maxpool_output_size,
self._num_highway,
Expand All @@ -131,7 +133,7 @@ def __init__(self,
self._projection = None
self._output_size = maxpool_output_size

def forward(self, inputs, mask=None): # pylint: disable=arguments-differ
def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-differ
r"""
Forward computation for char_encoder
Expand All @@ -150,17 +152,11 @@ def forward(self, inputs, mask=None): # pylint: disable=arguments-differ
"""
if mask is not None:
inputs = inputs * mask.expand_dims(-1)

inputs = nd.transpose(inputs, axes=(1, 2, 0))
inputs = F.broadcast_mul(inputs, mask.expand_dims(-1))

filter_outputs = []
for conv in self._convs:
filter_outputs.append(
self._activation(conv(inputs).max(axis=2))
)
inputs = F.transpose(inputs, axes=(1, 2, 0))

output = nd.concat(*filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0]
output = self._convs(inputs)

if self._highways:
output = self._highways(output)
Expand Down
45 changes: 37 additions & 8 deletions tests/unittest/test_convolutional_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@
# under the License.

import mxnet as mx
import pytest
from numpy.testing import assert_almost_equal

from gluonnlp.model import ConvolutionalEncoder


def test_conv_encoder_nonhighway_forward():
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('mask', [True, False])
def test_conv_encoder_nonhighway_forward(hybridize, mask):
encoder = ConvolutionalEncoder(embed_size=2, num_filters=(1, 1), ngram_filter_sizes=(1, 2))
print(encoder)
encoder.initialize(init='One')
if hybridize:
encoder.hybridize()
inputs = mx.nd.array([[[.7, .8], [.1, 1.5], [.2, .3]], [[.5, .6], [.2, 2.5], [.4, 4]]])
output = encoder(inputs, None)
if mask:
output = encoder(inputs, mx.nd.ones(inputs.shape[:-1]))
else:
output = encoder(inputs)
assert output.shape == (3, 2), output.shape
assert_almost_equal(output.asnumpy(),
mx.nd.array([[1.37, 1.42],
Expand All @@ -37,36 +45,57 @@ def test_conv_encoder_nonhighway_forward():
decimal=2)


def test_conv_encoder_nohighway_forward_largeinputs():
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('mask', [True, False])
def test_conv_encoder_nohighway_forward_largeinputs(hybridize, mask):
encoder = ConvolutionalEncoder(embed_size=7,
num_filters=(1, 1, 2, 3),
ngram_filter_sizes=(1, 2, 3, 4),
output_size=30)
print(encoder)
encoder.initialize()
if hybridize:
encoder.hybridize()
inputs = mx.nd.random.uniform(shape=(4, 8, 7))
output = encoder(inputs, None)
if mask:
output = encoder(inputs, mx.nd.ones(inputs.shape[:-1]))
else:
output = encoder(inputs)
assert output.shape == (8, 30), output.shape


def test_conv_encoder_highway_forward():
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('mask', [True, False])
def test_conv_encoder_highway_forward(hybridize, mask):
encoder = ConvolutionalEncoder(embed_size=2,
num_filters=(2, 1),
ngram_filter_sizes=(1, 2),
num_highway=2,
output_size=1)
print(encoder)
encoder.initialize()
if hybridize:
encoder.hybridize()
inputs = mx.nd.array([[[.7, .8], [.1, 1.5], [.7, .8]], [[.7, .8], [.1, 1.5], [.7, .8]]])
output = encoder(inputs, None)
if mask:
output = encoder(inputs, mx.nd.ones(inputs.shape[:-1]))
else:
output = encoder(inputs)
print(output)
assert output.shape == (3, 1), output.shape


def test_conv_encoder_highway_default_forward():
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('mask', [True, False])
def test_conv_encoder_highway_default_forward(hybridize, mask):
encoder = ConvolutionalEncoder()
encoder.initialize(init='One')
if hybridize:
encoder.hybridize()
print(encoder)
inputs = mx.nd.random.uniform(shape=(10, 20, 15))
output = encoder(inputs, None)
if mask:
output = encoder(inputs, mx.nd.ones(inputs.shape[:-1]))
else:
output = encoder(inputs)
assert output.shape == (20, 525), output.shape

0 comments on commit ca071a6

Please sign in to comment.