Skip to content

Commit

Permalink
TorchFX: Optional partition size limit
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Sep 15, 2023
1 parent b97ef1c commit 5015d7a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def _call(*args):
with torch.no_grad():
model.eval()
partitioner = Partitioner()
compiled_model = partitioner.make_partitions(model)
compiled_model, num_partitions = partitioner.make_partitions(model)
if num_partitions == 0:
return compile_fx(subgraph, example_inputs)

if executor_parameters is not None and 'model_hash_str' in executor_parameters:
# Check if the model is fully supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import typing as t
import logging
import os

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -59,7 +60,15 @@ def make_partitions(self, graph_module: GraphModule) -> GraphModule:
partitioner = CapabilityBasedPartitioner(
graph_module, self.supported_ops, allows_single_node_partition=False)
partitions = partitioner.propose_partitions()
self.add_get_attr_inputs(partitions)
fused_graph_module = partitioner.fuse_partitions(partitions)

return fused_graph_module
new_partitions = []
min_num_nodes = 0
if os.getenv("OPENVINO_TORCH_MIN_NUM_NODES") is not None:
min_num_nodes = int(os.getenv("OPENVINO_TORCH_MIN_NUM_NODES"))
for part in partitions:
if len(part.nodes) > min_num_nodes:
new_partitions.append(part)
self.add_get_attr_inputs(new_partitions)
fused_graph_module = partitioner.fuse_partitions(new_partitions)

return fused_graph_module, len(new_partitions)

0 comments on commit 5015d7a

Please sign in to comment.