Skip to content

Commit

Permalink
Merge pull request #14 from reyoung/feature/refactorize_framework_proto
Browse files Browse the repository at this point in the history
Polish Our code by YuYang's review
  • Loading branch information
reyoung authored Aug 14, 2017
2 parents 88a3d8d + f09cb65 commit 1ed5f02
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 164 deletions.
26 changes: 14 additions & 12 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient();
AddOutput("Out", "Out of Add").IgnoreGradient();
AddInput("X", "Input X of Add").AsNoGradient();
AddInput("b", "Bias of Add").AsNoGradient();
AddOutput("Out", "Out of Add").AsNoGradient();
AddComment("Add Op");
}
};
Expand Down Expand Up @@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_result", "").SetTemporary();
AddOutput("add_result", "").SetTemporary();
AddOutput("mul_result", "").AsIntermediate();
AddOutput("add_result", "").AsIntermediate();
AddOutput("Out", "");
AddComment("");
}
Expand Down Expand Up @@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple();
AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y");
AddComment("");
}
Expand Down Expand Up @@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_["all"].size(),

const char *all = paddle::operators::NetOp::kAll;
EXPECT_EQ(grad_fc.inputs_[all].size(),
2UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_["all"].size(),
EXPECT_EQ(grad_fc.outputs_[all].size(),
2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add
*/
+ 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL);
}
7 changes: 0 additions & 7 deletions paddle/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list);
}

std::string DDim::DebugString() const {
std::ostringstream ss;
ss << *this;
return ss.str();
}

} // namespace framework
} // namespace paddle
2 changes: 0 additions & 2 deletions paddle/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ struct DDim {
DDim operator*(DDim d) const;

ssize_t size() const;

std::string DebugString() const;
};

/**
Expand Down
3 changes: 0 additions & 3 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ permissions and limitations under the License. */

namespace paddle {
namespace framework {

class OpRegistry;

enum class OpArgType { IN, OUT };

static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
Expand Down
12 changes: 6 additions & 6 deletions paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple();
AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input");
AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").SetMultiple();
AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs");
}
};
Expand All @@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient();
AddInput("In3_mult", "another multiple input").SetMultiple();
AddOutput("Out1_mult", "a multiple output").SetMultiple();
AddOutput("Out2", "a single output").IgnoreGradient();
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient();
AddComment("op with inputs and outputs ignored in gradient calculating");
}
};
Expand Down
36 changes: 18 additions & 18 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
struct VariableBuilder {
OpProto::Var* var_;

VariableBuilder& SetMultiple() {
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}

VariableBuilder& SetTemporary() {
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}

VariableBuilder& IgnoreGradient() {
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this;
}
Expand Down Expand Up @@ -118,7 +121,7 @@ class OpProtoAndCheckerMaker {

class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
using VarNameMap = OperatorBase::VarNameMap;

public:
template <typename OpType, typename ProtoMakerType>
Expand Down Expand Up @@ -164,25 +167,22 @@ class OpRegistry {
return std::shared_ptr<OperatorBase>(op);
}

static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs;
for (auto& input : op_desc.inputs()) {
auto& var_names = inputs[input.parameter()];
auto& var_names_in_proto = input.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}

VarNameMap outputs;
for (auto& output : op_desc.outputs()) {
auto& var_names = outputs[output.parameter()];
auto& var_names_in_proto = output.arguments();
static VarNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
VarNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
auto& var_names_in_proto = var.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
return ret_val;
}

static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
Expand Down
49 changes: 19 additions & 30 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").SetMultiple();
AddOutput("output", "output of cosine op").SetTemporary();
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
Expand All @@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle

static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}

REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
Expand All @@ -59,13 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";

auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());

float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
Expand All @@ -85,13 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";

auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());

auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
Expand All @@ -115,13 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";

auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());

ASSERT_TRUE(op_desc.IsInitialized());

Expand All @@ -136,13 +130,8 @@ TEST(OpRegistry, DefaultValue) {
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
auto input = op_desc.add_inputs();
input->set_parameter("input");
*input->mutable_arguments()->Add() = "ii";

auto output = op_desc.add_outputs();
output->set_parameter("output");
*output->mutable_arguments()->Add() = "oo";
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());

// attr 'test_attr' is not set
bool caught = false;
Expand Down
57 changes: 44 additions & 13 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
}

const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_,
name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return it->second[0];
return ins[0];
}

const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
return inputs_.at(name);
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
name);
return it->second;
}

const std::string& OperatorBase::Output(const std::string& name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
auto& outs = Outputs(name);
PADDLE_ENFORCE_EQ(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return it->second[0];
return outs[0];
}

const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
return outputs_.at(name);
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
name);
return it->second;
}

std::string OperatorBase::DebugString() const {
Expand Down Expand Up @@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
}
}

std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpProtos().find(type_);
PADDLE_ENFORCE(
it != OpProtos().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);

// get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}

} // namespace framework
} // namespace paddle
Loading

0 comments on commit 1ed5f02

Please sign in to comment.