From 7b86dc30de050ac5f2a992c09c15ce8d7556449c Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 12 Dec 2024 00:25:59 -0800 Subject: [PATCH] Bump PT pin to 20241028 (#1419) * Bump PT pin to 20241014 * Push bump to 1028 and add migration to export_for_training --- install/install_requirements.sh | 4 ++-- torchchat/export.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 3e1f9a655..eab92a4f1 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -62,10 +62,10 @@ echo "Using pip executable: $PIP_EXECUTABLE" # NOTE: If a newly-fetched version of the executorch repo changes the value of # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20241013 +PYTORCH_NIGHTLY_VERSION=dev20241028 # Nightly version for torchvision -VISION_NIGHTLY_VERSION=dev20241013 +VISION_NIGHTLY_VERSION=dev20241028 # Nightly version for torchtune TUNE_NIGHTLY_VERSION=dev20241013 diff --git a/torchchat/export.py b/torchchat/export.py index 7c5243b68..979778b7c 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -122,7 +122,7 @@ def export_for_server( from executorch.exir.tracer import Value from torch._export import capture_pre_autograd_graph - from torch.export import export, ExportedProgram + from torch.export import export, export_for_training, ExportedProgram from torchchat.model import apply_rotary_emb, Attention from torchchat.utils.build_utils import get_precision @@ -238,7 +238,7 @@ def _to_core_aten( raise ValueError( f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" ) - core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes) + core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes) if verbose: logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") return core_aten_ep