Skip to content

Commit

Permalink
Merge pull request #5115 from qingqing01/lstm_bp
Browse files Browse the repository at this point in the history
Add backward implementation for LSTM operator.
  • Loading branch information
qingqing01 authored Nov 1, 2017
2 parents 3d56786 + 7061e01 commit 36d2060
Show file tree
Hide file tree
Showing 13 changed files with 579 additions and 165 deletions.
8 changes: 4 additions & 4 deletions paddle/framework/lod_tensor_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ TEST(LoDTensor, LoDInGPU) {
lod_tensor.mutable_data<float>(place);

lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
EXPECT_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL);

auto lod = lod_tensor.lod();

test<<<1, 8>>>(lod[0].data(), lod[0].size());
cudaDeviceSynchronize();

for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
}
}
93 changes: 67 additions & 26 deletions paddle/operators/lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ class LSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTM should not be null.");

auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");

if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
Expand All @@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same.");
}

int frame_size = x_dims[1] / 4;
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
Expand All @@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
ctx->SetOutputDim("BatchGate", x_dims);
framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}

protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};

class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.");
"batch size, D is the hidden size.")
.AsDispensable();
AddInput("C0",
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time");
"batch size. `H0` and `C0` can be NULL but only at the same time")
.AsDispensable();
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
Expand All @@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
.AsDispensable();
AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is got in the forward and used "
"in the backward.")
.AsIntermediate();
AddAttr<bool>("usePeepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
Expand Down Expand Up @@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null.");

PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");

auto in_g_name = framework::GradVarName("Input");
if (ctx->HasOutput(in_g_name))
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));

auto w_g_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));

auto b_g_name = framework::GradVarName("Bias");
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
}

protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};

Expand Down
Loading

0 comments on commit 36d2060

Please sign in to comment.