-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Output tensor type is lost after serializing and loading back a quantized model #39690
Comments
It's not related to quantization actually, and more of a property of JIT/TorchScript. Tracing records shapes/types in the IR, but they are not serialized (and generally not enforced for execution). So on loading you're seeing a generic Tensor type. If you have to rely on shapes/types, I suggest doing conversion right after the tracing. That's e.g. what onnx export does. Alternatively you can write some analysis pass, the rules would be pretty simple, starting from quantize_xxx() or quantized::xxx ops. cc @jamesr66a |
Yes this is what I've been doing. The need for serialization arises when the calibration step is heavy. Ideally we would like to calibrate -> quantize -> trace -> serialize once, rather than doing calibration on an in memory torch module every time we want to do the conversion. I can see dropping shape information on serialization makes sense, but shouldn't dtype be preserved? The output type of
I'm ok with writing a type propagation pass myself, but it would be great if Torch could add this pass as one of its jit passes, since this could be useful to other people. |
I am going to remove the quant label and add the jit/script one, as it seems to this is a larger scope issue |
Is there any way to restore the shapes/types in the IR after torch.jit.load? Maybe something like tracing the loaded ScriptModule again? |
There is |
I also need this utility. |
Hello! I've been seeing some odd behaviour of quantized PyTorch models and found this Issue - would you mind confirming whether the problem I've ran into is the same as reported in this Issue? I have a scripted quantized PyTorch model that I have previously saved and then loaded back. When I put it through TVM, I get lots of warning like
When I look at the generated Relay graph, it has lots of questionable
If I interpret it correctly (please let me know if I don't!), |
No this is expected. Weight tensors are unpacked to float32 and quantized to int8. "that tensor should already be in int8" is not true in pytorch (maybe it is for tflite). Please ask TVM specific questions to https://discuss.tvm.apache.org/ I'm happy to help PyTorch quantization issues. |
@masahi Sorry, I didn't realize it is a PyTorch repo, I should have asked in some TVM space. Thank you for your clarification, I didn't realize before that the weights of a quantized model are kept as floats and left for the compiler to turn into int8. I'm glad that the behaviour is as expected! |
Cross posting from https://discuss.pytorch.org/t/output-tensor-type-is-lost-after-serializing-and-loading-back-a-quantized-model/84700
@jerryzh168 @raghuramank100
It seems that after I serialize and load back a quantized model, the output type of quantized operators,
QUInt8
, is lost and instead it is replaced by floatTensor
type. See below for a module with a single quantized conv layer.Before
torch.jit.save
After
torch.jit.load
The PyTorch frontend in TVM uses this tensor type information to decide if a torch op is invoked on a quantized tensor. See for example the case of converting adaptive avg pooling, which requires special care for quantized case, but in the Torch IR the same op
aten::adaptive_avg_pool2d
appears for both float and quantized input.https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/frontend/pytorch.py#L600-L601
Without correct typing, we cannot convert serialized quantized PyTorch models. What happens right now is since Torch tells TVM that input tensor is float type, TVM incorrectly converts some quantized ops into float ops.
A repro script, tested on v1.5
cc @suo @gmagogsfm @jerryzh168 @jianyuh @dzhulgakov @raghuramank100 @jamesr66a
The text was updated successfully, but these errors were encountered: