diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 03663bff41a1..3ffc39b2bc00 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -5014,6 +5014,19 @@ def from_pytorch( data_inputs.append(arg) else: func_args.append(arg) + + # Ensures the order of data_input is the same as the order of inputs specified in input_info. + order_input_infos = { + input_info[0]: len(input_infos) - idx for idx, input_info in enumerate(input_infos) + } + data_inputs = sorted( + data_inputs, + key=lambda data_input: order_input_infos[data_input.name_hint] + if data_input.name_hint in order_input_infos + else -1, + reverse=True, + ) + func_args = data_inputs + func_args mod["main"] = tvm.relay.Function(func_args, ret)