Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes #5277

Merged
merged 17 commits into from
Apr 10, 2020
9 changes: 8 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,17 @@ def _func_wrapper(attrs, args):
return _func_wrapper


_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")


@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
10 changes: 6 additions & 4 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,22 +587,24 @@ def PartitionGraph():



def AnnotateTarget(target):
def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.

Parameters
----------
target : String
The target compiler used for codegen.
targets : str or List[str]
The list of target compilers used for codegen.

Returns
-------
ret : tvm.relay.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _ffi_api.AnnotateTarget(target)
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])


def Inline():
Expand Down
60 changes: 34 additions & 26 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/runtime/container.h>

#include <unordered_map>
#include <vector>
Expand All @@ -31,7 +32,7 @@ namespace relay {

AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate;
}
}
Expand All @@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
}

// Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
for (const auto& input : src->ins) {
dest->ins.push_back(input);
dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins_) {
dest->ins_.push_back(input);
}
for (const auto& output : src->outs) {
dest->outs.push_back(output);
for (const auto& output : src->outs_) {
dest->outs_.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]);
if (it != src->nodes.end()) {
dest->outs.remove(*it);
auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes_.end()) {
dest->outs_.remove(*it);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
dest->ins_.remove(input);
}
regions_.erase(src);
}
Expand All @@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) {
MergeRegions(src, dest);
} else {
dest->nodes.insert(expr);
dest->nodes_.insert(expr);
}
}

AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++;
(*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first;
}

class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op) :
begin_op_(region_begin_op), end_op_(region_end_op) {}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {}

void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
Expand All @@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
region->ins.push_back(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;

// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
region = region_set_->MakeRegion();
region->nodes.insert(call->args[0]);
// Create a new region if the argument is not belonged to any regions yet.
region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
}
region->nodes.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call));
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
Expand Down
37 changes: 23 additions & 14 deletions src/relay/analysis/annotated_region_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h>

#include <string>
Expand All @@ -49,47 +50,55 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id);
Array<Expr> nodes_array(nodes.begin(), nodes.end());
v->Visit("id", &id_);
v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end());
Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end());
Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array);
}

/*! \brief Get the region ID. */
int GetID() const {
return id;
return id_;
}

/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
}

/*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const {
return ins;
return ins_;
}

/*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const {
return outs;
return outs_;
}

/*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes;
return nodes_;
}

static constexpr const char* _type_key = "relay.AnnotatedRegion";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);

protected:
/*! \brief The region ID. */
int id{-1};
int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */
std::list<Expr> ins;
std::list<Expr> ins_;
/*! \brief The outputs of this region */
std::list<Expr> outs;
std::list<Expr> outs_;
/*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;

friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode;
Expand Down Expand Up @@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr);

/*!
* \brief Make a new region.
* \brief Make a new region for a target.
*
* \return The new region.
*/
AnnotatedRegion MakeRegion();
AnnotatedRegion MakeRegion(const std::string& target);

std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */
Expand Down
71 changes: 24 additions & 47 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}

void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));

// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
// Do nothing
}

void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;

std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;

Expand Down Expand Up @@ -103,52 +96,36 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
}

// Analyze the output buffers
std::vector<Type> out_types;
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}

out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());

std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());

// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}

std::string JIT(void) {
Expand Down
13 changes: 6 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,20 +924,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));

Expand Down
Loading