Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#12 from cavusmustafa/fx_backend_ge…
Browse files Browse the repository at this point in the history
…tattr_support

fx_backend: include get_attr ops to the partitions
  • Loading branch information
cavusmustafa authored Apr 18, 2023
2 parents fc457ae + 067be96 commit 308f68c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
27 changes: 26 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
import torch
import numpy as np
import inspect
import ctypes

def fetch_attr(self_module, target : str):
"""
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
Args:
target (str): The fully-qualified name of the attribute to fetch
Return:
Any: The value of the attribute.
"""
target_atoms = target.split('.')
attr_itr = self_module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr

def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)
Expand Down Expand Up @@ -95,6 +114,10 @@ def get_value_from_getattr(getattr_node, self_module):
"torch.BoolTensor": OVType.boolean,
}

ov_to_c_type_map = {
OVType.f32: ctypes.c_float,
OVType.i32: ctypes.c_int,
}

class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True):
Expand Down Expand Up @@ -662,7 +685,9 @@ def as_constant(self):
ovshape = PartialShape(ret.size())
ovtype = pt_to_ov_type_map[ret.type()]
print(ovshape, ovtype)
ov_const = make_constant(ovtype, ovshape.get_shape(), ret.data_ptr())
c_type = ctypes.POINTER(ov_to_c_type_map[ovtype])
data_c_ptr = ctypes.cast(ret.data_ptr(), c_type)
ov_const = op.Constant(ovtype, ovshape.get_shape(), data_c_ptr[:ret.nelement()])
print('Made constant')
return ov_const.outputs()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self):
"torch.ops.aten.addmm.default": None,
"_operator.getitem": None,
"torch.ops.aten.t.default": None,
"torch.ops.aten.empty.memory_format": None
#"torch.ops.aten.empty.memory_format": None
}

super().__init__(support_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch
from torch.nn import Module
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition

from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
from torch._decomp import decomposition_table
Expand All @@ -29,7 +29,19 @@ def fx_serialize(self, graph_module: GraphModule, *args, **kwargs):
#DecompositionInterpreter(fx_gm, prim_graph, decomposition_table=aten2aten_decomp).run(*args, **kwargs)
#prim_module = torch.fx.GraphModule(fx_gm, prim_graph)
return fx_gm #prim_module


def add_get_attr_inputs(self, partitions: t.List[Partition]):
#TODO: Find a more efficient way to include input
#"get_attr" nodes to the partitions.
getattr_to_merge : Dict[Node, Node] = {}
for partition in partitions:
for pnode in partition.nodes:
for pnode_input in pnode.all_input_nodes:
if pnode_input.op in ['get_attr']:
if pnode_input.op not in getattr_to_merge:
getattr_to_merge[pnode_input] = partition
for getattr_node, getattr_part in getattr_to_merge.items():
getattr_part.add_node(getattr_node)

def make_partitions(self, graph_module: GraphModule) -> GraphModule:
# entry function for nvFuser backend
Expand All @@ -38,7 +50,9 @@ def make_partitions(self, graph_module: GraphModule) -> GraphModule:
# FX graph based partitioning based on nvfuser supported ops
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=True)
fused_graph_module = partitioner.partition_and_fuse()
partitions = partitioner.propose_partitions()
self.add_get_attr_inputs(partitions)
fused_graph_module = partitioner.fuse_partitions(partitions)

return fused_graph_module

Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"prim::Constant", op::translate_constant},
{"prim::device", op::translate_constant},
{"prim::GetAttr", op::translate_get_attr},
{"get_attr", op::translate_constant},
{"prim::If", op::translate_if},
{"prim::is_cuda", op::return_false_scalar},
{"prim::ListConstruct", op::translate_list_construct},
Expand Down

0 comments on commit 308f68c

Please sign in to comment.