diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 89464face725..6906daccad00 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4493,7 +4493,7 @@ def _get_pytorch_value_type(typ, default_dtype="float32"): return "ListType" elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]: pt_dtype = str(typ).lower() - dtype = pt_dtype if pt_dtype == "OptionalType" else _convert_data_type(pt_dtype) + dtype = pt_dtype if kind == "OptionalType" else _convert_data_type(pt_dtype) return dtype else: return "UnsupportedType"