Skip to content

Commit

Permalink
[SYCL][CUDA][libclc] Implement nextafter for sycl::half in generic/. (#…
Browse files Browse the repository at this point in the history
…4939)

sycl::nextafter(half,half) was defaulting to sycl::nextafter(float,float) which does not return the next half. 

Software implementation written in libclc/generic and #included into ptx-nvidiacl.
  • Loading branch information
hdelan authored Nov 18, 2021
1 parent 0c55d3a commit 53c3268
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
49 changes: 49 additions & 0 deletions libclc/generic/libspirv/math/half_nextafter.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef HALF_NEXTAFTER_INC
#define HALF_NEXTAFTER_INC

#include <clcmacro.h>
#include <math/math.h>
#include <spirv/spirv.h>

#ifdef cl_khr_fp16

#pragma OPENCL EXTENSION cl_khr_fp16 : enable

_CLC_OVERLOAD _CLC_DEF half __spirv_ocl_nextafter(half x, half y) {
// NaNs
if (x != x)
return x;
if (y != y)
return y;
// Parity
if (x == y)
return x;

short *a = (short *)&x;
short *b = (short *)&y;
// Checking for sign digit
if (*a & 0x8000)
*a = 0x8000 - *a;
if (*b & 0x8000)
*b = 0x8000 - *b;
// Increment / decrement
*a += (*a < *b) ? 1 : -1;
// Undo the sign flip if necessary
*a = (*a < 0) ? 0x8000 - *a : *a;
return x;
}

_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_nextafter, half,
half)

#endif

#endif
2 changes: 2 additions & 0 deletions libclc/generic/libspirv/math/nextafter.cl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ _CLC_DEFINE_BINARY_BUILTIN(double, __spirv_ocl_nextafter, __builtin_nextafter,
double, double)

#endif

#include "half_nextafter.inc"
20 changes: 19 additions & 1 deletion libclc/ptx-nvidiacl/libspirv/math/nextafter.cl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@

#include <spirv/spirv.h>

#include "utils.h"
#include <../../include/libdevice.h>
#include <clcmacro.h>

#define __CLC_FUNCTION __spirv_ocl_nextafter
#define __CLC_BUILTIN __nv_nextafter
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
#include <math/binary_builtin.inc>
#define __CLC_BUILTIN_D __CLC_BUILTIN

_CLC_DEFINE_BINARY_BUILTIN(float, __CLC_FUNCTION, __CLC_BUILTIN_F, float, float)

#ifndef __FLOAT_ONLY

#ifdef cl_khr_fp64

#pragma OPENCL EXTENSION cl_khr_fp64 : enable

_CLC_DEFINE_BINARY_BUILTIN(double, __CLC_FUNCTION, __CLC_BUILTIN_D, double,
double)

#endif

#include "../../../generic/libspirv/math/half_nextafter.inc"

#endif

0 comments on commit 53c3268

Please sign in to comment.