Skip to content

Commit

Permalink
Standalone LSTM decoder language model (#934)
Browse files Browse the repository at this point in the history
Summary:
Currently, the LSTM models in Fairseq master can only be used in an encoder/decoder setting, for example, in `class LSTMModel(FairseqEncoderDecoderModel)`. This PR adds a standalone LSTM decoder language model.

Changes:
- adds support for `LSTMDecoder` in cases where an encoder is not present, for instance, where `encoder_output_units=0`.
- fixes bugs in `LSTMDecoder` that only become apparent when using it in a standalone fashion, for example, not handling `src_lengths` as an optional argument.
- adds `class LSTMLanguageModel(FairseqLanguageModel)` for training LSTM language models.
- tests for the `LSTMLanguageModel`. Changes to the `LSTMDecoder` are handled by existing test cases.
Pull Request resolved: fairinternal/fairseq-py#934

Reviewed By: myleott

Differential Revision: D18816310

Pulled By: joshim5

fbshipit-source-id: 4773695a7f5d36aa773da8a45db2e02f76c968a9
  • Loading branch information
joshim5 authored and facebook-github-bot committed Jan 24, 2020
1 parent 1da061f commit 9f4256e
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 16 deletions.
71 changes: 55 additions & 16 deletions fairseq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from fairseq.modules import AdaptiveSoftmax

DEFAULT_MAX_SOURCE_POSITIONS = 1e5
DEFAULT_MAX_TARGET_POSITIONS = 1e5

@register_model('lstm')
class LSTMModel(FairseqEncoderDecoderModel):
Expand Down Expand Up @@ -85,6 +87,9 @@ def build_model(cls, args, task):
if args.encoder_layers != args.decoder_layers:
raise ValueError('--encoder-layers must match --decoder-layers')

max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS)
max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS)

def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
Expand Down Expand Up @@ -149,6 +154,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
dropout_out=args.encoder_dropout_out,
bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed,
max_source_positions=max_source_positions
)
decoder = LSTMDecoder(
dictionary=task.target_dictionary,
Expand All @@ -166,6 +172,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
max_target_positions=max_target_positions
)
return cls(encoder, decoder)

Expand All @@ -176,13 +183,15 @@ def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad=True, pretrained_embed=None, padding_value=0.,
max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS
):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.bidirectional = bidirectional
self.hidden_size = hidden_size
self.max_source_positions = max_source_positions

num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
Expand Down Expand Up @@ -269,7 +278,7 @@ def reorder_encoder_out(self, encoder_out, new_order):

def max_positions(self):
"""Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number
return self.max_source_positions


class AttentionLayer(nn.Module):
Expand Down Expand Up @@ -312,13 +321,15 @@ def __init__(
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False, adaptive_softmax_cutoff=None,
max_target_positions=DEFAULT_MAX_TARGET_POSITIONS
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
self.max_target_positions = max_target_positions

self.adaptive_softmax = None
num_embeddings = len(dictionary)
Expand All @@ -329,14 +340,18 @@ def __init__(
self.embed_tokens = pretrained_embed

self.encoder_output_units = encoder_output_units
if encoder_output_units != hidden_size:
if encoder_output_units != hidden_size and encoder_output_units != 0:
self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size)
self.encoder_cell_proj = Linear(encoder_output_units, hidden_size)
else:
self.encoder_hidden_proj = self.encoder_cell_proj = None

# disable input feeding if there is no encoder
# input feeding is described in arxiv.org/abs/1508.04025
input_feed_size = 0 if encoder_output_units == 0 else hidden_size
self.layers = nn.ModuleList([
LSTMCell(
input_size=hidden_size + embed_dim if layer == 0 else hidden_size,
input_size=input_feed_size + embed_dim if layer == 0 else hidden_size,
hidden_size=hidden_size,
)
for layer in range(num_layers)
Expand All @@ -355,7 +370,7 @@ def __init__(
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)

def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
x, attn_scores = self.extract_features(
prev_output_tokens, encoder_out, incremental_state
)
Expand All @@ -367,16 +382,23 @@ def extract_features(
"""
Similar to *forward* but only return features.
"""
encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_out = encoder_out['encoder_out']
if encoder_out is not None:
encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_out = encoder_out['encoder_out']
else:
encoder_padding_mask = None
encoder_out = None

if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()

# get outputs from encoder
encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
srclen = encoder_outs.size(0)
if encoder_out is not None:
encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
srclen = encoder_outs.size(0)
else:
srclen = None

# embed tokens
x = self.embed_tokens(prev_output_tokens)
Expand All @@ -389,20 +411,33 @@ def extract_features(
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
elif encoder_out is not None:
# setup recurrent cells
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
if self.encoder_hidden_proj is not None:
prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
input_feed = x.new_zeros(bsz, self.hidden_size)

attn_scores = x.new_zeros(srclen, seqlen, bsz)
else:
# setup zero cells, since there is no encoder
num_layers = len(self.layers)
zero_state = x.new_zeros(bsz, self.hidden_size)
prev_hiddens = [zero_state for i in range(num_layers)]
prev_cells = [zero_state for i in range(num_layers)]
input_feed = None

assert srclen is not None or self.attention is None, \
"attention is not supported if there are no encoder outputs"
attn_scores = x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None
outs = []
for j in range(seqlen):
# input feeding: concatenate context vector from previous time step
input = torch.cat((x[j, :, :], input_feed), dim=1)
if input_feed is not None:
input = torch.cat((x[j, :, :], input_feed), dim=1)
else:
input = x[j]

for i, rnn in enumerate(self.layers):
# recurrent cell
Expand All @@ -423,7 +458,8 @@ def extract_features(
out = F.dropout(out, p=self.dropout_out, training=self.training)

# input feeding
input_feed = out
if input_feed is not None:
input_feed = out

# save final output
outs.append(out)
Expand All @@ -445,7 +481,7 @@ def extract_features(
x = F.dropout(x, p=self.dropout_out, training=self.training)

# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
if not self.training and self.need_attn:
if not self.training and self.need_attn and self.attention is not None:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
Expand All @@ -469,14 +505,17 @@ def reorder_incremental_state(self, incremental_state, new_order):
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
elif state is not None:
return state.index_select(0, new_order)
else:
return None

new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)

def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
return self.max_target_positions

def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
Expand Down
125 changes: 125 additions & 0 deletions fairseq/models/lstm_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import options, utils
from fairseq.models import (
FairseqLanguageModel, register_model, register_model_architecture
)
from fairseq.models.lstm import (
LSTMDecoder, Embedding
)

DEFAULT_MAX_TARGET_POSITIONS = 1e5

@register_model('lstm_lm')
class LSTMLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)

@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
help='decoder hidden size')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
help='decoder attention')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')

# Granular dropout settings (if not specified these default to --dropout)
parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')

@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""

# make sure all arguments are present in older models
base_architecture(args)

if getattr(args, 'max_target_positions', None) is not None:
max_target_positions = args.max_target_positions
else:
max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)

def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(embed_path)
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)

pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path,
task.target_dictionary,
args.decoder_embed_dim
)

if args.share_decoder_input_output_embed:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
raise ValueError('--share-decoder-input-output-embeddings requires a joint dictionary')

if args.decoder_embed_dim != args.decoder_out_embed_dim:
raise ValueError(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)

decoder = LSTMDecoder(
dictionary=task.dictionary,
embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention),
encoder_output_units=0,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
adaptive_softmax_cutoff=(
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
max_target_positions=max_target_positions
)

return cls(decoder)


@register_model_architecture('lstm_lm', 'lstm_lm')
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
args.decoder_attention = getattr(args, 'decoder_attention', '0')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
15 changes: 15 additions & 0 deletions tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,21 @@ def test_lightconv_lm(self):
'--tokens-per-sample', '500',
])

def test_lstm_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
data_dir, 'lstm_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)
generate_main(data_dir, [
'--task', 'language_modeling',
'--sample-break-mode', 'eos',
'--tokens-per-sample', '500',
])


class TestMaskedLanguageModel(unittest.TestCase):

Expand Down

0 comments on commit 9f4256e

Please sign in to comment.