-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add inference program for Transformer. #727
Add inference program for Transformer. #727
Conversation
38a8d43
to
60dba9f
Compare
60dba9f
to
ff80721
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this work.
@@ -15,6 +15,23 @@ class TrainTaskConfig(object): | |||
# the params for learning rate scheduling | |||
warmup_steps = 4000 | |||
|
|||
# the directory for saving inference models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for saving inference models --> for saving trained models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
class InferTaskConfig(object): | ||
use_gpu = False | ||
# number of sequences contained in a mini-batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- number of sequences contained in a mini-batch --> the number of examples in one run for sequence generation.
- Please add a comment here to warn users currently the batch size can only be set to 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
class InferTaskConfig(object): | ||
use_gpu = False | ||
# number of sequences contained in a mini-batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number of sequences contained in a mini-batch --> the number of examples in one run for sequence generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# the params for beam search | ||
beam_size = 5 | ||
max_length = 30 | ||
n_best = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment n_best
. It is confusing to me about what is the difference between beam_size
and n_best
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -15,6 +15,23 @@ class TrainTaskConfig(object): | |||
# the params for learning rate scheduling | |||
warmup_steps = 4000 | |||
|
|||
# the directory for saving inference models | |||
model_dir = "transformer_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the name to "trained_models"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
max_length, | ||
slf_attn_bias_flag, | ||
src_attn_bias_flag, | ||
pos_flag=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change "pos_flag" into "is_pos=True"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dtype="float32", | ||
append_batch_size=False) | ||
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, | ||
batch_size, max_length, 1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0 --> False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, | ||
batch_size, max_length, 1, 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the last two 1 --> True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# Padding index do not contribute to the total loss. The weights is used to | ||
# cancel padding index in calculating the loss. | ||
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, | ||
max_length, 0, 0, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make the last three parameters of make_inputs
boolean parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
name=input_data_names[2] | ||
if slf_attn_bias_flag == 1 else input_data_names[-1], | ||
shape=[batch_size, n_head, max_length, max_length] | ||
if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make src_attn_bias_flag
a boolean parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Add inference program for Transformer.