Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernels of IncrementOp #9428

Merged
merged 11 commits into from
Mar 30, 2018
9 changes: 8 additions & 1 deletion paddle/fluid/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Y", string::Sprintf(
"(LoDTensor) the right hand operand of %s operator",
comment.type));
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
Expand Down Expand Up @@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
bool force_cpu = ctx.Attr<bool>("force_cpu");
kt.place_ = force_cpu ? platform::CPUPlace()
: ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/operators/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase {
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
return ips[0]->data<bool>()[0];
bool res = false;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
framework::LoDTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else {
res = ips[0]->data<bool>()[0];
}
return res;
}
};

Expand Down
94 changes: 37 additions & 57 deletions paddle/fluid/operators/increment_op.cc
Original file line number Diff line number Diff line change
@@ -1,71 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/increment_op.h"

namespace paddle {
namespace operators {

class IncrementInferShape : public framework::InferShapeBase {
class IncrementOp : public framework::OperatorWithKernel {
public:
void operator()(framework::InferShapeContext *ctx) const override {
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IncrementOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of IncrementOp should not be null.");
PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out");
}
};

struct IncrementFunctor {
IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out,
float value)
: x_(x), out_(out), value_(value) {}

template <typename T>
void operator()() const {
*out_->data<T>() = *x_.data<T>() + static_cast<T>(value_);
}

const framework::LoDTensor &x_;
framework::LoDTensor *out_;
float value_;
};

class IncrementOp : public framework::OperatorBase {
public:
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();

PADDLE_ENFORCE(platform::is_cpu_place(x.place()));
out.Resize(x.dims());
out.mutable_data(x.place(), x.type());
float value = Attr<float>("step");
VLOG(10) << Output("Out") << " increase " << Input("X") << " with "
<< value;
framework::VisitDataType(framework::ToDataType(out.type()),
IncrementFunctor(x, &out, value));
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// IncrementOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};

Expand Down Expand Up @@ -108,5 +83,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape,
ops::IncrementOpMaker, ops::IncrementGradOpMaker);
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>,
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>)
22 changes: 22 additions & 0 deletions paddle/fluid/operators/increment_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/increment_op.h"

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
increment, ops::IncrementKernel<paddle::platform::CUDADeviceContext, float>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, double>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int>,
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int64_t>)
39 changes: 39 additions & 0 deletions paddle/fluid/operators/increment_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class IncrementKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_tensor = context.Input<framework::Tensor>("X");
auto* out_tensor = context.Output<framework::Tensor>("Out");
float step = context.Attr<float>("step");

out_tensor->mutable_data<T>(context.GetPlace());
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
framework::EigenScalar<T>::From(*out_tensor).device(dev) =
framework::EigenScalar<T>::From(*x_tensor) + static_cast<T>(step);
}
};

} // namespace operators
} // namespace paddle
20 changes: 15 additions & 5 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .. import core
from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu
from ops import logical_and, logical_not, logical_or

__all__ = [
Expand Down Expand Up @@ -949,7 +950,7 @@ def create_array(dtype):
dtype=dtype)


def less_than(x, y, cond=None, **ignored):
def less_than(x, y, force_cpu=True, cond=None, **ignored):
"""
**Less than**

Expand All @@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored):
Args:
x(Variable): First operand of *less_than*
y(Variable): Second operand of *less_than*
force_cpu(Bool|True): The output data will be on CPU if set true.
cond(Variable|None): Optional output variable to store the result of *less_than*

Returns:
Expand All @@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored):
cond.stop_gradient = True

helper.append_op(
type='less_than', inputs={'X': [x],
'Y': [y]}, outputs={'Out': [cond]})
type='less_than',
inputs={'X': [x],
'Y': [y]},
outputs={'Out': [cond]},
attrs={'force_cpu': force_cpu or force_init_on_cpu()})
return cond


Expand Down Expand Up @@ -1396,7 +1401,8 @@ def step_input(self, x):
type='less_than',
inputs={'X': self.step_idx,
'Y': self.max_seq_len},
outputs={'Out': self.cond})
outputs={'Out': self.cond},
attrs={'force_cpu': True})

input_array = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_input_array'),
Expand Down Expand Up @@ -1445,7 +1451,11 @@ def block(self):
for new_mem, mem_array in self.mem_link:
array_write(x=new_mem, i=self.step_idx, array=mem_array)

less_than(x=self.step_idx, y=self.max_seq_len, cond=self.cond)
less_than(
x=self.step_idx,
y=self.max_seq_len,
force_cpu=True,
cond=self.cond)

self.status = DynamicRNN.AFTER_RNN
for each_array in self.output_array:
Expand Down