Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"exported scatter to python" #9038

Merged
merged 3 commits into from
Mar 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions paddle/fluid/operators/scatter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ref"),
"Input(Ref) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
"Input(Index) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Updates"),
"Input(Updates) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScatterOp should not be null.");

auto updates_dims = ctx->GetInputDim("Updates");
auto ref_dims = ctx->GetInputDim("Ref");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1,
"Update Index should be 1-D.");
auto ref_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1,
"Update Ids should be 1-D.");
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
"Reference and Updates should have the same shape size");
"Xerence and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
ctx->GetInputDim("Index")[0],
"Updates and Index should have same batch-size.");
ctx->GetInputDim("Ids")[0],
"Updates and Ids should have same batch-size.");
framework::DDim data_dim(updates_dims);
for (int i = 1; i < data_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]);
Expand All @@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
Expand All @@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
}
};
Expand All @@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op");
AddInput("Index",
"The index input of scatter op where Ref will be updated");
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Expand All @@ -91,8 +90,8 @@ Scatter Operator.
This operator obtains output by updating the input on selected indices on the first axis:

$$
Out = Ref \\
Out[Index] = Ref[Index] + Updates
Out = X \\
Out[Ids] = X[Ids] + Updates
$$

)DOC");
Expand Down
20 changes: 10 additions & 10 deletions paddle/fluid/operators/scatter_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index");
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");

Out->ShareDataWith(*Ref);
Out->ShareDataWith(*X);

GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
}
};

Expand All @@ -42,16 +42,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));

// In place gradient: dRef = dO
dRef->ShareDataWith(*dOut);
// In place gradient: dX = dO
dX->ShareDataWith(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index]
GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
};

Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/operators/scatter_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index");
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");

// In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith(*Ref);
// In place output: Out = X, Out[Ids] += Updates
Out->ShareDataWith(*X);
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
}
};

Expand All @@ -47,16 +47,16 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));

// In place gradient: dRef = dO
dRef->ShareDataWith(*dOut);
// In place gradient: dX = dO
dX->ShareDataWith(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
// Gradient by Gather: dUpdates += dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
};

Expand Down
32 changes: 7 additions & 25 deletions python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,13 @@
]

__all__ = [
'mean',
'mul',
'reshape',
'scale',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'clip',
'clip_by_norm',
'softmax',
'sequence_softmax',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum',
'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits',
'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip',
'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or',
'logical_xor', 'logical_not', 'uniform_random',
'uniform_random_batch_size_like', 'gaussian_random',
'gaussian_random_batch_size_like', 'cumsum', 'scatter'
] + __activations__

for _OP in set(__all__):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_scatter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUp(self):
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}

def test_check_output(self):
Expand Down