Skip to content

Commit

Permalink
[GPU] Int4 utils fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov committed Oct 27, 2023
1 parent fd88a6b commit 3bc0f42
Showing 1 changed file with 38 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "common.cl"

typedef struct __attribute__ ((packed)) int4x2_t { char s0; } int4x2_t;
typedef struct __attribute__ ((packed)) int4x4_t { int4x2_t s0; int4x2_t s1; } int4x4_t;
typedef struct __attribute__ ((packed)) int4x8_t { int4x2_t s0; int4x2_t s1; int4x2_t s2; int4x2_t s3; } int4x8_t;
Expand All @@ -26,76 +28,78 @@ inline char2 cvt_int4x2_to_int8x2(int4x2_t v) __attribute__((overloadable)) {
return (char2)(v0, v1);
}

inline half2 unpack_to_half(uint4x2_t v) __attribute__((overloadable)) {
return convert_half2(cvt_uint4x2_to_uint8x2(v));
}

inline float2 unpack_to_float(uint4x2_t v) __attribute__((overloadable)) {
return convert_float2(cvt_uint4x2_to_uint8x2(v));
}

inline half2 unpack_to_half(int4x2_t v) __attribute__((overloadable)) {
return convert_half2(cvt_int4x2_to_int8x2(v));
}

inline float2 unpack_to_float(int4x2_t v) __attribute__((overloadable)) {
return convert_float2(cvt_int4x2_to_int8x2(v));
}

inline half4 unpack_to_half(uint4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline float4 unpack_to_float(uint4x4_t v) __attribute__((overloadable)) {
float2 f0 = unpack_to_float(v.s0);
float2 f1 = unpack_to_float(v.s1);
return (float4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half4 unpack_to_half(int4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline float4 unpack_to_float(int4x4_t v) __attribute__((overloadable)) {
float2 f0 = unpack_to_float(v.s0);
float2 f1 = unpack_to_float(v.s1);
return (float4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half8 unpack_to_half(uint4x8_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
half2 f2 = unpack_to_half(v.s2);
half2 f3 = unpack_to_half(v.s3);
return (half8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
inline float8 unpack_to_float(uint4x8_t v) __attribute__((overloadable)) {
float2 f0 = unpack_to_float(v.s0);
float2 f1 = unpack_to_float(v.s1);
float2 f2 = unpack_to_float(v.s2);
float2 f3 = unpack_to_float(v.s3);
return (float8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
}

inline float8 unpack_to_float(uint4x8_t v) __attribute__((overloadable)) {
inline float8 unpack_to_float(int4x8_t v) __attribute__((overloadable)) {
float2 f0 = unpack_to_float(v.s0);
float2 f1 = unpack_to_float(v.s1);
float2 f2 = unpack_to_float(v.s2);
float2 f3 = unpack_to_float(v.s3);
return (float8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
}

inline half8 unpack_to_half(int4x8_t v) __attribute__((overloadable)) {
#if defined(cl_khr_fp16)
inline half2 unpack_to_half(uint4x2_t v) __attribute__((overloadable)) {
return convert_half2(cvt_uint4x2_to_uint8x2(v));
}

inline half2 unpack_to_half(int4x2_t v) __attribute__((overloadable)) {
return convert_half2(cvt_int4x2_to_int8x2(v));
}

inline half4 unpack_to_half(uint4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half4 unpack_to_half(int4x4_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
return (half4)(f0.s0, f0.s1, f1.s0, f1.s1);
}

inline half8 unpack_to_half(uint4x8_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
half2 f2 = unpack_to_half(v.s2);
half2 f3 = unpack_to_half(v.s3);
return (half8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
}

inline float8 unpack_to_float(int4x8_t v) __attribute__((overloadable)) {
float2 f0 = unpack_to_float(v.s0);
float2 f1 = unpack_to_float(v.s1);
float2 f2 = unpack_to_float(v.s2);
float2 f3 = unpack_to_float(v.s3);
return (float8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
inline half8 unpack_to_half(int4x8_t v) __attribute__((overloadable)) {
half2 f0 = unpack_to_half(v.s0);
half2 f1 = unpack_to_half(v.s1);
half2 f2 = unpack_to_half(v.s2);
half2 f3 = unpack_to_half(v.s3);
return (half8)(f0.s0, f0.s1, f1.s0, f1.s1, f2.s0, f2.s1, f3.s0, f3.s1);
}
#endif // defined(cl_khr_fp16)

#define UNPACK_INT4x2(target_type, value) CAT(unpack_to_, target_type)(value)

0 comments on commit 3bc0f42

Please sign in to comment.