Skip to content

Commit

Permalink
Small refactoring (openvinotoolkit#94)
Browse files Browse the repository at this point in the history
* Small refactoring

* Fix type

* Fix python codestyle
  • Loading branch information
mvafin authored Jan 12, 2023
1 parent 40dd26c commit f8a6694
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 86 deletions.
19 changes: 12 additions & 7 deletions src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)


def get_type_from_py_type(value):
if isinstance(value, float):
return OVType.f32
Expand All @@ -22,6 +23,7 @@ def get_type_from_py_type(value):
return OVType.boolean
return OVType.dynamic


def ivalue_to_constant(ivalue):
ov_type = get_type_from_py_type(ivalue)
if ov_type.is_static():
Expand All @@ -38,15 +40,16 @@ def ivalue_to_constant(ivalue):
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[ivalue.type()]
ov_const = make_constant(ovtype, ovshape.get_shape(), ivalue.data_ptr())
except:
except Exception:
# old variant that makes a slow data copying
print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.")
print("[ WARNING ] Constant wasn't able to convert from data_ptr.")
nvalues = ivalue.numpy()
ovtype = np_to_ov_type_map[str(nvalues.dtype)]
ovshape = PartialShape(nvalues.shape)
ov_const = make_constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()


def get_value_from_getattr(getattr_node, self_module):
assert getattr_node.kind() == 'prim::GetAttr', "Got node of kind not equal to prim::GetAttr"
# GetAttr nodes can be nested
Expand All @@ -60,10 +63,11 @@ def get_value_from_getattr(getattr_node, self_module):
module = self_module
while len(stack) > 0:
node = stack.pop()
assert(hasattr(module, node.s('name')))
assert (hasattr(module, node.s('name')))
module = getattr(module, node.s('name'))
return module


pt_to_ov_type_map = {
'float': OVType.f32,
'int': OVType.i32,
Expand Down Expand Up @@ -307,7 +311,7 @@ def as_constant_tensor(self, pt_value):
if pt_value.isCompleteTensor():
try:
ivalue = ivalue.to(memory_format=torch.contiguous_format).detach().cpu()
except:
except Exception:
print("[ WARNING ] Tensor couldn't detach")
if str(pt_value.type().dtype()) in pt_to_ov_type_map:
# Constant interpretation doesn't respect new-full type of PT
Expand All @@ -322,9 +326,9 @@ def as_constant_tensor(self, pt_value):
# TODO Check strides and pass them somehow
values = ivalue.data_ptr()
ov_const = make_constant(ovtype, ovshape.get_shape(), values)
except:
except Exception:
# old variant that makes a slow data copying
print(f"[ WARNING ] Constant wasn't able to convert from data_ptr.")
print("[ WARNING ] Constant wasn't able to convert from data_ptr.")
values = ivalue.flatten().tolist()
ov_const = make_constant(ovtype, ovshape.get_shape(), values)
return ov_const.outputs()
Expand All @@ -333,7 +337,8 @@ def as_constant_tensor(self, pt_value):
return None

def as_constant_list(self, pt_value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively rewrite them in that part where constant attributes are queried
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
# rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
ivalue = pt_value.toIValue()
# print(f'List toIValue: {ivalue}, type of it: {type(ivalue)}')
Expand Down
3 changes: 2 additions & 1 deletion src/core/src/op/util/framework_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ void ov::op::util::FrameworkNode::validate_and_infer_types() {
reset_output_shape_to_dynamic = true;
} else {
NODE_VALIDATION_CHECK(this,
m_inputs_desc[i] == std::make_tuple(input_pshape, input_type),
std::get<0>(m_inputs_desc[i]).compatible(input_pshape) &&
std::get<1>(m_inputs_desc[i]).compatible(input_type),
get_error_message());
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op/arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_arange(NodeContext& context) {
auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1}));
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace frontend {
namespace pytorch {
namespace op {

namespace {
OutputVector base_expand(NodeContext& context, ov::Output<ov::Node> x, ov::Output<ov::Node> sizes) {
auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1}));
auto sizes_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(sizes, element::i32));
Expand All @@ -21,6 +22,7 @@ OutputVector base_expand(NodeContext& context, ov::Output<ov::Node> x, ov::Outpu
auto shape = context.mark_node(std::make_shared<opset8::Select>(neg_sizes, ones, sizes));
return {std::make_shared<opset8::Broadcast>(x, shape, ov::op::BroadcastType::BIDIRECTIONAL)};
};
} // namespace

OutputVector translate_expand(NodeContext& context) {
auto x = context.get_input(0);
Expand Down
4 changes: 3 additions & 1 deletion src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace frontend {
namespace pytorch {
namespace op {

namespace {
ov::Output<Node> base_translate_full(NodeContext& context, ov::Output<Node> sizes, ov::Output<Node> value) {
return context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
}
Expand All @@ -34,6 +35,7 @@ ov::Output<Node> base_translate_full_with_convertlike(NodeContext& context,
auto filled_tensor = base_translate_full(context, sizes, value);
return context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out));
}
} // namespace

OutputVector translate_full(NodeContext& context) {
auto sizes = context.get_input(0);
Expand Down Expand Up @@ -103,7 +105,7 @@ OutputVector translate_new_zeros(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
if (context.get_input_size() == 6 && !context.input_is_none(2)){
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op/im2col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

namespace {
std::shared_ptr<Node> get_im2col_indices_along_dim(NodeContext& context,
ov::Output<Node> input_d,
Expand Down
10 changes: 6 additions & 4 deletions src/frontends/pytorch/src/op/upsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::InterpolateMode interpolate_mode) {
namespace {
OutputVector base_translate_upsample2d(NodeContext& context, opset8::Interpolate::InterpolateMode interpolate_mode) {
auto data = context.get_input(0);
std::vector<size_t> pad{0};
auto size_mode = opset8::Interpolate::ShapeCalcMode::SIZES;
Expand Down Expand Up @@ -47,17 +48,18 @@ OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::Int
}
return {context.mark_node(std::make_shared<opset8::Interpolate>(data, output_sizes, scales, target_axes, attrs))};
};
} // namespace

OutputVector translate_upsample_bilinear2d(NodeContext& context) {
return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::LINEAR_ONNX);
return base_translate_upsample2d(context, opset8::Interpolate::InterpolateMode::LINEAR_ONNX);
};

OutputVector translate_upsample_nearest2d(NodeContext& context) {
return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::NEAREST);
return base_translate_upsample2d(context, opset8::Interpolate::InterpolateMode::NEAREST);
};

OutputVector translate_upsample_bicubic2d(NodeContext& context) {
return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::CUBIC);
return base_translate_upsample2d(context, opset8::Interpolate::InterpolateMode::CUBIC);
};

} // namespace op
Expand Down
75 changes: 2 additions & 73 deletions src/frontends/pytorch/src/pt_framework_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,25 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {
: ov::op::util::FrameworkNode(inputs, output_size, decoder->get_subgraph_size()),
m_decoder(decoder) {
ov::op::util::FrameworkNodeAttrs attrs;
// std::cerr << "[ DEBUG ] Making PtFrameworkNode for " << m_decoder->get_op_type() << "\n";
attrs.set_type_name("PTFrameworkNode");
attrs["PtTypeName"] = m_decoder->get_op_type();
attrs["PtSchema"] = m_decoder->get_schema();
set_attrs(attrs);

// std::cout << attrs["PtTypeName"] << std::endl;

// Set output shapes and types if recognized

for (size_t i = 0; i < output_size; ++i) {
PartialShape ps;
// TODO: Try to decode PT type as a custom type
Any type = element::dynamic;
// FIXME: ROUGH
auto type = element::dynamic;
if (i < decoder->num_of_outputs()) {
try {
ps = m_decoder->get_output_shape(i);
} catch (...) {
// nothing, means the info cannot be queried and remains unknown
}
// FIXME: ROUGH
try {
type = simplified_type_interpret(m_decoder->get_output_type(i));
} catch (std::runtime_error& e) {
// nothing, means the info cannot be queried and remains unknown
std::cerr << "[ ERROR ] Cannot retrieve type\n" << e.what() << std::endl;
} catch (...) {
std::cerr << "[ ERROR ] Cannot retrieve type, not recognized exception\n";
}
} else {
// std::cerr << "[ WARNING ] Cannot retrieve type for output not existent in pt node: "
// << m_decoder->get_op_type() << " with 0 input: " << m_decoder->input(0) << std::endl;
}
// Let's see what type we have
// std::cout << "Can be represented as element::Type: " << type.is<element::Type>() << std::endl;
// std::cout << "element::Type value: " << type.as<element::Type>() << "\n";
// std::exit(0);

// TODO: Set custom `type` via special API
set_output_type(i, element::dynamic, ps);
set_output_type(i, type, ps);
}
}

Expand Down Expand Up @@ -91,55 +69,6 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {
return m_decoder.get();
}

bool visit_attributes(AttributeVisitor& visitor) override {
bool parent_visit_result = FrameworkNode::visit_attributes(visitor);
// TODO: correctly serialize bodies and descriptors. Only 1st body information can be serialized.
for (size_t i = 0; i < m_bodies.size(); ++i) {
visitor.on_attribute("body" + std::to_string(i), m_bodies[i]);
// visitor.on_attribute("input_descriptions" + std::to_string(i), m_input_descriptions[i]);
// visitor.on_attribute("output_descriptions", m_output_descriptions[i]);
}
return parent_visit_result;
}

void validate_and_infer_types() override {
for (int i = 0; i < m_bodies.size(); i++) {
// Input
for (const auto& input_description : m_input_descriptions[i]) {
auto index = input_description->m_input_index;
if (auto invariant_input_description =
ov::as_type_ptr<ov::op::util::MultiSubGraphOp::InvariantInputDescription>(input_description)) {
auto body_parameter =
m_bodies[i]->get_parameters().at(invariant_input_description->m_body_parameter_index);

auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();

body_parameter->set_partial_shape(input_partial_shape);
}
}

// Body
m_bodies[i]->validate_nodes_and_infer_types();

// Output
for (const auto& output_description : m_output_descriptions[i]) {
auto index = output_description->m_output_index;

const auto& body_value =
m_bodies[i]->get_results().at(output_description->m_body_value_index)->input_value(0).get_tensor();

if (auto body_output_description =
ov::as_type_ptr<ov::op::util::MultiSubGraphOp::BodyOutputDescription>(output_description)) {
const ov::PartialShape& ps = body_value.get_partial_shape();
auto et = body_value.get_element_type();
// TODO: Propagate custom type from body to the external in case if et is dynamic
set_output_type(index, et, ps);
}
}
}
}

private:
std::shared_ptr<Decoder> m_decoder;
};
Expand Down

0 comments on commit f8a6694

Please sign in to comment.