diff --git a/python/tvm/relax/op/transform.py b/python/tvm/relax/op/transform.py index f518c27b7736..f2554e55d101 100644 --- a/python/tvm/relax/op/transform.py +++ b/python/tvm/relax/op/transform.py @@ -262,7 +262,8 @@ def cast(data: Expr, dtype: Union[str, tvm.DataType]) -> Expr: def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr: - """Cast input tensor which is model param to data type. + """Cast input tensor which is model param to data type if the dtype of the input data is not + the same as the given dtype. Parameters ---------- @@ -278,6 +279,8 @@ def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr: The casted result. """ assert isinstance(data, relax.Constant) + if data.data.dtype == dtype: + return data if isinstance(dtype, str): dtype = tvm.DataType(dtype) return _ffi_api.wrap_param(data, dtype)