Skip to content

Commit

Permalink
Remove duplicated output args
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 11, 2020
1 parent 403929f commit 6dc8b77
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 125 deletions.
64 changes: 37 additions & 27 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,44 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
class CodegenC : public ExprVisitor, public CodegenCBase {
class CodegenC : public relay::ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }

void VisitExpr_(const VarNode* node) final {
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);

std::vector<Output> output;
if (expr.as<ConstantNode>()) {
output = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<VarNode>()) {
output = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<CallNode>()) {
output = VisitExpr_(expr.as<CallNode>());
} else {
LOG(FATAL) << "DNNL codegen doesn't support: " << expr->GetTypeKey();
}
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
return {output};
}

void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
// Note this is for demostration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
out_.push_back(visited_[constant]);
return;
}
std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
// Note this is for demonstration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.

std::ostringstream decl_stream;
std::ostringstream buf_stream;

out_.clear();
Output output;
output.name = "const_" + std::to_string(const_idx_++);
out_.push_back(output);
visited_[constant] = output;

runtime::NDArray array = cn->data;
const auto& shape = array.Shape();
Expand Down Expand Up @@ -99,9 +108,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
}
buf_stream << "};";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());

return {output};
}

void VisitExpr_(const CallNode* call) final {
std::vector<Output> VisitExpr_(const CallNode* call) final {
std::ostringstream macro_stream;
std::ostringstream decl_stream;
std::ostringstream buf_stream;
Expand Down Expand Up @@ -138,8 +149,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
bool first = true;
decl_stream << func_name << "(";
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
auto res = VisitExpr(call->args[i]);
for (auto out : res) {
if (!first) {
decl_stream << ", ";
}
Expand All @@ -162,26 +173,27 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
ext_func_body.push_back(decl_stream.str());

// Update output buffer
out_.clear();
// Note C codegen only handles TensorType. Therefore, we don't flatten
// tuples and only return a single vaule.
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
return {output};
}

/*!
* \brief Emit the source code that invokes C compiler compatible wrappers.
*
* \return The emitted code.
*/
std::string JIT() {
std::string JIT(const std::vector<Output>& out) {
// Write function macros
for (auto decl : func_decl_) {
code_stream_ << decl << "\n";
}
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
}

private:
Expand All @@ -202,9 +214,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

class CSourceCodegen : public CSourceModuleCodegenBase {
Expand All @@ -216,8 +226,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
auto sid = GetExtSymbol(func);

CodegenC builder(sid);
builder.VisitExpr(func->body);
code_stream_ << builder.JIT();
auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(out);
}

runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ class CodegenCBase {
/*!
* \brief Emit the code for external runtime.
*
* \param out The outputs.
*
* \return The code string.
*/
virtual std::string JIT() = 0;
virtual std::string JIT(const std::vector<Output>& out) = 0;

/*!
* \brief A common interface that is used by various external runtime to
Expand Down
89 changes: 50 additions & 39 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,42 +128,50 @@ std::vector<std::string> Add(const CallNode* call) {

// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
class CodegenDNNL : public relay::ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }

void VisitExpr_(const VarNode* node) final {
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);

std::vector<Output> output;
if (expr.as<ConstantNode>()) {
output = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<VarNode>()) {
output = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<CallNode>()) {
output = VisitExpr_(expr.as<CallNode>());
} else if (expr.as<TupleGetItemNode>()) {
output = VisitExpr_(expr.as<TupleGetItemNode>());
} else {
LOG(FATAL) << "DNNL codegen doesn't support: " << expr->GetTypeKey();
}
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
return {output};
}

void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));
std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
auto res = VisitExpr(op->tuple);
CHECK_GT(res.size(), static_cast<size_t>(op->index));

// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
return {res[op->index]};
}

void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
out_.push_back(visited_[constant]);
return;
}

out_.clear();
std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
Output output;
output.name = "const_" + std::to_string(const_idx_++);
output.dtype = "float";
out_.push_back(output);
visited_[constant] = output;

runtime::NDArray array = cn->data;

Expand All @@ -176,33 +184,38 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";

std::ostringstream buf_stream;
buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n";
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
for (int64_t i = 0; i < num_elems; i++) {
buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n";

// Allocate large arrays on the static section to avoids stakc overflow.
// Note that this would probably increase compilation time as the source
// file could be really large.
buf_stream << "static float " << output.name << "[" << num_elems <<"] = {";
for (int64_t i = 0; i < num_elems - 1; i++) {
buf_stream << ptr[i] << ",";
}
if (num_elems > 0) buf_stream << ptr[num_elems - 1];
buf_stream << "};\n";

ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
return {output};
}

void VisitExpr_(const CallNode* call) final {
std::vector<Output> VisitExpr_(const CallNode* call) final {
GenerateBodyOutput ret;
if (const auto* func = call->op.as<FunctionNode>()) {
ret = GenerateCompositeFunctionCall(func, call);
} else {
ret = GenerateOpCall(call);
}

out_.clear();
for (size_t i = 0; i < ret.outputs.size(); ++i) {
buf_decl_.push_back(ret.buffers[i]);
out_.push_back(ret.outputs[i]);
}
buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end());
std::vector<Output> out = ret.outputs;
ext_func_body.push_back(ret.decl);
return ret.outputs;
}

std::string JIT(void) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
std::string JIT(const std::vector<Output>& out) {
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
}

private:
Expand All @@ -215,8 +228,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> GetArgumentNames(const CallNode* call) {
std::vector<std::string> arg_names;
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
auto res = VisitExpr(call->args[i]);
for (const auto& out : res) {
arg_names.push_back(out.name);
}
}
Expand Down Expand Up @@ -331,17 +344,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
*/
int buf_idx_{0};
/*! \brief The index of global constants. */
int const_idx_ = 0;
int const_idx_{0};
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
Array<Var> ext_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name of the the outputs. */
std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

/*!
Expand All @@ -361,8 +372,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
auto sid = GetExtSymbol(func);

CodegenDNNL builder(sid);
builder.VisitExpr(func->body);
code_stream_ << builder.JIT();
auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(out);
}

/*!
Expand Down
51 changes: 37 additions & 14 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,42 @@ class Partitioner : public ExprMutator {
CHECK_EQ(call->args.size(), 1U);

// Traverse the rest graph.
auto input_expr = VisitExpr(call->args[0]);
Expr parent = call->args[0];
auto input_expr = VisitExpr(parent);

// Backtrace the parent to find the LCA node that is not a begin/ end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op ||
parent_call->op == compiler_end_op) {
parent = parent_call->args[0];
} else {
break;
}
}

AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);

auto cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}

return std::move(var);
if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
return shared_output_[parent][sg];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);

std::pair<Var, Expr> cand = std::make_pair(var, input_expr);

if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}
shared_output_[parent][sg] = var;
return std::move(var);
}
} else {
CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one
Expand Down Expand Up @@ -474,6 +491,12 @@ class Partitioner : public ExprMutator {
* belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;

/*!\brief Cache the output that is shared by different nodes. */
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;

/*!\brief The IRModule used for partitioning. */
IRModule module_;
};

Expand Down
Loading

0 comments on commit 6dc8b77

Please sign in to comment.