diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 6034c80fe904cd..ff4d08444b5266 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -13,6 +13,56 @@ def make_constant(*args, **kwargs): return op.Constant(*args, **kwargs) +def get_type_from_py_type(value): + if isinstance(value, float): + return OVType.f32 + if isinstance(value, int): + return OVType.i32 + if isinstance(value, bool): + return OVType.boolean + return OVType.dynamic + +def ivalue_to_constant(ivalue): + ov_type = get_type_from_py_type(ivalue) + if ov_type.is_static(): + return make_constant(ov_type, Shape([]), [ivalue]).outputs() + + if isinstance(ivalue, list): + assert len(ivalue) > 0, "Can't deduce type for empty list" + ov_type = get_type_from_py_type(ivalue[0]) + assert ov_type.is_static(), "Can't deduce type for list" + return make_constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() + + if ivalue.type() in pt_to_ov_type_map: + try: + ovshape = PartialShape(ivalue.size()) + ovtype = pt_to_ov_type_map[ivalue.type()] + ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue.data_ptr()) + except: + # old variant that makes a slow data copying + print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.") + nvalues = ivalue.numpy() + ovtype = np_to_ov_type_map[str(nvalues.dtype)] + ovshape = PartialShape(nvalues.shape) + ov_const = make_constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist()) + return ov_const.outputs() + +def get_value_from_getattr(getattr_node, self_module): + assert getattr_node.kind() == 'prim::GetAttr', "Got node of kind not equal to prim::GetAttr" + # GetAttr nodes can be nested + stack = [] + while getattr_node.kind() == 'prim::GetAttr': + stack.append(getattr_node) + inputs = list(getattr_node.inputs()) + if len(inputs) == 0: + break + getattr_node = inputs[0].node() + module = self_module + while len(stack) > 0: + node = stack.pop() + assert(hasattr(module, node.s('name'))) + module = getattr(module, node.s('name')) + return module pt_to_ov_type_map = { 'float': OVType.f32, @@ -43,14 +93,19 @@ def make_constant(*args, **kwargs): class TorchScriptPythonDecoder (Decoder): - def __init__(self, pt_module): + def __init__(self, pt_module, graph_element=None): Decoder.__init__(self) # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted self.m_decoders = [] + if graph_element is None: + assert hasattr(pt_module, 'inlined_graph'), 'graph_element must have inlined_graph' + self.graph_element = pt_module.inlined_graph + else: + self.graph_element = graph_element self.pt_module = pt_module def inputs(self): - return [x.unique() for x in self.pt_module.inputs()] + return [x.unique() for x in self.graph_element.inputs()] def input(self, index): # TODO: remove return self.inputs()[index] # TODO: find specialized method @@ -75,38 +130,38 @@ def _get_known_type_for_value(self, type): ''' Returns known/unknown types wrapped as OVAny ''' - #print(f'Trying to parse type {type} of class {type.__class__}') + # print(f'Trying to parse type {type} of class {type.__class__}') # Check for simple scalar types first # TODO: Don't use str, use native types if type is None: return OVAny(OVType.dynamic) if str(type) in pt_to_ov_type_map: - #print(f'Recognized native type, type.__class__ = {type.__class__}') + # print(f'Recognized native type, type.__class__ = {type.__class__}') return OVAny(pt_to_ov_type_map[str(type)]) elif type.__class__ is torch.TensorType: - #print(f'Recognized Tensor type with type.dtype() = {type.dtype()}') + # print(f'Recognized Tensor type with type.dtype() = {type.dtype()}') # Tensor type, parse element type # TODO: replace string by native type # return OVAny(PartialShape([1,2,3])) return OVAny(DecoderType.Tensor(self._get_known_type_for_value(type.dtype()))) elif type.__class__ is torch.ListType: element_type = type.getElementType() - #print(f'Recognized torch List type. Type of element is {element_type}') + # print(f'Recognized torch List type. Type of element is {element_type}') return OVAny(DecoderType.List(self._get_known_type_for_value(element_type))) else: - #print(f'Not a tensor nor native type: {type}') + # print(f'Not a tensor nor native type: {type}') # Not yet recognized return OVAny(OVType.dynamic) - #pt_type_class = value.type().__class__ + # pt_type_class = value.type().__class__ # if pt_type_class is torch.ListType: def get_shape_for_value(self, value): if value.isCompleteTensor(): ps = PartialShape(value.type().sizes()) - #print(f'SHAPE FOR COMPLETE TENSOR: {ps}') + # print(f'SHAPE FOR COMPLETE TENSOR: {ps}') return ps else: - #print(f'NOT COMPLETE TENSOR for {value}') + # print(f'NOT COMPLETE TENSOR for {value}') # TODO: Recognize types that we can represent as a nested constructs with objects from DecoderType # If recognized, return scalar instead of dynamic. Scalar means a single value of that custom type. # See get_type_for_value for reference @@ -114,7 +169,7 @@ def get_shape_for_value(self, value): return PartialShape.dynamic() def get_type_for_value(self, value): - #print(f'Decoding value type for value {value}') + # print(f'Decoding value type for value {value}') full_type = self._get_known_type_for_value(value.type()) # DecoderType.print(full_type) # new (full) type interpretation return full_type @@ -157,40 +212,40 @@ def get_output_transpose_order(self, index): return [] def get_subgraph_size(self): - return len(self.get_subgraphs()) if hasattr(self.pt_module, 'blocks') else 1 + return len(self.get_subgraphs()) if hasattr(self.graph_element, 'blocks') else 1 def visit_subgraph(self, node_visitor): # make sure topological order is satisfied - for node in self.pt_module.nodes(): - decoder = TorchScriptPythonDecoder(node) + for node in self.graph_element.nodes(): + decoder = TorchScriptPythonDecoder(self.pt_module, node) self.m_decoders.append(decoder) node_visitor(decoder) def get_subgraphs(self): - return list(self.pt_module.blocks()) + return list(self.graph_element.blocks()) def get_subgraph_decoder(self, index): - decoder = TorchScriptPythonDecoder(self.get_subgraphs()[index]) + decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index]) self.m_decoders.append(decoder) return decoder def get_op_type(self): - return self.pt_module.kind() + return self.graph_element.kind() def get_schema(self): - return self.pt_module.schema() + return self.graph_element.schema() def outputs(self): - return [x.unique() for x in self.pt_module.outputs()] + return [x.unique() for x in self.graph_element.outputs()] def _raw_outputs(self): - return [x for x in self.pt_module.outputs()] + return [x for x in self.graph_element.outputs()] def _raw_output(self, index): return self._raw_outputs()[index] def _raw_inputs(self): - return [x for x in self.pt_module.inputs()] + return [x for x in self.graph_element.inputs()] def _raw_input(self, index): return self._raw_inputs()[index] @@ -204,29 +259,37 @@ def output(self, index): def mark_node(self, node): return node + def try_decode_get_attr(self): + pt_value = get_value_from_getattr(self.graph_element, self.pt_module) + assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr" + if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule): + return ivalue_to_constant(pt_value) + else: + return [] + def as_constant(self): if not self.get_op_type() == 'prim::Constant': - #print(f'[ ERROR ] Requested const value {self._raw_output(0)} from a non const prim {self.get_op_type()}') + # print(f'[ ERROR ] Requested const value {self._raw_output(0)} from a non const prim {self.get_op_type()}') return None pt_value = self._raw_output(0) pt_type_class = pt_value.type().__class__ - #print(f'Not a tensor, type = {pt_value.type()}\ndir = {dir(pt_value.type())}\n__class__ = {pt_value.type().__class__}') + # print(f'Not a tensor, type = {pt_value.type()}\ndir = {dir(pt_value.type())}\n__class__ = {pt_value.type().__class__}') if pt_type_class is torch.TensorType: return self.as_constant_tensor(pt_value) if pt_type_class is torch.ListType: return self.as_constant_list(pt_value) - #print(f'Trying to recognize value {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') + # print(f'Trying to recognize value {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') if str(pt_value.type()) in ['torch.int32', 'int']: - #print(f'Found int value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') + # print(f'Found int value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') return make_constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs() if str(pt_value.type()) in ['torch.float', 'torch.FloatType', 'float']: - #print(f'Found float value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') + # print(f'Found float value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}') return make_constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs() if str(pt_value.type()) in ['torch.bool', 'bool']: - #print('Scalar bool detected') + # print('Scalar bool detected') return make_constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs() - #print(f'Left value not converted to const, value = {pt_value}') + # print(f'Left value not converted to const, value = {pt_value}') return None @@ -241,7 +304,7 @@ def as_string(self): def as_constant_tensor(self, pt_value): ivalue = pt_value.toIValue() - if pt_value.isCompleteTensor(): + if pt_value.isCompleteTensor(): try: ivalue = ivalue.to(memory_format=torch.contiguous_format).detach().cpu() except: @@ -266,35 +329,14 @@ def as_constant_tensor(self, pt_value): ov_const = make_constant(ovtype, ovshape.get_shape(), values) return ov_const.outputs() else: - # Incomplete tensor can be scalar - if isinstance(ivalue, float): - return make_constant(OVType.f32, Shape([]), [ivalue]).outputs() - if isinstance(ivalue, int): - return make_constant(OVType.i32, Shape([]), [ivalue]).outputs() - if isinstance(ivalue, bool): - return make_constant(OVType.boolean, Shape([]), [ivalue]).outputs() - - # TODO: verify that it correctly reads incomplete consts - if ivalue.type() in pt_to_ov_type_map: - try: - ovshape = PartialShape(ivalue.size()) - ovtype = pt_to_ov_type_map[ivalue.type()] - ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue.data_ptr()) - except: - # old variant that makes a slow data copying - print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.") - nvalues = ivalue.numpy() - ovtype = np_to_ov_type_map[str(nvalues.dtype)] - ovshape = PartialShape(nvalues.shape) - ov_const = make_constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist()) - return ov_const.outputs() + return ivalue_to_constant(ivalue) return None def as_constant_list(self, pt_value): # For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively rewrite them in that part where constant attributes are queried pt_element_type = str(pt_value.type().getElementType()) ivalue = pt_value.toIValue() - #print(f'List toIValue: {ivalue}, type of it: {type(ivalue)}') + # print(f'List toIValue: {ivalue}, type of it: {type(ivalue)}') is_known_type = pt_element_type in pt_to_ov_type_map # WA to broken ov.Type @@ -306,7 +348,7 @@ def as_constant_list(self, pt_value): if is_known_type: ovtype = pt_to_ov_type_map[pt_element_type] - #print(f'ovtype = {ovtype}, pt_element_type = {pt_element_type}, OVType.i32 = {OVType.i32}, {OVType.f32}') + # print(f'ovtype = {ovtype}, pt_element_type = {pt_element_type}, OVType.i32 = {OVType.i32}, {OVType.f32}') ovshape = PartialShape([len(ivalue)]) ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue) return ov_const.outputs() @@ -316,8 +358,15 @@ def input_is_none(self, index): return True else: r_input = self._raw_input(index) - return str(r_input.type()) in ['torch.NoneType', 'NoneType'] + if str(r_input.type()) in ['torch.NoneType', 'NoneType']: + return True + else: + in_node = r_input.node() + if in_node.kind() == 'prim::GetAttr': + pt_value = get_value_from_getattr(in_node, self.pt_module) + return pt_value is None + return False def debug(self): print(f'DEBUG CALLED FOR {self._raw_output(0)}') - # self.pt_module.print() + # self.graph_element.print() diff --git a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp index 4fb8f76bfc53ae..825dddf60a572a 100644 --- a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp +++ b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp @@ -54,6 +54,10 @@ class PyDecoder : public ov::frontend::pytorch::Decoder { PYBIND11_OVERRIDE_PURE(bool, Decoder, input_is_none, index); } + ov::OutputVector try_decode_get_attr() override { + PYBIND11_OVERRIDE_PURE(ov::OutputVector, Decoder, try_decode_get_attr); + } + ov::OutputVector as_constant() override { PYBIND11_OVERRIDE_PURE(ov::OutputVector, Decoder, as_constant); } diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 463864520a3e1f..28c3a190992c7a 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -109,6 +109,8 @@ struct Decoder { // TODO: Is it required to be enable_shared_from_this? // TODO: required? can be implemented in the context of a single node? virtual bool input_is_none(size_t index) const = 0; + virtual ov::OutputVector try_decode_get_attr() = 0; + // Work for natural constant nodes, e.g. for prim::Constant; don't know other nodes kinds that fit // TODO: why OutputVector instead of just single output? virtual OutputVector as_constant() = 0; diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index c09152a74f333b..751dbe166bcb11 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -73,22 +73,6 @@ std::shared_ptr FrontEnd::convert_partially(const ov::frontend::InputMode auto pytorch_model = std::dynamic_pointer_cast(model); auto model = convert_pytorch_model(pytorch_model->m_model); - // TODO: Propose better solution for the next code block - // Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter - // that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model - // inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any - // consumers in the finally converted model and remove this parameter. This parameter should have index 0. - if (model->get_parameters().size() > 0) { - auto self = model->get_parameters()[0]; - if (self->output(0).get_target_inputs().empty()) { - // There is no consumers: safe to remove - // std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never " - // "used and treated as `self`\n"; - model->remove_parameter(self); - } else { - std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n"; - } - } return model; } catch (const std::runtime_error& e) { std::cerr << "[ ERROR ] Unexpected error while converting pytorch model: " << e.what() << "\n"; @@ -104,6 +88,10 @@ std::shared_ptr FrontEnd::decode(const InputModel::Ptr& model) const { void FrontEnd::normalize(const std::shared_ptr& model) const { ov::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + // Have to run UnrollIf second time, because conditions are defined outside of nested If (ticket 98155) + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); @@ -116,6 +104,23 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.run_passes(model); apply_pytorch_conversion_transforms(model); + + // TODO: Propose better solution for the next code block + // Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter + // that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model + // inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any + // consumers in the finally converted model and remove this parameter. This parameter should have index 0. + if (model->get_parameters().size() > 0) { + auto self = model->get_parameters()[0]; + if (self->output(0).get_target_inputs().empty()) { + // There is no consumers: safe to remove + // std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never " + // "used and treated as `self`\n"; + model->remove_parameter(self); + } else { + std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n"; + } + } } void FrontEnd::add_extension(const std::shared_ptr& extension) { diff --git a/src/frontends/pytorch/src/op/get_attr.cpp b/src/frontends/pytorch/src/op/get_attr.cpp new file mode 100644 index 00000000000000..d47d079b663938 --- /dev/null +++ b/src/frontends/pytorch/src/op/get_attr.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "pt_framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_get_attr(NodeContext& context) { + auto res = context.get_decoder()->try_decode_get_attr(); + FRONT_END_OP_CONVERSION_CHECK(res.size() > 0, "GetAttr must have at least one output."); + return res; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/list_construct.cpp b/src/frontends/pytorch/src/op/list_construct.cpp new file mode 100644 index 00000000000000..e0dafcdfd74a27 --- /dev/null +++ b/src/frontends/pytorch/src/op/list_construct.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_list_construct(NodeContext& context) { + // Process the case when prim::ListConstruct has all inputs constant + ov::OutputVector consts; + for (int i = 0; i < context.get_input_size(); i++) { + auto input = context.get_input_from_visible_context(i); + auto c_node = std::dynamic_pointer_cast(input.get_node_shared_ptr()); + FRONT_END_OP_CONVERSION_CHECK(c_node, "Translation for prim::ListConstruct support only constant inputs"); + if (c_node->get_shape().size() == 0) { + c_node = std::make_shared(c_node->get_element_type(), Shape{1}, c_node->get_data_ptr()); + } + consts.push_back(c_node); + } + auto list_construct = std::make_shared(consts, 0); + if (list_construct->has_evaluate()) { + OutputVector replacements(list_construct->get_output_size()); + + if (list_construct->constant_fold(replacements, list_construct->input_values())) { + return replacements; + } + } + return {context.mark_output(list_construct)}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 0a3eec454b3840..5060b8a2c84fa3 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -40,6 +40,7 @@ OP_CONVERTER(translate_floor_divide); OP_CONVERTER(translate_full); OP_CONVERTER(translate_full_like); OP_CONVERTER(translate_gelu); +OP_CONVERTER(translate_get_attr); OP_CONVERTER(translate_group_norm); OP_CONVERTER(translate_hardtanh); OP_CONVERTER(translate_if); @@ -48,6 +49,7 @@ OP_CONVERTER(translate_int); OP_CONVERTER(translate_layer_norm); OP_CONVERTER(translate_len); OP_CONVERTER(translate_linear); +OP_CONVERTER(translate_list_construct); OP_CONVERTER(translate_loop); OP_CONVERTER(translate_max_poolnd); OP_CONVERTER(translate_max); @@ -96,6 +98,7 @@ OP_CONVERTER(translate_zeros_like); const std::map get_supported_ops() { return { + {"aten::__not__", op::translate_1to1_match_1_inputs}, {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, {"aten::abs", op::translate_1to1_match_1_inputs}, @@ -145,6 +148,8 @@ const std::map get_supported_ops() { {"aten::dim", op::translate_dim}, {"aten::div", op::translate_div}, {"aten::div_", op::inplace_op}, + {"aten::dropout", op::skip_node}, + {"aten::dropout_", op::skip_node}, {"aten::elu", op::translate_elu}, {"aten::embedding", op::translate_embedding}, {"aten::eq", op::translate_1to1_match_2_inputs}, @@ -250,8 +255,10 @@ const std::map get_supported_ops() { {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, {"prim::Constant", op::translate_constant}, + {"prim::GetAttr", op::translate_get_attr}, {"prim::If", op::translate_if}, {"prim::is_cuda", op::return_false_scalar}, + {"prim::ListConstruct", op::translate_list_construct}, {"prim::Loop", op::translate_loop}, {"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape [] {"prim::requires_grad", op::return_false_scalar}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 7db301f1ad8a97..fa97a3426fb5f3 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -404,7 +404,7 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc OV_FRONTEND_REQUIRE(tensor_map.count(tensor_id)); // model input was mutated we need to make a result for it auto mutated_tensor = tensor_map.at(tensor_id); - // empty external_tensor_map means this is main body of the model and we don't want to creatre + // empty external_tensor_map means this is main body of the model and we don't want to create // additional outputs in that case. if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty()) results.push_back(std::make_shared(tensor_map.at(tensor_id))); diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index eb01856521e4ec..ac21ac0ac0d163 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -47,10 +47,11 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti with torch.no_grad(): model.eval() if not kwargs.get('trace_model', False): - model = torch.jit.freeze(torch.jit.script(model)) + model = torch.jit.script(model) else: torch_inputs = [torch.from_numpy(inp) for inp in inputs] - model = torch.jit.freeze(torch.jit.trace(model, torch_inputs)) + model = torch.jit.trace(model, torch_inputs) + model = torch.jit.freeze(model) graph = model.inlined_graph print(graph) @@ -60,7 +61,7 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti fe_manager = FrontEndManager() fe = fe_manager.load_by_framework('pytorch') - decoder = TorchScriptPythonDecoder(graph) + decoder = TorchScriptPythonDecoder(model) im = fe.load(decoder) om = fe.convert(im)