From 9f6ce7cbf9df277a062ffcedc17064612bbd89e6 Mon Sep 17 00:00:00 2001 From: liwangshengya <40720152+liwangshengya@users.noreply.github.com> Date: Thu, 30 Mar 2023 12:20:56 +0800 Subject: [PATCH] [relay][frontend][pytorch]Fix a bug in the _get_pytorch_value_type function (#14421) * Fix a bug in the _get_pytorch_value_type function * Fix lint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 89464face725..6906daccad00 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4493,7 +4493,7 @@ def _get_pytorch_value_type(typ, default_dtype="float32"): return "ListType" elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]: pt_dtype = str(typ).lower() - dtype = pt_dtype if pt_dtype == "OptionalType" else _convert_data_type(pt_dtype) + dtype = pt_dtype if kind == "OptionalType" else _convert_data_type(pt_dtype) return dtype else: return "UnsupportedType"