Skip to content

Commit

Permalink
Cholesky mps implementation (pytorch#144193)
Browse files Browse the repository at this point in the history
Requested in pytorch#77764

PR is still in draft because it needs some cleanups and optimizations to get to cpu performance the least. Tasks:
- [x] Make `upper=True` work, only `upper=False` works now
- [x] Code cleanup
- [x] Optimizations(Though might need some help on this)(tried my best, maybe there is still some more to squeeze out)
- [x] Checks for positive definite input
- [x] Support for (*, N, N) input, currently only supports (B, N, N) input
- [x] Support other dtypes(float16, bfloat16)

Pull Request resolved: pytorch#144193
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Jan 16, 2025
1 parent 1b34665 commit 727ae13
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 2 deletions.
266 changes: 266 additions & 0 deletions aten/src/ATen/native/mps/kernels/LinearAlgebra.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <metal_array>
#include <metal_stdlib>

using namespace metal;
template <typename T>
Expand Down Expand Up @@ -31,6 +32,271 @@ kernel void naive_matmul(
outputData[x * strides[2].x + y * strides[2].y] = rc;
}

inline float blockReduceSum(
threadgroup float* sharedScratch,
float val,
uint tid,
uint tpg) {
sharedScratch[tid] = val;
threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint offset = tpg >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
sharedScratch[tid] += sharedScratch[tid + offset];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

return sharedScratch[0];
}

kernel void factorDiagonalBlock(
device float* A [[buffer(0)]],
device int* success [[buffer(1)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tid [[thread_position_in_threadgroup]],
uint bid [[threadgroup_position_in_grid]],
uint tpg [[threads_per_threadgroup]]) {
const uint actSize = min(N - k * NB, NB); // uint64 before NB
const uint batch_offset = bid * N * N;

const uint row0 = k * NB;
const uint col0 = k * NB;

threadgroup float tile[32][33];
threadgroup float reduceScratch[256];
const uint tileSize = actSize * actSize;

for (uint i = tid; i < tileSize; i += tpg) {
uint r = i / actSize;
uint c = i % actSize;
tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint kk = 0; kk < actSize; kk++) {
float diagElt = 0.0f;
if (kk > 0) {
float partialSum = 0.0f;
for (uint i = tid; i < kk; i += tpg) {
float val = tile[kk][i];
partialSum = fma(val, val, partialSum);
}
diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg);
}

if (tid == 0) {
float diagVal = tile[kk][kk] - diagElt;
// Check for positive definiteness
if (diagVal <= 0.0f) {
success[bid] = 0; // matrix is not positive definite
return;
}
tile[kk][kk] = sqrt(diagVal);
}
threadgroup_barrier(mem_flags::mem_threadgroup);

float pivot = tile[kk][kk];

for (uint j = kk + 1 + tid; j < actSize; j += tpg) {
float partialSum = 0.0f;
for (uint i = 0; i < kk; i++) {
partialSum = fma(tile[j][i], tile[kk][i], partialSum);
}

float val = tile[j][kk];
val -= partialSum;
val /= pivot;
tile[j][kk] = val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

for (uint i = tid; i < tileSize; i += tpg) {
uint r = i / actSize;
uint c = i % actSize;
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c];
}
}

kernel void applyTRSM(
device float* A [[buffer(0)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]]) {
uint b = tgid.x;
uint idxJ = tgid.y;

const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
const uint batch_offset = b * N * N;
const uint j = (k + 1) + idxJ;

uint row0 = j * NB;
uint col0 = k * NB;

uint actSize_j = (uint)min((int)(N - row0), (int)NB);
if (actSize_k == 0 || actSize_j == 0) {
return;
}
if (j == k) {
return;
}

threadgroup float diag[32 * 32];
threadgroup float target[32 * 32];

for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)];
}
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint col = 0; col < actSize_k; col++) {
float diag_val = diag[col * actSize_k + col];
if (abs(diag_val) < 1e-6f) {
diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f;
}

for (uint row = tid.x; row < actSize_j; row += tpg.x) {
float sum = target[row * actSize_k + col];

// kahan sum
float c = 0.0f;
for (uint p = 0; p < col; p++) {
float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}

target[row * actSize_k + col] = sum / diag_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i];
}
}

kernel void applySYRK(
device float* A [[buffer(0)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]]) {
uint b = tgid.x;
uint pairID = tgid.y;

uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2;
uint hRel = pairID - (jRel * (jRel + 1) >> 1);

const uint startJ = (k + 1);
uint j = startJ + jRel;
uint h = startJ + hRel;
uint row0 = j * NB;
uint col0 = h * NB;

const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
const uint actSize_j = min((uint)(N - row0), NB);
const uint actSize_h = min((uint)(N - col0), NB);
const uint batch_offset = b * N * N;

if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0)
return;

threadgroup float left[32 * 33];
threadgroup float right_t[32 * 33];
threadgroup float tile[32 * 33];

const uint threads = min(tpg.x, actSize_j * actSize_k);

for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) {
uint r = i / actSize_k;
uint c = i % actSize_k;
left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)];
}

for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) {
uint r = i / actSize_k;
uint c = i % actSize_k;
right_t[c * actSize_h + r] =
A[batch_offset + (h * NB + r) * N + (k * NB + c)];
}

for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
uint r = i / actSize_h;
uint c = i % actSize_h;
tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) {
uint r = idx / actSize_h;
uint c = idx % actSize_h;

if ((j == h) && (r < c))
continue;

uint tile_idx = r * actSize_h + c;
float sum = tile[tile_idx];

uint left_row = r * actSize_k;
uint right_col = c;

uint k = 0;
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};

for (; k + 4 <= actSize_k; k += 4) {
float4 left4 = {
left[left_row + k],
left[left_row + k + 1],
left[left_row + k + 2],
left[left_row + k + 3]};

float4 right4 = {
right_t[(k + 0) * actSize_h + right_col],
right_t[(k + 1) * actSize_h + right_col],
right_t[(k + 2) * actSize_h + right_col],
right_t[(k + 3) * actSize_h + right_col]};

sum4 = fma(left4, right4, sum4);
}

sum -= dot(sum4, 1.0);

for (; k < actSize_k; k++) {
sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum);
}

tile[tile_idx] = sum;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
uint r = i / actSize_h;
uint c = i % actSize_h;
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c];
}
}

#define INSTANTIATE_NAIVE_MM(DTYPE) \
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
naive_matmul<DTYPE>( \
Expand Down
98 changes: 98 additions & 0 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <ATen/ops/addr_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/cholesky_native.h>
#include <ATen/ops/linalg_cholesky_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/mm_native.h>
Expand Down Expand Up @@ -780,6 +782,83 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
return out;
}

static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
using namespace mps;

TORCH_CHECK(out.is_mps());
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32");
TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D");
TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square");

if (input.numel() == 0 || out.numel() == 0) {
out.zero_();
return out;
}
resize_output(out, input.sizes());
out.copy_(input);

int64_t ndim = out.dim();
int64_t N = out.size(-1);
int64_t B = 1;
for (int64_t i = 0; i < ndim - 2; i++) {
B *= out.size(i);
}

auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();

auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock");
auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM");
auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK");

int64_t NB = std::min<int64_t>(32, N);
int64_t numBlocks = (N + NB - 1) / NB;

Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
id<MTLBuffer> successBuffer = getMTLBufferStorage(success);

MTLSize threadGroupSize = MTLSizeMake(256, 1, 1);
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
[computeEncoder setBuffer:outBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2];
[computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3];

@autoreleasepool {
dispatch_sync_with_rethrow(stream->queue(), ^() {
for (int64_t k = 0; k < numBlocks; k++) {
[computeEncoder setComputePipelineState:factorDiagonalPSO];
[computeEncoder setBuffer:successBuffer offset:0 atIndex:1];
[computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4];
MTLSize gridSize = MTLSizeMake(B, 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];

// process all remaining blocks in this row/column in parallel
if (k < numBlocks - 1) {
int64_t startJ = k + 1;
int64_t nBlocksJ = (numBlocks - startJ);

if (nBlocksJ > 0) {
// TRSM for all blocks in parallel
MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1);
[computeEncoder setComputePipelineState:applyTRSMPSO];
[computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize];

// SYRK for all independent block pairs in parallel
uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2;
MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1);
[computeEncoder setComputePipelineState:applySYRKPSO];
[computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize];
}
}
}
});
}

TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite");
out.tril_(); //
return upper ? out.transpose_(ndim - 2, ndim - 1) : out;
}
} // namespace mps

Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
Expand Down Expand Up @@ -940,6 +1019,25 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons
return result;
}

Tensor cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
mps::linalg_cholesky_mps_impl(self, upper, out);
return out;
}

Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
}

Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
}

Tensor linalg_cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
return mps::linalg_cholesky_mps_impl(self, upper, out);
}

Tensor addbmm_mps(const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
Expand Down
Loading

0 comments on commit 727ae13

Please sign in to comment.