Skip to content

Commit

Permalink
Added try for backends (#3138)
Browse files Browse the repository at this point in the history
### Changes

- Added `try-catch` blocks for `isinstance` during backend calculation.

### Reason for changes

- Corrupted frameworks may lead to exceptions even if these frameworks
were not intended to be used.

### Related tickets

- 158806

### Tests

- Manual
  • Loading branch information
KodiaqQ authored Dec 16, 2024
1 parent 3775503 commit 5a55a7d
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import importlib
from copy import deepcopy
from enum import Enum
from typing import List, TypeVar
from typing import Any, Callable, List, TypeVar

import nncf

Expand All @@ -26,6 +26,16 @@ class BackendType(Enum):
OPENVINO = "OpenVINO"


def result_verifier(func: Callable[[TModel], bool]) -> Callable[..., None]:
def verify_result(*args: Any, **kwargs: Any): # type: ignore
try:
return func(*args, **kwargs)
except AttributeError:
return False

return verify_result


def get_available_backends() -> List[BackendType]:
"""
Returns a list of available backends.
Expand All @@ -51,6 +61,7 @@ def get_available_backends() -> List[BackendType]:
return available_backends


@result_verifier
def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.
Expand All @@ -64,6 +75,7 @@ def is_torch_model(model: TModel) -> bool:
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


@result_verifier
def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.
Expand All @@ -76,6 +88,7 @@ def is_torch_fx_model(model: TModel) -> bool:
return isinstance(model, torch.fx.GraphModule)


@result_verifier
def is_tensorflow_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of tensorflow.Module, otherwise False.
Expand All @@ -88,6 +101,7 @@ def is_tensorflow_model(model: TModel) -> bool:
return isinstance(model, tensorflow.Module)


@result_verifier
def is_onnx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of onnx.ModelProto, otherwise False.
Expand All @@ -100,6 +114,7 @@ def is_onnx_model(model: TModel) -> bool:
return isinstance(model, onnx.ModelProto)


@result_verifier
def is_openvino_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of openvino.runtime.Model, otherwise False.
Expand All @@ -112,6 +127,7 @@ def is_openvino_model(model: TModel) -> bool:
return isinstance(model, ov.Model)


@result_verifier
def is_openvino_compiled_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of openvino.runtime.CompiledModel, otherwise False.
Expand Down Expand Up @@ -150,7 +166,7 @@ def get_backend(model: TModel) -> BackendType:

raise nncf.UnsupportedBackendError(
"Could not infer the backend framework from the model type because "
"the framework is not available or the model type is unsupported. "
"the framework is not available or corrupted, or the model type is unsupported. "
"The available frameworks found: {}.".format(", ".join([b.value for b in available_backends]))
)

Expand Down

0 comments on commit 5a55a7d

Please sign in to comment.