Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Output tensor type is lost after serializing and loading back a quantized model #39690

Open
masahi opened this issue Jun 9, 2020 · 9 comments
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@masahi
Copy link

masahi commented Jun 9, 2020

Cross posting from https://discuss.pytorch.org/t/output-tensor-type-is-lost-after-serializing-and-loading-back-a-quantized-model/84700

@jerryzh168 @raghuramank100

It seems that after I serialize and load back a quantized model, the output type of quantized operators, QUInt8, is lost and instead it is replaced by float Tensor type. See below for a module with a single quantized conv layer.

Before torch.jit.save

graph(%self.1 : __torch__.AnnotatedConvModel,
      %X : Float(2, 3, 10, 10)):
  ...
  %input : QUInt8(2, 3, 10, 10) = aten::quantize_per_tensor(%X, %67, %68, %69), scope: __module.quant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
  ...
  %Xq : QUInt8(2, 3, 8, 8) = quantized::conv2d(%input, %71, %74, %77, %80, %81, %82, %83), scope: __module.conv # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
  %85 : Float(2, 3, 8, 8) = aten::dequantize(%Xq), scope: __module.dequant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
  return (%85)

After torch.jit.load

graph(%self.1 : __torch__.AnnotatedConvModel,
      %X.1 : Tensor):
  ...
  %input.1 : Tensor = aten::quantize_per_tensor(%X.1, %9, %10, %11) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
  %Xq.1 : Tensor = quantized::conv2d(%input.1, %15, %17, %18, %19, %16, %20, %21) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
  ...
  %24 : Tensor = aten::dequantize(%Xq.1) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
  return (%24)

The PyTorch frontend in TVM uses this tensor type information to decide if a torch op is invoked on a quantized tensor. See for example the case of converting adaptive avg pooling, which requires special care for quantized case, but in the Torch IR the same op aten::adaptive_avg_pool2d appears for both float and quantized input.

https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/frontend/pytorch.py#L600-L601

Without correct typing, we cannot convert serialized quantized PyTorch models. What happens right now is since Torch tells TVM that input tensor is float type, TVM incorrectly converts some quantized ops into float ops.

A repro script, tested on v1.5

import torch
from torch.quantization import QuantStub, DeQuantStub, default_qconfig


class AnnotatedConvModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvModel, self).__init__()
        self.qconfig = default_qconfig
        self.conv = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x


def quantize_model(model, inp):
    model.qconfig = default_qconfig
    torch.quantization.prepare(model, inplace=True)
    model(inp)
    torch.quantization.convert(model, inplace=True)


def test_conv():
    inp = torch.rand(2, 3, 10, 10)
    annotated_conv_model = AnnotatedConvModel()
    quantize_model(annotated_conv_model, inp)

    trace = torch.jit.trace(annotated_conv_model, inp)
    torch._C._jit_pass_inline(trace.graph)
    print(trace.graph)

    torch.jit.save(trace, "trace.pt")
    trace = torch.jit.load("trace.pt")
    print(trace.graph)


test_conv()

cc @suo @gmagogsfm @jerryzh168 @jianyuh @dzhulgakov @raghuramank100 @jamesr66a

@pbelevich pbelevich added the oncall: quantization Quantization support in PyTorch label Jun 9, 2020
@dzhulgakov
Copy link
Collaborator

It's not related to quantization actually, and more of a property of JIT/TorchScript. Tracing records shapes/types in the IR, but they are not serialized (and generally not enforced for execution). So on loading you're seeing a generic Tensor type.

If you have to rely on shapes/types, I suggest doing conversion right after the tracing. That's e.g. what onnx export does.

Alternatively you can write some analysis pass, the rules would be pretty simple, starting from quantize_xxx() or quantized::xxx ops.

cc @jamesr66a

@masahi
Copy link
Author

masahi commented Jun 9, 2020

If you have to rely on shapes/types, I suggest doing conversion right after the tracing. That's e.g. what onnx export does.

Yes this is what I've been doing. The need for serialization arises when the calibration step is heavy. Ideally we would like to calibrate -> quantize -> trace -> serialize once, rather than doing calibration on an in memory torch module every time we want to do the conversion.

I can see dropping shape information on serialization makes sense, but shouldn't dtype be preserved? The output type of aten::quantize_per_tensor should never be a float Tensor, so having %input.1 : Tensor = aten::quantize_per_tensor(...) in the IR seems strange to me. I wonder how Torch can compute the correct output with this IR (I verified that the output of a serialized, quantized resnet18 from torchvision is correct).

Alternatively you can write some analysis pass, the rules would be pretty simple, starting from quantize_xxx() or quantized::xxx ops.

I'm ok with writing a type propagation pass myself, but it would be great if Torch could add this pass as one of its jit passes, since this could be useful to other people.

@z-a-f
Copy link
Contributor

z-a-f commented Jul 8, 2020

I am going to remove the quant label and add the jit/script one, as it seems to this is a larger scope issue

@z-a-f z-a-f added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 8, 2020
@vkuzo vkuzo removed the oncall: quantization Quantization support in PyTorch label Jul 31, 2020
@kongroo
Copy link

kongroo commented Sep 3, 2020

It's not related to quantization actually, and more of a property of JIT/TorchScript. Tracing records shapes/types in the IR, but they are not serialized (and generally not enforced for execution). So on loading you're seeing a generic Tensor type.

If you have to rely on shapes/types, I suggest doing conversion right after the tracing. That's e.g. what onnx export does.

Alternatively you can write some analysis pass, the rules would be pretty simple, starting from quantize_xxx() or quantized::xxx ops.

cc @jamesr66a

Is there any way to restore the shapes/types in the IR after torch.jit.load? Maybe something like tracing the loaded ScriptModule again?

@t-vi
Copy link
Collaborator

t-vi commented Sep 29, 2020

There is torch._C._jit_pass_complete_shape_analysis(graph, inputs, with_grad:Bool) which works, but is internal.

@icyhearts
Copy link

It's not related to quantization actually, and more of a property of JIT/TorchScript. Tracing records shapes/types in the IR, but they are not serialized (and generally not enforced for execution). So on loading you're seeing a generic Tensor type.
If you have to rely on shapes/types, I suggest doing conversion right after the tracing. That's e.g. what onnx export does.
Alternatively you can write some analysis pass, the rules would be pretty simple, starting from quantize_xxx() or quantized::xxx ops.
cc @jamesr66a

Is there any way to restore the shapes/types in the IR after torch.jit.load? Maybe something like tracing the loaded ScriptModule again?

I also need this utility.

@ekalda
Copy link

ekalda commented Feb 15, 2021

Hello! I've been seeing some odd behaviour of quantized PyTorch models and found this Issue - would you mind confirming whether the problem I've ran into is the same as reported in this Issue?

I have a scripted quantized PyTorch model that I have previously saved and then loaded back. When I put it through TVM, I get lots of warning like

WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32

When I look at the generated Relay graph, it has lots of questionable qnn.quantize nodes, e.g. for convolution

  %1 = qnn.quantize(meta[relay.Constant][0] /* ty=Tensor[(32, 1, 3, 3), float32] */, 0.00470623f /* ty=float32 */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 1, 3, 3), int8] */;
  %2 = qnn.conv2d(%0, %1, 0 /* ty=int32 */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0.00470623f /* ty=float32 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 32, 26, 26), int32] */;

If I interpret it correctly (please let me know if I don't!), %1 is trying to turn a tensor that is in float32 into a tensor of int8, however, that tensor should already be in int8. So I suspect that this node is there because the type got erased and the frontend assumes it is float32?

@masahi
Copy link
Author

masahi commented Feb 15, 2021

No this is expected. Weight tensors are unpacked to float32 and quantized to int8. "that tensor should already be in int8" is not true in pytorch (maybe it is for tflite).

Please ask TVM specific questions to https://discuss.tvm.apache.org/ I'm happy to help PyTorch quantization issues.

@ekalda
Copy link

ekalda commented Feb 15, 2021

@masahi Sorry, I didn't realize it is a PyTorch repo, I should have asked in some TVM space. Thank you for your clarification, I didn't realize before that the weights of a quantized model are kept as floats and left for the compiler to turn into int8. I'm glad that the behaviour is as expected!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

9 participants