Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused-Multiply-Add #214

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Fused-Multiply-Add #214

wants to merge 15 commits into from

Conversation

ashvardanian
Copy link
Owner

@ashvardanian ashvardanian commented Oct 19, 2024

SimSIMD is expanding and becoming closer to a fully-fledged BLAS library. BLAS level 1 for now, but it's a start! SimSIMD will prioritize mixed and low-precision vector math, favoring modern AI workloads. For image & media processing workloads, the new fma and wsum kernels approach 65 GB/s per core on Intel Sapphire Rapids. That's 100x faster than the serial code for u8 inputs with f32 scaling and accumulation.

Contains the following element-wise operations:

  • WSum or Weighted-Sum: $R_i = \alpha \cdot A_i + \beta \cdot B_i$
  • FMA or Fused-Multiply-Add: $R_i = \alpha \cdot A_i \cdot B_i + \beta \cdot C_i$

In NumPy terms:

import numpy as np
def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
    assert A.dtype == B.dtype, "Input types must match and affect the output style"
    return (Alpha * A + Beta * B).astype(A.dtype)
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
    assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style"
    return (Alpha * A * B + Beta * C).astype(A.dtype)

This tiny set of operations is enough to implement a wide range of algorithms:

  • To scale a vector by a scalar, just call WSum with $\beta = 0$.
  • To sum two vectors, just call WSum with $\alpha = \beta = 1$.
  • To average two vectors, just call WSum with $\alpha = \beta = 0.5$.
  • To multiply vectors element-wise, just call FMA with $\beta = 0$.

Benchmarks

On Intel Sapphire Rapids:

Run on (16 X 3900 MHz CPU s)
CPU Caches:
  L1 Data 48 KiB (x8)
  L1 Instruction 32 KiB (x8)
  L2 Unified 2048 KiB (x8)
  L3 Unified 61440 KiB (x1)
Load Average: 0.79, 0.75, 0.56
-------------------------------------------------------------------------------------------------------------
Benchmark                                                   Time             CPU   Iterations UserCounters...
-------------------------------------------------------------------------------------------------------------
fma_f64_haswell<1536d>/min_time:10.000/threads:1         1344 ns         1344 ns     10391897 abs_delta=0 bytes=27.4208G/s pairs=743.836k/s relative_error=0
wsum_f64_haswell<1536d>/min_time:10.000/threads:1        1040 ns         1040 ns     13465261 abs_delta=0 bytes=23.6376G/s pairs=961.815k/s relative_error=0
fma_f32_haswell<1536d>/min_time:10.000/threads:1          651 ns          651 ns     21534450 abs_delta=23.597n bytes=28.3033G/s pairs=1.53555M/s relative_error=47.0002n
wsum_f32_haswell<1536d>/min_time:10.000/threads:1         392 ns          392 ns     36225731 abs_delta=19.6436n bytes=31.3326G/s pairs=2.54985M/s relative_error=54.2672n
fma_f16_haswell<1536d>/min_time:10.000/threads:1          188 ns          188 ns     74334715 abs_delta=9.24044u bytes=49.1302G/s pairs=5.33097M/s relative_error=18.3975u
wsum_f16_haswell<1536d>/min_time:10.000/threads:1         130 ns          129 ns    106997523 abs_delta=12.015u bytes=47.4441G/s pairs=7.72203M/s relative_error=33.1896u
fma_bf16_haswell<1536d>/min_time:10.000/threads:1         225 ns          225 ns     62443286 abs_delta=1.91338m bytes=41.0221G/s pairs=4.45118M/s relative_error=3.81108m
wsum_bf16_haswell<1536d>/min_time:10.000/threads:1        161 ns          161 ns     86471812 abs_delta=1.36093m bytes=38.1318G/s pairs=6.20635M/s relative_error=3.75961m
fma_u8_sapphire<1536d>/min_time:10.000/threads:1         70.9 ns         70.9 ns    197232316 abs_delta=9.2812 bytes=64.9867G/s pairs=14.103M/s relative_error=2.45142m
wsum_u8_sapphire<1536d>/min_time:10.000/threads:1        50.6 ns         50.6 ns    276672248 abs_delta=8.89144 bytes=60.6775G/s pairs=19.7518M/s relative_error=3.28203m
fma_i8_sapphire<1536d>/min_time:10.000/threads:1         94.0 ns         94.0 ns    149003863 abs_delta=10.1192 bytes=49.0403G/s pairs=10.6424M/s relative_error=6.98359m
wsum_i8_sapphire<1536d>/min_time:10.000/threads:1        70.4 ns         70.4 ns    198873173 abs_delta=9.76862 bytes=43.613G/s pairs=14.197M/s relative_error=9.3472m
fma_f64_skylake<1536d>/min_time:10.000/threads:1         1340 ns         1340 ns     10460553 abs_delta=39.3003a bytes=27.5182G/s pairs=746.479k/s relative_error=78.2836a
wsum_f64_skylake<1536d>/min_time:10.000/threads:1        1036 ns         1036 ns     13484768 abs_delta=28.4608a bytes=23.717G/s pairs=965.047k/s relative_error=78.6298a
fma_f32_skylake<1536d>/min_time:10.000/threads:1          626 ns          626 ns     22261554 abs_delta=25.3818n bytes=29.4286G/s pairs=1.5966M/s relative_error=50.5553n
wsum_f32_skylake<1536d>/min_time:10.000/threads:1         386 ns          386 ns     35032887 abs_delta=19.7444n bytes=31.8146G/s pairs=2.58908M/s relative_error=54.5454n
fma_bf16_skylake<1536d>/min_time:10.000/threads:1         188 ns          188 ns     74667249 abs_delta=415.805u bytes=48.9511G/s pairs=5.31154M/s relative_error=827.962u
wsum_bf16_skylake<1536d>/min_time:10.000/threads:1        147 ns          147 ns     95128759 abs_delta=269.793u bytes=41.8834G/s pairs=6.81696M/s relative_error=745.331u
fma_f16_serial<1536d>/min_time:10.000/threads:1           900 ns          900 ns     15592180 abs_delta=2.97965u bytes=10.2444G/s pairs=1.11159M/s relative_error=5.93995u
wsum_f16_serial<1536d>/min_time:10.000/threads:1          821 ns          821 ns     17058449 abs_delta=1.11521u bytes=7.48594G/s pairs=1.21841M/s relative_error=3.07961u
fma_u8_serial<1536d>/min_time:10.000/threads:1           6692 ns         6692 ns      2089290 abs_delta=1.66854 bytes=688.583M/s pairs=149.432k/s relative_error=440.882u
wsum_u8_serial<1536d>/min_time:10.000/threads:1          5577 ns         5577 ns      2508971 abs_delta=2.32787 bytes=550.797M/s pairs=179.296k/s relative_error=859.403u
fma_i8_serial<1536d>/min_time:10.000/threads:1           6874 ns         6874 ns      2039761 abs_delta=5.14013 bytes=670.367M/s pairs=145.479k/s relative_error=3.54862m
wsum_i8_serial<1536d>/min_time:10.000/threads:1          5851 ns         5851 ns      2394538 abs_delta=6.36953 bytes=525.018M/s pairs=170.904k/s relative_error=6.09231m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant