From 62b3ad49ecfb7e0030309e3d603fb1ee6c6aec62 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 24 Nov 2022 19:36:45 -0500 Subject: [PATCH] [Fix][Op] Do not apply WrapParam when dtypes are the same (#31) --- python/tvm/relax/op/transform.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/op/transform.py b/python/tvm/relax/op/transform.py index f518c27b7736..f2554e55d101 100644 --- a/python/tvm/relax/op/transform.py +++ b/python/tvm/relax/op/transform.py @@ -262,7 +262,8 @@ def cast(data: Expr, dtype: Union[str, tvm.DataType]) -> Expr: def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr: - """Cast input tensor which is model param to data type. + """Cast input tensor which is model param to data type if the dtype of the input data is not + the same as the given dtype. Parameters ---------- @@ -278,6 +279,8 @@ def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr: The casted result. """ assert isinstance(data, relax.Constant) + if data.data.dtype == dtype: + return data if isinstance(dtype, str): dtype = tvm.DataType(dtype) return _ffi_api.wrap_param(data, dtype)