Skip to content

Commit

Permalink
use nobetascaling
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent 274ec02 commit 7cf40e7
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ class EmitConv2dInstance:
""" Responsible for emitting a CUTLASS template definition."""

def __init__(self):
self.epilogue_default = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>"""
self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue},
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -159,12 +175,7 @@ def __init__(self):
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${epilogue},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
Expand All @@ -175,7 +186,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation):
def emit(self, operation, no_beta_scaling=True):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -237,4 +248,12 @@ def emit(self, operation):
"align_b": str(operation.B.alignment),
}

return substitute_template(self.template, values)
template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
return substitute_template(template, values)

0 comments on commit 7cf40e7

Please sign in to comment.