Skip to content

Commit

Permalink
[PT FE] Fix issue with kwargs in signature (#19088)
Browse files Browse the repository at this point in the history
* Fix issue with kwargs in signature

* Update src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

* Fix problem with some ops in detectron2

* Use debug name for extra input signature
  • Loading branch information
mvafin authored Aug 9, 2023
1 parent 4ee47fc commit dafe437
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
11 changes: 9 additions & 2 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=N
self.pt_module = pt_module
self.raw_inputs = list(self.graph_element.inputs())
self.raw_outputs = list(self.graph_element.outputs())
if self._input_signature is not None and "self" in self.raw_inputs[0].debugName():
self._input_signature.insert(0, "self")
if self._input_signature is not None:
if "self" in self.raw_inputs[0].debugName():
self._input_signature.insert(0, "self")
if 0 < len(self._input_signature) < len(self.raw_inputs):
# last input is args input, we need to multiply that name by number of extra inputs
self._input_signature = self._input_signature[:-1]
n = len(self._input_signature)
for i in range(len(self.raw_inputs) - n):
self._input_signature.append(self.raw_inputs[i + n].debugName())

if isinstance(self.graph_element, torch.Graph):
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
Expand Down
17 changes: 2 additions & 15 deletions src/frontends/pytorch/src/utils_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ Output<Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}

std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node);
if (!quant_node) {
return nullptr;
}
Expand All @@ -168,19 +168,6 @@ std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node) {
return quant_node;
}

std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
if (!quant_node) {
return nullptr;
}
const auto& attrs = quant_node->get_attrs();
if (attrs.find(QuantizedPtNode::quantized_node_type_key) == attrs.end() ||
attrs.at(QuantizedPtNode::quantized_node_type_key) != type) {
return nullptr;
}
return quant_node;
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
3 changes: 1 addition & 2 deletions src/frontends/pytorch/src/utils_quantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ Output<Node> quantize(const NodeContext& context,
const Output<Node>& zero_point,
const Output<Node>& quantized_node);

std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(std::shared_ptr<Node> node);

namespace op {
/**
Expand Down

0 comments on commit dafe437

Please sign in to comment.