Skip to content

Commit

Permalink
[PT FE] Make ModuleExtension patching in independent function scope (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#23584)

### Details:
 - *Make ModuleExtension patching in independent function scope*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored and alvoron committed Apr 29, 2024
1 parent bfc33e0 commit f7ced72
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 21 deletions.
17 changes: 10 additions & 7 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ def __exit__(self, *args):


def patch_model(model, module_extensions, orig_forward_name):
for name, m in model.named_modules():
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. '
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.')
continue
def module_patcher(m, name):
extension = None
if m in module_extensions:
extension = module_extensions[m]
Expand Down Expand Up @@ -54,7 +49,7 @@ def forward(*args, **kwargs):
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs)
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) # call user code
m.forward = patched_forward # return patched forward back
return results

Expand All @@ -65,6 +60,14 @@ def new_forward(*args, **kwargs):
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward

for name, m in model.named_modules():
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. '
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.')
continue
module_patcher(m, name)


def unpatch_model(model, orig_forward_name):
for _, m in model.named_modules():
Expand Down
85 changes: 71 additions & 14 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def forward(self, inp):
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "CustomElu", "Result"]


def test_framework_map_macros():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

Expand Down Expand Up @@ -293,10 +294,11 @@ def forward(self, x):

class CosModel(torch.nn.Module):
def __init__(self):
super(CosModel, self).__init__()
super(CosModel, self).__init__()

def forward(self, x):
return torch.cos(x.to(torch.float32))
return torch.cos(x.to(torch.float32))


def test_op_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -350,7 +352,7 @@ def test_op_extension_generic():

def test_module_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import ModuleExtension
from openvino.frontend.pytorch import ModuleExtension, ConversionExtension
from openvino import convert_model

class ModelWithModule(torch.nn.Module):
Expand All @@ -374,23 +376,71 @@ def forward(self, x):
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Convert", "Cos", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")])

converted_model = convert_model(model, example_input=(
torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(
100),), extension=[ModuleExtension(model.cos_module, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(model.cos_module, "aten::sin")])
converted_model = convert_model(model, example_input=(torch.randn(
100),), extension=[ModuleExtension("cos_module", "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension("cos_module", "aten::sin")])
def sin_op(context):
return ops.sin(context.get_input(0)).outputs()

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[
ModuleExtension("cos_module", "MyOp"), ConversionExtension("MyOp", sin_op)])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]


def test_multiple_module_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import ModuleExtension
from openvino import convert_model

class ModelWithModule(torch.nn.Module):
def __init__(self):
super(ModelWithModule, self).__init__()
self.cos_module = CosModel()
self.relu_module = torch.nn.ReLU()

def forward(self, x):
x = x.to(torch.float32)
return self.cos_module(x) + self.relu_module(x)

model = ModelWithModule()
decoder = TorchScriptPythonDecoder(model)

fem = FrontEndManager()
fe = fem.load_by_framework(framework="pytorch")
assert fe

input_model = fe.load(decoder)
assert input_model
converted_model = fe.convert(input_model)
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Convert", "Convert", "Cos", "Constant", "Relu", "Multiply", "Add", "Result"]

converted_model = convert_model(model, example_input=(
torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin"), ModuleExtension(model.relu_module, "aten::tan")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Tan", "Add", "Result"]


def test_pytorch_telemetry():
from openvino.frontend import TelemetryExtension
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -536,25 +586,31 @@ def forward(self, x: float, y: torch.Tensor):
r_t = "t"

if isinstance(l_type, type):
ov_lhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(l_type.__name__))
ov_lhs = ops.parameter(PartialShape(
[]), pt_to_ov_type_map.get(l_type.__name__))
pt_lhs = l_type(5)
l_t = l_type.__name__
elif l_scalar:
ov_lhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(str(l_type)))
ov_lhs = ops.parameter(PartialShape(
[]), pt_to_ov_type_map.get(str(l_type)))
pt_lhs = torch.tensor(1, dtype=l_type)
else:
ov_lhs = ops.parameter(PartialShape([2, 2]), pt_to_ov_type_map.get(str(l_type)))
ov_lhs = ops.parameter(PartialShape(
[2, 2]), pt_to_ov_type_map.get(str(l_type)))
pt_lhs = torch.rand([2, 2]).to(dtype=l_type)

if isinstance(r_type, type):
ov_rhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(r_type.__name__))
ov_rhs = ops.parameter(PartialShape(
[]), pt_to_ov_type_map.get(r_type.__name__))
pt_rhs = r_type(5)
r_t = r_type.__name__
elif r_scalar:
ov_rhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(str(r_type)))
ov_rhs = ops.parameter(PartialShape(
[]), pt_to_ov_type_map.get(str(r_type)))
pt_rhs = torch.tensor(1, dtype=r_type)
else:
ov_rhs = ops.parameter(PartialShape([2, 2]), pt_to_ov_type_map.get(str(r_type)))
ov_rhs = ops.parameter(PartialShape(
[2, 2]), pt_to_ov_type_map.get(str(r_type)))
pt_rhs = torch.rand([2, 2]).to(dtype=r_type)
model = get_scripted_model(locals().get(f"aten_add_{l_t}_{r_t}")())
decoder = TorchScriptPythonDecoder(model)
Expand All @@ -578,7 +634,8 @@ def forward(self, x: float, y: torch.Tensor):
pt_out_type = pt_to_ov_type_map.get(str(pt_out_type))
ov_out_type = om.get_output_element_type(0)
if pt_out_type == Type.i64 and ov_out_type == Type.i32 and "int" in [l_t, r_t]:
pytest.xfail("Pytorch int-like scalar in OV is converted to i32 instead of i64, mismatch is expected.")
pytest.xfail(
"Pytorch int-like scalar in OV is converted to i32 instead of i64, mismatch is expected.")
assert pt_out_type == ov_out_type
assert PartialShape(pt_out_shape) == om.get_output_partial_shape(0)

Expand Down

0 comments on commit f7ced72

Please sign in to comment.