Skip to content

Commit

Permalink
metal lowbit kernels: executorch ops
Browse files Browse the repository at this point in the history
Summary:
Refactors kernels/mps/src/OperationUntils.h, moving MetalShaderLibrary into its own header.
Integrates MPS delegate functions into lowbit.h
Registers out variants for the ATen ops
Registers ET ops

Differential Revision: D65957345
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Nov 22, 2024
1 parent 7489c7d commit 734496d
Show file tree
Hide file tree
Showing 14 changed files with 365 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
*/
#ifdef USE_ATEN
using namespace at::native::mps;
using at::native::mps::MetalShaderLibrary;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
#endif
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
Expand Down
78 changes: 78 additions & 0 deletions torchao/experimental/kernels/mps/src/MetalShaderLibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#ifdef USE_EXECUTORCH
#include <executorch/backends/apple/mps/runtime/MPSDevice.h>
using executorch::backends::mps::delegate::MPSDevice;
static id<MTLDevice> MTL_DEVICE = MPSDevice::getInstance()->device();
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
#endif
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
#endif
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
#endif
return cpl;
}
};
65 changes: 0 additions & 65 deletions torchao/experimental/kernels/mps/src/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,63 +40,6 @@ inline id<MTLDevice> getMetalDevice() {

static id<MTLDevice> MTL_DEVICE = getMetalDevice();

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;
}
};

class MPSStream {
public:
MPSStream() {
Expand Down Expand Up @@ -136,14 +79,6 @@ class MPSStream {
id<MTLComputeCommandEncoder> _commandEncoder = nil;
};

inline void finalize_block(MPSStream* mpsStream) {
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
}

inline MPSStream* getCurrentMPSStream() {
return new MPSStream();
}
24 changes: 19 additions & 5 deletions torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <torchao/experimental/kernels/mps/src/dispatch.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
#include <torchao/experimental/kernels/mps/src/packing.h>

#include <cassert>
Expand All @@ -20,9 +20,9 @@
#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
dispatch_sync_with_rethrow;
#elif defined(USE_EXECUTORCH)
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
using namespace executorch::backends::mps::delegate;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif
Expand Down Expand Up @@ -103,7 +103,13 @@ inline void linear_lowbit_quant_weights_mps_impl(
0};

MPSStream* mpsStream = getCurrentMPSStream();
#ifdef USE_ATEN
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
#elif defined(USE_EXECUTORCH)
dispatch_sync(mpsStream->queue(), ^() {
#else
dispatch_block(mpsStream->queue(), ^() {
#endif
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLComputePipelineState> cpl =
Expand All @@ -119,7 +125,15 @@ inline void linear_lowbit_quant_weights_mps_impl(
length:sizeof(uint32_t) * sizes.size()
atIndex:5];
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
finalize_block(mpsStream);
#ifdef USE_EXECUTORCH
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
#else
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
#endif
}
});
}
Expand Down
24 changes: 24 additions & 0 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,27 @@ install(
EXPORT _targets
DESTINATION lib
)

if(TORCHAO_BUILD_EXECUTORCH_OPS)
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/executorch/*.mm")
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT ${_SRCS})
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

add_library(torchao_ops_mps_executorch STATIC)
target_link_libraries(torchao_ops_mps_executorch PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
)
install(
TARGETS
torchao_ops_mps_executorch
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
EXPORT _targets
DESTINATION lib
)
endif()
43 changes: 39 additions & 4 deletions torchao/experimental/ops/mps/aten/register.mm
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ void check_linear_mps_args(
}

template <int nbit>
Tensor linear_mps_kernel(
Tensor linear_mps_kernel_out(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& S,
const Tensor& Z) {
const Tensor& Z,
Tensor& C) {
TORCH_CHECK(
A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps");
TORCH_CHECK(
Expand All @@ -84,15 +85,15 @@ Tensor linear_mps_kernel(
S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps");
TORCH_CHECK(
Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");
TORCH_CHECK(
C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");

check_linear_mps_args<nbit>(A, B, group_size, S, Z);

auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);

auto C = at::empty({M, N}, A.options());

LowBitQuantWeights<nbit>::linear(
getMTLBufferStorage(A),
getMTLBufferStorage(B),
Expand All @@ -108,6 +109,19 @@ Tensor linear_mps_kernel(
return C;
}

template <int nbit>
Tensor linear_mps_kernel(
const Tensor& A,
const Tensor& B,
int64_t group_size,
const Tensor& S,
const Tensor& Z) {
auto M = A.size(0);
auto N = B.size(0);
auto C = at::empty({M, N}, A.options());
return linear_mps_kernel_out<nbit>(A, B, group_size, S, Z, C);
}

template <int nbit>
Tensor linear_mps_kernel_meta(
const Tensor& A,
Expand Down Expand Up @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
"_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
m.def(
"_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
m.def(
"_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
Expand All @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>);
m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>);
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>);
m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>);
m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>);
m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>);
m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>);
m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>);
m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>);
m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
Expand Down
Loading

0 comments on commit 734496d

Please sign in to comment.