From ac26005cebc6ce198903bb5ca2e24fb635e60f84 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Sat, 16 Nov 2024 05:09:35 +0000 Subject: [PATCH] Refactor UnflattenedModule's adapt flat args (#140840) Test Plan: unblocks model launch Differential Revision: D66014709 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140840 Approved by: https://github.com/pianpwk --- torch/export/unflatten.py | 48 +++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 3a281e0523cbd5..57c18119c35cf7 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -519,6 +519,31 @@ def _print_graph(self): if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): print(mod.graph) + def _adapt_flat_args(self, flat_args, in_spec): + signature = self.module_call_graph[0].signature + if in_spec == signature.in_spec: + return flat_args + + if self.flat_args_adapter is None: + raise TypeError( + "There is no flat args adapter sepcified. " + "Are you sure you are calling this with the right arguments? " + ) + else: + flat_args = self.flat_args_adapter.adapt( + target_spec=signature.in_spec, + input_spec=in_spec, + input_args=flat_args, + ) + + if len(flat_args) != signature.in_spec.num_leaves: + raise TypeError( + f"Flat args adaption failed, number of args mismatch " + f"Adatped: {len(flat_args)} \n" + f"Exported module: {signature.in_spec.num_leaves}" + ) + return flat_args + def forward(self, *args, **kwargs): signature = self.module_call_graph[0].signature @@ -544,26 +569,9 @@ def forward(self, *args, **kwargs): f"Input treespec: {in_spec}. ", f"Exported module treespec: {signature.in_spec}", ) - if self.flat_args_adapter is None: - raise TypeError( - "There is no flat args adapter sepcified. " - "Are you sure you are calling this with the right arguments? " - ) - else: - if not self.adapted: - print("Adapting flat arg to match exported module's treespec") - flat_args = self.flat_args_adapter.adapt( - target_spec=signature.in_spec, - input_spec=in_spec, - input_args=flat_args, - ) - self.adapted = True - if len(flat_args) != signature.in_spec.num_leaves: - raise TypeError( - f"Flat args adaption failed, number of args mismatch " - f"Adatped: {len(flat_args)} \n" - f"Exported module: {signature.in_spec.num_leaves}" - ) + print("Adapting flat arg to match exported module's treespec") + flat_args = self._adapt_flat_args(flat_args, in_spec) + self.adapted = True if self.check_input_constraints: # Import here to avoid an unfortunate circular dependency.