diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index c226da5864fc..79b63a8a9147 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -263,6 +263,9 @@ Str2StrMap Conv2dArgs(const Map& attrs) { std::string Conv2dOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { + bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" || + attrs.at("op_type") == "cutlass.conv2d_bias_relu"; + std::ostringstream conv2d_decl; CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); @@ -307,6 +310,11 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, ICHECK(func_args.size() >= 2); CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); @@ -322,9 +330,17 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); + if (has_bias) { + CutlassPrint(conv2d_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); + } CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); - CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); - CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + if (has_bias) { + CutlassPrint(conv2d_decl, " {alpha},\n"); + } else { + CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + } CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); @@ -461,6 +477,20 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi const auto* conv2d_call = GetRootCall(callee->body.as(), 0, {"nn.conv2d"}); return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 1, {"nn.conv2d", add_or_bias_add}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias_relu") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; @@ -507,7 +537,8 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); - } else if (func_name == "cutlass_conv2d") { + } else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" || + func_name == "cutlass_conv2d_bias_relu") { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); }