Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TxT model to the latest gluon API #2802

Merged
merged 1 commit into from
Aug 31, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 55 additions & 46 deletions pyzoo/zoo/models/recommendation/txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
#

from mxnet import gluon
from gluonnlp.model.transformer import TransformerEncoder, TransformerEncoderCell
from mxnet.gluon.block import HybridBlock


class MeanMaxPooling(HybridBlock):
def __init__(self, axis=1, dropout=0.0, prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
class MeanMaxPooling(gluon.nn.HybridBlock):
def __init__(self, axis=1, dropout=0.0, prefix=None, params=None, **kwargs):
super(MeanMaxPooling, self).__init__(**kwargs)
# super().__init__(prefix=prefix, params=params)
self.axis = axis
self.dropout = dropout

Expand All @@ -30,43 +31,43 @@ def hybrid_forward(self, F, inputs):
outputs = F.concat(mean_out, max_out, dim=1)
if self.dropout:
outputs = F.Dropout(data=outputs, p=self.dropout)
outputs = F.LayerNorm(data=outputs)
# outputs = F.LayerNorm(outputs)
return outputs


class SequenceTransformer(HybridBlock):
class SequenceTransformer(gluon.nn.HybridBlock):
def __init__(self, num_items, item_embed, item_hidden_size, item_max_length, item_num_heads,
item_num_layers, item_transformer_dropout, item_pooling_dropout, cross_size,
prefix=None,
params=None):
super().__init__(prefix=prefix, params=params)
self.num_items = num_items
self.item_embed = item_embed
self.cross_size = cross_size
prefix=None, params=None, **kwargs):
super(SequenceTransformer, self).__init__(**kwargs)
# super().__init__(prefix=prefix, params=params)
with self.name_scope():
self.item_pooling_dp = MeanMaxPooling(dropout=item_pooling_dropout)
self.item_encoder = TransformerEncoder(units=item_embed, hidden_size=item_hidden_size,
self.item_encoder = TransformerEncoder(units=item_embed,
hidden_size=item_hidden_size,
num_heads=item_num_heads,
num_layers=item_num_layers,
max_length=item_max_length,
dropout=item_transformer_dropout)
self.embedding = gluon.nn.Embedding(input_dim=num_items, output_dim=item_embed)
self.dense = gluon.nn.Dense(cross_size)

def hybrid_forward(self, F, input_item, item_valid_length=None):
item_embed_out = F.Embedding(data=input_item, input_dim=self.num_items,
output_dim=self.item_embed)
item_encoding, item_att = self.item_encoder.hybrid_forward(F, inputs=item_embed_out,
valid_length=item_valid_length)
item_out = self.item_pooling_dp.hybrid_forward(F, inputs=item_encoding)
item_out = F.FullyConnected(data=item_out, num_hidden=self.cross_size)
item_embed_out = self.embedding(input_item)
item_encoding, item_att = self.item_encoder(
inputs=item_embed_out, valid_length=item_valid_length)
item_out = self.item_pooling_dp(item_encoding)
item_out = self.dense(item_out)

return item_out


class ContextTransformer(HybridBlock):
class ContextTransformer(gluon.nn.HybridBlock):
def __init__(self, context_dims, context_embed, context_hidden_size,
context_num_heads, context_transformer_dropout, context_pooling_dropout,
cross_size, prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
cross_size, prefix=None, params=None, **kwargs):
super(ContextTransformer, self).__init__(**kwargs)
# super().__init__(prefix=prefix, params=params)
self.context_dims = context_dims
self.context_embed = context_embed
self.cross_size = cross_size
Expand All @@ -77,36 +78,37 @@ def __init__(self, context_dims, context_embed, context_hidden_size,
num_heads=context_num_heads,
dropout=context_transformer_dropout
)
self.dense = gluon.nn.Dense(self.cross_size)
self.embeddings = gluon.nn.HybridSequential()
for i, context_dim in enumerate(self.context_dims):
self.embeddings.add(gluon.nn.Embedding(self.context_dims[i], self.context_embed))

def hybrid_forward(self, F, input_context_list):
context_embed = [F.Embedding(data=input_context_list[i], input_dim=self.context_dims[i],
output_dim=self.context_embed)
for i, context_dim in enumerate(self.context_dims)]
context_embed = [
self.embeddings[i](input_context) for i, input_context in enumerate(input_context_list)]
context_input = []
for i in context_embed:
context_input.append(F.expand_dims(i, axis=1))
context_embedding = F.concat(*context_input, dim=1)
context_encoding, context_att = self.context_encoder. \
hybrid_forward(F, inputs=context_embedding)
context_out = self.context_pooling_dp.hybrid_forward(F, inputs=context_encoding)
context_out = F.FullyConnected(data=context_out, num_hidden=self.cross_size)
context_encoding, context_att = self.context_encoder(context_embedding)
context_out = self.context_pooling_dp(context_encoding)
context_out = self.dense(context_out)

return context_out


class TxT(HybridBlock):
class TxT(gluon.nn.HybridBlock):
def __init__(self, num_items, context_dims, item_embed=100, context_embed=100,
item_hidden_size=256, item_max_length=8, item_num_heads=4, item_num_layers=2,
item_transformer_dropout=0.0, item_pooling_dropout=0.1, context_hidden_size=256,
context_num_heads=2, context_transformer_dropout=0.0,
context_pooling_dropout=0.0, act_type="relu", cross_size=100,
prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
self.num_items = num_items
context_num_heads=2, context_transformer_dropout=0.0, context_pooling_dropout=0.0,
act_type="gelu", cross_size=100, prefix=None, params=None, **kwargs):
super(TxT, self).__init__(**kwargs)
self.act_type = act_type
with self.name_scope():
self.sequence_transformer = SequenceTransformer(
num_items=num_items, item_embed=item_embed,
num_items=num_items,
item_embed=item_embed,
item_hidden_size=item_hidden_size,
item_max_length=item_max_length,
item_num_heads=item_num_heads,
Expand All @@ -126,17 +128,24 @@ def __init__(self, num_items, context_dims, item_embed=100, context_embed=100,
cross_size=cross_size,
prefix=prefix, params=params
)
self.dense1 = gluon.nn.Dense(units=num_items//2)
if act_type == "relu":
self.act = gluon.nn.Activation(activation="relu")
elif act_type == "gelu":
self.act = gluon.nn.GELU()
elif act_type == "leakyRelu":
self.act = gluon.nn.LeakyReLU(alpha=0.2)
else:
raise NotImplementedError
self.dense2 = gluon.nn.Dense(units=num_items, activation=None)

def hybrid_forward(self, F, input_item, item_valid_length, input_context_list):
item_outs = self.sequence_transformer(input_item, item_valid_length)
context_outs = self.context_transformer(input_context_list)

def hybrid_forward(self, F, input_item, input_context_list,
label, item_valid_length=None):
item_outs = self.sequence_transformer.hybrid_forward(F, input_item=input_item,
item_valid_length=item_valid_length)
context_outs = self.context_transformer.hybrid_forward(
F, input_context_list=input_context_list
)
outs = F.broadcast_mul(item_outs, context_outs)
outs = F.Activation(data=outs, act_type=self.act_type)
outs = F.FullyConnected(data=outs, num_hidden=int(self.num_items))
outs = F.SoftmaxOutput(data=outs, label=label)
outs = self.dense1(outs)
outs = self.act(outs)
outs = self.dense2(outs)

return outs