diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 8a886ff260b81..35308928cdabf 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -143,6 +143,22 @@ class EmitConv2dInstance: """ Responsible for emitting a CUTLASS template definition.""" def __init__(self): + self.epilogue_default = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + self.epilogue_no_beta_scaling = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >""" + self.template = """ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" using ${operation_name} = @@ -159,12 +175,7 @@ def __init__(self): cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, + ${epilogue}, ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, ${stages}, ${math_operator}, @@ -175,7 +186,7 @@ def __init__(self): >::Kernel; """ - def emit(self, operation): + def emit(self, operation, no_beta_scaling=True): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( @@ -237,4 +248,12 @@ def emit(self, operation): "align_b": str(operation.B.alignment), } - return substitute_template(self.template, values) + template = substitute_template( + self.template, + { + "epilogue": self.epilogue_no_beta_scaling + if no_beta_scaling + else self.epilogue_default + }, + ) + return substitute_template(template, values)