From c8486fcf05cbe45fcd799345b7c42dc10ecb92c6 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Sun, 12 May 2024 17:39:26 -0700 Subject: [PATCH] handle rando device for exported model in model builder (#759) * handle rando device for exported model in model builder * typo --- build/builder.py | 20 +++++++++++++------- build/utils.py | 8 ++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/build/builder.py b/build/builder.py index e094b4e7a..1f9cfe35f 100644 --- a/build/builder.py +++ b/build/builder.py @@ -19,7 +19,7 @@ from quantize import quantize_model from build.model import Transformer -from build.utils import device_sync, name_to_dtype +from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype @dataclass @@ -371,6 +371,12 @@ def _initialize_model( _set_gguf_kwargs(builder_args, is_et=is_pte, context="generate") if builder_args.dso_path: + if not is_cuda_or_cpu_device(builder_args.device): + print( + f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead" + ) + builder_args.device = "cpu" + # assert ( # quantize is None or quantize == "{ }" # ), "quantize not valid for exported DSO model. Specify quantization during export." @@ -381,12 +387,6 @@ def _initialize_model( print(f"Time to load model: {time.time() - t0:.02f} seconds") try: - if "mps" in builder_args.device: - print( - "Cannot load specified DSO to MPS. Attempting to load model to CPU instead" - ) - builder_args.device = "cpu" - # Replace model forward with the AOT-compiled forward # This is a hacky way to quickly demo AOTI's capability. # model is still a Python object, and any mutation to its @@ -399,6 +399,12 @@ def _initialize_model( except: raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") elif builder_args.pte_path: + if not is_cpu_device(builder_args.device): + print( + f"Cannot load specified PTE to {builder_args.device}. Attempting to load model to CPU instead" + ) + builder_args.device = "cpu" + # assert ( # quantize is None or quantize == "{ }" # ), "quantize not valid for exported PTE model. Specify quantization during export." diff --git a/build/utils.py b/build/utils.py index 940626328..417abbcb8 100644 --- a/build/utils.py +++ b/build/utils.py @@ -255,3 +255,11 @@ def get_device(device) -> str: else "mps" if is_mps_available() else "cpu" ) return torch.device(device) + + +def is_cuda_or_cpu_device(device) -> bool: + return device == "" or str(device) == "cpu" or ("cuda" in str(device)) + + +def is_cpu_device(device) -> bool: + return device == "" or str(device) == "cpu"