Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal lowbit kernels: executorch ops #1322

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
*/

#ifdef USE_ATEN
using namespace at::native::mps;
using at::native::mps::MetalShaderLibrary;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
#endif

static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
Expand Down
78 changes: 78 additions & 0 deletions torchao/experimental/kernels/mps/src/MetalShaderLibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#ifdef USE_EXECUTORCH
#include <executorch/backends/apple/mps/runtime/MPSDevice.h>
using executorch::backends::mps::delegate::MPSDevice;
static id<MTLDevice> MTL_DEVICE = MPSDevice::getInstance()->device();
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
#endif
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
#endif
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
#endif
return cpl;
}
};
65 changes: 0 additions & 65 deletions torchao/experimental/kernels/mps/src/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,63 +40,6 @@ inline id<MTLDevice> getMetalDevice() {

static id<MTLDevice> MTL_DEVICE = getMetalDevice();

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;
}
};

class MPSStream {
public:
MPSStream() {
Expand Down Expand Up @@ -136,14 +79,6 @@ class MPSStream {
id<MTLComputeCommandEncoder> _commandEncoder = nil;
};

inline void finalize_block(MPSStream* mpsStream) {
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
}

inline MPSStream* getCurrentMPSStream() {
return new MPSStream();
}
24 changes: 19 additions & 5 deletions torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <torchao/experimental/kernels/mps/src/dispatch.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
#include <torchao/experimental/kernels/mps/src/packing.h>

#include <cassert>
Expand All @@ -20,9 +20,9 @@
#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
dispatch_sync_with_rethrow;
#elif defined(USE_EXECUTORCH)
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
using namespace executorch::backends::mps::delegate;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif
Expand Down Expand Up @@ -103,7 +103,13 @@ inline void linear_lowbit_quant_weights_mps_impl(
0};

MPSStream* mpsStream = getCurrentMPSStream();
#ifdef USE_ATEN
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
#elif defined(USE_EXECUTORCH)
dispatch_sync(mpsStream->queue(), ^() {
#else
dispatch_block(mpsStream->queue(), ^() {
#endif
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLComputePipelineState> cpl =
Expand All @@ -119,7 +125,15 @@ inline void linear_lowbit_quant_weights_mps_impl(
length:sizeof(uint32_t) * sizes.size()
atIndex:5];
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
finalize_block(mpsStream);
#ifdef USE_EXECUTORCH
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
#else
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
#endif
}
});
}
Expand Down
24 changes: 24 additions & 0 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,27 @@ install(
EXPORT _targets
DESTINATION lib
)

if(TORCHAO_BUILD_EXECUTORCH_OPS)
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/executorch/*.mm")
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT ${_SRCS})
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

add_library(torchao_ops_mps_executorch STATIC)
target_link_libraries(torchao_ops_mps_executorch PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
)
install(
TARGETS
torchao_ops_mps_executorch
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
EXPORT _targets
DESTINATION lib
)
endif()
43 changes: 39 additions & 4 deletions torchao/experimental/ops/mps/aten/register.mm
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ void check_linear_mps_args(
}

template <int nbit>
Tensor linear_mps_kernel(
Tensor linear_mps_kernel_out(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& S,
const Tensor& Z) {
const Tensor& Z,
Tensor& C) {
TORCH_CHECK(
A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps");
TORCH_CHECK(
Expand All @@ -84,15 +85,15 @@ Tensor linear_mps_kernel(
S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps");
TORCH_CHECK(
Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");
TORCH_CHECK(
C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");

check_linear_mps_args<nbit>(A, B, group_size, S, Z);

auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);

auto C = at::empty({M, N}, A.options());

LowBitQuantWeights<nbit>::linear(
getMTLBufferStorage(A),
getMTLBufferStorage(B),
Expand All @@ -108,6 +109,19 @@ Tensor linear_mps_kernel(
return C;
}

template <int nbit>
Tensor linear_mps_kernel(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& S,
const Tensor& Z) {
auto M = A.size(0);
auto N = B.size(0);
auto C = at::empty({M, N}, A.options());
return linear_mps_kernel_out<nbit>(A, B, group_size, S, Z, C);
}

template <int nbit>
Tensor linear_mps_kernel_meta(
const Tensor& A,
Expand Down Expand Up @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
"_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
m.def(
"_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
m.def(
"_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
Expand All @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>);
m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>);
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>);
m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>);
m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>);
m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>);
m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>);
m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>);
m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>);
m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
Expand Down
Loading
Loading