Skip to content

Commit

Permalink
cmake torchao_ops_mps_linear_fp_act_xbit_weight
Browse files Browse the repository at this point in the history
Differential Revision: D66120124

Pull Request resolved: #1304
  • Loading branch information
manuelcandales authored Nov 21, 2024
1 parent 7446433 commit 7489c7d
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 46 deletions.
21 changes: 15 additions & 6 deletions torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import os
import sys
import yaml

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"
if len(sys.argv) != 2:
print("Usage: gen_metal_shader_lib.py <output_file>")
sys.exit(1)

# Output file where the generated code will be written
OUTPUT_FILE = sys.argv[1]

MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
Expand All @@ -21,9 +32,6 @@
# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")

# Output file where the generated code will be written
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")

prefix = """/**
* This file is generated by gen_metal_shader_lib.py
*/
Expand All @@ -48,6 +56,7 @@
"""

os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
with open(OUTPUT_FILE, "w") as outf:
outf.write(prefix)
for file in metal_files:
Expand Down
1 change: 1 addition & 0 deletions torchao/experimental/ops/mps/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cmake-out/
60 changes: 60 additions & 0 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

project(torchao_ops_mps_linear_fp_act_xbit_weight)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)

if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(FATAL_ERROR "Unified Memory requires Apple Silicon architecture")
endif()
else()
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
endif()

find_package(Torch REQUIRED)

# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})

if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
endif()
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)

target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)

# Enable Metal support
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

install(
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
EXPORT _targets
DESTINATION lib
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// LICENSE file in the root directory of this source tree.

// clang-format off
#include <torch/extension.h>
#include <torch/library.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/lowbit.h>
// clang-format on
Expand Down Expand Up @@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
return B;
}

// Registers _C as a Python extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

TORCH_LIBRARY(torchao, m) {
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
Expand Down
19 changes: 19 additions & 0 deletions torchao/experimental/ops/mps/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cd "$(dirname "$BASH_SOURCE")"

export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=${PWD}/cmake-out
echo "CMAKE_OUT: ${CMAKE_OUT}"

cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-S . \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release
23 changes: 0 additions & 23 deletions torchao/experimental/ops/mps/setup.py

This file was deleted.

37 changes: 25 additions & 12 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,38 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized

def parameterized(test_cases):
def decorator(func):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
func(self, *case)
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

return wrapper

return decorator
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestLowBitQuantWeightsLinear(unittest.TestCase):
cases = [
CASES = [
(nbit, *param)
for nbit in range(1, 8)
for param in [
Expand Down Expand Up @@ -73,7 +86,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized(cases)
@parameterized.expand(CASES)
def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):
print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}")
A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit)
Expand Down
23 changes: 22 additions & 1 deletion torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,34 @@
import sys

import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize

libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
Expand Down

0 comments on commit 7489c7d

Please sign in to comment.