Skip to content

Commit

Permalink
make dtype selection a function of host and device (#768)
Browse files Browse the repository at this point in the history
* address #651 by making dtype selection a function of host and device

* fix typo

* typo

* typo
  • Loading branch information
mikekgfb authored May 13, 2024
1 parent 262d5de commit 49651ab
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-readme-pr-mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
conda activate test-quantization-mps-macos
# NS: Remove previous installation of torch first
# as this script does not isntall anything into conda env
#but rather system dep
#but rather system dep
pip3 uninstall -y torch || true
set -eou pipefail
Expand Down
5 changes: 4 additions & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,14 @@ def from_args(cls, args): # -> BuilderArgs:

if args.output_pte_path and args.dtype.startswith("fast"):
if args.dtype == "fast":
# As per Kimish, float32 should be faster on ET XNNPACK
# (because fp16 is implemented as upcast to fp32 for several
# operators, and in particular a8w4dq and ET's sdpa+kv)
dtype = torch.float32
else:
dtype = torch.float16
else:
dtype = name_to_dtype(args.dtype)
dtype = name_to_dtype(args.dtype, args.device)

return cls(
checkpoint_dir=checkpoint_dir,
Expand Down
6 changes: 4 additions & 2 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ def get_precision():
### dtype name to torch.dtype mapping ###


def name_to_dtype(name):
def name_to_dtype(name, device):
if (name == "fast") or (name == "fast16"):
# MacOS now supports bfloat16
import platform

if platform.processor() == "arm":
if int(platform.mac_ver()[0].split(".")[0]) < 14:
device=get_device_str(device)
# ARM CPU is faster with float16, MPS with bf16 if supported
if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
return torch.float16
return torch.bfloat16

Expand Down
2 changes: 1 addition & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
self.tokenizer = tokenizer

if isinstance(dtype, str):
dtype = name_to_dtype(dtype)
dtype = name_to_dtype(dtype, device)
self.dtype = dtype

def create_quantized_state_dict(self) -> Dict: # "StateDict"
Expand Down

0 comments on commit 49651ab

Please sign in to comment.