diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index eea7e4266..7764c0871 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -37,9 +37,9 @@ */ #ifdef USE_ATEN -using namespace at::native::mps; +using at::native::mps::MetalShaderLibrary; #else -#include +#include #endif static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT( diff --git a/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h new file mode 100644 index 000000000..325f758a9 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h @@ -0,0 +1,78 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#ifdef USE_EXECUTORCH +#include +using executorch::backends::mps::delegate::MPSDevice; +static id MTL_DEVICE = MPSDevice::getInstance()->device(); +#else +#include +#endif + +static id compileLibraryFromSource( + id device, + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; +#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling + if (library == nil) { + throw_exception( + "Failed to compile: " + std::string(error.description.UTF8String)); + } +#endif + return library; +} + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(device, shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + return get_compute_pipeline_state(load_func(fname)); + } + + private: + std::string shaderSource; + id device = MTL_DEVICE; + id lib = nil; + + id load_func(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; +#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling + if (func == nil) { + throw_exception("Can't get function:" + func_name); + } +#endif + return func; + } + + id get_compute_pipeline_state( + id func) const { + NSError* error = nil; + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; +#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling + if (cpl == nil) { + throw_exception( + "Failed to construct pipeline state: " + + std::string(error.description.UTF8String)); + } +#endif + return cpl; + } +}; diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h index 7cb902f23..7064c313d 100644 --- a/torchao/experimental/kernels/mps/src/OperationUtils.h +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -40,63 +40,6 @@ inline id getMetalDevice() { static id MTL_DEVICE = getMetalDevice(); -static id compileLibraryFromSource( - id device, - const std::string& source) { - NSError* error = nil; - MTLCompileOptions* options = [MTLCompileOptions new]; - [options setLanguageVersion:MTLLanguageVersion3_1]; - NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; - id library = [device newLibraryWithSource:kernel_source - options:options - error:&error]; - if (library == nil) { - throw_exception( - "Failed to compile: " + std::string(error.description.UTF8String)); - } - return library; -} - -class MetalShaderLibrary { - public: - MetalShaderLibrary(const std::string& src) : shaderSource(src) { - lib = compileLibraryFromSource(device, shaderSource); - } - MetalShaderLibrary(const MetalShaderLibrary&) = delete; - MetalShaderLibrary(MetalShaderLibrary&&) = delete; - - id getPipelineStateForFunc( - const std::string& fname) { - return get_compute_pipeline_state(load_func(fname)); - } - - private: - std::string shaderSource; - id device = MTL_DEVICE; - id lib = nil; - - id load_func(const std::string& func_name) const { - id func = [lib - newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; - if (func == nil) { - throw_exception("Can't get function:" + func_name); - } - return func; - } - - id get_compute_pipeline_state( - id func) const { - NSError* error = nil; - auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; - if (cpl == nil) { - throw_exception( - "Failed to construct pipeline state: " + - std::string(error.description.UTF8String)); - } - return cpl; - } -}; - class MPSStream { public: MPSStream() { @@ -136,14 +79,6 @@ class MPSStream { id _commandEncoder = nil; }; -inline void finalize_block(MPSStream* mpsStream) { - id encoder = mpsStream->commandEncoder(); - id cmdBuffer = mpsStream->commandBuffer(); - [encoder endEncoding]; - [cmdBuffer commit]; - [cmdBuffer waitUntilCompleted]; -} - inline MPSStream* getCurrentMPSStream() { return new MPSStream(); } diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index d37001350..6c9911352 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -10,7 +10,7 @@ #include #include -#include +#include // metal_lowbit_quantized_lib #include #include @@ -20,9 +20,9 @@ #ifdef USE_ATEN #include using namespace at::native::mps; -inline void finalize_block(MPSStream* mpsStream) {} -void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) = - dispatch_sync_with_rethrow; +#elif defined(USE_EXECUTORCH) +#include +using namespace executorch::backends::mps::delegate; #else #include #endif @@ -103,7 +103,13 @@ inline void linear_lowbit_quant_weights_mps_impl( 0}; MPSStream* mpsStream = getCurrentMPSStream(); +#ifdef USE_ATEN + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { +#elif defined(USE_EXECUTORCH) + dispatch_sync(mpsStream->queue(), ^() { +#else dispatch_block(mpsStream->queue(), ^() { +#endif @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); id cpl = @@ -119,7 +125,15 @@ inline void linear_lowbit_quant_weights_mps_impl( length:sizeof(uint32_t) * sizes.size() atIndex:5]; dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); - finalize_block(mpsStream); +#ifdef USE_EXECUTORCH + ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok); +#else + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +#endif } }); } diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 044433ef9..d3302ca65 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -58,3 +58,27 @@ install( EXPORT _targets DESTINATION lib ) + +if(TORCHAO_BUILD_EXECUTORCH_OPS) + include_directories(${CMAKE_INSTALL_PREFIX}/schema/include) + include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include) + file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/executorch/*.mm") + add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT ${_SRCS}) + add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib) + target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") + target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + + add_library(torchao_ops_mps_executorch STATIC) + target_link_libraries(torchao_ops_mps_executorch PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_executorch + ) + install( + TARGETS + torchao_ops_mps_executorch + torchao_ops_mps_linear_fp_act_xbit_weight_executorch + EXPORT _targets + DESTINATION lib + ) +endif() diff --git a/torchao/experimental/ops/mps/aten/register.mm b/torchao/experimental/ops/mps/aten/register.mm index 92a3ba89f..e11e55c5a 100644 --- a/torchao/experimental/ops/mps/aten/register.mm +++ b/torchao/experimental/ops/mps/aten/register.mm @@ -70,12 +70,13 @@ void check_linear_mps_args( } template -Tensor linear_mps_kernel( +Tensor linear_mps_kernel_out( const Tensor& A, const Tensor& B, int64_t group_size, const Tensor& S, - const Tensor& Z) { + const Tensor& Z, + Tensor& C) { TORCH_CHECK( A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps"); TORCH_CHECK( @@ -84,6 +85,8 @@ Tensor linear_mps_kernel( S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps"); TORCH_CHECK( Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); + TORCH_CHECK( + C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); check_linear_mps_args(A, B, group_size, S, Z); @@ -91,8 +94,6 @@ Tensor linear_mps_kernel( auto N = B.size(0); auto K = A.size(1); - auto C = at::empty({M, N}, A.options()); - LowBitQuantWeights::linear( getMTLBufferStorage(A), getMTLBufferStorage(B), @@ -108,6 +109,19 @@ Tensor linear_mps_kernel( return C; } +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto M = A.size(0); + auto N = B.size(0); + auto C = at::empty({M, N}, A.options()); + return linear_mps_kernel_out(A, B, group_size, S, Z, C); +} + template Tensor linear_mps_kernel_meta( const Tensor& A, @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); + m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>); + m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>); + m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>); + m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>); + m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>); + m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>); + m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { diff --git a/torchao/experimental/ops/mps/executorch/linear_fp_act_xbit_weight.h b/torchao/experimental/ops/mps/executorch/linear_fp_act_xbit_weight.h new file mode 100644 index 000000000..d79145579 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/linear_fp_act_xbit_weight.h @@ -0,0 +1,112 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::backends::mps::delegate::getMTLBufferStorage; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::tensor_is_rank; + +namespace { + +std::string scalar_type_to_string(const ScalarType& scalar_type) { + switch (scalar_type) { + case ScalarType::Float: + return "float"; + case ScalarType::Half: + return "half"; + case ScalarType::BFloat16: + return "bfloat"; + default: + ET_CHECK_MSG( + false, "Unsupported type by lowbit quantized linear"); + return "undefined"; + } +} + +template +bool check_linear_mps_args( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto N = B.size(0); + auto K = A.size(1); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + A.scalar_type() == ScalarType::BFloat16 || + A.scalar_type() == ScalarType::Half || + A.scalar_type() == ScalarType::Float, + "Expect A to be either 32-bit or 16-bit float tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + tensor_is_rank(A, 2), "Expect A to be 2D tensor."); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.scalar_type() == ScalarType::Byte, "Expect B to be uint8 tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.size(1) == (K / 8) * nbit, "Expect B.size(1) == (K / 8) * nbit"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE(K % 8 == 0, "Expect K to be multiple of 8"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256, + "Expect group_size to be 32, 64, 128 or 256"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + S.dim() == 2 && S.size(1) == N, + "Expect S to be 2d tensor with shape [:, N]"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + Z.dim() == 2 && Z.size(1) == N, + "Expect Z to be 2d tensor with shape [:, N]"); + + return true; +} + +template +Tensor& linear_mps_kernel_et_ctx_out( + KernelRuntimeContext& ctx, + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z, + Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + check_linear_mps_args(A, B, group_size, S, Z), + InvalidArgument, + out); + + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + torchao::kernels::mps::lowbit::LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(S), + getMTLBufferStorage(Z), + getMTLBufferStorage(out), + M, + K, + N, + scalar_type_to_string(A.scalar_type())); + + return out; +} + +} // namespace diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_1bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_1bit_weight.mm new file mode 100644 index 000000000..be440a103 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_1bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_1bit_weight.out", linear_mps_kernel_et_ctx_out<1>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_2bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_2bit_weight.mm new file mode 100644 index 000000000..ae57147c4 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_2bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_2bit_weight.out", linear_mps_kernel_et_ctx_out<2>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_3bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_3bit_weight.mm new file mode 100644 index 000000000..fbff97635 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_3bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_3bit_weight.out", linear_mps_kernel_et_ctx_out<3>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_4bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_4bit_weight.mm new file mode 100644 index 000000000..916acc460 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_4bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_4bit_weight.out", linear_mps_kernel_et_ctx_out<4>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_5bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_5bit_weight.mm new file mode 100644 index 000000000..48e6fd657 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_5bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_5bit_weight.out", linear_mps_kernel_et_ctx_out<5>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_6bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_6bit_weight.mm new file mode 100644 index 000000000..e36471508 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_6bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_6bit_weight.out", linear_mps_kernel_et_ctx_out<6>); diff --git a/torchao/experimental/ops/mps/executorch/op_linear_fp_act_7bit_weight.mm b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_7bit_weight.mm new file mode 100644 index 000000000..cd7142492 --- /dev/null +++ b/torchao/experimental/ops/mps/executorch/op_linear_fp_act_7bit_weight.mm @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow one +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_7bit_weight.out", linear_mps_kernel_et_ctx_out<7>);