Skip to content

Commit

Permalink
[BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (apache#5277
Browse files Browse the repository at this point in the history
)

* add target to region

* refactor annotate_target

* Make all unit test working

* quick fix

* enable BN, unit test failed

* Fix vm test, unit test. Refactor annotate_target a bit.

* quick fix fusion

* revert fusion change

* style fix

* Refactor merge region pass

* format

* minor fix

* Skip e2e test

* lint

* support AnnotateTarget multiple runs

* Add HasAttr and revert DNNL codegen

* address comment

Co-authored-by: Zhi Chen <[email protected]>
  • Loading branch information
2 people authored and dpankratz committed Apr 24, 2020
1 parent 202821c commit bff3c99
Show file tree
Hide file tree
Showing 15 changed files with 609 additions and 529 deletions.
9 changes: 8 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,17 @@ def _func_wrapper(attrs, args):
return _func_wrapper


_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")


@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
10 changes: 6 additions & 4 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,22 +587,24 @@ def PartitionGraph():



def AnnotateTarget(target):
def AnnotateTarget(targets):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Parameters
----------
target : String
The target compiler used for codegen.
targets : str or List[str]
The list of target compilers used for codegen.
Returns
-------
ret : tvm.relay.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _ffi_api.AnnotateTarget(target)
if isinstance(targets, str):
targets = [targets]
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])


def Inline():
Expand Down
60 changes: 34 additions & 26 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/runtime/container.h>

#include <unordered_map>
#include <vector>
Expand All @@ -31,7 +32,7 @@ namespace relay {

AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
return candidate;
}
}
Expand All @@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
}

// Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
for (const auto& input : src->ins) {
dest->ins.push_back(input);
dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
for (const auto& input : src->ins_) {
dest->ins_.push_back(input);
}
for (const auto& output : src->outs) {
dest->outs.push_back(output);
for (const auto& output : src->outs_) {
dest->outs_.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
for (const auto& input : dest->ins_) {
auto call = Downcast<Call>(input);
auto it = src->nodes.find(call->args[0]);
if (it != src->nodes.end()) {
dest->outs.remove(*it);
auto it = src->nodes_.find(call->args[0]);
if (it != src->nodes_.end()) {
dest->outs_.remove(*it);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
dest->ins_.remove(input);
}
regions_.erase(src);
}
Expand All @@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
if (src.defined()) {
MergeRegions(src, dest);
} else {
dest->nodes.insert(expr);
dest->nodes_.insert(expr);
}
}

AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++;
(*ret.first)->id_ = region_id_++;
(*ret.first)->target_ = target;
return *ret.first;
}

class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op) :
begin_op_(region_begin_op), end_op_(region_end_op) {}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}
Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {}

void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();
Expand All @@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
region->ins.push_back(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;

// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
region = region_set_->MakeRegion();
region->nodes.insert(call->args[0]);
// Create a new region if the argument is not belonged to any regions yet.
region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
CHECK_EQ(region->GetTarget(), target);
}
region->nodes.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call));
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
Expand Down
37 changes: 23 additions & 14 deletions src/relay/analysis/annotated_region_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/transform.h>

#include <string>
Expand All @@ -49,47 +50,55 @@ class AnnotatedRegionSet;
class AnnotatedRegionNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("id", &id);
Array<Expr> nodes_array(nodes.begin(), nodes.end());
v->Visit("id", &id_);
v->Visit("target", &target_);
Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
v->Visit("nodes", &nodes_array);
Array<Expr> args_array(ins.begin(), ins.end());
Array<Expr> args_array(ins_.begin(), ins_.end());
v->Visit("args", &args_array);
Array<Expr> rets_array(outs.begin(), outs.end());
Array<Expr> rets_array(outs_.begin(), outs_.end());
v->Visit("rets", &rets_array);
}

/*! \brief Get the region ID. */
int GetID() const {
return id;
return id_;
}

/*! \brief Get the region target. */
std::string GetTarget() const {
return target_;
}

/*! \brief Get the region's inputs. */
std::list<Expr> GetInputs() const {
return ins;
return ins_;
}

/*! \brief Get the region's outputs. */
std::list<Expr> GetOutputs() const {
return outs;
return outs_;
}

/*! \brief Get the region's nodes. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
return nodes;
return nodes_;
}

static constexpr const char* _type_key = "relay.AnnotatedRegion";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);

protected:
/*! \brief The region ID. */
int id{-1};
int id_{-1};
/*! \brief The target for this region. */
std::string target_ = "default";
/*! \brief The inputs to this region. */
std::list<Expr> ins;
std::list<Expr> ins_;
/*! \brief The outputs of this region */
std::list<Expr> outs;
std::list<Expr> outs_;
/*! \brief Nodes in this region. */
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;

friend class AnnotatedRegionSet;
friend class AnnotatedRegionSetNode;
Expand Down Expand Up @@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
void AddToRegion(AnnotatedRegion dest, const Expr& expr);

/*!
* \brief Make a new region.
* \brief Make a new region for a target.
*
* \return The new region.
*/
AnnotatedRegion MakeRegion();
AnnotatedRegion MakeRegion(const std::string& target);

std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
/*! \brief The next region ID to assign. */
Expand Down
71 changes: 24 additions & 47 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}

void VisitExpr_(const TupleGetItemNode* op) final {
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));

// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
auto item = out_[op->index];
out_.clear();
out_.push_back(item);
// Do nothing
}

void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;

std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;

Expand Down Expand Up @@ -103,52 +96,36 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
}

// Analyze the output buffers
std::vector<Type> out_types;
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}

out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());

std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());

// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}

std::string JIT(void) {
Expand Down
13 changes: 6 additions & 7 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,20 +924,19 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));

Expand Down
Loading

0 comments on commit bff3c99

Please sign in to comment.