Skip to content

Commit

Permalink
Refactor UnflattenedModule's adapt flat args (pytorch#140840)
Browse files Browse the repository at this point in the history
Test Plan: unblocks model launch

Differential Revision: D66014709

Pull Request resolved: pytorch#140840
Approved by: https://github.com/pianpwk
  • Loading branch information
angelayi authored and Ryo-not-rio committed Dec 2, 2024
1 parent 9c2a357 commit ac26005
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions torch/export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,31 @@ def _print_graph(self):
if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
print(mod.graph)

def _adapt_flat_args(self, flat_args, in_spec):
signature = self.module_call_graph[0].signature
if in_spec == signature.in_spec:
return flat_args

if self.flat_args_adapter is None:
raise TypeError(
"There is no flat args adapter sepcified. "
"Are you sure you are calling this with the right arguments? "
)
else:
flat_args = self.flat_args_adapter.adapt(
target_spec=signature.in_spec,
input_spec=in_spec,
input_args=flat_args,
)

if len(flat_args) != signature.in_spec.num_leaves:
raise TypeError(
f"Flat args adaption failed, number of args mismatch "
f"Adatped: {len(flat_args)} \n"
f"Exported module: {signature.in_spec.num_leaves}"
)
return flat_args

def forward(self, *args, **kwargs):
signature = self.module_call_graph[0].signature

Expand All @@ -544,26 +569,9 @@ def forward(self, *args, **kwargs):
f"Input treespec: {in_spec}. ",
f"Exported module treespec: {signature.in_spec}",
)
if self.flat_args_adapter is None:
raise TypeError(
"There is no flat args adapter sepcified. "
"Are you sure you are calling this with the right arguments? "
)
else:
if not self.adapted:
print("Adapting flat arg to match exported module's treespec")
flat_args = self.flat_args_adapter.adapt(
target_spec=signature.in_spec,
input_spec=in_spec,
input_args=flat_args,
)
self.adapted = True
if len(flat_args) != signature.in_spec.num_leaves:
raise TypeError(
f"Flat args adaption failed, number of args mismatch "
f"Adatped: {len(flat_args)} \n"
f"Exported module: {signature.in_spec.num_leaves}"
)
print("Adapting flat arg to match exported module's treespec")
flat_args = self._adapt_flat_args(flat_args, in_spec)
self.adapted = True

if self.check_input_constraints:
# Import here to avoid an unfortunate circular dependency.
Expand Down

0 comments on commit ac26005

Please sign in to comment.