Skip to content

Commit

Permalink
feat(fast cli): Import torch lazily in all places used by the CLI tha…
Browse files Browse the repository at this point in the history
…t don't need a model (#1349)

These changes add a little complexity with the lazy and local imports, but
they also greatly improve the CLI's response for --help, list, and where.

Changes:

* Move `import torch` into function bodies that need them
* Use `importlib.metadata.version` to check the torch version rather than
  torch.__version__
* Switch from using torch.inference_mode as a decorator to using it as a
  context manager.
  * I also removed it from convert_hf_checkpoint_to_tune since that does
    not use torch at all
* In build_utils, wrap the dtype values in lambdas so they're lazily
  fetched.

#1347
Branch: FasterCli-1347

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart authored Nov 7, 2024
1 parent ac02ffb commit 170581a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 38 deletions.
20 changes: 13 additions & 7 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,16 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib.metadata
import json
import logging
import os
import sys
from pathlib import Path

import torch

from torchchat.cli.download import download_and_convert, is_model_downloaded

from torchchat.utils.build_utils import (
allowable_dtype_names,
allowable_params_table,
get_device_str,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
Expand All @@ -42,6 +38,9 @@

# Handle CLI arguments that are common to a majority of subcommands.
def check_args(args, verb: str) -> None:
# Local import to avoid unnecessary expensive imports
from torchchat.cli.download import download_and_convert, is_model_downloaded

# Handle model download. Skip this for download, since it has slightly
# different semantics.
if (
Expand Down Expand Up @@ -498,9 +497,10 @@ def _add_speculative_execution_args(parser) -> None:


def arg_init(args):
if not (torch.__version__ > "2.3"):
torch_version = importlib.metadata.version("torch")
if not torch_version or (torch_version <= "2.3"):
raise RuntimeError(
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
)

if sys.version_info.major != 3 or sys.version_info.minor < 10:
Expand All @@ -521,6 +521,9 @@ def arg_init(args):
raise RuntimeError("Device not supported by ExecuTorch")
args.device = "cpu"
else:
# Localized import to minimize expensive imports
from torchchat.utils.build_utils import get_device_str

args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", args.device)
)
Expand All @@ -534,5 +537,8 @@ def arg_init(args):
vars(args)["compile_prefill"] = False

if hasattr(args, "seed") and args.seed:
# Localized import to minimize expensive imports
import torch

torch.manual_seed(args.seed)
return args
25 changes: 12 additions & 13 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,23 @@
from pathlib import Path
from typing import Optional

import torch

from torchchat.model import TransformerArgs

# support running without installing as a package
wd = Path(__file__).parent.parent
sys.path.append(str(wd.resolve()))
sys.path.append(str((wd / "build").resolve()))

from torchchat.model import ModelArgs


@torch.inference_mode()
def convert_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
) -> None:

# Local imports to avoid expensive imports
from torchchat.model import ModelArgs, TransformerArgs
import torch

if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
if model_name is None:
Expand Down Expand Up @@ -58,10 +56,11 @@ def convert_hf_checkpoint(
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
# Confirm we can load it
loaded_result = torch.load(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
with torch.inference_mode():
loaded_result = torch.load(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
Expand Down Expand Up @@ -130,7 +129,8 @@ def load_safetensors():
state_dict = None
for loader in loaders:
try:
state_dict = loader()
with torch.inference_mode():
state_dict = loader()
break
except Exception:
continue
Expand Down Expand Up @@ -173,7 +173,6 @@ def load_safetensors():
os.remove(file)


@torch.inference_mode()
def convert_hf_checkpoint_to_tune(
*,
model_dir: Optional[Path] = None,
Expand Down
50 changes: 32 additions & 18 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,31 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

##########################################################################
### unpack packed weights ###


class _LazyImportTorch:
"""This is a wrapper around the import of torch that only performs the
import when an actual attribute is needed off of torch.
"""
@staticmethod
def __getattribute__(name: str) -> Any:
import torch
return getattr(torch, name)


# Alias torch to the lazy import
torch = _LazyImportTorch()


def unpack_packed_weights(
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
input_dtype: "torch.dtype",
unpacked_dims: Tuple,
) -> torch.Tensor:
) -> "torch.Tensor":
"""Given a packed weight matrix `packed_weights`, a Callable
implementing a packed linear function for the packed format, and the
unpacked dimensions of the weights, recreate the unpacked weight
Expand Down Expand Up @@ -169,26 +182,27 @@ def name_to_dtype(name, device):
return torch.bfloat16

try:
return name_to_dtype_dict[name]
return _name_to_dtype_dict[name]()
except KeyError:
raise RuntimeError(f"unsupported dtype name {name} specified")


def allowable_dtype_names() -> List[str]:
return name_to_dtype_dict.keys()


name_to_dtype_dict = {
"fp32": torch.float,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"float": torch.float,
"half": torch.float16,
"float32": torch.float,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fast": None,
"fast16": None,
return _name_to_dtype_dict.keys()


# NOTE: values are wrapped in lambdas to avoid proactive imports for torch
_name_to_dtype_dict = {
"fp32": lambda: torch.float,
"fp16": lambda: torch.float16,
"bf16": lambda: torch.bfloat16,
"float": lambda: torch.float,
"half": lambda: torch.float16,
"float32": lambda: torch.float,
"float16": lambda: torch.float16,
"bfloat16": lambda: torch.bfloat16,
"fast": lambda: None,
"fast16": lambda: None,
}


Expand Down

0 comments on commit 170581a

Please sign in to comment.