diff --git a/Cargo.lock b/Cargo.lock index eee65954e8..4bb9acfa96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1713,6 +1713,7 @@ checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ "cfg-if", "crunchy", + "serde", ] [[package]] @@ -2149,6 +2150,7 @@ dependencies = [ "criterion", "diff", "env_logger", + "half", "hexf-parse", "hlsl-snapshots", "indexmap", diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 3041a60099..7eea4f46d9 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -25,8 +25,9 @@ dot-out = [] glsl-in = ["dep:pp-rs"] glsl-out = [] msl-out = [] -serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"] -deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"] +f16 = ["half"] +serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"] +deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] spv-in = ["dep:petgraph", "dep:spirv"] spv-out = ["dep:spirv"] @@ -59,6 +60,7 @@ pp-rs = { version = "0.2.1", optional = true } hexf-parse = { version = "0.2.1", optional = true } unicode-xid = { version = "0.2.3", optional = true } arrayvec.workspace = true +half = {version = "2.4.1", optional = true} [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] criterion = { version = "0.5", features = [] } diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index c8c7ea557d..c4f980c0bf 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2571,6 +2571,10 @@ impl<'a, W: Write> Writer<'a, W> { // decimal part even it's zero which is needed for a valid glsl float constant crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?, crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + #[cfg(feature = "half")] + crate::Literal::F16(value) => { + return Err(Error::Custom("GLSL has no 16-bit float type".into())); + } // Unsigned integers need a `u` at the end // // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 86d8f89035..c853fd42d7 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2236,6 +2236,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // decimal part even it's zero crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + #[cfg(feature = "half")] + crate::Literal::F16(value) => write!(self.out, "{value:?}h")?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::U64(value) => write!(self.out, "{}uL", value)?, diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index e250d0b72c..df965bd312 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -324,8 +324,12 @@ impl crate::Scalar { match self { Self { kind: Sk::Float, - width: _, + width: 4, } => "float", + Self { + kind: Sk::Float, + width: 2, + } => "half", Self { kind: Sk::Sint, width: 4, @@ -1266,6 +1270,10 @@ impl Writer { crate::Literal::F64(_) => { return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) } + #[cfg(feature = "half")] + crate::Literal::F16(_) => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT16)) + } crate::Literal::F32(value) => { if value.is_infinite() { let sign = if value.is_sign_negative() { "-" } else { "" }; diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 789f6f62bf..1e764be5d3 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1211,6 +1211,8 @@ impl Writer { match expressions[expr] { Expression::Literal(literal) => match literal { + #[cfg(feature = "half")] + crate::Literal::F16(value) => write!(self.out, "{}h", value)?, crate::Literal::F32(value) => write!(self.out, "{}f", value)?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => { @@ -1957,6 +1959,10 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { kind: Sk::Float, width: 4, } => "f32", + Scalar { + kind: Sk::Float, + width: 2, + } => "f16", Scalar { kind: Sk::Sint, width: 4, diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 7ac5a18cd6..a31e2d3846 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -43,6 +43,9 @@ use crate::{ FastHashMap, FastHashSet, FastIndexMap, }; +#[cfg(feature = "half")] +use half::f16; + use petgraph::graphmap::GraphMap; use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf}; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e7cce17723..1198f4640b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1656,6 +1656,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr: Typed = match *expr { ast::Expression::Literal(literal) => { let literal = match literal { + ast::Literal::Number(Number::F16(f)) => crate::Literal::F16(f), ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index ceb2cb336c..c12406cc81 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -1,6 +1,9 @@ use crate::front::wgsl::error::NumberError; use crate::front::wgsl::parse::lexer::Token; +#[cfg(feature = "half")] +use half::f16; + /// When using this type assume no Abstract Int/Float for now #[derive(Copy, Clone, Debug, PartialEq)] pub enum Number { @@ -17,6 +20,8 @@ pub enum Number { /// Concrete u64 U64(u64), /// Concrete f32 + F16(f16), + /// Concrete f32 F32(f32), /// Concrete f64 F64(f64), diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 24e1b02c76..f5ffa4dc47 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -286,6 +286,8 @@ pub use crate::arena::{Arena, Handle, Range, UniqueArena}; pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; +#[cfg(feature = "half")] +use half::f16; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] @@ -881,6 +883,8 @@ pub enum Literal { F64(f64), /// May not be NaN or infinity. F32(f32), + #[cfg(feature = "half")] + F16(f16), U32(u32), I32(i32), U64(u64), diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index a0057f39ac..03c1d86702 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -114,6 +114,8 @@ bitflags::bitflags! { const SUBGROUP = 0x10000; /// Support for subgroup barriers. const SUBGROUP_BARRIER = 0x20000; + /// Support for 16-bit floating-point types. + const FLOAT16 = 0x40000; } }