Skip to content

Commit

Permalink
Merge pull request #7219 from reyoung/feature/correctly_handle_lod_in…
Browse files Browse the repository at this point in the history
…formation_for_image_operators

Correctly handle lod information of image operators
  • Loading branch information
reyoung authored Jan 5, 2018
2 parents f3c42f6 + 040dc59 commit a8b3996
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion paddle/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"Input X must have 2 to 5 dimensions.");

const int C =
const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);

Expand All @@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
};

Expand Down
7 changes: 3 additions & 4 deletions paddle/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension should be the same.");

int input_channels = in_dims[1];
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");

int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ(
output_channels % groups, 0,
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");

std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
Expand All @@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
dilations[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
}

Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
}

void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
Expand Down

0 comments on commit a8b3996

Please sign in to comment.