-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add backward implementation for LSTM operator. #5115
Changes from 8 commits
3d8b6eb
d2bd735
cd38286
2e02987
ac3370a
bcc0dad
bd680f1
b50c33f
1d7c03e
6f658bb
7061e01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,17 +21,20 @@ class LSTMOp : public framework::OperatorWithKernel { | |
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Input"), | ||
"Input(Input) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), | ||
"Output(Hidden) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Cell"), | ||
"Output(Cell) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), | ||
"Output(BatchGate) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), | ||
"Output(BatchGate) of LSTM should not be null."); | ||
|
||
auto x_dims = ctx->GetInputDim("Input"); | ||
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); | ||
auto in_dims = ctx->GetInputDim("Input"); | ||
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); | ||
|
||
if (ctx->HasInput("H0")) { | ||
PADDLE_ENFORCE(ctx->HasInput("C0"), | ||
|
@@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel { | |
"should be the same."); | ||
} | ||
|
||
int frame_size = x_dims[1] / 4; | ||
int frame_size = in_dims[1] / 4; | ||
auto w_dims = ctx->GetInputDim("Weight"); | ||
PADDLE_ENFORCE_EQ(w_dims.size(), 2, | ||
"The rank of Input(Weight) should be 2."); | ||
|
@@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel { | |
"4 * %d if disable peepholes connection", | ||
frame_size); | ||
} | ||
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); | ||
ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); | ||
ctx->SetOutputDim("BatchGate", x_dims); | ||
framework::DDim out_dims({in_dims[0], frame_size}); | ||
ctx->SetOutputDim("Hidden", out_dims); | ||
ctx->SetOutputDim("Cell", out_dims); | ||
ctx->SetOutputDim("BatchGate", in_dims); | ||
ctx->SetOutputDim("BatchCellPreAct", out_dims); | ||
ctx->ShareLoD("Input", "Hidden"); | ||
ctx->ShareLoD("Input", "Cell"); | ||
} | ||
|
||
protected: | ||
framework::DataType IndicateDataType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::ToDataType( | ||
ctx.Input<framework::LoDTensor>("Input")->type()); | ||
} | ||
}; | ||
|
||
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { | ||
|
@@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { | |
AddInput("Input", | ||
"(LoDTensor) the first input is a LodTensor, which support " | ||
"variable-time length input sequence. The underlying tensor in " | ||
"this LoDTensor is a matrix with shape (T X 4D), where, T is the " | ||
"this LoDTensor is a matrix with shape (T X 4D), where T is the " | ||
"total time steps in this mini-batch, D is the hidden size."); | ||
AddInput("H0", | ||
"(Tensor, optional) the initial hidden state is an optional " | ||
"input. This is a tensor with shape (N x D), where N is the " | ||
"batch size, D is the hidden size."); | ||
"batch size, D is the hidden size.") | ||
.AsDispensable(); | ||
AddInput("C0", | ||
"(Tensor, optional) the initial cell state is an optional " | ||
"input. This is a tensor with shape (N x D), where N is the " | ||
"batch size. `H0` and `C0` can be NULL but only at the same time"); | ||
"batch size. `H0` and `C0` can be NULL but only at the same time") | ||
.AsDispensable(); | ||
AddInput("Weight", | ||
"(Tensor) the learnable hidden-hidden weights." | ||
" - The shape is (D x 4D), where D is the hidden size. " | ||
|
@@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { | |
" - Bias = {b_c, b_i, b_f, b_o}." | ||
"2. `usePeepholes = True` " | ||
" - The shape is (1 x 7D). " | ||
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); | ||
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") | ||
.AsDispensable(); | ||
AddOutput("Hidden", | ||
"(LoDTensor) the hidden state lod tensor of LSTM operator. " | ||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("Cell", | ||
"(LoDTensor) the cell state lod tensor of LSTM operator. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("BatchGate", | ||
"(LoDTensor) This LoDTensor contains input gate, forget gate " | ||
"and output gate after the nonlinear computation. This " | ||
"LoDTensor has the same shape with the reorganized input, which " | ||
"was also be called batch input. The LoD size is 2. The first " | ||
"is also be called batch input. The LoD size is 2. The first " | ||
"LoD is the batch offsets and the second LoD contains the " | ||
"indexes, which denote the position of reorganized sequence " | ||
"in the raw input.") | ||
.AsIntermediate(); | ||
AddOutput("Hidden", | ||
"(LoDTensor) the hidden state lod tensor of LSTM operator. " | ||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("Cell", | ||
"(LoDTensor) the cell state lod tensor of LSTM operator. " | ||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("BatchCellPreAct", | ||
"(LoDTensor) This LoDTensor is get in the forward and used " | ||
"in the backward.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get -> got There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
.AsIntermediate(); | ||
AddAttr<bool>("usePeepholes", | ||
"(bool, defalut: True) " | ||
"whether to enable diagonal/peephole connections.") | ||
|
@@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel { | |
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), | ||
"Input(Hidden@GRAD) should not be null"); | ||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), | ||
"Input(Cell@GRAD) should not be null"); | ||
ctx->SetOutputDim(framework::GradVarName("Weight"), | ||
ctx->GetInputDim("Weight")); | ||
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); | ||
PADDLE_ENFORCE(ctx->HasInput("Input"), | ||
"Input(Input) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("Hidden"), | ||
"Input(Hidden) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("Cell"), | ||
"Input(Cell) of LSTM should not be null."); | ||
|
||
PADDLE_ENFORCE(ctx->HasInput("BatchGate"), | ||
"Input(BatchGate) of LSTM should not be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), | ||
"Input(BatchGate) of LSTM should not be null."); | ||
|
||
auto in_g_name = framework::GradVarName("Input"); | ||
if (ctx->HasOutput(in_g_name)) | ||
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); | ||
|
||
auto w_g_name = framework::GradVarName("Weight"); | ||
if (ctx->HasOutput(w_g_name)) | ||
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); | ||
|
||
auto b_g_name = framework::GradVarName("Bias"); | ||
if (ctx->HasOutput(b_g_name)) | ||
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); | ||
} | ||
|
||
protected: | ||
framework::DataType IndicateDataType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::ToDataType( | ||
ctx.Input<framework::LoDTensor>("Input")->type()); | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the hidden state of LSTM operator,中间的lod tensor多余?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Remove lod tensor and fix the shape info.