Skip to content

Commit

Permalink
adding fusion support to codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent 0de5ebd commit 274ec02
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
Expand Down Expand Up @@ -307,6 +310,11 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
Expand All @@ -322,9 +330,17 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
if (has_bias) {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_c_bias), 0},\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
if (has_bias) {
CutlassPrint(conv2d_decl, " {alpha},\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}
CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand Down Expand Up @@ -461,6 +477,20 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d"});
return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", add_or_bias_add});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_relu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -507,7 +537,8 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d") {
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu") {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down

0 comments on commit 274ec02

Please sign in to comment.