diff --git a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py index 8e823888557222..a55e031cf2c73f 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py @@ -18,12 +18,7 @@ def __exit__(self, *args): def patch_model(model, module_extensions, orig_forward_name): - for name, m in model.named_modules(): - if hasattr(m, orig_forward_name): - # already patched, skipping with a warning because it is unexpected - print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. ' - 'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.') - continue + def module_patcher(m, name): extension = None if m in module_extensions: extension = module_extensions[m] @@ -54,7 +49,7 @@ def forward(*args, **kwargs): m.forward = getattr(m, orig_forward_name) # call user code results = extension.evaluate( - m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) + m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) # call user code m.forward = patched_forward # return patched forward back return results @@ -65,6 +60,14 @@ def new_forward(*args, **kwargs): setattr(m, orig_forward_name, m.forward) m.forward = new_forward + for name, m in model.named_modules(): + if hasattr(m, orig_forward_name): + # already patched, skipping with a warning because it is unexpected + print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. ' + 'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.') + continue + module_patcher(m, name) + def unpatch_model(model, orig_forward_name): for _, m in model.named_modules(): diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index 5063f66d1e139e..98369f2953a03e 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -260,6 +260,7 @@ def forward(self, inp): assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "CustomElu", "Result"] + def test_framework_map_macros(): from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder @@ -293,10 +294,11 @@ def forward(self, x): class CosModel(torch.nn.Module): def __init__(self): - super(CosModel, self).__init__() + super(CosModel, self).__init__() def forward(self, x): - return torch.cos(x.to(torch.float32)) + return torch.cos(x.to(torch.float32)) + def test_op_extension(): from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder @@ -350,7 +352,7 @@ def test_op_extension_generic(): def test_module_extension(): from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder - from openvino.frontend.pytorch import ModuleExtension + from openvino.frontend.pytorch import ModuleExtension, ConversionExtension from openvino import convert_model class ModelWithModule(torch.nn.Module): @@ -374,23 +376,71 @@ def forward(self, x): assert converted_model assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "Convert", "Cos", "Result"] - - converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")]) + + converted_model = convert_model(model, example_input=( + torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")]) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Sin", "Result"] + + converted_model = convert_model(model, example_input=(torch.randn( + 100),), extension=[ModuleExtension(model.cos_module, "aten::sin")]) assert converted_model assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "Sin", "Result"] - converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(model.cos_module, "aten::sin")]) + converted_model = convert_model(model, example_input=(torch.randn( + 100),), extension=[ModuleExtension("cos_module", "aten::sin")]) assert converted_model assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "Sin", "Result"] - converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension("cos_module", "aten::sin")]) + def sin_op(context): + return ops.sin(context.get_input(0)).outputs() + + converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ + ModuleExtension("cos_module", "MyOp"), ConversionExtension("MyOp", sin_op)]) assert converted_model assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "Sin", "Result"] +def test_multiple_module_extension(): + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + from openvino.frontend.pytorch import ModuleExtension + from openvino import convert_model + + class ModelWithModule(torch.nn.Module): + def __init__(self): + super(ModelWithModule, self).__init__() + self.cos_module = CosModel() + self.relu_module = torch.nn.ReLU() + + def forward(self, x): + x = x.to(torch.float32) + return self.cos_module(x) + self.relu_module(x) + + model = ModelWithModule() + decoder = TorchScriptPythonDecoder(model) + + fem = FrontEndManager() + fe = fem.load_by_framework(framework="pytorch") + assert fe + + input_model = fe.load(decoder) + assert input_model + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Convert", "Convert", "Cos", "Constant", "Relu", "Multiply", "Add", "Result"] + + converted_model = convert_model(model, example_input=( + torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin"), ModuleExtension(model.relu_module, "aten::tan")]) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Sin", "Tan", "Add", "Result"] + + def test_pytorch_telemetry(): from openvino.frontend import TelemetryExtension from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder @@ -536,25 +586,31 @@ def forward(self, x: float, y: torch.Tensor): r_t = "t" if isinstance(l_type, type): - ov_lhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(l_type.__name__)) + ov_lhs = ops.parameter(PartialShape( + []), pt_to_ov_type_map.get(l_type.__name__)) pt_lhs = l_type(5) l_t = l_type.__name__ elif l_scalar: - ov_lhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(str(l_type))) + ov_lhs = ops.parameter(PartialShape( + []), pt_to_ov_type_map.get(str(l_type))) pt_lhs = torch.tensor(1, dtype=l_type) else: - ov_lhs = ops.parameter(PartialShape([2, 2]), pt_to_ov_type_map.get(str(l_type))) + ov_lhs = ops.parameter(PartialShape( + [2, 2]), pt_to_ov_type_map.get(str(l_type))) pt_lhs = torch.rand([2, 2]).to(dtype=l_type) if isinstance(r_type, type): - ov_rhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(r_type.__name__)) + ov_rhs = ops.parameter(PartialShape( + []), pt_to_ov_type_map.get(r_type.__name__)) pt_rhs = r_type(5) r_t = r_type.__name__ elif r_scalar: - ov_rhs = ops.parameter(PartialShape([]), pt_to_ov_type_map.get(str(r_type))) + ov_rhs = ops.parameter(PartialShape( + []), pt_to_ov_type_map.get(str(r_type))) pt_rhs = torch.tensor(1, dtype=r_type) else: - ov_rhs = ops.parameter(PartialShape([2, 2]), pt_to_ov_type_map.get(str(r_type))) + ov_rhs = ops.parameter(PartialShape( + [2, 2]), pt_to_ov_type_map.get(str(r_type))) pt_rhs = torch.rand([2, 2]).to(dtype=r_type) model = get_scripted_model(locals().get(f"aten_add_{l_t}_{r_t}")()) decoder = TorchScriptPythonDecoder(model) @@ -578,7 +634,8 @@ def forward(self, x: float, y: torch.Tensor): pt_out_type = pt_to_ov_type_map.get(str(pt_out_type)) ov_out_type = om.get_output_element_type(0) if pt_out_type == Type.i64 and ov_out_type == Type.i32 and "int" in [l_t, r_t]: - pytest.xfail("Pytorch int-like scalar in OV is converted to i32 instead of i64, mismatch is expected.") + pytest.xfail( + "Pytorch int-like scalar in OV is converted to i32 instead of i64, mismatch is expected.") assert pt_out_type == ov_out_type assert PartialShape(pt_out_shape) == om.get_output_partial_shape(0)