Skip to content

Commit

Permalink
[Fix][Op] Do not apply WrapParam when dtypes are the same (apache#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Nov 25, 2022
1 parent dfe9d42 commit 62b3ad4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/relax/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def cast(data: Expr, dtype: Union[str, tvm.DataType]) -> Expr:


def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr:
"""Cast input tensor which is model param to data type.
"""Cast input tensor which is model param to data type if the dtype of the input data is not
the same as the given dtype.
Parameters
----------
Expand All @@ -278,6 +279,8 @@ def wrap_param(data: Expr, dtype: Union[str, tvm.DataType] = "float32") -> Expr:
The casted result.
"""
assert isinstance(data, relax.Constant)
if data.data.dtype == dtype:
return data
if isinstance(dtype, str):
dtype = tvm.DataType(dtype)
return _ffi_api.wrap_param(data, dtype)
Expand Down

0 comments on commit 62b3ad4

Please sign in to comment.