Skip to content

Commit

Permalink
Add torchao mps ops
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelcandales committed Dec 10, 2024
1 parent e0ce144 commit 0f1825c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
25 changes: 25 additions & 0 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ Note: only the ExecuTorch C++ runner in torchchat when built using the instructi
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
```

## Experimental TorchAO MPS lowbit kernels

WARNING: These kernels only work on devices with Apple Silicon.

### Use

#### linear:fpaxw
The quantization scheme linear:fpaxw quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize.
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256).

### Setup
To use linear:fpaxw, you must set up the torchao mps experimental kernels. These will only work on device with Apple Silicon.

From the torchchat root directory, run
```
sh torchchat/utils/scripts/build_torchao_ops.sh mps
```

### Examples

#### Eager mode
```
python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:fpaxw": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5
```

## Quantization Profiles

Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json`
Expand Down
38 changes: 21 additions & 17 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
Expand All @@ -80,8 +80,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand Down Expand Up @@ -116,15 +116,18 @@ def quantize_model(
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
set_precision(torch.float32)

# We set global precision from quantize options if it is specified at cli.py:485

if quantizer == "linear:fpaxw" and device != "mps":
raise RuntimeError("linear:fpaxw quantization can only run on mps device!")

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

Expand Down Expand Up @@ -915,10 +918,12 @@ def quantized_model(self) -> nn.Module:
from torchao_experimental_quant_api import (
Int8DynActIntxWeightLinearQuantizer,
IntxWeightEmbeddingQuantizer,
UIntxWeightOnlyLinearQuantizer,
)

quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
quantizer_class_dict["linear:fpaxw"] = UIntxWeightOnlyLinearQuantizer

# Try loading custom op
try:
Expand All @@ -928,15 +933,14 @@ def quantized_model(self) -> nn.Module:
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])
except Exception as e:
print("Failed to torchao ops library with error: ", e)
print("Slow fallback kernels will be used.")
print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.")

try:
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
torch.ops.load_library(libpath)
except Exception as e:
print("Unabled to load torchao mps ops library.")

except Exception as e:
class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler
print("Unabled to import torchao experimental quant_api with error: ", e)
7 changes: 6 additions & 1 deletion torchchat/utils/scripts/build_torchao_ops.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

device=${1:-cpu}

if [[ "$device" != "cpu" && "$device" != "mps" ]]; then
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
exit 1
fi

source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"

pushd ${TORCHCHAT_ROOT}
find_cmake_prefix_path
clone_torchao
install_torchao_aten_ops
install_torchao_aten_ops "$device"
popd
22 changes: 18 additions & 4 deletions torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,35 @@ clone_torchao() {

git clone https://github.com/pytorch/ao.git
cd ao
git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)
# The next two lines will be removed before landing this PR
# Instead, the torcha-pin.txt will be updated once ao PR #1322 lands
git fetch origin pull/1322/head:pr-1322
git checkout pr-1322
# git checkout $(cat ${TORCHCHAT_ROOT}/install/.pins/torchao-pin.txt)

popd
}

install_torchao_aten_ops() {
echo "Building torchao custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
local device=${1:-cpu}

if [[ "$device" == "cpu" ]]; then
echo "Building torchao custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
elif [[ "$device" == "mps" ]]; then
echo "Building torchao mps custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/ops/mps
else
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
return 1
fi

CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
-DCMAKE_BUILD_TYPE="Release" \
-S . \
-B ${CMAKE_OUT_DIR} -G Ninja
-B ${CMAKE_OUT_DIR}
cmake --build ${CMAKE_OUT_DIR} --target install --config Release

popd
Expand Down

0 comments on commit 0f1825c

Please sign in to comment.