Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Module <class 'brevitas.proxy.float_runtime_quant.ActFloatQuantProxyFromInjector'> not supported for export #1091

Open
1 of 3 tasks
jcollyer-turing opened this issue Nov 13, 2024 · 19 comments
Labels
bug Something isn't working

Comments

@jcollyer-turing
Copy link

Describe the bug

Attempting to save PTQ TorchVision models using the ptq_benchmark_torchvision.py script after amending the script to save the model using export_torch_qcdq as a final step.

The traceback is below:

['Traceback (most recent call last):\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/ptq_benchmark_torchvision.py", line 343, in ptq_torchvision_models\n    export_torch_qcdq(quant_model.to(\'cpu\'), torch.randn(1, 3, 244, 244).to(\'cpu\'), export_path = f"{folder}/{uuid}")\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/__init__.py", line 29, in export_torch_qcdq\n    return TorchQCDQManager.export(*args, **kwargs)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/torch/qcdq/manager.py", line 56, in export\n    traced_module = cls.jit_inference_trace(module, args, export_path)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 209, in jit_inference_trace\n    module.apply(cls.set_export_handler)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply\n    module.apply(fn)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply\n    module.apply(fn)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 896, in apply\n    fn(self)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/torch/qcdq/manager.py", line 39, in set_export_handler\n    _set_proxy_export_handler(cls, module)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 127, in _set_proxy_export_handler\n    _set_export_handler(manager_cls, module, QuantProxyProtocol, no_inheritance=True)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 115, in _set_export_handler\n    raise RuntimeError(f"Module {module.__class__} not supported for export.")\n', "RuntimeError: Module <class 'brevitas.proxy.float_runtime_quant.ActFloatQuantProxyFromInjector'> not supported for export.\n"]

Reproducibility

  • Can be reproduced consistently.
  • Difficult to reproduce.
  • Unable to reproduce.

To Reproduce

Steps to reproduce the behavior. For example:

  1. Add export_torch_qcdq() and associated args to the end of the ptq_benchmark_torchvision.py
  2. Execute the following command (with calibration and validation imagenet sets):
python ptq_benchmark_torchvision.py 0 --calibration-dir <path-to-calib> --validation-dir <path-to-val> \
--quant_format float \
--scale_factor_type float_scale \
--weight_bit_width 2 3 4 5 6 7 8 \
--act_bit_width 2 3 4 5 6 7 8 \
--weight_mantissa_bit_width 1 2 3 4 5 6 \
--weight_exponent_bit_width 1 2 3 4 5 6 \
--act_mantissa_bit_width 1 2 3 4 5 6 \
--act_exponent_bit_width 1 2 3 4 5 6 \
--bias_bit_width None \
--weight_quant_granularity per_channel per_tensor \
--act_quant_type sym \
--weight_param_method stats \
--act_param_method mse \
--bias_corr True \
--graph_eq_iterations 20 \
--graph_eq_merge_bias True \
--act_equalization layerwise \
--learned_round False \
--gptq False \
--gpxq_act_order False \
--gpfq False \
--gpfq_p None \
--gpfa2q False \
--accumulator_bit_width None \
--uint_sym_act_for_unsigned_values False \
--act_quant_percentile None

Expected behavior
The model should be saved.

please complete the following information:

If known:

  • Brevitas version: 0.11.0
  • PyTorch version: 2.4.1
  • Operating System / platform: MacOS M2 (running using CPU not MPS)

Additional context
I have tired torch.save() natively and it doesn't work either.

@jcollyer-turing jcollyer-turing added the bug Something isn't working label Nov 13, 2024
@Giuseppe5
Copy link
Collaborator

Hello,
Thanks for pointing this out.
At the moment we only support ONNX for FP8 export.

Unfortunately, torch quant/dequant op that we normally use to map quantization (see https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) don't support minifloat/fp8 quantization.
We could have a work around with custom ops but that requires torch 2.0+ to be done properly, and we will probably move in that direction in the future versions of Brevitas once we deprecate older pytorch versions.

With respect to ONNX, we can only export FP8 and not lower bit-width for similar reasons, meaning that ONNX only supports a few types of fp8 but it doesn't allow to define a custom minifloat data type with arbitrary mantissa and exponent bit-width (or other configurations).

If you share with us what is your goal with this export flow, maybe we can guide you towards a custom solution while we build a generic export flow.

@jcollyer-turing
Copy link
Author

Hey @Giuseppe5 - Thank you for the speedy reply.

That makes a lot of sense and sounds like something that is a bit tricky to do generically given the complexity of the different ops to support it correctly.

I am running a series of experiments to look at the effect quantisation has on adversarial robustness and want to quantise the models on a HPC style environment, save them somewhere sensible and then evaluate the robustness downstream. It doesn't have to be a valid torch/onnx model to load into another library necessarily. I tried to use dill which I have had success with in the past but it too seems to have it's own issues (Struggling with the multi-inhertiance/mixins).

I'd appreciate any ideas you may have to save the models to be loaded again later.

@Giuseppe5
Copy link
Collaborator

We recently merged in dev the possibility to export minifloat to QONNX (QONNX ref and QONNX Minifloat ref).

This representation allows us to represent explicitly all the various minifloat configuration that Brevitas can simulate.
With respect to QONNX, it generally provides reference implementations of its kernels, similar to what happens in ONNX, allowing you to execute your ONNX graph for numerical correctness. However, I think for the moment QONNX floatquant is only an interface and they're working on the reference implementation which means that you can't consider this a "valid" ONNX for the moment. I will let @maltanar comment on this in case I'm saying something that is not correct or precise.

Would this work for you? Do you need a torch-based export for this task or do you just need to represent (even with custom ops) the computational graph?

@jcollyer-turing
Copy link
Author

This sounds like something that could be very useful for my use case but does not cover everything. Ideally, it would be torch based for easy integration into torch based adversarial attack libraries (especially the attacks that require gradient information).

That being said, the features of qonnx and the ability to calculate inference statistics is actually another interesting set of data to collect. If this is the easiest way for me to save the model and conduct down stream inference, I can definitely work with this!

@Giuseppe5
Copy link
Collaborator

Another option is to just save the state dict with model.state_dict() (all quantization parameters would be saved as well), and then re-generate and quantize the model downstream, and then re-load the checkpoint. This would require you to carry Brevitas as dependency downstream as well.

Talking with @nickfraser, we might try a few ideas on how to make Brevitas compatible with serialization. It should be a quick experiment, and I'll keep the issue updated if it works out so that you can replicate it while we go through all the PR process.

@jcollyer-turing
Copy link
Author

That sounds wonderful. Thank you!

Re. carrying brevitas as a dependency, that is definitely an option and something I could do but introduces a different challenge (sorry!) - brevitas pins numpy<=1.26.4 which clashes with other req's elsewhere in my project.

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Nov 15, 2024

To be fair, that pin might be outdated.

When numpy 2.0 was released, a lot of things broke and our decision was to wait until it stabilized a bit before we tried to unpin it. We don't use any specific numpy functions that are no longer available in 2.0 or things like that. I will open a PR with the unpinned version to see how many things break.

Fingers crossed : #1093

@jcollyer-turing
Copy link
Author

Sounds great! Thank you 👍

@Giuseppe5
Copy link
Collaborator

I think there might be a reasonable optimism that Brevitas doesn't have anymore any hard-requirement on numpy.

Having said that, I notice that torch keeps installing numpy 1.26 even if I don't specify any particular version.

I tried manually upgrading it (with torch 2.4) and everything seems to work fine. With an older version of torch there seems to be conflicts.

Let me know what you see from your side. I will still look into serialization but this could be the fastest way to get there IMHO.

@jcollyer-turing
Copy link
Author

What is the best branch to install/setup to test this? and any luck with the serialisation experimentation?

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Nov 18, 2024

What is the best branch to install/setup to test this?

After install brevitas, just update numpy to whatever version you need and everything should work. Be sure to have
onnxruntime updated as well to the latest version.

and any luck with the serialisation experimentation?

I opened a draft PR with an example #1096

It seems to work locally but there could be side effects when using the pickled model downstream.
We always assume that the quant_injector is present, so we use it even though we could store the value at init time and then discard the quant_injector completely.

I believe all the issues can be fixed relatively easily.
Basically what could happen downstream is that you're trying to access an attribute in the quant_injector and it's not there. If that's the case, the solution is to modify the proxy to store that attribute at init time before you generate the pickle and then you're good to go. If you face any such issues, feel free to report here and I can add fixes in the draft PR.

Wouldn't you need to carry on Brevitas downstream also with pickle?

@jcollyer-turing
Copy link
Author

I would need to carry brevitas as a dependency if it was pickle but I would be able to save the model in a torch compatible way, allowing me to get gradient info which I think if I go to qonnx, that wouldn't be possible.

I will pull the draft PR version and have an experiment this week and report back! (Probably close of play Thursday!)

@jcollyer-turing
Copy link
Author

jcollyer-turing commented Nov 29, 2024

Hey @Giuseppe5 - Sorry for being slow to get to this. I have looked at the linked PR and the code example used brevitas.nn but does not show the unloading/saving/loading of the injector. After I apply the quantisation using the imagenet helper script I end up with a model which is <class 'torchvision.models.mobilenetv2.MobileNetV2'>. Can I use the same method as describe in the PR but removing the injectors?

Edit:

I've worked out that the layers get replaced with quant versions and can now see the things the injector adds like ActQuantProxyFromInjector. Is there a uniform way for me to drop the injector itself?

@Giuseppe5
Copy link
Collaborator

The unloading/saving/loading of the injectors happens within the context manager I created in that PR.

When you enter the context manager, the injectors are temporarily detached, allowing you to serialize your model (the call to torch.save in the context manager would fail otherwise).

After you exit the context manager, the injectors are re-attached. The serialized model won't have them, and this might cause some issue when re-loading the model with torch.load, but I believe we can address any bug that comes up because of that.

To sum it up, after you generate a model with the script you mentioned, enter the context manager and save it, and then try to re-load the model and use it as you would normally. If you see bugs, let me know and I'll advise on how to proceed/update the PR.

Does this answer your question? Unfortunately I am not quite sure I understand the part after the EDIT, so I hope this is sufficient to get you unblocked.

@jcollyer-turing
Copy link
Author

Thanks for the speedy reply @Giuseppe5. I now understand so thank you for the response.

I am having some issues setting up the draft PR. I have pulled it locally and have tried several different methods of installing (uv and vanilla pip) both locally and from git. Every time I do I get the following error:

WARNING setuptools_scm.pyproject_reading toml section missing 'pyproject.toml does not contain a tool.setuptools_scm section'
      Traceback (most recent call last):
        File "/private/var/folders/3b/jc4f12pn5p3djgnj25yqdbb80000gr/T/pip-build-env-ysck1g_2/normal/lib/python3.9/site-packages/setuptools_scm/_integration/pyproject_reading.py", line 36, in read_pyproject
          section = defn.get("tool", {})[tool_name]
      KeyError: 'setuptools_scm'

Any ideas?

@nickfraser
Copy link
Collaborator

nickfraser commented Dec 9, 2024

setuptools_scm issues are often caused by either of the following:

  • The setuptools version being used is too new (likely in this case)
  • setuptools_scm is unable to find a version tag running git fetch --tags will hopefully solve it (run git tag --list to see that you've retrieved the v0.11.0 tag)

@jcollyer-turing
Copy link
Author

Thank you @nickfraser! I managed to get it setup with starting a fresh and installing the requirements.txt before then installing the package using setup.py!

@Giuseppe5 - I have managed to successfully get the toy example in #1096 (with a few mods - see below!) working for both saving and loading.

Saving

import brevitas.nn as qnn
from brevitas.export.inference.manager import quant_inference_mode
import torch 

model =  qnn.QuantLinear(3, 8)
# This is needed to supress an error around cannot 
# save when training
model.eval()

with quant_inference_mode(model , delete_injector=True):
    b = model(torch.randn(1,3))
    # Amended this to save the state_dict rather than model itself
    torch.save(model.state_dict(), "test_dict.pickle")

Loading

import brevitas.nn as qnn
from brevitas.export.inference.manager import quant_inference_mode
import torch 

model =  qnn.QuantLinear(3, 8)
model.eval()
model.load_state_dict(torch.load(open("test_dict.pickle", "rb")))

with quant_inference_mode(model):
    b = model(torch.randn(1,3))
    print(b)

I am time limited this week but I will try to have a go with a more complex model (such as a TorchVision ImageNet ResNet) and then report back findings!

Thank you for the changes and tips to date. I really appreciate it!

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Dec 9, 2024

Amended this to save the state_dict rather than model itself

Just one comment, I believe that if you're only storing the state dict, you don't need the context manager + delete injector. That part is needed only if you try to serialize (i.e., torch.save) the entire model.

model =  qnn.QuantLinear(3, 8)
# This is needed to supress an error around cannot 
# save when training
model.eval()

b = model(torch.randn(1,3))
torch.save(model.state_dict(), "test_dict.pickle")

In any case, feel free to experiment and keep us posted with updates/issues :)

@Giuseppe5
Copy link
Collaborator

Also, it is recommended always to do at least one forward pass with your quantized model before saving/exporting, because there are some quant parameters that requires a forward pass for initialization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants