Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#54 from mvafin/mvafin/pt_fe/no_fre…
Browse files Browse the repository at this point in the history
…ezing

Support converting models without freezing
  • Loading branch information
slyalin authored Jan 11, 2023
2 parents b1661dc + 51a4f8f commit 636be1a
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 75 deletions.
159 changes: 104 additions & 55 deletions src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,56 @@
def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)

def get_type_from_py_type(value):
if isinstance(value, float):
return OVType.f32
if isinstance(value, int):
return OVType.i32
if isinstance(value, bool):
return OVType.boolean
return OVType.dynamic

def ivalue_to_constant(ivalue):
ov_type = get_type_from_py_type(ivalue)
if ov_type.is_static():
return make_constant(ov_type, Shape([]), [ivalue]).outputs()

if isinstance(ivalue, list):
assert len(ivalue) > 0, "Can't deduce type for empty list"
ov_type = get_type_from_py_type(ivalue[0])
assert ov_type.is_static(), "Can't deduce type for list"
return make_constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()

if ivalue.type() in pt_to_ov_type_map:
try:
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[ivalue.type()]
ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue.data_ptr())
except:
# old variant that makes a slow data copying
print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.")
nvalues = ivalue.numpy()
ovtype = np_to_ov_type_map[str(nvalues.dtype)]
ovshape = PartialShape(nvalues.shape)
ov_const = make_constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()

def get_value_from_getattr(getattr_node, self_module):
assert getattr_node.kind() == 'prim::GetAttr', "Got node of kind not equal to prim::GetAttr"
# GetAttr nodes can be nested
stack = []
while getattr_node.kind() == 'prim::GetAttr':
stack.append(getattr_node)
inputs = list(getattr_node.inputs())
if len(inputs) == 0:
break
getattr_node = inputs[0].node()
module = self_module
while len(stack) > 0:
node = stack.pop()
assert(hasattr(module, node.s('name')))
module = getattr(module, node.s('name'))
return module

pt_to_ov_type_map = {
'float': OVType.f32,
Expand Down Expand Up @@ -43,14 +93,19 @@ def make_constant(*args, **kwargs):


class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module):
def __init__(self, pt_module, graph_element=None):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
if graph_element is None:
assert hasattr(pt_module, 'inlined_graph'), 'graph_element must have inlined_graph'
self.graph_element = pt_module.inlined_graph
else:
self.graph_element = graph_element
self.pt_module = pt_module

def inputs(self):
return [x.unique() for x in self.pt_module.inputs()]
return [x.unique() for x in self.graph_element.inputs()]

def input(self, index): # TODO: remove
return self.inputs()[index] # TODO: find specialized method
Expand All @@ -75,46 +130,46 @@ def _get_known_type_for_value(self, type):
'''
Returns known/unknown types wrapped as OVAny
'''
#print(f'Trying to parse type {type} of class {type.__class__}')
# print(f'Trying to parse type {type} of class {type.__class__}')
# Check for simple scalar types first
# TODO: Don't use str, use native types
if type is None:
return OVAny(OVType.dynamic)
if str(type) in pt_to_ov_type_map:
#print(f'Recognized native type, type.__class__ = {type.__class__}')
# print(f'Recognized native type, type.__class__ = {type.__class__}')
return OVAny(pt_to_ov_type_map[str(type)])
elif type.__class__ is torch.TensorType:
#print(f'Recognized Tensor type with type.dtype() = {type.dtype()}')
# print(f'Recognized Tensor type with type.dtype() = {type.dtype()}')
# Tensor type, parse element type
# TODO: replace string by native type
# return OVAny(PartialShape([1,2,3]))
return OVAny(DecoderType.Tensor(self._get_known_type_for_value(type.dtype())))
elif type.__class__ is torch.ListType:
element_type = type.getElementType()
#print(f'Recognized torch List type. Type of element is {element_type}')
# print(f'Recognized torch List type. Type of element is {element_type}')
return OVAny(DecoderType.List(self._get_known_type_for_value(element_type)))
else:
#print(f'Not a tensor nor native type: {type}')
# print(f'Not a tensor nor native type: {type}')
# Not yet recognized
return OVAny(OVType.dynamic)
#pt_type_class = value.type().__class__
# pt_type_class = value.type().__class__
# if pt_type_class is torch.ListType:

def get_shape_for_value(self, value):
if value.isCompleteTensor():
ps = PartialShape(value.type().sizes())
#print(f'SHAPE FOR COMPLETE TENSOR: {ps}')
# print(f'SHAPE FOR COMPLETE TENSOR: {ps}')
return ps
else:
#print(f'NOT COMPLETE TENSOR for {value}')
# print(f'NOT COMPLETE TENSOR for {value}')
# TODO: Recognize types that we can represent as a nested constructs with objects from DecoderType
# If recognized, return scalar instead of dynamic. Scalar means a single value of that custom type.
# See get_type_for_value for reference
pass
return PartialShape.dynamic()

def get_type_for_value(self, value):
#print(f'Decoding value type for value {value}')
# print(f'Decoding value type for value {value}')
full_type = self._get_known_type_for_value(value.type())
# DecoderType.print(full_type) # new (full) type interpretation
return full_type
Expand Down Expand Up @@ -157,40 +212,40 @@ def get_output_transpose_order(self, index):
return []

def get_subgraph_size(self):
return len(self.get_subgraphs()) if hasattr(self.pt_module, 'blocks') else 1
return len(self.get_subgraphs()) if hasattr(self.graph_element, 'blocks') else 1

def visit_subgraph(self, node_visitor):
# make sure topological order is satisfied
for node in self.pt_module.nodes():
decoder = TorchScriptPythonDecoder(node)
for node in self.graph_element.nodes():
decoder = TorchScriptPythonDecoder(self.pt_module, node)
self.m_decoders.append(decoder)
node_visitor(decoder)

def get_subgraphs(self):
return list(self.pt_module.blocks())
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index):
decoder = TorchScriptPythonDecoder(self.get_subgraphs()[index])
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index])
self.m_decoders.append(decoder)
return decoder

def get_op_type(self):
return self.pt_module.kind()
return self.graph_element.kind()

def get_schema(self):
return self.pt_module.schema()
return self.graph_element.schema()

def outputs(self):
return [x.unique() for x in self.pt_module.outputs()]
return [x.unique() for x in self.graph_element.outputs()]

def _raw_outputs(self):
return [x for x in self.pt_module.outputs()]
return [x for x in self.graph_element.outputs()]

def _raw_output(self, index):
return self._raw_outputs()[index]

def _raw_inputs(self):
return [x for x in self.pt_module.inputs()]
return [x for x in self.graph_element.inputs()]

def _raw_input(self, index):
return self._raw_inputs()[index]
Expand All @@ -204,29 +259,37 @@ def output(self, index):
def mark_node(self, node):
return node

def try_decode_get_attr(self):
pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule):
return ivalue_to_constant(pt_value)
else:
return []

def as_constant(self):
if not self.get_op_type() == 'prim::Constant':
#print(f'[ ERROR ] Requested const value {self._raw_output(0)} from a non const prim {self.get_op_type()}')
# print(f'[ ERROR ] Requested const value {self._raw_output(0)} from a non const prim {self.get_op_type()}')
return None
pt_value = self._raw_output(0)

pt_type_class = pt_value.type().__class__
#print(f'Not a tensor, type = {pt_value.type()}\ndir = {dir(pt_value.type())}\n__class__ = {pt_value.type().__class__}')
# print(f'Not a tensor, type = {pt_value.type()}\ndir = {dir(pt_value.type())}\n__class__ = {pt_value.type().__class__}')
if pt_type_class is torch.TensorType:
return self.as_constant_tensor(pt_value)
if pt_type_class is torch.ListType:
return self.as_constant_list(pt_value)
#print(f'Trying to recognize value {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
# print(f'Trying to recognize value {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
if str(pt_value.type()) in ['torch.int32', 'int']:
#print(f'Found int value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
# print(f'Found int value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
return make_constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_value.type()) in ['torch.float', 'torch.FloatType', 'float']:
#print(f'Found float value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
# print(f'Found float value= {pt_value}, type = {type(pt_value.toIValue())}, ivalue = {pt_value.toIValue()}')
return make_constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_value.type()) in ['torch.bool', 'bool']:
#print('Scalar bool detected')
# print('Scalar bool detected')
return make_constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()
#print(f'Left value not converted to const, value = {pt_value}')
# print(f'Left value not converted to const, value = {pt_value}')

return None

Expand All @@ -241,7 +304,7 @@ def as_string(self):

def as_constant_tensor(self, pt_value):
ivalue = pt_value.toIValue()
if pt_value.isCompleteTensor():
if pt_value.isCompleteTensor():
try:
ivalue = ivalue.to(memory_format=torch.contiguous_format).detach().cpu()
except:
Expand All @@ -266,35 +329,14 @@ def as_constant_tensor(self, pt_value):
ov_const = make_constant(ovtype, ovshape.get_shape(), values)
return ov_const.outputs()
else:
# Incomplete tensor can be scalar
if isinstance(ivalue, float):
return make_constant(OVType.f32, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, int):
return make_constant(OVType.i32, Shape([]), [ivalue]).outputs()
if isinstance(ivalue, bool):
return make_constant(OVType.boolean, Shape([]), [ivalue]).outputs()

# TODO: verify that it correctly reads incomplete consts
if ivalue.type() in pt_to_ov_type_map:
try:
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[ivalue.type()]
ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue.data_ptr())
except:
# old variant that makes a slow data copying
print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.")
nvalues = ivalue.numpy()
ovtype = np_to_ov_type_map[str(nvalues.dtype)]
ovshape = PartialShape(nvalues.shape)
ov_const = make_constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()
return ivalue_to_constant(ivalue)
return None

def as_constant_list(self, pt_value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
ivalue = pt_value.toIValue()
#print(f'List toIValue: {ivalue}, type of it: {type(ivalue)}')
# print(f'List toIValue: {ivalue}, type of it: {type(ivalue)}')
is_known_type = pt_element_type in pt_to_ov_type_map

# WA to broken ov.Type
Expand All @@ -306,7 +348,7 @@ def as_constant_list(self, pt_value):

if is_known_type:
ovtype = pt_to_ov_type_map[pt_element_type]
#print(f'ovtype = {ovtype}, pt_element_type = {pt_element_type}, OVType.i32 = {OVType.i32}, {OVType.f32}')
# print(f'ovtype = {ovtype}, pt_element_type = {pt_element_type}, OVType.i32 = {OVType.i32}, {OVType.f32}')
ovshape = PartialShape([len(ivalue)])
ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()
Expand All @@ -316,8 +358,15 @@ def input_is_none(self, index):
return True
else:
r_input = self._raw_input(index)
return str(r_input.type()) in ['torch.NoneType', 'NoneType']
if str(r_input.type()) in ['torch.NoneType', 'NoneType']:
return True
else:
in_node = r_input.node()
if in_node.kind() == 'prim::GetAttr':
pt_value = get_value_from_getattr(in_node, self.pt_module)
return pt_value is None
return False

def debug(self):
print(f'DEBUG CALLED FOR {self._raw_output(0)}')
# self.pt_module.print()
# self.graph_element.print()
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class PyDecoder : public ov::frontend::pytorch::Decoder {
PYBIND11_OVERRIDE_PURE(bool, Decoder, input_is_none, index);
}

ov::OutputVector try_decode_get_attr() override {
PYBIND11_OVERRIDE_PURE(ov::OutputVector, Decoder, try_decode_get_attr);
}

ov::OutputVector as_constant() override {
PYBIND11_OVERRIDE_PURE(ov::OutputVector, Decoder, as_constant);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ struct Decoder { // TODO: Is it required to be enable_shared_from_this?
// TODO: required? can be implemented in the context of a single node?
virtual bool input_is_none(size_t index) const = 0;

virtual ov::OutputVector try_decode_get_attr() = 0;

// Work for natural constant nodes, e.g. for prim::Constant; don't know other nodes kinds that fit
// TODO: why OutputVector instead of just single output?
virtual OutputVector as_constant() = 0;
Expand Down
37 changes: 21 additions & 16 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,6 @@ std::shared_ptr<Model> FrontEnd::convert_partially(const ov::frontend::InputMode
auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(model);
auto model = convert_pytorch_model(pytorch_model->m_model);

// TODO: Propose better solution for the next code block
// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
// inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any
// consumers in the finally converted model and remove this parameter. This parameter should have index 0.
if (model->get_parameters().size() > 0) {
auto self = model->get_parameters()[0];
if (self->output(0).get_target_inputs().empty()) {
// There is no consumers: safe to remove
// std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never "
// "used and treated as `self`\n";
model->remove_parameter(self);
} else {
std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n";
}
}
return model;
} catch (const std::runtime_error& e) {
std::cerr << "[ ERROR ] Unexpected error while converting pytorch model: " << e.what() << "\n";
Expand All @@ -104,6 +88,10 @@ std::shared_ptr<Model> FrontEnd::decode(const InputModel::Ptr& model) const {
void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
ov::pass::Manager manager;

manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::UnrollIf>();
// Have to run UnrollIf second time, because conditions are defined outside of nested If (ticket 98155)
manager.register_pass<ngraph::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
Expand All @@ -116,6 +104,23 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.run_passes(model);

apply_pytorch_conversion_transforms(model);

// TODO: Propose better solution for the next code block
// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
// inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any
// consumers in the finally converted model and remove this parameter. This parameter should have index 0.
if (model->get_parameters().size() > 0) {
auto self = model->get_parameters()[0];
if (self->output(0).get_target_inputs().empty()) {
// There is no consumers: safe to remove
// std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never "
// "used and treated as `self`\n";
model->remove_parameter(self);
} else {
std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n";
}
}
}

void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
Expand Down
Loading

0 comments on commit 636be1a

Please sign in to comment.