Skip to content

Commit

Permalink
created function for adding get_attr nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Apr 18, 2023
1 parent d9a36a3 commit 067be96
Showing 1 changed file with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.nn import Module
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition

from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
from torch._decomp import decomposition_table
Expand All @@ -29,17 +29,8 @@ def fx_serialize(self, graph_module: GraphModule, *args, **kwargs):
#DecompositionInterpreter(fx_gm, prim_graph, decomposition_table=aten2aten_decomp).run(*args, **kwargs)
#prim_module = torch.fx.GraphModule(fx_gm, prim_graph)
return fx_gm #prim_module


def make_partitions(self, graph_module: GraphModule) -> GraphModule:
# entry function for nvFuser backend
# logger.debug("Compiling graph_module: ", graph_module.code)
print("Compiling graph_module: ", graph_module.code)
# FX graph based partitioning based on nvfuser supported ops
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()

def add_get_attr_inputs(self, partitions: t.List[Partition]):
#TODO: Find a more efficient way to include input
#"get_attr" nodes to the partitions.
getattr_to_merge : Dict[Node, Node] = {}
Expand All @@ -52,6 +43,15 @@ def make_partitions(self, graph_module: GraphModule) -> GraphModule:
for getattr_node, getattr_part in getattr_to_merge.items():
getattr_part.add_node(getattr_node)

def make_partitions(self, graph_module: GraphModule) -> GraphModule:
# entry function for nvFuser backend
# logger.debug("Compiling graph_module: ", graph_module.code)
print("Compiling graph_module: ", graph_module.code)
# FX graph based partitioning based on nvfuser supported ops
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()
self.add_get_attr_inputs(partitions)
fused_graph_module = partitioner.fuse_partitions(partitions)

return fused_graph_module
Expand Down

0 comments on commit 067be96

Please sign in to comment.