diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 62c3da973e9a..8b5846217021 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -113,7 +113,10 @@ def concatenate(data, # Find the dtype of the input expr. This is required for the requantize op. Since, this is # concatenate op, the dtype of the input is same as dtype of the output. - data0 = relay.transform.infer_type(data[0]) + mod = relay.Module.from_expr(data[0]) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + data0 = entry if isinstance(data[0], relay.Function) else entry.body in_dtype = data0.checked_type.dtype # First check if all the input qnn params match. If yes, we can call concatenate first, followed