Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recompute_scale_factor patch without fine-tune #312

Closed
persts opened this issue Sep 9, 2022 · 5 comments
Closed

recompute_scale_factor patch without fine-tune #312

persts opened this issue Sep 9, 2022 · 5 comments
Assignees

Comments

@persts
Copy link
Contributor

persts commented Sep 9, 2022

Based on this recompute_scale_factor discussion and update that followed, would it not be possible to patch the base MD models by doing the following?

import torch

ckpt = torch.load('md_v5a.0.0.pt')
for m in ckpt['model'].modules():
    if type(m) is torch.nn.Upsample:
        m.recompute_scale_factor = None
torch.save(ckpt, './md_v5a.0.1.pt')

This may likely prevent any changes to the box precision noted in #297 due to the fine-tune hack. I can confirm the new model loads in the newest versions of PyTorch but I did not check the precision of box coordinates between the two models.

@agentmorris
Copy link
Contributor

Good idea!

Since this incompatibility appears to be fixed in PyTorch 1.12 and the latest YOLOv5 (without re-building the model at all), the main purpose of doing this export step is M1 support, which may also be fixed without re-building when PyTorch 1.13 is released. Recap from other threads: this incompatibility is fixed in the stable (1.12) build, but still exists in the nightly (1.13) build, and only the latter supports M1 inference. Unclear whether this issue will resurface when 1.13 becomes the stable build. If so, this may be a really useful step.

But for now, I did this, and confirmed that the resulting model does produce the same results as MDv5a to working precision.

But the M1 testing I did earlier this week was on a borrowed setup to which I no longer have access, so I can't actually confirm that this approach solves the problem at hand, i.e. that the resulting model (a) avoids this incompatibility with PT 1.13 and (b) allows M1-accelerated inference. If you're not bored of this issue yet :), I'd be interested in your take on this, but IMO it's also OK to pause this for now since the workarounds we have are close enough to bridge the gap until we find out what happens with the 1.13 PT release.

@persts
Copy link
Contributor Author

persts commented Sep 9, 2022

I still get the recompute_scale_factor issue with md_v5a.0.0.pt and torch 1.12.1+cu113 on Ubuntu and with torch 1.13.0.dev20220824 on the M1. The patch above fixes that.

The patched model on M1 with torch 1.13.0.dev20220824 still requires another slight modification to the PTDetector:

index 7778d1b..7774ac0 100644
--- a/detection/pytorch_detector.py
+++ b/detection/pytorch_detector.py
@@ -45,8 +45,8 @@ class PTDetector:
 
     @staticmethod
     def _load_model(model_pt_path, device):
-        checkpoint = torch.load(model_pt_path, map_location=device)
-        model = checkpoint['model'].float().fuse().eval()  # FP32 model
+        checkpoint = torch.load(model_pt_path)
+        model = checkpoint['model'].float().fuse().eval().to(device)  # FP32 model
         return model

It seems the older versions of YOLOV5 had some float64 variables. If the map_location is set in the load() you will get an error:

"TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead."

The float64 variables have apparently been removed in more recent version of YOLOV5, which is why the fine-tuned model works with the current PTDetector code. So this will not be a problem moving forward with newer models.

@agentmorris
Copy link
Contributor

Thanks... that change (from "map_location=device" to ".to(device)") looks non-controversial; I'll confirm that it has no impact on the existing model in the recommended (PT 1.10, old YOLOv5) environment, and make that change when I get a chance. Will leave this issue open until I make that change.

I still prefer not to update the "official" models until the next time we update the recommended environment, which may not be for a while, at least until the dust settles on at least the PyTorch 1.13 release and we see what incompatibilities still exist. But with the changes you've suggested and the discussions we have on these issues, I think anyone who needs to avoid the recommended environment and/or wants to do M1 inference has plenty of options.

@agentmorris agentmorris self-assigned this Sep 12, 2022
@persts
Copy link
Contributor Author

persts commented Sep 12, 2022

I think anyone who can get the MD pipeline up and running could apply this patch/approach if needed with little effort. I have never used the map_location=device parameter in torch.load(). Would be curious if to see if there is a difference in memory usage between the two approaches.

I don't think the recompute_scale_factor will ever go away for models built with PyTorch 1.10.x or lower. On Oct-20-2021 the recompute_scale_factor was exposed in the Upsample function so it looks like it is here to stay https://github.com/pytorch/pytorch/commits/master/torch/nn/modules/upsampling.py

@aa-hernandez
Copy link
Collaborator

Thank you for your valuable input. We've moved to a new codebase with the release of Pytorch-Wildlife v1.0. The issue you mentioned might not exist anymore or it might be resolved in the new codebase. We encourage you to try the new release here and share any further feedback or concerns. If the issue still persists, please let us know through this channel and we will work on it!

If you're still interested in working with the previous version, you can take a look at Dan Morris' fork. The issue may or may not be resolved in that codebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants