Skip to content

Commit

Permalink
support sigmoid fusion (only fp32 accum for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent 3705bbd commit 0489d14
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 15 deletions.
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def handle_conv2d(
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.conv2d_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down
8 changes: 5 additions & 3 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def create_conv2d_operator(
op_entry["runtime"] = 9999999

# fused ops
for epilogue, opdef in zip(
for epilogue, opdef, no_bias_scaling in zip(
[
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
],
["opdef_bias", "opdef_bias_relu"],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand All @@ -108,7 +110,7 @@ def create_conv2d_operator(
swizzling_functor_,
)

op_entry[opdef] = kernel_emitter.emit(op)
op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling)

ret.append(op_entry)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationRelu = enum_auto()
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()


EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu",
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
}


Expand Down
12 changes: 10 additions & 2 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def make_conv2d_pattern(with_bias=False, with_act=None, out_dtype="float16"):
else:
conv2d_out = conv2d

if with_act is not None and with_act == "relu":
return is_op("nn.relu")(conv2d_out)
if with_act is not None:
if with_act == "relu":
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)

return conv2d_out

Expand Down Expand Up @@ -149,6 +152,11 @@ def partition_for_cutlass(mod):
make_conv2d_pattern(with_bias=True, with_act="relu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_sigmoid",
make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
check_conv2d,
),
("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
Expand Down
23 changes: 18 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu";
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
Expand Down Expand Up @@ -317,8 +319,11 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");

if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n");
}
CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
Expand All @@ -338,7 +343,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
if (has_bias) {
if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
Expand Down Expand Up @@ -493,6 +498,13 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_sigmoid") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -540,7 +552,8 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu") {
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down
19 changes: 14 additions & 5 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"):
return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))


def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16"):
return relay.sigmoid(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))


def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
Expand Down Expand Up @@ -423,12 +427,17 @@ def test_conv2d_fusion():
w_shape = (32, 16, 3, 3)
padding = (1, 1)

mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding)
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
# mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding)
# verify_conv2d(
# mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
# )

# mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding)
# verify_conv2d(
# mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
# )

mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding)
mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float32")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)
Expand Down

0 comments on commit 0489d14

Please sign in to comment.