Skip to content

Commit

Permalink
Add walkthrough test on python and debug (PaddlePaddle#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Sep 6, 2020
1 parent 4dd175d commit 2a0c2f8
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 22 deletions.
16 changes: 12 additions & 4 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ class GraphNode;
*/
class GraphEdge : public Object {
public:
GraphEdge(GraphNode* source, GraphNode* sink) : source_(source), sink_(sink) {}
GraphEdge(GraphNode* source, GraphNode* sink, int index = -1) : source_(source), sink_(sink), index_(index) {}

GraphNode* source() const { return source_; }
GraphNode* sink() const { return sink_; }
const char* type_info() const override { return __type_info__; }
int index() const { return index_; }

private:
//! the index in sink node's inlinks_ or source node's outlinks_
//! this is used to keep the input/output tensor's order of operator node
int index_{-1};
//! Source of this edge.
GraphNode* source_{};
//! End of this edge.
Expand All @@ -64,9 +68,10 @@ class GraphNode : public Object {
EdgeT *a, *b;
CHECK(other);
CHECK_NE(other, this) << "cannot link to itself";
auto edge = make_shared<GraphEdge>(this, other);
auto edge1 = make_shared<GraphEdge>(this, other);

auto edge = make_shared<GraphEdge>(this, other, index_outlinks);
auto edge1 = make_shared<GraphEdge>(this, other, other->index_inlinks);
index_outlinks++;
other->index_inlinks++;
outlinks_.insert(edge);
other->inlinks_.insert(edge1);

Expand Down Expand Up @@ -140,6 +145,9 @@ class GraphNode : public Object {
std::set<common::Shared<GraphEdge>, GraphEdgeCompare> outlinks_;

mutable int visited_time_{};
//! used to mark the index of node's input/output tensors
int index_inlinks{0};
int index_outlinks{0};
};

/**
Expand Down
33 changes: 33 additions & 0 deletions cinn/frontend/syntax.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/utils/string.h"

Expand Down Expand Up @@ -34,6 +35,38 @@ Variable Program::add(const Variable& a, const Variable& b) {
return instr.GetOutputs()[0];
}

Variable Program::relu(const Variable& a) {
Instruction instr("relu");
instr.SetInputs({a});
AddInstruction(instr);
return instr.GetOutputs()[0];
}

std::vector<Variable> Program::conv2d(
const Variable& a,
const Variable& b,
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store) {
Instruction instr("conv2d");
instr.SetInputs({a, b});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AddInstruction(instr);
return instr.GetOutputs();
}

Variable Program::batchnorm(const Variable& a,
const Variable& b,
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store) {
Instruction instr("batchnorm");
instr.SetInputs({a, b});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AddInstruction(instr);
return instr.GetOutputs()[0];
}

Instruction& Program::operator[](size_t i) {
CHECK_LT(i, instrs.size());
return instrs[i];
Expand Down
35 changes: 35 additions & 0 deletions cinn/frontend/syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,41 @@ struct Program {
*/
Variable add(const Variable& a, const Variable& b);

/**
* Apply Rectified Linear Unit on input Variable.
* Actually apply: outupt = max(input,0)
*
* @param a The first variable.
* @return The result.
*/
Variable relu(const Variable& a);

/**
* The convolution2D layer calculates the output based on the input, filter
* and strides, paddings, dilations, groups parameters.
*
* @param a The first variable input.
* @param b The second variable filter(weights).
* @param attr_store The params like padding, stride, dilation, etc.
* @return The result.
*/
std::vector<Variable> conv2d(const Variable& a,
const Variable& b,
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store);

/**
* The batchnorm layer can be used as a normalizer function
* for convolution or fully_connected operations.
*
* @param a The first variable input.
* @param b The second variable filter(weights).
* @param attr_store The params like eplison.
* @return The result.
*/
Variable batchnorm(const Variable& a,
const Variable& b,
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store);

/**
* Get \p i-th instruction.
*/
Expand Down
20 changes: 15 additions & 5 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ namespace cinn {
namespace hlir {
namespace framework {

void GraphCompiler::PrintFunc() {
auto [nodes, edges] = graph_->topological_order();
for (auto& n : nodes) {
auto* node = n->safe_as<Node>();
if (node) {
auto lowered_func = GetOpFunc(node);
}
}
}

std::unique_ptr<Program> GraphCompiler::Build() {
auto [nodes, edges] = graph_->topological_order();
for (auto& n : nodes) {
Expand Down Expand Up @@ -52,7 +62,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
auto& dtype_dict = graph_->GetAttrs<std::unordered_map<std::string, Type>>("inferdtype");
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
for (auto& i : node->inlinks()) {
for (auto& i : node->inlinks_in_order()) {
std::string input_id = i->source()->as<NodeData>()->id();
std::vector<int> in_shape = shape_dict.at(input_id);
Type dtype = dtype_dict.at(input_id);
Expand All @@ -63,7 +73,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
cinn_inputs.push_back(common::CINNValue(temp));
}
std::vector<Type> out_types;
for (auto& out : node->outlinks()) {
for (auto& out : node->outlinks_in_order()) {
std::string out_id = out->sink()->safe_as<NodeData>()->id();
Type dtype = dtype_dict.at(out_id);
out_types.push_back(dtype);
Expand All @@ -80,21 +90,21 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
}

auto func = Lower(GenOpFuncName(node), stages, inputs);

LOG(INFO) << "The function of node [" << node->attrs.node_name << "] is: " << func;
return func;
}

std::vector<std::string> GraphCompiler::OpGetInputNames(const Node* node) const {
std::vector<std::string> res;
for (auto& i : node->inlinks()) {
for (auto& i : node->inlinks_in_order()) {
res.push_back(i->source()->as<NodeData>()->id());
}
return res;
}

std::vector<std::string> GraphCompiler::OpGetOutputNames(const Node* node) const {
std::vector<std::string> res;
for (auto& i : node->outlinks()) {
for (auto& i : node->outlinks_in_order()) {
res.push_back(i->sink()->as<NodeData>()->id());
}
return res;
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/framework/graph_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class GraphCompiler final {

std::unique_ptr<Program> Build();

void PrintFunc();

private:
ir::LoweredFunc GetOpFunc(const Node* node);

Expand Down
34 changes: 32 additions & 2 deletions cinn/hlir/framework/node.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#include "cinn/hlir/framework/node.h"
#include <algorithm>

namespace cinn {
namespace hlir {
namespace framework {

std::tuple<common::GraphEdge *, common::GraphEdge *> Node::LinkTo(NodeData *other) {
std::tuple<common::GraphEdge*, common::GraphEdge*> Node::LinkTo(NodeData* other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

std::tuple<common::GraphEdge *, common::GraphEdge *> NodeData::LinkTo(Node *other) {
std::tuple<common::GraphEdge*, common::GraphEdge*> NodeData::LinkTo(Node* other) {
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
}

Expand Down Expand Up @@ -49,6 +50,35 @@ std::ostream &operator<<(std::ostream &os, const NodeAttr &node_attr) {
return os;
}

//! Using index to sort the input/output tensors
bool edge_index_compare(const common::Shared<common::GraphEdge>& a, const common::Shared<common::GraphEdge>& b) {
return a->index() < b->index();
}

const std::vector<common::Shared<common::GraphEdge>>& Node::inlinks_in_order() const {
if (inlinks_in_order_.empty()) {
for (auto& in_edge : this->inlinks()) {
inlinks_in_order_.push_back(in_edge);
CHECK_GE(in_edge->index(), 0) << "The index of a node's inlinks should be >= 0! Now index is: "
<< in_edge->index() << ". Please check.";
}
std::sort(inlinks_in_order_.begin(), inlinks_in_order_.end(), edge_index_compare);
}
return inlinks_in_order_;
}

const std::vector<common::Shared<common::GraphEdge>>& Node::outlinks_in_order() const {
if (outlinks_in_order_.empty()) {
for (auto& out_edge : this->outlinks()) {
outlinks_in_order_.push_back(out_edge);
CHECK_GE(out_edge->index(), 0) << "The index of a node's outlinks should be >= 0! Now index is: "
<< out_edge->index() << ". Please check.";
}
std::sort(outlinks_in_order_.begin(), outlinks_in_order_.end(), edge_index_compare);
}
return outlinks_in_order_;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
9 changes: 9 additions & 0 deletions cinn/hlir/framework/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "cinn/common/graph_utils.h"
#include "cinn/common/shared.h"
#include "cinn/hlir/framework/op.h"

namespace cinn {
Expand Down Expand Up @@ -75,6 +76,12 @@ class Node : public common::GraphNode {
*/
NodeAttr attrs;

//! Get the input tensors in order to match tensors correctly.
const std::vector<common::Shared<common::GraphEdge>> &inlinks_in_order() const;

//! Get the output tensors in order to match tensors correctly.
const std::vector<common::Shared<common::GraphEdge>> &outlinks_in_order() const;

inline const Operator *op() const { return this->attrs.op; }

inline bool is_variable() { return (this->attrs.op == nullptr); }
Expand All @@ -95,6 +102,8 @@ class Node : public common::GraphNode {
* \brief The unique id of the node.
*/
std::string id_;
mutable std::vector<common::Shared<common::GraphEdge>> outlinks_in_order_{};
mutable std::vector<common::Shared<common::GraphEdge>> inlinks_in_order_{};
};

/**
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/framework/pass.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"

namespace cinn {
namespace hlir {
Expand Down
14 changes: 11 additions & 3 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,21 @@ std::vector<std::vector<int>> InferShapeForConv2d(const std::vector<std::vector<
CHECK_EQ(inputs_shape[0].size(), 4) << "The first input tensor's shape size of conv2d op is not 4! Please check.";
int out_shape_h = (inputs_shape[0][2] - ((inputs_shape[1][2] - 1) * dilation + 1) + 2 * padding[0]) / stride[0] + 1;
int out_shape_w = (inputs_shape[0][3] - ((inputs_shape[1][3] - 1) * dilation + 1) + 2 * padding[1]) / stride[1] + 1;
std::vector<std::vector<int>> res{{inputs_shape[0][0], inputs_shape[1][0], out_shape_h, out_shape_w}};
std::vector<std::vector<int>> res{{inputs_shape[0][0],
inputs_shape[0][1],
inputs_shape[0][2] + 2 * padding[0],
inputs_shape[0][3] + 2 * padding[1]},
{inputs_shape[1][0],
inputs_shape[1][1],
(inputs_shape[1][2] - 1) * dilation + 1,
(inputs_shape[1][3] - 1) * dilation + 1},
{inputs_shape[0][0], inputs_shape[1][0], out_shape_h, out_shape_w}};
return res;
}

std::vector<Type> InferDtypeForConv2d(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
std::vector<Type> res{inputs_type[0], inputs_type[1], inputs_type[0]};
return res;
}

Expand Down Expand Up @@ -254,7 +262,7 @@ CINN_REGISTER_HELPER(nn_ops) {
CINN_REGISTER_OP(conv2d)
.describe("Do a 2-D convolution with an NCHW-layout.")
.set_num_inputs(2) // here we consider filter as anohter input
.set_num_outputs(1)
.set_num_outputs(3)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForConv2d)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForConv2d))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForConv2d))
Expand Down
33 changes: 30 additions & 3 deletions cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "cinn/common/type.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/op/use_ops.h"
#include "cinn/utils/string.h"

namespace cinn::pybind {

using common::Type;
using frontend::Placeholder;
namespace py = pybind11;
using namespace cinn::frontend; // NOLINT

void BindFrontend(pybind11::module *m) {
py::class_<Variable>(*m, "Variable") //
.def(py::init<const std::string &>(), py::arg("id") = "")
.def("__str__", [](Variable &self) { return self->id; })
.def("__repr__", [](Variable &self) { return utils::GetStreamCnt(self); });
.def("__repr__", [](Variable &self) { return utils::GetStreamCnt(self); })
.def("set_type",
[](Variable &self, const Type &type) {
self->type = type;
return self;
})
.def("set_shape", [](Variable &self, const std::vector<int> &shape) {
self->shape = shape;
return self;
});

py::class_<Placeholder>(*m, "Placeholder") //
.def(py::init<const common::Type &, const std::vector<int> &, std::string_view>(),
Expand Down Expand Up @@ -45,7 +62,17 @@ void BindFrontend(pybind11::module *m) {
.def(py::init<>())
.def("size", &Program::size)
.def("__getitem__", [](Program &self, int idx) { return self[idx]; })
.def("add", &Program::add);
.def("add", &Program::add)
.def("relu", &Program::relu)
.def("conv2d", &Program::conv2d)
.def("batchnorm", &Program::batchnorm)
.def("print_func", [](Program &self, const common::Target &target) {
std::shared_ptr<hlir::framework::Graph> g(new hlir::framework::Graph(self));
hlir::framework::ApplyPass(g.get(), "InferShape");
std::shared_ptr<hlir::framework::Scope> scope = hlir::framework::BuildScope(target, g);
hlir::framework::GraphCompiler gc(target, scope, g);
gc.PrintFunc();
});
} // namespace frontend

} // namespace cinn::pybind
5 changes: 5 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ ADD_TEST(NAME test_cinn_op_broadcast
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_op_broadcast.py
)

ADD_TEST(NAME test_cinn_frontend
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_frontend.py
)
2 changes: 1 addition & 1 deletion python/cinn/frontend.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .core_api.frontend import Variable, Program, Instruction, Placeholder
from .core_api.frontend import *
Loading

0 comments on commit 2a0c2f8

Please sign in to comment.