Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] Add SM 6.6 8-bit packed types and intrinsics #3325

Merged
merged 5 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tools/clang/lib/SPIRV/AstTypeProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
case BuiltinType::Bool:
case BuiltinType::Int:
case BuiltinType::UInt:
case BuiltinType::Int8_4Packed:
case BuiltinType::UInt8_4Packed:
case BuiltinType::Float:
return 32;
case BuiltinType::Double:
Expand All @@ -456,6 +458,11 @@ uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
// if -enable-16bit-types is false.
case BuiltinType::HalfFloat:
return 32;
case BuiltinType::UChar:
case BuiltinType::Char_U:
case BuiltinType::SChar:
case BuiltinType::Char_S:
return 8;
// The following types are treated as 16-bit if '-enable-16bit-types' option
// is enabled. They are treated as 32-bit otherwise.
case BuiltinType::Min12Int:
Expand Down Expand Up @@ -485,6 +492,24 @@ bool canTreatAsSameScalarType(QualType type1, QualType type2) {
type2.removeLocalConst();

return (type1.getCanonicalType() == type2.getCanonicalType()) ||
// Treat uint8_t4_packed and int8_t4_packed as the same because they
// are both repressented as 32-bit unsigned integers in SPIR-V.
(type1->isSpecificBuiltinType(BuiltinType::Int8_4Packed) &&
type2->isSpecificBuiltinType(BuiltinType::UInt8_4Packed)) ||
(type2->isSpecificBuiltinType(BuiltinType::Int8_4Packed) &&
type1->isSpecificBuiltinType(BuiltinType::UInt8_4Packed)) ||
// Treat uint8_t4_packed and uint32_t as the same because they
// are both repressented as 32-bit unsigned integers in SPIR-V.
(type1->isSpecificBuiltinType(BuiltinType::UInt) &&
type2->isSpecificBuiltinType(BuiltinType::UInt8_4Packed)) ||
(type2->isSpecificBuiltinType(BuiltinType::UInt) &&
type1->isSpecificBuiltinType(BuiltinType::UInt8_4Packed)) ||
// Treat int8_t4_packed and uint32_t as the same because they
// are both repressented as 32-bit unsigned integers in SPIR-V.
(type1->isSpecificBuiltinType(BuiltinType::UInt) &&
type2->isSpecificBuiltinType(BuiltinType::Int8_4Packed)) ||
(type2->isSpecificBuiltinType(BuiltinType::UInt) &&
type1->isSpecificBuiltinType(BuiltinType::Int8_4Packed)) ||
// Treat 'literal float' and 'float' as the same
(type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
type2->isFloatingType()) ||
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
// Integer-related capabilities
if (const auto *intType = dyn_cast<IntegerType>(type)) {
switch (intType->getBitwidth()) {
case 8: {
addCapability(spv::Capability::Int8);
break;
}
case 16: {
// Usage of a 16-bit integer type.
addCapability(spv::Capability::Int16);
Expand Down
12 changes: 12 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
return spvContext.getSIntType(32);
case BuiltinType::UInt:
case BuiltinType::ULong:
// The 'int8_t4_packed' and 'uint8_t4_packed' types are in fact 32-bit
// unsigned integers.
case BuiltinType::Int8_4Packed:
case BuiltinType::UInt8_4Packed:
return spvContext.getUIntType(32);

// void and bool
Expand Down Expand Up @@ -316,6 +320,14 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
case BuiltinType::UShort: // uint16_t
return spvContext.getUIntType(16);

// 8-bit integer types
case BuiltinType::UChar:
case BuiltinType::Char_U:
return spvContext.getUIntType(8);
case BuiltinType::SChar:
case BuiltinType::Char_S:
return spvContext.getSIntType(8);

// Relaxed precision types
case BuiltinType::Min10Float:
case BuiltinType::Min16Float:
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ SpirvContext::~SpirvContext() {
}

inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
assert(bitwidth >= 16 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));
assert(bitwidth >= 8 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));

return llvm::Log2_32(bitwidth);
}
Expand Down
172 changes: 172 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7452,6 +7452,20 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
}
break;
}
case hlsl::IntrinsicOp::IOP_pack_s8:
case hlsl::IntrinsicOp::IOP_pack_u8:
case hlsl::IntrinsicOp::IOP_pack_clamp_s8:
case hlsl::IntrinsicOp::IOP_pack_clamp_u8: {
retVal = processIntrinsic8BitPack(callExpr, hlslOpcode);
break;
}
case hlsl::IntrinsicOp::IOP_unpack_s8s16:
case hlsl::IntrinsicOp::IOP_unpack_s8s32:
case hlsl::IntrinsicOp::IOP_unpack_u8u16:
case hlsl::IntrinsicOp::IOP_unpack_u8u32: {
retVal = processIntrinsic8BitUnpack(callExpr, hlslOpcode);
break;
}
// DXR raytracing intrinsics
case hlsl::IntrinsicOp::IOP_DispatchRaysDimensions:
case hlsl::IntrinsicOp::IOP_DispatchRaysIndex:
Expand Down Expand Up @@ -9812,6 +9826,164 @@ SpirvEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
return spvBuilder.createBinaryOp(scaleOp, returnType, log2, scale, loc);
}

SpirvInstruction *
SpirvEmitter::processIntrinsic8BitPack(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
const auto loc = callExpr->getExprLoc();
assert(op == hlsl::IntrinsicOp::IOP_pack_s8 ||
op == hlsl::IntrinsicOp::IOP_pack_u8 ||
op == hlsl::IntrinsicOp::IOP_pack_clamp_s8 ||
op == hlsl::IntrinsicOp::IOP_pack_clamp_u8);

// Here's the signature for the pack intrinsic operations:
//
// uint8_t4_packed pack_u8(uint32_t4 unpackedVal);
// uint8_t4_packed pack_u8(uint16_t4 unpackedVal);
// int8_t4_packed pack_s8(int32_t4 unpackedVal);
// int8_t4_packed pack_s8(int16_t4 unpackedVal);
//
// These functions take a vec4 of 16-bit or 32-bit integers as input. For each
// element of the vec4, they pick the lower 8 bits, and drop the other bits.
// The result is four 8-bit values (32 bits in total) which are packed in an
// unsigned uint32_t.
//
//
// Here's the signature for the pack_clamp intrinsic operations:
//
// uint8_t4_packed pack_clamp_u8(int32_t4 val); // Pack and Clamp [0, 255]
// uint8_t4_packed pack_clamp_u8(int16_t4 val); // Pack and Clamp [0, 255]
//
// int8_t4_packed pack_clamp_s8(int32_t4 val); // Pack and Clamp [-128, 127]
// int8_t4_packed pack_clamp_s8(int16_t4 val); // Pack and Clamp [-128, 127]
//
// These functions take a vec4 of 16-bit or 32-bit integers as input. For each
// element of the vec4, they first clamp the value to a range (depending on
// the signedness) then pick the lower 8 bits, and drop the other bits.
// The result is four 8-bit values (32 bits in total) which are packed in an
// unsigned uint32_t.
//
// Note: uint8_t4_packed and int8_t4_packed are NOT vector types! They are
// both scalar 32-bit unsigned integer types where each byte represents one
// value.
//
// Note: In pack_clamp_{s|u}8 intrinsics, an input of 0x100 will be turned
// into 0xFF, not 0x00. Therefore, it is important to perform a clamp first,
// and then a truncation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I like your description here.


// Steps:
// Use GLSL extended instruction set's clamp (only for clamp instructions).
// Use OpUConvert/OpSConvert to truncate each element of the vec4 to 8 bits.
// Use OpBitcast to make a 32-bit uint out of the new vec4.
auto *arg = callExpr->getArg(0);
const auto argType = arg->getType();
SpirvInstruction *argInstr = doExpr(arg);
QualType elemType = {};
uint32_t elemCount = 0;
(void)isVectorType(argType, &elemType, &elemCount);
const bool isSigned = elemType->isSignedIntegerType();
assert(elemCount == 4);

const bool doesClamp = op == hlsl::IntrinsicOp::IOP_pack_clamp_s8 ||
op == hlsl::IntrinsicOp::IOP_pack_clamp_u8;
if (doesClamp) {
const auto bitwidth = getElementSpirvBitwidth(
astContext, elemType, spirvOptions.enable16BitTypes);
int32_t clampMin = op == hlsl::IntrinsicOp::IOP_pack_clamp_u8 ? 0 : -128;
int32_t clampMax = op == hlsl::IntrinsicOp::IOP_pack_clamp_u8 ? 255 : 127;
auto *minInstr = spvBuilder.getConstantInt(
elemType, llvm::APInt(bitwidth, clampMin, isSigned));
auto *maxInstr = spvBuilder.getConstantInt(
elemType, llvm::APInt(bitwidth, clampMax, isSigned));
auto *minVec = spvBuilder.getConstantComposite(
argType, {minInstr, minInstr, minInstr, minInstr});
auto *maxVec = spvBuilder.getConstantComposite(
argType, {maxInstr, maxInstr, maxInstr, maxInstr});
auto clampOp = isSigned ? GLSLstd450SClamp : GLSLstd450UClamp;
// ehsan
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove // ehsan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch! 😄 how did this sneak in! Good catch. Thanks

argInstr = spvBuilder.createGLSLExtInst(argType, clampOp,
{argInstr, minVec, maxVec}, loc);
}

if (isSigned) {
QualType v4Int8Type =
astContext.getExtVectorType(astContext.SignedCharTy, 4);
auto *bytesVecInstr = spvBuilder.createUnaryOp(spv::Op::OpSConvert,
v4Int8Type, argInstr, loc);
return spvBuilder.createUnaryOp(
spv::Op::OpBitcast, astContext.Int8_4PackedTy, bytesVecInstr, loc);
} else {
QualType v4Uint8Type =
astContext.getExtVectorType(astContext.UnsignedCharTy, 4);
auto *bytesVecInstr = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
v4Uint8Type, argInstr, loc);
return spvBuilder.createUnaryOp(
spv::Op::OpBitcast, astContext.UInt8_4PackedTy, bytesVecInstr, loc);
}
}

SpirvInstruction *
SpirvEmitter::processIntrinsic8BitUnpack(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
const auto loc = callExpr->getExprLoc();
assert(op == hlsl::IntrinsicOp::IOP_unpack_s8s16 ||
op == hlsl::IntrinsicOp::IOP_unpack_s8s32 ||
op == hlsl::IntrinsicOp::IOP_unpack_u8u16 ||
op == hlsl::IntrinsicOp::IOP_unpack_u8u32);

// Here's the signature for the pack intrinsic operations:
//
// int16_t4 unpack_s8s16(int8_t4_packed packedVal); // Sign Extended
// uint16_t4 unpack_u8u16(uint8_t4_packed packedVal); // Non-Sign Extended
// int32_t4 unpack_s8s32(int8_t4_packed packedVal); // Sign Extended
// uint32_t4 unpack_u8u32(uint8_t4_packed packedVal); // Non-Sign Extended
//
// These functions take a 32-bit unsigned integer as input (where each byte of
// the input represents one value, i.e. it's packed). They first unpack the
// 32-bit integer to a vector of 4 bytes. Then for each element of the vec4,
// they zero-extend or sign-extend the byte in order to achieve a 16-bit or
// 32-bit vector of integers.
//
// Note: uint8_t4_packed and int8_t4_packed are NOT vector types! They are
// both scalar 32-bit unsigned integer types where each byte represents one
// value.

// Steps:
// Use OpBitcast to make a vec4 of bytes from a 32-bit value.
// Use OpUConvert/OpSConvert to zero-extend/sign-extend each element of the
// vec4 to 16 or 32 bits.
auto *arg = callExpr->getArg(0);
SpirvInstruction *argInstr = doExpr(arg);

const bool isSigned = op == hlsl::IntrinsicOp::IOP_unpack_s8s16 ||
op == hlsl::IntrinsicOp::IOP_unpack_s8s32;

QualType resultType = {};
if (op == hlsl::IntrinsicOp::IOP_unpack_s8s16 ||
op == hlsl::IntrinsicOp::IOP_unpack_u8u16) {
resultType = astContext.getExtVectorType(
isSigned ? astContext.ShortTy : astContext.UnsignedShortTy, 4);
} else {
resultType = astContext.getExtVectorType(
isSigned ? astContext.IntTy : astContext.UnsignedIntTy, 4);
}

if (isSigned) {
QualType v4Int8Type =
astContext.getExtVectorType(astContext.SignedCharTy, 4);
auto *bytesVecInstr =
spvBuilder.createUnaryOp(spv::Op::OpBitcast, v4Int8Type, argInstr, loc);
return spvBuilder.createUnaryOp(spv::Op::OpSConvert, resultType,
bytesVecInstr, loc);
} else {
QualType v4Uint8Type =
astContext.getExtVectorType(astContext.UnsignedCharTy, 4);
auto *bytesVecInstr = spvBuilder.createUnaryOp(spv::Op::OpBitcast,
v4Uint8Type, argInstr, loc);
return spvBuilder.createUnaryOp(spv::Op::OpUConvert, resultType,
bytesVecInstr, loc);
}
}

SpirvInstruction *SpirvEmitter::processRayBuiltins(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
spv::BuiltIn builtin = spv::BuiltIn::Max;
Expand Down
9 changes: 9 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,15 @@ class SpirvEmitter : public ASTConsumer {
/// Processes the NonUniformResourceIndex intrinsic function.
SpirvInstruction *processIntrinsicNonUniformResourceIndex(const CallExpr *);

/// Processes the SM 6.6 pack_{s|u}8 and pack_clamp_{s|u}8 intrinsic
/// functions.
SpirvInstruction *processIntrinsic8BitPack(const CallExpr *,
hlsl::IntrinsicOp);

/// Processes the SM 6.6 unpack_{s|u}8{s|u}{16|32} intrinsic functions.
SpirvInstruction *processIntrinsic8BitUnpack(const CallExpr *,
hlsl::IntrinsicOp);

/// Process builtins specific to raytracing.
SpirvInstruction *processRayBuiltins(const CallExpr *, hlsl::IntrinsicOp op);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Run: %dxc -E main -T ps_6_6 -enable-16bit-types

float4 main(int16_t4 input1 : Inputs1, int16_t4 input2 : Inputs2) : SV_Target {
int16_t4 v4int16_var;
int32_t4 v4int32_var;

// Note: pack_clamp_s8 and pack_clamp_u8 do NOT accept an unsigned argument.

// CHECK: [[glsl_set:%\d+]] = OpExtInstImport "GLSL.std.450"

// CHECK: %short = OpTypeInt 16 1
// CHECK: %v4short = OpTypeVector %short 4

// CHECK: [[const_v4short_n128:%\d+]] = OpConstantComposite %v4short %short_n128 %short_n128 %short_n128 %short_n128
// CHECK: [[const_v4short_127:%\d+]] = OpConstantComposite %v4short %short_127 %short_127 %short_127 %short_127

// CHECK: [[const_v4int_n128:%\d+]] = OpConstantComposite %v4int %int_n128 %int_n128 %int_n128 %int_n128
// CHECK: [[const_v4int_127:%\d+]] = OpConstantComposite %v4int %int_127 %int_127 %int_127 %int_127

// CHECK: [[const_v4short_0:%\d+]] = OpConstantComposite %v4short %short_0 %short_0 %short_0 %short_0
// CHECK: [[const_v4short_255:%\d+]] = OpConstantComposite %v4short %short_255 %short_255 %short_255 %short_255

// CHECK: [[const_v4int_0:%\d+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
// CHECK: [[const_v4int_255:%\d+]] = OpConstantComposite %v4int %int_255 %int_255 %int_255 %int_255

// CHECK: %char = OpTypeInt 8 1
// CHECK: %v4char = OpTypeVector %char 4

////////////////////////////
// pack_clamp_s8 variants //
////////////////////////////

// CHECK: [[v4int16_var:%\d+]] = OpLoad %v4short %v4int16_var
// CHECK: [[clamped:%\d+]] = OpExtInst %v4short [[glsl_set]] SClamp [[v4int16_var]] [[const_v4short_n128]] [[const_v4short_127]]
// CHECK: [[truncated:%\d+]] = OpSConvert %v4char [[clamped]]
// CHECK: [[packed:%\d+]] = OpBitcast %uint [[truncated]]
// CHECK: OpStore %ps1 [[packed]]
int8_t4_packed ps1 = pack_clamp_s8(v4int16_var);

// CHECK: [[v4int16_var:%\d+]] = OpLoad %v4int %v4int32_var
// CHECK: [[clamped:%\d+]] = OpExtInst %v4int [[glsl_set]] SClamp [[v4int16_var]] [[const_v4int_n128]] [[const_v4int_127]]
// CHECK: [[truncated:%\d+]] = OpSConvert %v4char [[clamped]]
// CHECK: [[packed:%\d+]] = OpBitcast %uint [[truncated]]
// CHECK: OpStore %ps3 [[packed]]
int8_t4_packed ps3 = pack_clamp_s8(v4int32_var);

////////////////////////////
// pack_clamp_u8 variants //
////////////////////////////

// CHECK: [[v4int16_var:%\d+]] = OpLoad %v4short %v4int16_var
// CHECK: [[clamped:%\d+]] = OpExtInst %v4short [[glsl_set]] SClamp [[v4int16_var]] [[const_v4short_0]] [[const_v4short_255]]
// CHECK: [[truncated:%\d+]] = OpSConvert %v4char [[clamped]]
// CHECK: [[packed:%\d+]] = OpBitcast %uint [[truncated]]
// CHECK: OpStore %pu1 [[packed]]
uint8_t4_packed pu1 = pack_clamp_u8(v4int16_var);

// CHECK: [[v4int32_var:%\d+]] = OpLoad %v4int %v4int32_var
// CHECK: [[clamped:%\d+]] = OpExtInst %v4int [[glsl_set]] SClamp [[v4int32_var]] [[const_v4int_0]] [[const_v4int_255]]
// CHECK: [[truncated:%\d+]] = OpSConvert %v4char [[clamped]]
// CHECK: [[packed:%\d+]] = OpBitcast %uint [[truncated]]
// CHECK: OpStore %pu3 [[packed]]
uint8_t4_packed pu3 = pack_clamp_u8(v4int32_var);

return 0.xxxx;
}
Loading