Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Exception "object is not initialized" when using OpenVINO with torch.nn.functional.max_pool1d #23528

Closed
3 tasks done
Thrsu opened this issue Mar 19, 2024 · 2 comments
Closed
3 tasks done
Assignees
Labels
bug Something isn't working category: CPU OpenVINO CPU plugin

Comments

@Thrsu
Copy link

Thrsu commented Mar 19, 2024

OpenVINO Version

2024.0.0-14509-34caeefd078-releases/2024/0

Operating System

Ubuntu 18.04 (LTS)

Device used for inference

CPU

Framework

PyTorch

Model used

Given in the following script

Issue description

When attempting to compile and execute a PyTorch model containing torch.nn.functional.max_pool1d using OpenVINO, an exception is raised with the message "object is not initialized." This occurs during the inference step when calling the compile_model.

Step-by-step reproduction

  1. Execute the following script, which reproduces the issue:
import torch
from torch.nn import Module
import openvino as ov
import numpy as np


def compile_torch(model, input_data):
    ov_model = ov.convert_model(model, example_input=input_data)
    ir_path = f"temp_OVIR.xml"
    ov.save_model(ov_model, ir_path, compress_to_fp16=False)
    core = ov.Core()
    model = core.read_model(ir_path)

    compiled_model = core.compile_model(model=model, device_name="CPU")
    output_key = compiled_model.output(0)
    result = compiled_model(input_data)[output_key]
    return result

input_data = torch.randn([1, 2, 9], dtype=torch.float32)

class max_pool1d(Module):
    def forward(self, *args):
        return torch.nn.functional.max_pool1d(args[0],16,stride=2,padding=4,dilation=1,return_indices=False,ceil_mode=True,)

torch_model = max_pool1d().float().eval()
torch_outputs = torch_model(input_data).cpu().numpy()

trace = torch.jit.trace(torch_model, input_data)
trace = torch.jit.freeze(trace)

input_shapes = input_data.shape
res_ov = compile_torch(trace, input_data)
np.testing.assert_allclose(torch_outputs, res_ov, rtol=1e-3, atol=1e-3)

Relevant log output

RuntimeError: Exception from src/inference/src/cpp/infer_request.cpp:223:
object is not initialized

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@Thrsu Thrsu added bug Something isn't working support_request labels Mar 19, 2024
@ilya-lavrenov ilya-lavrenov added the category: PyTorch FE OpenVINO PyTorch Frontend label Mar 19, 2024
@mvafin
Copy link
Contributor

mvafin commented Mar 28, 2024

@Thrsu This problem seem to be caused by CPU plugin. IR that was generated seems correct and GPU can infer it. Reassigning this to CPU team

@mvafin mvafin removed their assignment Mar 28, 2024
@mvafin mvafin added category: CPU OpenVINO CPU plugin and removed category: PyTorch FE OpenVINO PyTorch Frontend labels Mar 28, 2024
@yuxu42
Copy link
Contributor

yuxu42 commented Mar 28, 2024

@tiger100256-hu Could you please take a look? Thanks!

github-merge-queue bot pushed a commit that referenced this issue Apr 23, 2024
### Details:
- *the issue happen when src - (1 + (krn - 1) * dil) + pad_l < 0 && (src
- (1 + (krn - 1) * dil) %stride != 0* && ceil mode

### Tickets:
-
*[issues-23528](#23528

---------

Signed-off-by: HU Yuan2 <[email protected]>
@yuxu42 yuxu42 closed this as completed Apr 24, 2024
alvoron pushed a commit to alvoron/openvino that referenced this issue Apr 29, 2024
### Details:
- *the issue happen when src - (1 + (krn - 1) * dil) + pad_l < 0 && (src
- (1 + (krn - 1) * dil) %stride != 0* && ceil mode

### Tickets:
-
*[issues-23528](openvinotoolkit#23528

---------

Signed-off-by: HU Yuan2 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working category: CPU OpenVINO CPU plugin
Projects
None yet
Development

No branches or pull requests

7 participants