Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Inference fails if data is moved to GPU #25484

Open
3 tasks done
guillermoayllon opened this issue Jul 10, 2024 · 9 comments
Open
3 tasks done

[Bug]: Inference fails if data is moved to GPU #25484

guillermoayllon opened this issue Jul 10, 2024 · 9 comments
Assignees
Labels
bug Something isn't working category: GPU OpenVINO GPU plugin PSE

Comments

@guillermoayllon
Copy link

OpenVINO Version

2024.2.0

Operating System

Ubuntu 20.04 (LTS)

Device used for inference

GPU

Framework

PyTorch

Model used

VGG16

Issue description

Hiya,

My goal is to perform inference on Intel GPU with an openvino model. However, inference fails if I move the input data to GPU before performing inference.

Step-by-step reproduction

First, I convert the VGG16 model to openvino IR format in this way:

import torch
from torchvision import models
import openvino as ov
import intel_extension_for_pytorch

model = models.vgg16(weights="VGG16_Weights.DEFAULT")
model.eval()

ov_model = ov.convert_model(model, input=[128, 3, 224, 224])

Then I compile the model:

core = ov.Core()
compiled_model = core.compile_model(
        model=ov_model,
        device_name="GPU",
    )
output_layer = compiled_model.output(0)

Finally, I attempt inference in the following way:

data_loader  = img_generator_pytorch(TEST_DATASET_PATH, batch_size=128, shuffle=False)

for images, _ in data_loader:
    images = images.to(torch.device("xpu"), memory_format=torch.channels_last)
    with torch.no_grad():
        compiled_model(images)[output_layer]

To produce the dataloader, I use this function:

from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder

def img_generator_pytorch(dataset_path: str, batch_size: int, shuffle: bool, drop_remainder: bool = False):
    """
    Image Data Generator for preprocessing the images, input to feed the VGG16 model

    Args:
        dataset_path (str): The path to the dataset.
        batch_size (int): The batch size.
        shuffle (bool): Whether to shuffle the dataset.
        drop_remainder (bool, optional): Whether to drop the last batch if it's smaller than the batch size. Defaults to False.

    Returns:
        DataLoader: The data loader for the dataset to be used in training, validation, or evaluation.
    """
    data_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])  # (0.485, 0.456, 0.406) and (0.229, 0.224, 0.225) correspond to the mean and standard deviation of the ImageNet dataset

    dataset = ImageFolder(dataset_path, transform=data_transform)
    data_loader = DataLoader(dataset, 
                             batch_size, 
                             shuffle=shuffle, 
                             drop_last=drop_remainder)
                            
    data_loader.num_images = len(dataset)
 
    return data_loader

Environment:

openvino                         2024.2.0
openvino-telemetry               2024.1.0
intel_extension_for_pytorch      2.1.30.post0
torch                            2.1.0.post2+cxx11.abi
torchaudio                       2.1.0.post2+cxx11.abi
torchvision                      0.16.0.post2+cxx11.abi

Device Info: x2 Intel(R) Data Center GPU Max 1100

Relevant log output

Traceback (most recent call last):
  File "/home/cic/intel_sustainable_AI_phase2/development/issue_script.py", line 58, in <module>
    compiled_model(images)[output_layer]
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/openvino/runtime/ie_api.py", line 388, in __call__
    return self._infer_request.infer(
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/openvino/runtime/ie_api.py", line 132, in infer
    return OVDict(super().infer(_data_dispatch(
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py", line 429, in _data_dispatch
    return create_shared(inputs, request) if is_shared else create_copied(inputs, request)
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py", line 209, in create_shared
    request._inputs_data = normalize_arrays(inputs, is_shared=True)
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/openvino/runtime/utils/data_helpers/data_dispatcher.py", line 155, in normalize_arrays
    return to_c_style(np.array(inputs, copy=False), is_shared) if is_shared else np.array(inputs, copy=True)
  File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/torch/_tensor.py", line 1030, in __array__
    return self.numpy()
TypeError: can't convert xpu:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@guillermoayllon guillermoayllon added bug Something isn't working support_request labels Jul 10, 2024
@guillermoayllon
Copy link
Author

guillermoayllon commented Jul 10, 2024

Converting the tensor to host memory before inference fixes the issue, but throughput is significantly reduced. To the point where the throughput is larger with the non-optimized model.

@paguilomanas
Copy link

paguilomanas commented Jul 11, 2024

I have the same issue. Are there any examples available on how to integrate intel gpu with openvino code? Appart from the device argument in the core.compile function? Also I saw the compress_to_FP16 argument in the ov.convert_model function is deprecated. Is there an equivalent updated compression option?

@avitial avitial added the category: GPU OpenVINO GPU plugin label Jul 11, 2024
@avitial avitial removed the bug Something isn't working label Jul 11, 2024
@Iffa-Intel
Copy link

@guillermoayllon could you share

  1. Your model files for us to further investigate?
  2. Are you using a custom model?, please help to share some info regarding the model.

@guillermoayllon
Copy link
Author

Hello @Iffa-Intel, thank you for looking into the issue.

The model we are using is TorchVision's VGG16: https://pytorch-org.translate.goog/vision/main/models/generated/torchvision.models.vgg16.html?_x_tr_sl=en&_x_tr_tl=es&_x_tr_hl=es&_x_tr_pto=sc

One should be able to reproduce the issue by importing the model and using the default weights as indicated in the first code block above:

import torch
**from torchvision import models**
import openvino as ov
import intel_extension_for_pytorch

**model = models.vgg16(weights="VGG16_Weights.DEFAULT")**
model.eval()

ov_model = ov.convert_model(model, input=[128, 3, 224, 224])

@Iffa-Intel
Copy link

Iffa-Intel commented Jul 16, 2024

@guillermoayllon to get that VGG16 model (as you mentioned), intel_extension_for_pytorch is required.
May I know which version did you use? I noticed that only v1.10.200+gpu supports Ubuntu 20.04.
If you are using other versions, perhaps this is the underlying cause for the throughput issue.

@guillermoayllon
Copy link
Author

Hello @Iffa-Intel, thank you for pointing that out.

The Ubuntu version above is wrong. The version that we are using is Ubuntu 22.04.4 LTS (not 20.04).
The mentioned intel_extension_for_pytorch version is correct : 2.1.30.post0

Nevertheless, the throughput issue appears when converting the tensor to host memory before inference. But that is not ideal. We would like to keep the tensor in GPU memory.

@Iffa-Intel Iffa-Intel added the PSE label Jul 19, 2024
@avitial avitial added bug Something isn't working and removed support_request labels Oct 8, 2024
@avitial
Copy link
Contributor

avitial commented Oct 8, 2024

Ref. 154510

@p-wysocki
Copy link
Contributor

p-wysocki commented Nov 25, 2024

Hello @guillermoayllon, you can watch the current progress and issue explanation in #27725. The new feature allows creation of OpenVINO GPU Tensors directly from a pointer to a Torch GPU tensor:

image = torch.rand(128, 3, 224, 224)
image = image.to(torch.device("xpu"), memory_format=torch.channels_last)
data_ptr = image.detach().data_ptr()
ov_tensor = Tensor(data_ptr, Shape(image.shape), pt_to_ov_type_map[str(image.dtype)])

The solution should be available in the next OpenVINO release.

Before the PR gets merged, if you would like to stay with your current OpenVINO version, the only thing that can be done is to create the Torch tensor on CPU (instead of XPU) and then pass it to OpenVINO inference. It will still be performed on GPU, but that way you'll avoid an additional data copy. It's only a partial fix, since the data will still have to be copied from CPU back to GPU.

Once the PR is merged you can test the solution using our nightly releases:
https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/release-policy.html#nightly-releases

@guillermoayllon
Copy link
Author

Thank you for looking into this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working category: GPU OpenVINO GPU plugin PSE
Projects
None yet
Development

No branches or pull requests

6 participants