From 45999b2856e479080a9bc56610fcfc51fbbb4278 Mon Sep 17 00:00:00 2001 From: sweetcocoa Date: Mon, 3 Apr 2023 20:11:10 +0900 Subject: [PATCH 1/4] preserve the order of input_info --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 03663bff41a1..ee4aedb47583 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5014,6 +5014,17 @@ def from_pytorch( data_inputs.append(arg) else: func_args.append(arg) + + # Ensures that the order of data_input is the same as the order of inputs specified in input_info. + _data_inputs = [] + for input_info in input_infos: + _data_input = [di for di in data_inputs if di.name_hint == input_info[0]] + if len(_data_input) != 1: + msg = "Unexpected input name ({}) which is unreachable.".format(input_info[0]) + raise RuntimeError(msg) + _data_inputs.append(_data_input[0]) + data_inputs = _data_inputs + func_args = data_inputs + func_args mod["main"] = tvm.relay.Function(func_args, ret) From eb75592ec778f7dfb213896f5d070c89c2445a05 Mon Sep 17 00:00:00 2001 From: sweetcocoa Date: Tue, 4 Apr 2023 13:49:35 +0900 Subject: [PATCH 2/4] pylint: applied --- python/tvm/relay/frontend/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ee4aedb47583..6eafa3734f20 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5015,7 +5015,7 @@ def from_pytorch( else: func_args.append(arg) - # Ensures that the order of data_input is the same as the order of inputs specified in input_info. + # Ensures the order of data_input is the same as the order of inputs specified in input_info. _data_inputs = [] for input_info in input_infos: _data_input = [di for di in data_inputs if di.name_hint == input_info[0]] @@ -5024,7 +5024,7 @@ def from_pytorch( raise RuntimeError(msg) _data_inputs.append(_data_input[0]) data_inputs = _data_inputs - + func_args = data_inputs + func_args mod["main"] = tvm.relay.Function(func_args, ret) From 5bae5f393f6c32eae89d417ced24a658ccd96588 Mon Sep 17 00:00:00 2001 From: sweetcocoa Date: Tue, 4 Apr 2023 20:00:47 +0900 Subject: [PATCH 3/4] remove unreachable assertion --- python/tvm/relay/frontend/pytorch.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6eafa3734f20..c9356ac8bca7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5016,14 +5016,13 @@ def from_pytorch( func_args.append(arg) # Ensures the order of data_input is the same as the order of inputs specified in input_info. - _data_inputs = [] - for input_info in input_infos: - _data_input = [di for di in data_inputs if di.name_hint == input_info[0]] - if len(_data_input) != 1: - msg = "Unexpected input name ({}) which is unreachable.".format(input_info[0]) - raise RuntimeError(msg) - _data_inputs.append(_data_input[0]) - data_inputs = _data_inputs + order_input_infos = {input_info[0]: idx for idx, input_info in enumerate(input_infos)} + data_inputs = sorted( + data_inputs, + key=lambda data_input: order_input_infos[data_input.name_hint] + if data_input.name_hint in order_input_infos + else -1, + ) func_args = data_inputs + func_args From af23e320fabee10ced999141a3868a8f1d30c079 Mon Sep 17 00:00:00 2001 From: sweetcocoa Date: Wed, 5 Apr 2023 19:24:42 +0900 Subject: [PATCH 4/4] sort fix : input info first, others back --- python/tvm/relay/frontend/pytorch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c9356ac8bca7..3ffc39b2bc00 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5016,12 +5016,15 @@ def from_pytorch( func_args.append(arg) # Ensures the order of data_input is the same as the order of inputs specified in input_info. - order_input_infos = {input_info[0]: idx for idx, input_info in enumerate(input_infos)} + order_input_infos = { + input_info[0]: len(input_infos) - idx for idx, input_info in enumerate(input_infos) + } data_inputs = sorted( data_inputs, key=lambda data_input: order_input_infos[data_input.name_hint] if data_input.name_hint in order_input_infos else -1, + reverse=True, ) func_args = data_inputs + func_args