Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RT-DETR inference with float16 and bfloat16 #31639

Merged
merged 5 commits into from
Jun 26, 2024

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Jun 26, 2024

What does this PR do?

Fix positional embeddings and anchor data types for RT-DETR.

Running the model in float16 or bfloat16 with the following code crashes with the error:

import torch
import requests

from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

device = "cuda"
dtype = torch.float16

url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 
image = Image.open(requests.get(url, stream=True).raw)

image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd", attn_implementation="eager", torch_dtype=dtype, device_map=device)

inputs = image_processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values.to(device).to(dtype)

with torch.no_grad():
    outputs = model(pixel_values)

assert outputs.logits.dtype == dtype

results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)

for result in results:
    for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
        score, label = score.item(), label_id.item()
        box = [round(i, 2) for i in box.tolist()]
        print(f"{model.config.id2label[label]}: {score:.2f} {box}")

Error:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half

One approach to fix is to use torch.autocast(), but it could also be fixed with proper dtypes in the modeling file.

Steps:

  1. Add test that fails with float16/bfloat16 dtypes
  2. Add fix

Before submitting

  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qubvel
Copy link
Member Author

qubvel commented Jun 26, 2024

Added test failed for float16 and bfloat16 as expected:

Screenshot 2024-06-26 at 14 14 00

@qubvel
Copy link
Member Author

qubvel commented Jun 26, 2024

Tests are green with the fix

@qubvel qubvel requested a review from amyeroberts June 26, 2024 13:33
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

src/transformers/models/rt_detr/modeling_rt_detr.py Outdated Show resolved Hide resolved
src/transformers/models/rt_detr/modeling_rt_detr.py Outdated Show resolved Hide resolved
@qubvel
Copy link
Member Author

qubvel commented Jun 26, 2024

Thanks for the review!

@qubvel qubvel merged commit b1ec745 into huggingface:main Jun 26, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants