From 20b2d6cb3e71b913ed5d05ece2aaa1d3f8ac4eab Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Sun, 31 Mar 2024 17:54:02 -0700 Subject: [PATCH] Improve cpu prompt eval speed This change upstreams llamafile's cpu matrix multiplication kernels which improve image and prompt evaluation speed. For starters, Q4_0 and Q8_0 weights should go ~40% faster on CPU. The biggest benefits are with data types like f16 / f32, which process prompts 2x faster thus making them faster than quantized data types for prompt evals. This change also introduces bona fide AVX512 support since tinyBLAS is able to exploit the larger register file. For example, on my CPU llama.cpp llava-cli processes an image prompt at 305 tokens/second, using the Q4_K and Q4_0 types, which has always been faster than if we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With this change, f16 LLaVA performance leap frogs to 464 tokens/second. On Intel Core i9-14900K this change improves F16 prompt perf by 5x. For example, using llama.cpp at HEAD with Mistral 7b f16 to process a 215 token prompt will go 13 tok/sec. This change has fixes making it go 52 tok/sec. It's mostly thanks to my vectorized outer product kernels but also because I added support for correctly counting the number of cores on Alderlake, so the default thread count discounts Intel's new efficiency cores. Only Linux right now can count cores. This work was sponsored by Mozilla who's given permission to change the license of this code from Apache 2.0 to MIT. To read more about what's improved, and how it works, see: https://justine.lol/matmul/ --- CMakeLists.txt | 2 + Makefile | 5 +- Package.swift | 1 + build.zig | 15 +- common/common.cpp | 69 +++ common/common.h | 3 +- ggml.c | 44 ++ scripts/sync-ggml-am.sh | 4 + sgemm.cpp | 1173 +++++++++++++++++++++++++++++++++++++++ sgemm.h | 12 + 10 files changed, 1319 insertions(+), 9 deletions(-) create mode 100644 sgemm.cpp create mode 100644 sgemm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 19fdfa46ca4f15..158174c2094d25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1151,6 +1151,8 @@ add_library(ggml OBJECT ggml-backend.h ggml-quants.c ggml-quants.h + sgemm.cpp + sgemm.h ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} diff --git a/Makefile b/Makefile index ebbbcd354c010d..79d478af62cc7b 100644 --- a/Makefile +++ b/Makefile @@ -676,13 +676,16 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h $(CC) $(CFLAGS) -c $< -o $@ +sgemm.o: sgemm.cpp sgemm.h ggml.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + unicode.o: unicode.cpp unicode.h $(CXX) $(CXXFLAGS) -c $< -o $@ unicode-data.o: unicode-data.cpp unicode-data.h $(CXX) $(CXXFLAGS) -c $< -o $@ -OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o +OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o sgemm.o llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ diff --git a/Package.swift b/Package.swift index 8b7195869b7874..bcdcafe48543b3 100644 --- a/Package.swift +++ b/Package.swift @@ -30,6 +30,7 @@ let package = Package( ], sources: [ "ggml.c", + "sgemm.cpp", "llama.cpp", "unicode.cpp", "unicode-data.cpp", diff --git a/build.zig b/build.zig index 7f36e596888da1..ad10b4aa60c109 100644 --- a/build.zig +++ b/build.zig @@ -112,6 +112,7 @@ pub fn build(b: *std.build.Builder) !void { make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false; const ggml = make.obj("ggml", "ggml.c"); + const sgemm = make.obj("sgemm", "sgemm.cpp"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_backend = make.obj("ggml-backend", "ggml-backend.c"); const ggml_quants = make.obj("ggml-quants", "ggml-quants.c"); @@ -128,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void { const clip = make.obj("clip", "examples/llava/clip.cpp"); const llava = make.obj("llava", "examples/llava/llava.cpp"); - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, console, grammar_parser }); - _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); - _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); - _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); - _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train }); - _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train }); + _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, console, grammar_parser }); + _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); + _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); + _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo }); + _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train }); + _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train }); - const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava }); + const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava }); if (server.target.isWindows()) { server.linkSystemLibrary("ws2_32"); } diff --git a/common/common.cpp b/common/common.cpp index b8323c1c1a90b1..4864c2ecb41750 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -104,6 +104,75 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } +#if defined(__x86_64__) && defined(__linux__) +#include + +static void cpuid(unsigned leaf, unsigned subleaf, + unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) { + __asm__("movq\t%%rbx,%%rsi\n\t" + "cpuid\n\t" + "xchgq\t%%rbx,%%rsi" + : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx) + : "0"(leaf), "2"(subleaf)); +} + +static int pin_cpu(int cpu) { + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(cpu, &mask); + return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); +} + +static bool is_hybrid_cpu(void) { + unsigned eax, ebx, ecx, edx; + cpuid(7, 0, &eax, &ebx, &ecx, &edx); + return !!(edx & (1u << 15)); +} + +static bool is_running_on_efficiency_core(void) { + unsigned eax, ebx, ecx, edx; + cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx); + int intel_atom = 0x20; + int core_type = (eax & 0xff000000u) >> 24; + return core_type == intel_atom; +} + +static int count_math_cpus(int cpu_count) { + int result = 0; + for (int cpu = 0; cpu < cpu_count; ++cpu) { + if (pin_cpu(cpu)) + return -1; + if (is_running_on_efficiency_core()) + continue; // efficiency cores harm lockstep threading + ++cpu; // hyperthreading isn't useful for linear algebra + ++result; + } + return result; +} + +#endif // __x86_64__ && __linux__ + +/** + * Returns number of CPUs on system that are useful for math. + */ +int get_math_cpu_count() { +#if defined(__x86_64__) && defined(__linux__) + int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); + if (cpu_count < 1) + return get_num_physical_cores(); + if (is_hybrid_cpu()) { + cpu_set_t affinity; + if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { + int result = count_math_cpus(cpu_count); + pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); + if (result > 0) + return result; + } + } +#endif + return get_num_physical_cores(); +} + void process_escapes(std::string & input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; diff --git a/common/common.h b/common/common.h index 99ee90bc3c7282..c03de15f76787c 100644 --- a/common/common.h +++ b/common/common.h @@ -39,6 +39,7 @@ extern char const *LLAMA_BUILD_TARGET; struct llama_control_vector_load_info; +int get_math_cpu_count(); int32_t get_num_physical_cores(); // @@ -48,7 +49,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = get_num_physical_cores(); + int32_t n_threads = get_math_cpu_count(); int32_t n_threads_draft = -1; int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; diff --git a/ggml.c b/ggml.c index 7471e792606c15..5be24bfe817b9a 100644 --- a/ggml.c +++ b/ggml.c @@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "sgemm.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -10817,6 +10818,27 @@ static void ggml_compute_forward_mul_mat( } #endif + if (src1_cont) { + for (int64_t j = 0; j < ne13; j++) + for (int64_t i = 0; i < ne12; i++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i/r2*nb02 + j/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)src1->data + i*nb12 + j*nb13, + nb11/ggml_type_size(src1->type), + (char *)dst->data + i*nb2 + j*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1: + (void)0; + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; @@ -10848,6 +10870,28 @@ static void ggml_compute_forward_mul_mat( const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (src1_cont) { + for (int64_t j = 0; j < ne13; j++) + for (int64_t i = 0; i < ne12; i++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i/r2*nb02 + j/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i + + nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*j), + row_size/ggml_type_size(vec_dot_type), + (char *)dst->data + i*nb2 + j*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2: + (void)0; + const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne1*ne12*ne13; // src1 rows diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 3003290f6c2fb4..2d7c05ea4e30b6 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -113,6 +113,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml-sycl.h -> ggml-sycl.h # src/ggml-vulkan.cpp -> ggml-vulkan.cpp # src/ggml-vulkan.h -> ggml-vulkan.h + # src/sgemm.cpp -> sgemm.cpp + # src/sgemm.h -> sgemm.h # include/ggml/ggml.h -> ggml.h # include/ggml/ggml-alloc.h -> ggml-alloc.h # include/ggml/ggml-backend.h -> ggml-backend.h @@ -147,6 +149,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/src\/ggml-sycl\.h/ggml-sycl.h/g' \ -e 's/src\/ggml-vulkan\.cpp/ggml-vulkan.cpp/g' \ -e 's/src\/ggml-vulkan\.h/ggml-vulkan.h/g' \ + -e 's/src\/sgemm\.cpp/sgemm.cpp/g' \ + -e 's/src\/sgemm\.h/sgemm.h/g' \ -e 's/include\/ggml\/ggml\.h/ggml.h/g' \ -e 's/include\/ggml\/ggml-alloc\.h/ggml-alloc.h/g' \ -e 's/include\/ggml\/ggml-backend\.h/ggml-backend.h/g' \ diff --git a/sgemm.cpp b/sgemm.cpp new file mode 100644 index 00000000000000..a6e0c155cf26aa --- /dev/null +++ b/sgemm.cpp @@ -0,0 +1,1173 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// BASIC LINEAR ALGEBRA SUBPROGRAMS +// +// +// This file implements multithreaded CPU matrix multiplication for the +// common contiguous use case C = Aᵀ * B. These kernels are designed to +// have excellent performance[1] for matrices that fit in the CPU cache +// without imposing any overhead such as cache filling or malloc calls. +// +// This implementation does not guarantee any upper bound with rounding +// errors, which grow along with k. Our goal's to maximally exploit the +// hardware for performance, and then use whatever resources remain for +// improving numerical accuracy. +// +// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, 2024. [Online]. +// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. + +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wignored-attributes" + +#include "sgemm.h" +#include +#ifdef __x86_64__ +#include +#endif +#ifdef __ARM_NEON +#include +#endif +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-quants.h" + +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + +#if defined(__ARM_NEON) || defined(__AVX512F__) +#define VECTOR_REGISTERS 32 +#else +#define VECTOR_REGISTERS 16 +#endif + +// there will be blocks +#define BEGIN_KERNEL(RM, RN) \ + int ytiles = (m - m0) / RM; \ + int xtiles = (n - n0) / RN; \ + int tiles = ytiles * xtiles; \ + int duty = (tiles + nth - 1) / nth; \ + if (duty < 1) \ + duty = 1; \ + int start = duty * ith; \ + int end = start + duty; \ + if (end > tiles) \ + end = tiles; \ + for (int job = start; job < end; ++job) { \ + int i = m0 + job / xtiles * RM; \ + int j = n0 + job % xtiles * RN; + +#define END_KERNEL() } + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +namespace { + +typedef ggml_fp16_t half; + +inline float unhalf(half d) { + return GGML_FP16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// vectorized fused multiply add + +#ifdef __ARM_NEON +inline float32x4_t madd(float32x4_t x, float32x4_t y, float32x4_t z) { + return vaddq_f32(vmulq_f32(x, y), z); +} +#endif // __ARM_NEON + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +inline float16x8_t madd(float16x8_t x, float16x8_t y, float16x8_t z) { + return vaddq_f16(vmulq_f16(x, y), z); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m128 madd(__m128 x, __m128 y, __m128 z) { + return _mm_add_ps(_mm_mul_ps(x, y), z); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m256 madd(__m256 x, __m256 y, __m256 z) { + return _mm256_add_ps(_mm256_mul_ps(x, y), z); +} +#endif // __AVX__ + +#if defined(__AVX512F__) +inline __m512 madd(__m512 x, __m512 y, __m512 z) { + return _mm512_add_ps(_mm512_mul_ps(x, y), z); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// vectorized horizontal sum + +#if defined(__ARM_NEON) +inline float hsum(float32x4_t x) { + return vaddvq_f32(x); +} +#endif // __ARM_NEON + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +inline float hsum(float16x8_t x) { + float32x4_t t = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t u = vcvt_f32_f16(vget_high_f16(x)); + return vaddvq_f32(vaddq_f32(t, u)); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m256 x) { + return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); +} +#endif // __AVX__ + +#if defined(__AVX512F__) +inline float hsum(__m512 x) { + return _mm512_reduce_add_ps(x); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// vectorized memory loading + +template T load(const U *); + +#if defined(__ARM_NEON) +template <> inline float32x4_t load(const float *p) { + return vld1q_f32(p); +} +template <> inline float16x8_t load(const half *p) { + return vld1q_f16((const __fp16 *)p); +} +template <> inline float32x4_t load(const half *p) { + return vcvt_f32_f16(vld1_f16((const __fp16 *)p));; +} +#endif // __ARM_NEON + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m128 load(const float *p) { + return _mm_loadu_ps(p); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const float *p) { + return _mm256_loadu_ps(p); +} +#endif // __AVX__ + +#if defined(__F16C__) +template <> inline __m256 load(const half *p) { + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); +} +#endif // __F16C__ + +#if defined(__AVX512F__) +template <> inline __m512 load(const float *p) { + return _mm512_loadu_ps(p); +} +template <> inline __m512 load(const half *p) { + return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); +} +#endif // __AVX512F__ + +////////////////////////////////////////////////////////////////////////////////////////// +// Floating Point Matrix Multiplication + +template +class tinyBLAS { + public: + tinyBLAS(int k, + const TA *A, int lda, + const TB *B, int ldb, + TC *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + if (m - m0 <= 0 || n - n0 <= 0) + return; + int mc, nc, mp, np; + if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) { + mc = 5; + nc = 5; + gemm5x5(m0, m, n0, n); + } else if (n - n0 >= 4 && m - m0 >= 3) { + mc = 3; + nc = 4; + gemm3x4(m0, m, n0, n); + } else if (n - n0 >= 4) { + mc = 1; + nc = 4; + gemm1x4(m0, m, n0, n); + } else if (m - m0 >= 4) { + mc = 4; + nc = 1; + gemm4x1(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm5x5(int m0, int m, int n0, int n) { + BEGIN_KERNEL(5, 5) + V c00 = {0}; + V c01 = {0}; + V c02 = {0}; + V c03 = {0}; + V c04 = {0}; + V c10 = {0}; + V c11 = {0}; + V c12 = {0}; + V c13 = {0}; + V c14 = {0}; + V c20 = {0}; + V c21 = {0}; + V c22 = {0}; + V c23 = {0}; + V c24 = {0}; + V c30 = {0}; + V c31 = {0}; + V c32 = {0}; + V c33 = {0}; + V c34 = {0}; + V c40 = {0}; + V c41 = {0}; + V c42 = {0}; + V c43 = {0}; + V c44 = {0}; + for (int l = 0; l < k; l += KN) { + V k0 = load(B + ldb * (j + 0) + l); + V k1 = load(B + ldb * (j + 1) + l); + V k2 = load(B + ldb * (j + 2) + l); + V k3 = load(B + ldb * (j + 3) + l); + V k4 = load(B + ldb * (j + 4) + l); + V a0 = load(A + lda * (i + 0) + l); + c00 = madd(a0, k0, c00); + c01 = madd(a0, k1, c01); + c02 = madd(a0, k2, c02); + c03 = madd(a0, k3, c03); + c04 = madd(a0, k4, c04); + V a1 = load(A + lda * (i + 1) + l); + c10 = madd(a1, k0, c10); + c11 = madd(a1, k1, c11); + c12 = madd(a1, k2, c12); + c13 = madd(a1, k3, c13); + c14 = madd(a1, k4, c14); + V a2 = load(A + lda * (i + 2) + l); + c20 = madd(a2, k0, c20); + c21 = madd(a2, k1, c21); + c22 = madd(a2, k2, c22); + c23 = madd(a2, k3, c23); + c24 = madd(a2, k4, c24); + V a3 = load(A + lda * (i + 3) + l); + c30 = madd(a3, k0, c30); + c31 = madd(a3, k1, c31); + c32 = madd(a3, k2, c32); + c33 = madd(a3, k3, c33); + c34 = madd(a3, k4, c34); + V a4 = load(A + lda * (i + 4) + l); + c40 = madd(a4, k0, c40); + c41 = madd(a4, k1, c41); + c42 = madd(a4, k2, c42); + c43 = madd(a4, k3, c43); + c44 = madd(a4, k4, c44); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 0) + (i + 3)] = hsum(c30); + C[ldc * (j + 0) + (i + 4)] = hsum(c40); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 1) + (i + 3)] = hsum(c31); + C[ldc * (j + 1) + (i + 4)] = hsum(c41); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 2) + (i + 3)] = hsum(c32); + C[ldc * (j + 2) + (i + 4)] = hsum(c42); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + C[ldc * (j + 3) + (i + 1)] = hsum(c13); + C[ldc * (j + 3) + (i + 2)] = hsum(c23); + C[ldc * (j + 3) + (i + 3)] = hsum(c33); + C[ldc * (j + 3) + (i + 4)] = hsum(c43); + C[ldc * (j + 4) + (i + 0)] = hsum(c04); + C[ldc * (j + 4) + (i + 1)] = hsum(c14); + C[ldc * (j + 4) + (i + 2)] = hsum(c24); + C[ldc * (j + 4) + (i + 3)] = hsum(c34); + C[ldc * (j + 4) + (i + 4)] = hsum(c44); + END_KERNEL() + } + + NOINLINE void gemm3x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(3, 4) + V c00 = {0}; + V c01 = {0}; + V c02 = {0}; + V c03 = {0}; + V c10 = {0}; + V c11 = {0}; + V c12 = {0}; + V c13 = {0}; + V c20 = {0}; + V c21 = {0}; + V c22 = {0}; + V c23 = {0}; + for (int l = 0; l < k; l += KN) { + V k0 = load(B + ldb * (j + 0) + l); + V k1 = load(B + ldb * (j + 1) + l); + V k2 = load(B + ldb * (j + 2) + l); + V k3 = load(B + ldb * (j + 3) + l); + V a0 = load(A + lda * (i + 0) + l); + c00 = madd(a0, k0, c00); + c01 = madd(a0, k1, c01); + c02 = madd(a0, k2, c02); + c03 = madd(a0, k3, c03); + V a1 = load(A + lda * (i + 1) + l); + c10 = madd(a1, k0, c10); + c11 = madd(a1, k1, c11); + c12 = madd(a1, k2, c12); + c13 = madd(a1, k3, c13); + V a2 = load(A + lda * (i + 2) + l); + c20 = madd(a2, k0, c20); + c21 = madd(a2, k1, c21); + c22 = madd(a2, k2, c22); + c23 = madd(a2, k3, c23); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + C[ldc * (j + 3) + (i + 1)] = hsum(c13); + C[ldc * (j + 3) + (i + 2)] = hsum(c23); + END_KERNEL() + } + + NOINLINE void gemm1x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 4) + V c00 = {0}; + V c01 = {0}; + V c02 = {0}; + V c03 = {0}; + for (int l = 0; l < k; l += KN) { + V a0 = load(A + lda * (i + 0) + l); + V k0 = load(B + ldb * (j + 0) + l); + V k1 = load(B + ldb * (j + 1) + l); + V k2 = load(B + ldb * (j + 2) + l); + V k3 = load(B + ldb * (j + 3) + l); + c00 = madd(a0, k0, c00); + c01 = madd(a0, k1, c01); + c02 = madd(a0, k2, c02); + c03 = madd(a0, k3, c03); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 3) + (i + 0)] = hsum(c03); + END_KERNEL() + } + + NOINLINE void gemm4x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 1) + V c00 = {0}; + V c01 = {0}; + V c02 = {0}; + V c10 = {0}; + V c11 = {0}; + V c12 = {0}; + V c20 = {0}; + V c21 = {0}; + V c22 = {0}; + V c30 = {0}; + V c31 = {0}; + V c32 = {0}; + int l = 0; + while (l + KN * 3 <= k) { + { + V k0 = load(B + ldb * (j + 0) + l); + c00 = madd(load(A + lda * (i + 0) + l), k0, c00); + c10 = madd(load(A + lda * (i + 1) + l), k0, c10); + c20 = madd(load(A + lda * (i + 2) + l), k0, c20); + c30 = madd(load(A + lda * (i + 3) + l), k0, c30); + } + l += KN; + { + V k0 = load(B + ldb * (j + 0) + l); + c01 = madd(load(A + lda * (i + 0) + l), k0, c01); + c11 = madd(load(A + lda * (i + 1) + l), k0, c11); + c21 = madd(load(A + lda * (i + 2) + l), k0, c21); + c31 = madd(load(A + lda * (i + 3) + l), k0, c31); + } + l += KN; + { + V k0 = load(B + ldb * (j + 0) + l); + c02 = madd(load(A + lda * (i + 0) + l), k0, c02); + c12 = madd(load(A + lda * (i + 1) + l), k0, c12); + c22 = madd(load(A + lda * (i + 2) + l), k0, c22); + c32 = madd(load(A + lda * (i + 3) + l), k0, c32); + } + l += KN; + } + for (; l < k; l += KN) { + V k0 = load(B + ldb * (j + 0) + l); + c00 = madd(load(A + lda * (i + 0) + l), k0, c00); + c10 = madd(load(A + lda * (i + 1) + l), k0, c10); + c20 = madd(load(A + lda * (i + 2) + l), k0, c20); + c30 = madd(load(A + lda * (i + 3) + l), k0, c30); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00) + hsum(c01) + hsum(c02); + C[ldc * (j + 0) + (i + 1)] = hsum(c10) + hsum(c11) + hsum(c12); + C[ldc * (j + 0) + (i + 2)] = hsum(c20) + hsum(c21) + hsum(c22); + C[ldc * (j + 0) + (i + 3)] = hsum(c30) + hsum(c31) + hsum(c32); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + V c0 = {0}; + V c1 = {0}; + V c2 = {0}; + V c3 = {0}; + int l = 0; + while (l + KN * 4 <= k) { + c0 = madd(load(A + lda * i + l), load(B + ldb * j + l), c0); + l += KN; + c1 = madd(load(A + lda * i + l), load(B + ldb * j + l), c1); + l += KN; + c2 = madd(load(A + lda * i + l), load(B + ldb * j + l), c2); + l += KN; + c3 = madd(load(A + lda * i + l), load(B + ldb * j + l), c3); + l += KN; + } + for (; l < k; l += KN) + c0 = madd(load(A + lda * i + l), load(B + ldb * j + l), c0); + C[ldc * j + i] = (hsum(c0) + hsum(c1)) + (hsum(c2) + hsum(c3)); + END_KERNEL() + } + + const TA *const A; + const TB *const B; + TC *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; + +////////////////////////////////////////////////////////////////////////////////////////// +// quant zero matrix multiplication + +#ifdef __ARM_FEATURE_DOTPROD +class tinyBLAS_Q0_ARM { + public: + tinyBLAS_Q0_ARM(int k, + const block_q8_0 *A, int lda, + const block_q8_0 *B, int ldb, + float *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + if (m - m0 <= 0 || n - n0 <= 0) + return; + int mc, nc, mp, np; + if (m - m0 >= 3 && n - n0 >= 3) { + mc = 3; + nc = 3; + gemm3x3(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm3x3(int m0, int m, int n0, int n) { + BEGIN_KERNEL(3, 3) + int32x4_t zero = vdupq_n_s32(0); + float32x4_t c00 = vdupq_n_f32(0.f); + float32x4_t c01 = vdupq_n_f32(0.f); + float32x4_t c02 = vdupq_n_f32(0.f); + float32x4_t c10 = vdupq_n_f32(0.f); + float32x4_t c11 = vdupq_n_f32(0.f); + float32x4_t c12 = vdupq_n_f32(0.f); + float32x4_t c20 = vdupq_n_f32(0.f); + float32x4_t c21 = vdupq_n_f32(0.f); + float32x4_t c22 = vdupq_n_f32(0.f); + const block_q8_0 *Ap0 = A + lda * (i + 0); + const block_q8_0 *Ap1 = A + lda * (i + 1); + const block_q8_0 *Ap2 = A + lda * (i + 2); + const block_q8_0 *Bp0 = B + ldb * (j + 0); + const block_q8_0 *Bp1 = B + ldb * (j + 1); + const block_q8_0 *Bp2 = B + ldb * (j + 2); + for (int l = 0; l < k; ++l) { + c00 = vmlaq_n_f32( + c00, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap0[l].qs), vld1q_s8(Bp0[l].qs)), + vld1q_s8(Ap0[l].qs + 16), vld1q_s8(Bp0[l].qs + 16))), + unhalf(Ap0[l].d) * unhalf(Bp0[l].d)); + c01 = vmlaq_n_f32( + c01, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap0[l].qs), vld1q_s8(Bp1[l].qs)), + vld1q_s8(Ap0[l].qs + 16), vld1q_s8(Bp1[l].qs + 16))), + unhalf(Ap0[l].d) * unhalf(Bp1[l].d)); + c02 = vmlaq_n_f32( + c02, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap0[l].qs), vld1q_s8(Bp2[l].qs)), + vld1q_s8(Ap0[l].qs + 16), vld1q_s8(Bp2[l].qs + 16))), + unhalf(Ap0[l].d) * unhalf(Bp2[l].d)); + c10 = vmlaq_n_f32( + c10, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap1[l].qs), vld1q_s8(Bp0[l].qs)), + vld1q_s8(Ap1[l].qs + 16), vld1q_s8(Bp0[l].qs + 16))), + unhalf(Ap1[l].d) * unhalf(Bp0[l].d)); + c11 = vmlaq_n_f32( + c11, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap1[l].qs), vld1q_s8(Bp1[l].qs)), + vld1q_s8(Ap1[l].qs + 16), vld1q_s8(Bp1[l].qs + 16))), + unhalf(Ap1[l].d) * unhalf(Bp1[l].d)); + c12 = vmlaq_n_f32( + c12, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap1[l].qs), vld1q_s8(Bp2[l].qs)), + vld1q_s8(Ap1[l].qs + 16), vld1q_s8(Bp2[l].qs + 16))), + unhalf(Ap1[l].d) * unhalf(Bp2[l].d)); + c20 = vmlaq_n_f32( + c20, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap2[l].qs), vld1q_s8(Bp0[l].qs)), + vld1q_s8(Ap2[l].qs + 16), vld1q_s8(Bp0[l].qs + 16))), + unhalf(Ap2[l].d) * unhalf(Bp0[l].d)); + c21 = vmlaq_n_f32( + c21, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap2[l].qs), vld1q_s8(Bp1[l].qs)), + vld1q_s8(Ap2[l].qs + 16), vld1q_s8(Bp1[l].qs + 16))), + unhalf(Ap2[l].d) * unhalf(Bp1[l].d)); + c22 = vmlaq_n_f32( + c22, + vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, vld1q_s8(Ap2[l].qs), vld1q_s8(Bp2[l].qs)), + vld1q_s8(Ap2[l].qs + 16), vld1q_s8(Bp2[l].qs + 16))), + unhalf(Ap2[l].d) * unhalf(Bp2[l].d)); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + float32x4_t acc = vdupq_n_f32(0.f); + const block_q8_0 *Ap = A + lda * i; + const block_q8_0 *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + acc = vmlaq_n_f32(acc, + vcvtq_f32_s32(vdotq_s32( + vdotq_s32(vdupq_n_s32(0), vld1q_s8(Ap[l].qs), vld1q_s8(Bp[l].qs)), + vld1q_s8(Ap[l].qs + 16), vld1q_s8(Bp[l].qs + 16))), + unhalf(Ap[l].d) * unhalf(Bp[l].d)); + } + C[ldc * j + i] = hsum(acc); + END_KERNEL() + } + + const block_q8_0 *const A; + const block_q8_0 *const B; + float *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; +#endif // __ARM_FEATURE_DOTPROD + +#ifdef __AVX2__ +template +class tinyBLAS_Q0_AVX2 { + public: + tinyBLAS_Q0_AVX2(int k, + const TA *A, int lda, + const TB *B, int ldb, + TC *C, int ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int m, int n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int m0, int m, int n0, int n) { + if (m - m0 <= 0 || n - n0 <= 0) + return; + int mc, nc, mp, np; + if (m - m0 >= 4 && n - n0 >= 3) { + mc = 4; + nc = 3; + gemm4x3(m0, m, n0, n); + } else if (m - m0 >= 4 && n - n0 >= 1) { + mc = 4; + nc = 1; + gemm4x1(m0, m, n0, n); + } else if (m - m0 >= 1 && n - n0 >= 4) { + mc = 1; + nc = 4; + gemm1x4(m0, m, n0, n); + } else { + mc = 1; + nc = 1; + gemm1x1(m0, m, n0, n); + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, mp, np, n); + mnpack(mp, m, np, n); + } + + NOINLINE void gemm4x3(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 3) + __m256 c00 = _mm256_setzero_ps(); + __m256 c10 = _mm256_setzero_ps(); + __m256 c20 = _mm256_setzero_ps(); + __m256 c30 = _mm256_setzero_ps(); + __m256 c01 = _mm256_setzero_ps(); + __m256 c11 = _mm256_setzero_ps(); + __m256 c21 = _mm256_setzero_ps(); + __m256 c31 = _mm256_setzero_ps(); + __m256 c02 = _mm256_setzero_ps(); + __m256 c12 = _mm256_setzero_ps(); + __m256 c22 = _mm256_setzero_ps(); + __m256 c32 = _mm256_setzero_ps(); + const TA *Ap0 = A + lda * (i + 0); + const TA *Ap1 = A + lda * (i + 1); + const TA *Ap2 = A + lda * (i + 2); + const TA *Ap3 = A + lda * (i + 3); + const TB *Bp0 = B + ldb * (j + 0); + const TB *Bp1 = B + ldb * (j + 1); + const TB *Bp2 = B + ldb * (j + 2); + for (int l = 0; l < k; ++l) { + float da0 = unhalf(Ap0[l].d); + float da1 = unhalf(Ap1[l].d); + float da2 = unhalf(Ap2[l].d); + float da3 = unhalf(Ap3[l].d); + __m256i e0 = load(Ap0 + l); + __m256i e1 = load(Ap1 + l); + __m256i e2 = load(Ap2 + l); + __m256i e3 = load(Ap3 + l); + float db0 = unhalf(Bp0[l].d); + __m256 d00 = _mm256_set1_ps(da0 * db0); + __m256 d10 = _mm256_set1_ps(da1 * db0); + __m256 d20 = _mm256_set1_ps(da2 * db0); + __m256 d30 = _mm256_set1_ps(da3 * db0); + __m256i f0 = load(Bp0 + l); + __m256i u0 = _mm256_sign_epi8(f0, f0); + __m256i s00 = _mm256_sign_epi8(e0, f0); + __m256i s10 = _mm256_sign_epi8(e1, f0); + __m256i s20 = _mm256_sign_epi8(e2, f0); + __m256i s30 = _mm256_sign_epi8(e3, f0); + c00 = madd(d00, updot(u0, s00), c00); + c10 = madd(d10, updot(u0, s10), c10); + c20 = madd(d20, updot(u0, s20), c20); + c30 = madd(d30, updot(u0, s30), c30); + float db1 = unhalf(Bp1[l].d); + __m256 d01 = _mm256_set1_ps(da0 * db1); + __m256 d11 = _mm256_set1_ps(da1 * db1); + __m256 d21 = _mm256_set1_ps(da2 * db1); + __m256 d31 = _mm256_set1_ps(da3 * db1); + __m256i f1 = load(Bp1 + l); + __m256i u1 = _mm256_sign_epi8(f1, f1); + __m256i s01 = _mm256_sign_epi8(e0, f1); + __m256i s11 = _mm256_sign_epi8(e1, f1); + __m256i s21 = _mm256_sign_epi8(e2, f1); + __m256i s31 = _mm256_sign_epi8(e3, f1); + c01 = madd(d01, updot(u1, s01), c01); + c11 = madd(d11, updot(u1, s11), c11); + c21 = madd(d21, updot(u1, s21), c21); + c31 = madd(d31, updot(u1, s31), c31); + float db2 = unhalf(Bp2[l].d); + __m256 d02 = _mm256_set1_ps(da0 * db2); + __m256 d12 = _mm256_set1_ps(da1 * db2); + __m256 d22 = _mm256_set1_ps(da2 * db2); + __m256 d32 = _mm256_set1_ps(da3 * db2); + __m256i f2 = load(Bp2 + l); + __m256i u2 = _mm256_sign_epi8(f2, f2); + __m256i s02 = _mm256_sign_epi8(e0, f2); + __m256i s12 = _mm256_sign_epi8(e1, f2); + __m256i s22 = _mm256_sign_epi8(e2, f2); + __m256i s32 = _mm256_sign_epi8(e3, f2); + c02 = madd(d02, updot(u2, s02), c02); + c12 = madd(d12, updot(u2, s12), c12); + c22 = madd(d22, updot(u2, s22), c22); + c32 = madd(d32, updot(u2, s32), c32); + } + C[ldc * (j + 0) + (i + 0)] = hsum(c00); + C[ldc * (j + 0) + (i + 1)] = hsum(c10); + C[ldc * (j + 0) + (i + 2)] = hsum(c20); + C[ldc * (j + 0) + (i + 3)] = hsum(c30); + C[ldc * (j + 1) + (i + 0)] = hsum(c01); + C[ldc * (j + 1) + (i + 1)] = hsum(c11); + C[ldc * (j + 1) + (i + 2)] = hsum(c21); + C[ldc * (j + 1) + (i + 3)] = hsum(c31); + C[ldc * (j + 2) + (i + 0)] = hsum(c02); + C[ldc * (j + 2) + (i + 1)] = hsum(c12); + C[ldc * (j + 2) + (i + 2)] = hsum(c22); + C[ldc * (j + 2) + (i + 3)] = hsum(c32); + END_KERNEL() + } + + NOINLINE void gemm4x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(4, 1) + __m256 c0 = _mm256_setzero_ps(); + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + const TA *Ap0 = A + lda * (i + 0); + const TA *Ap1 = A + lda * (i + 1); + const TA *Ap2 = A + lda * (i + 2); + const TA *Ap3 = A + lda * (i + 3); + const TB *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + float db0 = unhalf(Bp[l].d); + __m256i f = load(Bp + l); + __m256i u = _mm256_sign_epi8(f, f); + __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0); + __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0); + __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0); + __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0); + __m256i e0 = load(Ap0 + l); + __m256i e1 = load(Ap1 + l); + __m256i e2 = load(Ap2 + l); + __m256i e3 = load(Ap3 + l); + __m256i s0 = _mm256_sign_epi8(e0, f); + __m256i s1 = _mm256_sign_epi8(e1, f); + __m256i s2 = _mm256_sign_epi8(e2, f); + __m256i s3 = _mm256_sign_epi8(e3, f); + __m256 g0 = updot(u, s0); + __m256 g1 = updot(u, s1); + __m256 g2 = updot(u, s2); + __m256 g3 = updot(u, s3); + c0 = madd(d0, g0, c0); + c1 = madd(d1, g1, c1); + c2 = madd(d2, g2, c2); + c3 = madd(d3, g3, c3); + } + C[ldc * j + (i + 0)] = hsum(c0); + C[ldc * j + (i + 1)] = hsum(c1); + C[ldc * j + (i + 2)] = hsum(c2); + C[ldc * j + (i + 3)] = hsum(c3); + END_KERNEL() + } + + NOINLINE void gemm1x4(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 4) + __m256 c0 = _mm256_setzero_ps(); + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + const TB *Bp0 = B + ldb * (j + 0); + const TB *Bp1 = B + ldb * (j + 1); + const TB *Bp2 = B + ldb * (j + 2); + const TB *Bp3 = B + ldb * (j + 3); + const TA *Ap = A + lda * i; + for (int l = 0; l < k; ++l) { + float da0 = unhalf(Ap[l].d); + __m256i f = load(Ap + l); + __m256i u = _mm256_sign_epi8(f, f); + __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0); + __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0); + __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0); + __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0); + __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f)); + __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f)); + __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f)); + __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f)); + c0 = madd(d0, g0, c0); + c1 = madd(d1, g1, c1); + c2 = madd(d2, g2, c2); + c3 = madd(d3, g3, c3); + } + C[ldc * (j + 0) + i] = hsum(c0); + C[ldc * (j + 1) + i] = hsum(c1); + C[ldc * (j + 2) + i] = hsum(c2); + C[ldc * (j + 3) + i] = hsum(c3); + END_KERNEL() + } + + NOINLINE void gemm1x1(int m0, int m, int n0, int n) { + BEGIN_KERNEL(1, 1) + __m256 c = _mm256_setzero_ps(); + const TA *Ap = A + lda * i; + const TB *Bp = B + ldb * j; + for (int l = 0; l < k; ++l) { + __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d)); + __m256i e = load(Ap + l); + __m256i f = load(Bp + l); + __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e)); + c = madd(d, g, c); + } + C[ldc * j + i] = hsum(c); + END_KERNEL() + } + + inline __m256i load(const block_q8_0 *b) { + return _mm256_loadu_si256((const __m256i *)b->qs); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m256 updot(__m256i u, __m256i s) { + __m256i res; +#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + static inline __m256i denibble(const uint8_t *p) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)p); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(15); + return _mm256_and_si256(lowMask, bytes); + } + + const TA *const A; + const TB *const B; + TC *const C; + const int k; + const int lda; + const int ldb; + const int ldc; + const int ith; + const int nth; +}; +#endif // __AVX2__ + +} // namespace + +/** + * Performs optimized matrix multiplication on CPU. + * + * This subroutine may compute C = Aᵀ * B with column major ordering. + * Despite its name, this isn't a generalized implementation. Work is + * only performed when a handwritten kernel is written and available. + * Otherwise the caller should fall back to a general matmul routine. + * + * For example, for single-threaded single-precision GEMM you can say + * + * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, + * 0, 1, GGML_TASK_TYPE_COMPUTE, + * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); + * + * @param m is rows in `A` and `C` + * @param n is cols in `B` and `C` + * @param k is cols in `A` and rows in `B` + * @param A is first input matrix (always transposed) + * @param lda is row stride of `A` + * @param B is second input matrix (never transposed) + * @param ldb is row stride of `B` + * @param C is input/output array of output matrices + * @param ldc is row stride of `C` + * @param ith is thread id (must be less than `nth`) + * @param nth is number of threads (must be greater than zero) + * @param task is GGML task type + * @param Atype is GGML data type of `A` + * @param Btype is GGML data type of `B` + * @param Ctype is GGML data type of `C` + * @return true if this function was able to service the matmul request + */ +bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C, + int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { + + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(lda >= k); + assert(ldb >= k); + assert(ldc >= m); + assert(nth > 0); + assert(ith < nth); + assert(1ll * lda * m <= 0x7fffffff); + assert(1ll * ldb * n <= 0x7fffffff); + assert(1ll * ldc * n <= 0x7fffffff); + + if (Ctype != GGML_TYPE_F32) + return false; + + switch (Atype) { + + case GGML_TYPE_F32: { + if (Btype != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + if (k % 16) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<16, __m512, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__AVX__) + if (k % 8) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<8, __m256, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_NEON) + if (n < 4) + return false; + if (k % 4) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<4, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_F16: { +#if defined(__AVX512F__) + if (k % 16) + return false; + if (Btype != GGML_TYPE_F32) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<16, __m512, half, float, float> tb{ + k, (const half *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__AVX__) && defined(__F16C__) + if (k % 8) + return false; + if (Btype != GGML_TYPE_F32) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<8, __m256, half, float, float> tb{ + k, (const half *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + if (n < 4) + return false; + if (k % 8) + return false; + if (Btype != GGML_TYPE_F16) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<8, float16x8_t, half, half, float> tb{ + k, (const half *)A, lda, + (const half *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_NEON) + if (n < 4) + return false; + if (k % 4) + return false; + if (Btype != GGML_TYPE_F32) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS<4, float32x4_t, half, float, float> tb{ + k, (const half *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q8_0: { +#if defined(__AVX2__) + if (k % 32) + return false; + if (Btype != GGML_TYPE_Q8_0) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS_Q0_AVX2 tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + if (k % 32) + return false; + if (Btype != GGML_TYPE_Q8_0) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS_Q0_ARM tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q4_0: { +#if defined(__AVX2__) + if (k % 32) + return false; + if (Btype != GGML_TYPE_Q8_0) + return false; + if (task != GGML_TASK_TYPE_COMPUTE) + return true; + tinyBLAS_Q0_AVX2 tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + default: + return false; + } + + (void)m; + (void)n; + (void)k; + (void)A; + (void)lda; + (void)B; + (void)ldb; + (void)C; + (void)ldc; + (void)ith; + (void)nth; + (void)task; + (void)Atype; + (void)Btype; + (void)Ctype; +} diff --git a/sgemm.h b/sgemm.h new file mode 100644 index 00000000000000..da23b209c4dd5b --- /dev/null +++ b/sgemm.h @@ -0,0 +1,12 @@ +#pragma once +#include +#ifdef __cplusplus +extern "C" { +#endif + +bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, + void *, int, int, int, int, int, int, int); + +#ifdef __cplusplus +} +#endif