Skip to content

Commit

Permalink
Re-enable slang-test for WGSL (#5120)
Browse files Browse the repository at this point in the history
My previous commit disabled the WGSL test by a mistake. This commit fixes the mistake and run the slang-test for WGSL tests.

frexp and modf were still not working for the vector types.
  • Loading branch information
jkwak-work authored Sep 20, 2024
1 parent 26ca9c5 commit 0677956
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
40 changes: 37 additions & 3 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -8686,6 +8686,7 @@ vector<T, N> fract(vector<T, N> x)
// Split float into mantissa and exponent
__generic<T : __BuiltinFloatingPointType>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T frexp(T x, out int exp)
{
Expand All @@ -8708,14 +8709,16 @@ T frexp(T x, out int exp)

__generic<T : __BuiltinFloatingPointType>
[__readNone]
[ForceInline]
[require(wgsl)]
void __wgsl_frexp(T x, out T fract, out int exp)
{
__intrinsic_asm "{ var s = frexp($0); $1 = s.fract; $2 = s.exp; }";
__intrinsic_asm "{ var s = frexp($0); (*($1)) = s.fract; (*($2)) = s.exp; }";
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
vector<T, N> frexp(vector<T, N> x, out vector<int, N> exp)
{
Expand All @@ -8727,13 +8730,27 @@ vector<T, N> frexp(vector<T, N> x, out vector<int, N> exp)
case spirv: return spirv_asm {
result:$$vector<T, N> = OpExtInst glsl450 Frexp $x &exp
};
case wgsl:
vector<T,N> fract;
__wgsl_frexp<T>(x, fract, exp);
return fract;
default:
VECTOR_MAP_BINARY(T, N, frexp, x, exp);
}
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
[ForceInline]
[require(wgsl)]
void __wgsl_frexp(vector<T, N> x, out vector<T, N> fract, out vector<int, N> exp)
{
__intrinsic_asm "{ var s = frexp($0); (*($1)) = s.fract; (*($2)) = s.exp; }";
}

__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<int, N, M, L> exp)
{
Expand Down Expand Up @@ -11099,6 +11116,7 @@ vector<T,N> fmedian3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
// split into integer and fractional parts (both with same sign)
__generic<T : __BuiltinFloatingPointType>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T modf(T x, out T ip)
{
Expand All @@ -11121,14 +11139,16 @@ T modf(T x, out T ip)

__generic<T : __BuiltinFloatingPointType>
[__readNone]
[ForceInline]
[require(wgsl)]
void __wgsl_modf(T x, out T fract, out T whole)
{
__intrinsic_asm "{ var s = modf($0); $1 = s.fract; $2 = s.whole; }";
__intrinsic_asm "{ var s = modf($0); (*($1)) = s.fract; (*($2)) = s.whole; }";
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
vector<T,N> modf(vector<T,N> x, out vector<T,N> ip)
{
Expand All @@ -11140,13 +11160,27 @@ vector<T,N> modf(vector<T,N> x, out vector<T,N> ip)
case spirv: return spirv_asm {
result:$$vector<T,N> = OpExtInst glsl450 Modf $x &ip
};
case wgsl:
vector<T,N> fract;
__wgsl_modf<T>(x, fract, ip);
return fract;
default:
VECTOR_MAP_BINARY(T, N, modf, x, ip);
}
}

__generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
[ForceInline]
[require(wgsl)]
void __wgsl_modf(vector<T,N> x, out vector<T,N> fract, out vector<T,N> whole)
{
__intrinsic_asm "{ var s = modf($0); (*($1)) = s.fract; (*($2)) = s.whole; }";
}

__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int>
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M,L> ip)
{
Expand Down Expand Up @@ -20624,4 +20658,4 @@ extension<T, L : IBufferDataLayout> RWStructuredBuffer<T, L> : IRWArray<T>
extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IRWArray<T>
{
int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; }
}
}
6 changes: 1 addition & 5 deletions tools/slang-test/slang-test-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target)
case SLANG_HOST_CPP_SOURCE:
case SLANG_CUDA_SOURCE:
case SLANG_METAL:
case SLANG_WGSL:
{
return 0;
}
Expand All @@ -993,11 +994,6 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target)
return PassThroughFlag::Metal;
}

case SLANG_WGSL:
{
return PassThroughFlag::WGSL;
}

case SLANG_SHADER_HOST_CALLABLE:
case SLANG_HOST_HOST_CALLABLE:

Expand Down

0 comments on commit 0677956

Please sign in to comment.