Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Refine DNNL Codegen #5288

Merged
merged 3 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,10 @@ def _func_wrapper(attrs, args):
return _func_wrapper


_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")


@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
71 changes: 47 additions & 24 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}

void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing
VisitExpr(op->tuple);
CHECK(out_.size() > static_cast<size_t>(op->index));

// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks ok to me, why "FIXME"? TupleGetItem only cares about one output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of the semantic, if we generate a partition function like conv2d->BN->relu, we should have a tuple output with 3 elements (relu_out, BN_out_1, BN_out_2) even the rest 2 are useless. In general we should achieve this semantic, but since no one cares the rest 2 outputs of batch_norm it's fine for now.

auto item = out_[op->index];
out_.clear();
out_.push_back(item);
}

void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;

// Args: ID
std::vector<std::string> args;

Expand Down Expand Up @@ -96,36 +103,52 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}
}

// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
// Analyze the output buffers
std::vector<Type> out_types;
if (call->checked_type()->IsInstance<TupleTypeNode>()) {
auto type_node = call->checked_type().as<TupleTypeNode>();
for (auto field : type_node->fields) {
CHECK(field->IsInstance<TensorTypeNode>());
out_types.push_back(field);
}
} else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
out_types.push_back(call->checked_type());
} else {
LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
}

out_.clear();
for (auto out_type : out_types) {
const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());

std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(out_type);
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
std::ostringstream buf_stream;
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Update output buffer
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());

// Update output buffer
out_.clear();
Output output;
output.name = out;
output.dtype = dtype;
output.need_copy = true;
output.size = out_size;
out_.push_back(output);
}

std::string JIT(void) {
Expand Down
13 changes: 7 additions & 6 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,19 +924,20 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());

masahi marked this conversation as resolved.
Show resolved Hide resolved
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse the shape functions.
pass_seqs.push_back(transform::FuseOps());
// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));

Expand Down
6 changes: 4 additions & 2 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
read_from_dnnl_memory(out, dst_memory);
}

extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_N_, int p_C_,
extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
int p_H_, int p_W_, int p_E_) {
// FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
// the rest two because no one cares about them for now. Should update it in the future.
using tag = memory::format_tag;
using dt = memory::data_type;

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);

extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
int p_e_);
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);

extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
Expand Down
5 changes: 3 additions & 2 deletions tests/python/relay/test_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_run():
test_annotate()
test_run()


@pytest.mark.skip(reason="fix constant node before opening this case")
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
Expand All @@ -172,6 +172,7 @@ def test_extern_dnnl_mobilenet():
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')

mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = transform.AnnotateTarget("dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
Expand Down Expand Up @@ -267,5 +268,5 @@ def after():
if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
#test_extern_dnnl_mobilenet()
test_composite_function()
6 changes: 4 additions & 2 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import sys
import numpy as np
import pytest

import tvm
import tvm.relay.testing
Expand Down Expand Up @@ -438,7 +439,7 @@ def get_func():
check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)


@pytest.mark.skip(reason="fix constant node before opening this case")
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
Expand All @@ -450,6 +451,7 @@ def test_extern_dnnl_mobilenet():
batch_size=1, dtype='float32')

op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
mod = WhiteListAnnotator(op_list, "dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
Expand Down Expand Up @@ -862,7 +864,7 @@ def expected():
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
#test_extern_dnnl_mobilenet()
masahi marked this conversation as resolved.
Show resolved Hide resolved
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()
Expand Down