Skip to content

Commit

Permalink
Add shader I64 and U64 support (gfx-rs#5154)
Browse files Browse the repository at this point in the history
Co-authored-by: Connor Fitzgerald <[email protected]>
  • Loading branch information
atlv24 and cwfitzgerald authored Mar 12, 2024
1 parent 3107f5e commit 4e6f873
Show file tree
Hide file tree
Showing 40 changed files with 2,183 additions and 164 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ By @cwfitzgerald in [#5325](https://github.com/gfx-rs/wgpu/pull/5325).
- As with other instance flags, this flag can be changed in calls to `InstanceFlags::with_env` with the new `WGPU_GPU_BASED_VALIDATION` environment variable.

By @ErichDonGubler in [#5146](https://github.com/gfx-rs/wgpu/pull/5146), [#5046](https://github.com/gfx-rs/wgpu/pull/5046).
- Signed and unsigned 64 bit integer support in shaders. By @rodolphito and @cwfitzgerald in [#5154](https://github.com/gfx-rs/wgpu/pull/5154)
- `wgpu::Instance` can now report which `wgpu::Backends` are available based on the build configuration. By @wumpf [#5167](https://github.com/gfx-rs/wgpu/pull/5167)
```diff
-wgpu::Instance::any_backend_feature_enabled()
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,9 @@ impl<'a, W: Write> Writer<'a, W> {
crate::Literal::I64(_) => {
return Err(Error::Custom("GLSL has no 64-bit integer type".into()));
}
crate::Literal::U64(_) => {
return Err(Error::Custom("GLSL has no 64-bit integer type".into()));
}
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
return Err(Error::Custom(
"Abstract types should not appear in IR presented to backends".into(),
Expand Down
12 changes: 10 additions & 2 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ impl crate::Scalar {
/// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
match self.kind {
crate::ScalarKind::Sint => Ok("int"),
crate::ScalarKind::Uint => Ok("uint"),
crate::ScalarKind::Sint => match self.width {
4 => Ok("int"),
8 => Ok("int64_t"),
_ => Err(Error::UnsupportedScalar(self)),
},
crate::ScalarKind::Uint => match self.width {
4 => Ok("uint"),
8 => Ok("uint64_t"),
_ => Err(Error::UnsupportedScalar(self)),
},
crate::ScalarKind::Float => match self.width {
2 => Ok("half"),
4 => Ok("float"),
Expand Down
84 changes: 66 additions & 18 deletions naga/src/back/hlsl/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ The [`temp_access_chain`] field is a member of [`Writer`] solely to
allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
needed once HLSL for the access has been generated.
Note about DXC and Load/Store functions:
DXC's HLSL has a generic [`Load` and `Store`] function for [`ByteAddressBuffer`] and
[`RWByteAddressBuffer`]. This is not available in FXC's HLSL, so we use
it only for types that are only available in DXC. Notably 64 and 16 bit types.
FXC's HLSL has functions Load, Load2, Load3, and Load4 and Store, Store2, Store3, Store4.
This loads/stores a vector of length 1, 2, 3, or 4. We use that for 32bit types, bitcasting to the
correct type if necessary.
[`Storage`]: crate::AddressSpace::Storage
[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
Expand All @@ -42,6 +52,7 @@ needed once HLSL for the access has been generated.
[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
[`temp_access_chain`]: super::Writer::temp_access_chain
[`Writer`]: super::Writer
[`Load` and `Store`]: https://github.com/microsoft/DirectXShaderCompiler/wiki/ByteAddressBuffer-Load-Store-Additions
*/

use super::{super::FunctionCtx, BackendResult, Error};
Expand Down Expand Up @@ -161,20 +172,39 @@ impl<W: fmt::Write> super::Writer<'_, W> {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load(")?;
// See note about DXC and Load/Store in the module's documentation.
if scalar.width == 4 {
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load(")?;
} else {
let ty = scalar.to_hlsl_str()?;
write!(self.out, "{var_name}.Load<{ty}>(")?;
};
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, "))")?;
write!(self.out, ")")?;
if scalar.width == 4 {
write!(self.out, ")")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Vector { size, scalar } => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
let size = size as u8;
// See note about DXC and Load/Store in the module's documentation.
if scalar.width == 4 {
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load{size}(")?;
} else {
let ty = scalar.to_hlsl_str()?;
write!(self.out, "{var_name}.Load<{ty}{size}>(")?;
};
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, "))")?;
write!(self.out, ")")?;
if scalar.width == 4 {
write!(self.out, ")")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Matrix {
Expand Down Expand Up @@ -288,26 +318,44 @@ impl<W: fmt::Write> super::Writer<'_, W> {
}
};
match *ty_resolution.inner_with(&module.types) {
crate::TypeInner::Scalar(_) => {
crate::TypeInner::Scalar(scalar) => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
// See note about DXC and Load/Store in the module's documentation.
if scalar.width == 4 {
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
} else {
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ");")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Vector { size, .. } => {
crate::TypeInner::Vector { size, scalar } => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
// See note about DXC and Load/Store in the module's documentation.
if scalar.width == 4 {
write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
} else {
write!(self.out, "{}{}.Store(", level, var_name)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ");")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Matrix {
Expand Down
91 changes: 72 additions & 19 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::U64(value) => write!(self.out, "{}uL", value)?,
crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
Expand Down Expand Up @@ -2551,7 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
convert,
} => {
let inner = func_ctx.resolve_type(expr, &module.types);
match convert {
let close_paren = match convert {
Some(dst_width) => {
let scalar = crate::Scalar {
kind,
Expand Down Expand Up @@ -2584,13 +2585,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
)));
}
};
true
}
None => {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
if inner.scalar_width() == Some(64) {
false
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
true
}
}
}
};
self.write_expr(module, expr, func_ctx)?;
write!(self.out, ")")?;
if close_paren {
write!(self.out, ")")?;
}
}
Expression::Math {
fun,
Expand Down Expand Up @@ -2862,9 +2871,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
write!(self.out, ")")?
}
// These overloads are only missing on FXC, so this is only needed for 32bit types,
// as non-32bit types are DXC only.
Function::MissingIntOverload(fun_name) => {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
if let Some(ScalarKind::Sint) = scalar_kind {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
if let Some(crate::Scalar {
kind: ScalarKind::Sint,
width: 4,
}) = scalar_kind
{
write!(self.out, "asint({fun_name}(asuint(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
Expand All @@ -2874,9 +2889,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
// These overloads are only missing on FXC, so this is only needed for 32bit types,
// as non-32bit types are DXC only.
Function::MissingIntReturnType(fun_name) => {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
if let Some(ScalarKind::Sint) = scalar_kind {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
if let Some(crate::Scalar {
kind: ScalarKind::Sint,
width: 4,
}) = scalar_kind
{
write!(self.out, "asint({fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
Expand All @@ -2895,23 +2916,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = scalar.kind {
write!(self.out, "min((32u){s}, firstbitlow(")?;
let scalar_width_bits = scalar.width * 8;

if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
write!(
self.out,
"min(({scalar_width_bits}u){s}, firstbitlow("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
// This is only needed for the FXC path, on 32bit signed integers.
write!(
self.out,
"asint(min(({scalar_width_bits}u){s}, firstbitlow("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar(scalar) => {
if let ScalarKind::Uint = scalar.kind {
write!(self.out, "min(32u, firstbitlow(")?;
let scalar_width_bits = scalar.width * 8;

if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min(32u, firstbitlow(")?;
// This is only needed for the FXC path, on 32bit signed integers.
write!(
self.out,
"asint(min({scalar_width_bits}u, firstbitlow("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
Expand All @@ -2930,30 +2966,47 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = scalar.kind {
write!(self.out, "((31u){s} - firstbithigh(")?;
// scalar width - 1
let constant = scalar.width * 8 - 1;

if scalar.kind == ScalarKind::Uint {
write!(self.out, "(({constant}u){s} - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
let conversion_func = match scalar.width {
4 => "asint",
_ => "",
};
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
" < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
" < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar(scalar) => {
// scalar width - 1
let constant = scalar.width * 8 - 1;

if let ScalarKind::Uint = scalar.kind {
write!(self.out, "(31u - firstbithigh(")?;
write!(self.out, "({constant}u - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
let conversion_func = match scalar.width {
4 => "asint",
_ => "",
};
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
write!(
self.out,
" < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
Expand Down
Loading

0 comments on commit 4e6f873

Please sign in to comment.