Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanmukh committed Oct 14, 2020
1 parent 34ba321 commit b30569a
Showing 1 changed file with 32 additions and 64 deletions.
96 changes: 32 additions & 64 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def register_tensorrt_annotations(trt_version, use_implicit_batch=True):
# _register_external_op_helper("split")
# _register_external_op_helper("slice_like")

# @tvm.ir.register_op_attr("add", "target.tensorrt")
def add_whitelist_fn(attrs, args): # pylint: disable=unused-variable

if any([x.checked_type.dtype != "float32" for x in args]):
Expand All @@ -218,15 +217,12 @@ def add_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False

# Skip this add op in TRT to avoid accuracy mismatch
if all([a == b for a, b in zip(args[0].checked_type.shape, [1, 546, 1, 1])]) and all(
[a == b for a, b in zip(args[1].checked_type.shape, [1, 546, 1, 1])]
):
if all([list(map(int, arg.checked_type.shape)) == [1, 546, 1, 1] for arg in args]):
print("add: bug in TRT with add of shape (1, 546, 1, 1).")
return False

return True

# @tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt")
def batch_norm_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -236,7 +232,6 @@ def batch_norm_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.softmax", "target.tensorrt")
def softmax_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -246,7 +241,6 @@ def softmax_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt")
def conv2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -262,12 +256,6 @@ def conv2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

_register_external_dynamic_check_func("add", add_whitelist_fn)
_register_external_dynamic_check_func("nn.batch_norm", batch_norm_whitelist_fn)
_register_external_dynamic_check_func("nn.softmax", softmax_whitelist_fn)
_register_external_dynamic_check_func("nn.conv2d", conv2d_whitelist_fn)

# @tvm.ir.register_op_attr("nn.dense", "target.tensorrt")
def dense_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -282,7 +270,6 @@ def dense_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt")
def bias_add_whitelist_fn(attrs, args): # pylint: disable=unused-variable
# TODO(trevmorr): BiasAddSimplifier creates a pattern which cannot be
# converted to TRT without binding params and constant folding.
Expand All @@ -297,7 +284,6 @@ def bias_add_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt")
def max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -310,7 +296,6 @@ def max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt")
def avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -336,7 +321,6 @@ def avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt")
def global_max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -346,7 +330,6 @@ def global_max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-vari
return False
return True

# @tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt")
def global_avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -356,7 +339,6 @@ def global_avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-vari
return False
return True

# @tvm.ir.register_op_attr("expand_dims", "target.tensorrt")
def expand_dims_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -366,7 +348,6 @@ def expand_dims_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("squeeze", "target.tensorrt")
def squeeze_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -379,7 +360,6 @@ def squeeze_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("concatenate", "target.tensorrt")
def concatenate_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -396,7 +376,6 @@ def concatenate_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt")
def conv2d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -423,7 +402,6 @@ def conv2d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variab
return False
return True

# @tvm.ir.register_op_attr("transpose", "target.tensorrt")
def transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -433,7 +411,6 @@ def transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("reshape", "target.tensorrt")
def reshape_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if args[0].checked_type.dtype != "float32":
print("Only float32 inputs are supported for TensorRT.")
Expand Down Expand Up @@ -463,7 +440,6 @@ def reshape_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.pad", "target.tensorrt")
def pad_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -476,20 +452,6 @@ def pad_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

_register_external_dynamic_check_func("nn.dense", dense_whitelist_fn)
_register_external_dynamic_check_func("nn.bias_add", bias_add_whitelist_fn)
_register_external_dynamic_check_func("nn.max_pool2d", max_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.avg_pool2d", avg_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.global_max_pool2d", global_max_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.global_avg_pool2d", global_avg_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("expand_dims", expand_dims_whitelist_fn)
_register_external_dynamic_check_func("squeeze", squeeze_whitelist_fn)
_register_external_dynamic_check_func("concatenate", concatenate_whitelist_fn)
_register_external_dynamic_check_func("nn.conv2d_transpose", conv2d_transpose_whitelist_fn)
_register_external_dynamic_check_func("transpose", transpose_whitelist_fn)
_register_external_dynamic_check_func("reshape", reshape_whitelist_fn)
_register_external_dynamic_check_func("nn.pad", pad_whitelist_fn)

def reduce_whitelist_fn(attrs, args, op_name, trt_version):
if not attrs.axis or len(attrs.axis) == 0:
print("{}: cannot reduce to scalar.".format(op_name))
Expand All @@ -502,25 +464,23 @@ def reduce_whitelist_fn(attrs, args, op_name, trt_version):
return False
return True

_register_external_op_helper_func("sum", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("prod", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("max", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("min", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("mean", reduce_whitelist_fn, trt_version)

def trt_5_1_5_whitelist_fn(attrs, args, op_name, trt_version):
if trt_version < (5, 1, 5):
print("{}: requires TensorRT version 5.1.5 or higher.".format(op_name))
return False
return True

_register_external_op_helper_func("sum", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("prod", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("max", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("min", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("mean", reduce_whitelist_fn, trt_version)
_register_external_op_helper_func("nn.leaky_relu", trt_5_1_5_whitelist_fn, trt_version)
_register_external_op_helper_func("sin", trt_5_1_5_whitelist_fn, trt_version)
_register_external_op_helper_func("cos", trt_5_1_5_whitelist_fn, trt_version)
_register_external_op_helper_func("atan", trt_5_1_5_whitelist_fn, trt_version)
_register_external_op_helper_func("ceil", trt_5_1_5_whitelist_fn, trt_version)

# @tvm.ir.register_op_attr("strided_slice", "target.tensorrt")
def strided_slice_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -546,12 +506,10 @@ def strided_slice_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("image.resize", "target.tensorrt")
def resize_whitelist_fn(attrs, args): # pylint: disable=unused-variable
# TODO(trevmorr): Output does not match TVM. Disable.
return False

# @tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt")
def adapative_max_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -561,7 +519,6 @@ def adapative_max_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-va
return False
return True

# @tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt")
def adapative_avg_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -571,22 +528,10 @@ def adapative_avg_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-va
return False
return True

# @tvm.ir.register_op_attr("nn.upsampling", "target.tensorrt")
def upsampling_whitelist_fn(attrs, args): # pylint: disable=unused-variable
# TODO(trevmorr): Output does not match TVM. Disable.
return False

_register_external_dynamic_check_func("strided_slice", strided_slice_whitelist_fn)
_register_external_dynamic_check_func("image.resize", resize_whitelist_fn)
_register_external_dynamic_check_func(
"nn.adaptive_max_pool2d", adapative_max_pool2d_whitelist_fn
)
_register_external_dynamic_check_func(
"nn.adaptive_avg_pool2d", adapative_avg_pool2d_whitelist_fn
)
_register_external_dynamic_check_func("nn.upsampling", upsampling_whitelist_fn)

# @tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt")
def conv3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -605,7 +550,6 @@ def conv3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt")
def max_pool_3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -618,7 +562,6 @@ def max_pool_3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt")
def avg_pool_3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand All @@ -631,7 +574,6 @@ def avg_pool_3d_whitelist_fn(attrs, args): # pylint: disable=unused-variable
return False
return True

# @tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt")
def conv3d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
Expand Down Expand Up @@ -666,6 +608,32 @@ def conv3d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variab
return False
return True

_register_external_dynamic_check_func("add", add_whitelist_fn)
_register_external_dynamic_check_func("nn.batch_norm", batch_norm_whitelist_fn)
_register_external_dynamic_check_func("nn.softmax", softmax_whitelist_fn)
_register_external_dynamic_check_func("nn.conv2d", conv2d_whitelist_fn)
_register_external_dynamic_check_func("nn.dense", dense_whitelist_fn)
_register_external_dynamic_check_func("nn.bias_add", bias_add_whitelist_fn)
_register_external_dynamic_check_func("nn.max_pool2d", max_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.avg_pool2d", avg_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.global_max_pool2d", global_max_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("nn.global_avg_pool2d", global_avg_pool_2d_whitelist_fn)
_register_external_dynamic_check_func("expand_dims", expand_dims_whitelist_fn)
_register_external_dynamic_check_func("squeeze", squeeze_whitelist_fn)
_register_external_dynamic_check_func("concatenate", concatenate_whitelist_fn)
_register_external_dynamic_check_func("nn.conv2d_transpose", conv2d_transpose_whitelist_fn)
_register_external_dynamic_check_func("transpose", transpose_whitelist_fn)
_register_external_dynamic_check_func("reshape", reshape_whitelist_fn)
_register_external_dynamic_check_func("nn.pad", pad_whitelist_fn)
_register_external_dynamic_check_func("strided_slice", strided_slice_whitelist_fn)
_register_external_dynamic_check_func("image.resize", resize_whitelist_fn)
_register_external_dynamic_check_func(
"nn.adaptive_max_pool2d", adapative_max_pool2d_whitelist_fn
)
_register_external_dynamic_check_func(
"nn.adaptive_avg_pool2d", adapative_avg_pool2d_whitelist_fn
)
_register_external_dynamic_check_func("nn.upsampling", upsampling_whitelist_fn)
_register_external_dynamic_check_func("nn.conv3d", conv3d_whitelist_fn)
_register_external_dynamic_check_func("nn.max_pool3d", max_pool_3d_whitelist_fn)
_register_external_dynamic_check_func("nn.avg_pool3d", avg_pool_3d_whitelist_fn)
Expand Down

0 comments on commit b30569a

Please sign in to comment.