Skip to content

Commit

Permalink
Implement bf16 compiler runtime library
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Aug 2, 2024
1 parent 9ebacb7 commit a80ab3f
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 7 deletions.
39 changes: 39 additions & 0 deletions libc/intrin/extendbfsf2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
╞══════════════════════════════════════════════════════════════════════════════╡
│ Copyright 2024 Justine Alexandra Roberts Tunney │
│ │
│ Permission to use, copy, modify, and/or distribute this software for │
│ any purpose with or without fee is hereby granted, provided that the │
│ above copyright notice and this permission notice appear in all copies. │
│ │
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
│ PERFORMANCE OF THIS SOFTWARE. │
╚─────────────────────────────────────────────────────────────────────────────*/

float __extendbfsf2(__bf16 f) {
union {
__bf16 f;
unsigned short i;
} ub = {f};

// convert brain16 to binary32
unsigned x = (unsigned)ub.i << 16;

// force nan to quiet
if ((x & 0x7fffffff) > 0x7f800000)
x |= 0x00400000;

// pun to float
union {
unsigned i;
float f;
} uf = {x};
return uf.f;
}
24 changes: 24 additions & 0 deletions libc/intrin/truncdfbf2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
╞══════════════════════════════════════════════════════════════════════════════╡
│ Copyright 2024 Justine Alexandra Roberts Tunney │
│ │
│ Permission to use, copy, modify, and/or distribute this software for │
│ any purpose with or without fee is hereby granted, provided that the │
│ above copyright notice and this permission notice appear in all copies. │
│ │
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
│ PERFORMANCE OF THIS SOFTWARE. │
╚─────────────────────────────────────────────────────────────────────────────*/

__bf16 __truncsfbf2(float);
__bf16 __truncdfbf2(double f) {
// TODO(jart): What else are we supposed to do here?
return __truncsfbf2(f);
}
40 changes: 40 additions & 0 deletions libc/intrin/truncsfbf2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
╞══════════════════════════════════════════════════════════════════════════════╡
│ Copyright 2024 Justine Alexandra Roberts Tunney │
│ │
│ Permission to use, copy, modify, and/or distribute this software for │
│ any purpose with or without fee is hereby granted, provided that the │
│ above copyright notice and this permission notice appear in all copies. │
│ │
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
│ PERFORMANCE OF THIS SOFTWARE. │
╚─────────────────────────────────────────────────────────────────────────────*/

__bf16 __truncsfbf2(float f) {
union {
float f;
unsigned i;
} uf = {f};
unsigned x = uf.i;

if ((x & 0x7fffffff) > 0x7f800000)
// force nan to quiet
x = (x | 0x00400000) >> 16;
else
// convert binary32 to brain16 with nearest rounding
x = (x + (0x7fff + ((x >> 16) & 1))) >> 16;

// pun to bf16
union {
unsigned short i;
__bf16 f;
} ub = {x};
return ub.f;
}
47 changes: 40 additions & 7 deletions test/libc/tinymath/fdot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "libc/stdio/stdio.h"
#include "libc/testlib/benchmark.h"
#include "libc/x/xasprintf.h"
#include "third_party/aarch64/arm_neon.internal.h"
#include "third_party/intel/immintrin.internal.h"

#define EXPENSIVE_TESTS 0

Expand All @@ -18,12 +20,11 @@
#define FASTMATH __attribute__((__optimize__("-O3,-ffast-math")))
#define PORTABLE __target_clones("avx512f,avx")

static unsigned long long lcg = 1;

int rand32(void) {
/* Knuth, D.E., "The Art of Computer Programming," Vol 2,
Seminumerical Algorithms, Third Edition, Addison-Wesley, 1998,
p. 106 (line 26) & p. 108 */
static unsigned long long lcg = 1;
lcg *= 6364136223846793005;
lcg += 1442695040888963407;
return lcg >> 32;
Expand Down Expand Up @@ -122,6 +123,34 @@ float fdotf_recursive(const float *A, const float *B, size_t n) {
}
}

optimizespeed float fdotf_intrin(const float *A, const float *B, size_t n) {
size_t i = 0;
#ifdef __AVX512F__
__m512 vec[CHUNK] = {};
for (; i + CHUNK * 16 <= n; i += CHUNK * 16)
for (int j = 0; j < CHUNK; ++j)
vec[j] = _mm512_fmadd_ps(_mm512_loadu_ps(A + i + j * 16),
_mm512_loadu_ps(B + i + j * 16), vec[j]);
float res = 0;
for (int j = 0; j < CHUNK; ++j)
res += _mm512_reduce_add_ps(vec[j]);
#elif defined(__aarch64__)
float32x4_t vec[CHUNK] = {};
for (; i + CHUNK * 4 <= n; i += CHUNK * 4)
for (int j = 0; j < CHUNK; ++j)
vec[j] =
vfmaq_f32(vec[j], vld1q_f32(A + i + j * 4), vld1q_f32(B + i + j * 4));
float res = 0;
for (int j = 0; j < CHUNK; ++j)
res += vaddvq_f32(vec[j]);
#else
float res = 0;
#endif
for (; i < n; ++i)
res += A[i] * B[i];
return res;
}

FASTMATH float fdotf_ruler(const float *A, const float *B, size_t n) {
int rule, step = 2;
size_t chunk, sp = 0;
Expand Down Expand Up @@ -179,6 +208,8 @@ void test_fdotf_ruler(void) {
}

PORTABLE float fdotf_hefty(const float *A, const float *B, size_t n) {
if (1)
return 0;
unsigned i, par, len = 0;
float sum, res[n / CHUNK + 1];
for (res[0] = i = 0; i + CHUNK <= n; i += CHUNK)
Expand Down Expand Up @@ -244,7 +275,7 @@ int main() {
#if EXPENSIVE_TESTS
size_t n = 512 * 1024;
#else
size_t n = 1024;
size_t n = 4096;
#endif

float *A = new float[n];
Expand All @@ -253,22 +284,24 @@ int main() {
A[i] = numba();
B[i] = numba();
}
float kahan, naive, dubble, recursive, hefty, ruler;
float kahan, naive, dubble, recursive, ruler, intrin;
test_fdotf_naive();
test_fdotf_hefty();
// test_fdotf_hefty();
test_fdotf_ruler();
BENCHMARK(20, 1, (kahan = barrier(fdotf_kahan(A, B, n))));
BENCHMARK(20, 1, (dubble = barrier(fdotf_dubble(A, B, n))));
BENCHMARK(20, 1, (naive = barrier(fdotf_naive(A, B, n))));
BENCHMARK(20, 1, (recursive = barrier(fdotf_recursive(A, B, n))));
BENCHMARK(20, 1, (intrin = barrier(fdotf_intrin(A, B, n))));
BENCHMARK(20, 1, (ruler = barrier(fdotf_ruler(A, B, n))));
BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
// BENCHMARK(20, 1, (hefty = barrier(fdotf_hefty(A, B, n))));
printf("dubble = %f (%g)\n", dubble, fabs(dubble - dubble));
printf("kahan = %f (%g)\n", kahan, fabs(kahan - dubble));
printf("naive = %f (%g)\n", naive, fabs(naive - dubble));
printf("recursive = %f (%g)\n", recursive, fabs(recursive - dubble));
printf("intrin = %f (%g)\n", intrin, fabs(intrin - dubble));
printf("ruler = %f (%g)\n", ruler, fabs(ruler - dubble));
printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
// printf("hefty = %f (%g)\n", hefty, fabs(hefty - dubble));
delete[] B;
delete[] A;

Expand Down
113 changes: 113 additions & 0 deletions test/math/bf16_test.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │
╞══════════════════════════════════════════════════════════════════════════════╡
│ Copyright 2024 Justine Alexandra Roberts Tunney │
│ │
│ Permission to use, copy, modify, and/or distribute this software for │
│ any purpose with or without fee is hereby granted, provided that the │
│ above copyright notice and this permission notice appear in all copies. │
│ │
│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
│ PERFORMANCE OF THIS SOFTWARE. │
╚─────────────────────────────────────────────────────────────────────────────*/
#include "libc/math.h"

#define CHECK(x) \
if (!(x)) \
return __LINE__
#define FALSE(x) \
{ \
volatile bool x_ = x; \
if (x_) \
return __LINE__; \
}
#define TRUE(x) \
{ \
volatile bool x_ = x; \
if (!x_) \
return __LINE__; \
}

__bf16 identity(__bf16 x) {
return x;
}
__bf16 (*half)(__bf16) = identity;

unsigned toint(float f) {
union {
float f;
unsigned i;
} u = {f};
return u.i;
}

int main() {
volatile float f;
volatile double d;
volatile __bf16 pi = 3.141;

// half → float → half
f = pi;
pi = f;

// half → float
float __extendbfsf2(__bf16);
CHECK(0.f == __extendbfsf2(0));
CHECK(3.140625f == __extendbfsf2(pi));
CHECK(3.140625f == pi);

// half → double → half
d = pi;
pi = d;

// float → half
__bf16 __truncsfbf2(float);
CHECK(0 == (float)__truncsfbf2(0));
CHECK(pi == (float)__truncsfbf2(3.141f));
CHECK(3.140625f == (float)__truncsfbf2(3.141f));

// double → half
__bf16 __truncdfbf2(double);
CHECK(0 == (double)__truncdfbf2(0));
CHECK(3.140625 == (double)__truncdfbf2(3.141));

// specials
volatile __bf16 nan = NAN;
volatile __bf16 positive_infinity = +INFINITY;
volatile __bf16 negative_infinity = -INFINITY;
CHECK(isnan(nan));
CHECK(!isinf(pi));
CHECK(!isnan(pi));
CHECK(isinf(positive_infinity));
CHECK(isinf(negative_infinity));
CHECK(!isnan(positive_infinity));
CHECK(!isnan(negative_infinity));
CHECK(!signbit(pi));
CHECK(signbit(half(-pi)));
CHECK(!signbit(half(+0.)));
CHECK(signbit(half(-0.)));

// arithmetic
CHECK(half(-3) == -half(3));
CHECK(half(9) == half(3) * half(3));
CHECK(half(0) == half(pi) - half(pi));
CHECK(half(6.28125) == half(pi) + half(pi));

// comparisons
CHECK(half(3) > half(2));
CHECK(half(3) < half(4));
CHECK(half(3) <= half(3));
CHECK(half(3) >= half(3));
TRUE(half(NAN) != half(NAN));
FALSE(half(NAN) == half(NAN));
TRUE(half(3) != half(NAN));
FALSE(half(3) == half(NAN));
TRUE(half(NAN) != half(3));
FALSE(half(NAN) == half(3));
}

0 comments on commit a80ab3f

Please sign in to comment.