Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[External codegen] Add test cases for fused ops with manual annotation #4741

Closed
wants to merge 18 commits into from
Closed
1 change: 0 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def optimize(self, func, target=None, params=None):

return mod, params


def _set_params(self, params):
self._set_params_func(_convert_param_map(params))

Expand Down
145 changes: 89 additions & 56 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <fstream>
#include <sstream>
#include <numeric>

#include "../codegen_c/codegen_c.h"

Expand All @@ -50,82 +51,109 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
out_.push_back({node->name_hint(), 0});
}

void VisitExpr_(const TupleGetItemNode* op) final {
// Do nothing
}

void VisitExpr_(const CallNode* call) final {
std::ostringstream decl_stream;
std::ostringstream buf_stream;
// Args: ID
std::vector<std::string> args;
struct Output {
std::string decl, buf;
int out_size = 1;
std::string out;
};

auto generate_body = [=](const CallNode* root_call, const std::string& func_name,
const std::vector<std::string>& args,
const std::vector<std::string>& fused_func_args) {
// Make function call with input buffers when visiting arguments
bool first = true;
std::ostringstream arg_stream;
arg_stream << "(";
for (size_t i = 0; i < root_call->args.size(); ++i) {
VisitExpr(root_call->args[i]);
for (auto out : out_) {
if (!first) {
arg_stream << ", ";
}
first = false;
arg_stream << out.first;
}
}

for (auto arg_name : fused_func_args) {
arg_stream << ", " << arg_name;
}

// Analyze the output buffer
auto type_node = root_call->checked_type().as<TensorTypeNode>();
CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32))
<< "Only support single output tensor with float type";

auto out_shape = GetShape(root_call->checked_type());

Output ret;
ret.out = "buf_" + std::to_string(buf_idx_++);
ret.out_size = std::accumulate(out_shape.begin(), out_shape.end(), 1, std::multiplies<int>());

this->PrintIndents();

std::ostringstream buf_stream;
buf_stream << "float* " << ret.out << " = (float*)std::malloc(4 * " << ret.out_size << ");";
ret.buf = buf_stream.str();

// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
args = Conv2d(call);
arg_stream << ", " << ret.out;
// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
arg_stream << ", " << args[i];
}
arg_stream << ");";
ret.decl = func_name + arg_stream.str();

return ret;
};

Output ret;
if (auto conv_call = DetectFusedConv2DBiasReLU(call)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we really want to handle fused op from relay for external codegen. This looks quite ad-hoc to me. You may have countless combinations.

Copy link
Member Author

@masahi masahi Jan 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is for it to serve as an example of handling fused ops inside external codegen. I assume dnnl backend itself is not meant to be used in production; The purpose is to be a more realistic example than CodegenC, so I thought why don't we add an example of how to handle fused ops. I never intended to cover other fusion cases.

Since we are trying to be so nice to new backend implementers (who might not be familiar with TVM internals) as to add convenient op level annotation and semi automatic fusion mechanism etc for them, I don't think it is reasonable to expect them to figure out how to handle more complicated but often common cases (like fusion) and everything else on their own. Hope this make sense.

Copy link
Member Author

@masahi masahi Jan 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another usage scenario which I think is going to be common is translation from quantized Relay models. It would be great to add an example of translating QNN subgraphs to backend implementations, for example. Without it, it is not obvious how to go about it.

Since DNNL has quantization support and everyone can use it, it would serve as a good example and test case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree with you that it's fine to handle fusion in this DNNL codegen, I also agree with @zhiics that the current implementation is a bit too ad-hoc even it's only used for demo purpose for now. As you have implemented, MKL DNN uses set_post_ops to attach ops to be fused. I think this part could be more general. For example:

if call == "relu":
    visit(arg)
    if this->curr_layer == "conv2d":
        generate_post_ops(call)
    else:
        generate_a_layer(call)

In this way, the codegen is able to deal with all MKL DNN supported conv2d fusion (conv2d, conv2d+add, conv2d+add+relu). We could still put heuristic pattern annotations to the annotator and improve it gradually. I like the one you made for conv2d+bias+relu in this PR, for instance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is my minimal effort way to detect only the pattern I care about. Will think about how to make it more general.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can go ahead and implement this, but that would duplicate pattern matching logic that I already have in my python annotator. That sounds bad and it would become a perfect anti-example mentioned in the RFC below :)

I think I should close this one and wait for a better solution to be ready. I will wait for your input for now @comaniac @zhiics

https://discuss.tvm.ai/t/rfc-external-codegen-defining-composite-relay-operators/5470/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had a brief discussion with @u99127 before. I will read the discussion more carefully and probably we can discuss from there and try to have some consensus on a design/implementation. Sorry for being late/slow because I am on vacation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also leave the current dumb implementation as it is, with the understanding that

  • This is a temporary solution
  • It will serve as a concrete motivation and test case for validating a more general mechanism to be introduced

Trying to be a bit more clever and duplicating an entire state machine logic here do not seem worth it to me anymore. Either way I'm fine.

ret = generate_body(conv_call, "dnnl_fused_conv2d_bias_relu",
FusedConv2dBiasReLU(conv_call), ext_fused_func_args_);
} else if (IsOp(call, "nn.conv2d")) {
ret = generate_body(call, "dnnl_conv2d", Conv2d(call), {});
} else if (IsOp(call, "nn.dense")) {
decl_stream << "dnnl_dense";
args = Dense(call);
ret = generate_body(call, "dnnl_dense", Dense(call), {});
} else if (IsOp(call, "nn.relu")) {
decl_stream << "dnnl_relu";
args = Relu(call);
ret = generate_body(call, "dnnl_relu", Relu(call), {});
} else if (IsOp(call, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
args = BatchNorm(call);
ret = generate_body(call, "dnnl_bn", BatchNorm(call), {});
} else if (IsOp(call, "add")) {
decl_stream << "dnnl_add";
args = Add(call);
ret = generate_body(call, "dnnl_add", Add(call), {});
} else {
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
}

// Make function call with input buffers when visiting arguments
bool first = true;
decl_stream << "(";
for (size_t i = 0; i < call->args.size(); ++i) {
VisitExpr(call->args[i]);
for (auto out : out_) {
if (!first) {
decl_stream << ", ";
}
first = false;
decl_stream << out.first;
}
}

// Analyze the output buffer
auto type_node = call->checked_type().as<TensorTypeNode>();
CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32))
<< "Only support single output tensor with float type";
std::string out = "buf_" + std::to_string(buf_idx_++);
auto out_shape = GetShape(call->checked_type());
int out_size = 1;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_size *= out_shape[i];
}
this->PrintIndents();
buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
buf_decl_.push_back(buf_stream.str());
decl_stream << ", " << out;

// Attach attribute arguments
for (size_t i = 0; i < args.size(); ++i) {
decl_stream << ", " << args[i];
}
decl_stream << ");";
ext_func_body.push_back(decl_stream.str());
buf_decl_.push_back(ret.buf);
ext_func_body.push_back(ret.decl);

// Update output buffer
out_.clear();
out_.push_back({out, out_size});
out_.push_back({ret.out, ret.out_size});
}

std::string JIT(void) {
ext_func_args_.insert(ext_func_args_.end(), ext_fused_func_args_.begin(),
ext_fused_func_args_.end());
return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_);
}

private:
const CallNode* DetectFusedConv2DBiasReLU(const CallNode* call) {
if (!IsOp(call, "nn.relu")) return nullptr;
auto relu_arg = call->args[0];
const CallNode* add_call = relu_arg.as<CallNode>();
if (!add_call || !IsOp(add_call, "add")) return nullptr;
auto add_arg = add_call->args[0];
const CallNode* conv_call = add_arg.as<CallNode>();
if (!conv_call || !IsOp(conv_call, "nn.conv2d")) return nullptr;
auto bias_name = "dnnl_fused_input" + std::to_string(ext_fused_func_args_.size());
ext_fused_func_args_.push_back(bias_name);
return conv_call;
}

std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> args;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
Expand All @@ -152,6 +180,10 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
return args;
}

std::vector<std::string> FusedConv2dBiasReLU(const CallNode* call) {
return Conv2d(call);
}

std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
Expand Down Expand Up @@ -214,6 +246,7 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
int buf_idx_{0};
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
std::vector<std::string> ext_func_args_;
std::vector<std::string> ext_fused_func_args_;
/*! \brief statement of the function that will be compiled using DNNL kernels. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
Expand Down
63 changes: 45 additions & 18 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
}

extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_,
int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
Expand All @@ -65,32 +65,26 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
memory::dims conv2d_bias_tz = {p_O_};
memory::dims conv2d_dst_tz = {p_N_, p_O_,
(p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
(p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
memory::dims conv2d_strides = {p_Sh_, p_Sw_};
memory::dims conv2d_padding = {p_Ph_, p_Pw_};

std::vector<float> conv2d_bias(p_O_, 0);

auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
auto user_weights_memory =
memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);

auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);

auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct,
conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md,
conv2d_strides, conv2d_padding, conv2d_padding);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);

auto conv2d_src_memory = user_src_memory;
auto conv2d_weights_memory = user_weights_memory;
Expand All @@ -105,6 +99,39 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
read_from_dnnl_memory(out, conv2d_dst_memory);
}

extern "C" void dnnl_conv2d(float* data, float* weights, float* out,
int p_N_, int p_C_, int p_H_, int p_W_,
int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out,
p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
attr);
}

primitive_attr create_attr_with_relu_post_op() {
post_ops ops;
ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);

primitive_attr attr;
attr.set_post_ops(ops);

return attr;
}

extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
int p_Sh_, int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out,
p_N_, p_C_, p_H_, p_W_,
p_O_, p_G_, p_Ph_, p_Pw_,
p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op());
}

extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
int p_I_, int p_O_) {
using tag = memory::format_tag;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_,
int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);

extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
float* out, int p_N_, int p_C_, int p_H_,
int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_,
int p_Sw_);

extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_);

Expand Down
Loading