Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement F16 support in shaders #5701

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ ndk-sys = "0.5.0"
#gpu-alloc = { path = "../gpu-alloc/gpu-alloc" }

[patch.crates-io]
half = { git = "https://github.com/FL33TW00D/half-rs.git", branch = "feature/arbitrary" }
#glow = { path = "../glow" }
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }
Expand Down
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ encase = { workspace = true, features = ["glam"] }
flume.workspace = true
getrandom.workspace = true
glam.workspace = true
half = { version = "2.1.0", features = ["bytemuck"] }
ktx2.workspace = true
log.workspace = true
nanorand.workspace = true
Expand Down
1 change: 1 addition & 0 deletions examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod mipmap;
pub mod msaa_line;
pub mod render_to_texture;
pub mod repeated_compute;
pub mod shader_f16;
pub mod shadow;
pub mod skybox;
pub mod srgb_blend;
Expand Down
6 changes: 6 additions & 0 deletions examples/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ const EXAMPLES: &[ExampleDesc] = &[
webgl: false, // No RODS
webgpu: true,
},
ExampleDesc {
name: "shader-f16",
function: wgpu_examples::shader_f16::main,
webgl: false, // No RODS
webgpu: true,
},
];

fn get_example_name() -> Option<String> {
Expand Down
9 changes: 9 additions & 0 deletions examples/src/shader_f16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# shader-f16

Demonstrate the ability to perform compute in F16 using wgpu.

## To Run

```
RUST_LOG=shader_f16 cargo run --bin wgpu-examples shader_f16
```
189 changes: 189 additions & 0 deletions examples/src/shader_f16/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use half::f16;
use std::{borrow::Cow, str::FromStr};
use wgpu::util::DeviceExt;

#[cfg_attr(test, allow(dead_code))]
async fn run() {
let numbers = if std::env::args().len() <= 2 {
let default = vec![
f16::from_f32(27.),
f16::from_f32(7.),
f16::from_f32(5.),
f16::from_f32(3.),
];
println!("No numbers were provided, defaulting to {default:?}");
default
} else {
std::env::args()
.skip(2)
.map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!"))
.collect()
};

let steps = execute_gpu(&numbers).await.unwrap();
println!("Steps: [{:?}]", steps);
#[cfg(target_arch = "wasm32")]
log::info!("Steps: [{:?}]", steps);
}

#[cfg_attr(test, allow(dead_code))]
async fn execute_gpu(numbers: &[f16]) -> Option<Vec<f16>> {
// Instantiates instance of WebGPU
let instance = wgpu::Instance::default();

// `request_adapter` instantiates the general connection to the GPU
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await?;

// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
// `features` being the available features.
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::SHADER_F16,
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: Default::default(),
},
None,
)
.await
.unwrap();

execute_gpu_inner(&device, &queue, numbers).await
}

async fn execute_gpu_inner(
device: &wgpu::Device,
queue: &wgpu::Queue,
numbers: &[f16],
) -> Option<Vec<f16>> {
// Loads the shader from WGSL
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});

// Gets the size in bytes of the buffer.
let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress;

// Instantiates buffer without data.
// `usage` of buffer specifies how it can be used:
// `BufferUsages::MAP_READ` allows it to be read (outside the shader).
// `BufferUsages::COPY_DST` allows it to be the destination of the copy.
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});

// Instantiates buffer with data (`numbers`).
// Usage allowing the buffer to be:
// A storage buffer (can be bound within a bind group and thus available to a shader).
// The destination of a copy.
// The source of a copy.
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(numbers),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});

// A bind group defines how buffers are accessed by shaders.
// It is to WebGPU what a descriptor set is to Vulkan.
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`).

// A pipeline specifies the operation of a shader

// Instantiates the pipeline.
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &cs_module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});

// Instantiates the bind group, once again specifying the binding of buffers.
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
}],
});

// A command encoder executes one or many pipelines.
// It is to WebGPU what a command buffer is to Vulkan.
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, Some(&bind_group), &[]);
cpass.insert_debug_marker("compute collatz iterations");
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
}
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);

// Submits command encoder for processing
queue.submit(Some(encoder.finish()));

// Note that we're not calling `.await` here.
let buffer_slice = staging_buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());

// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
device.poll(wgpu::Maintain::wait()).panic_on_timeout();

// Awaits until `buffer_future` can be read from
if let Ok(Ok(())) = receiver.recv_async().await {
// Gets contents of buffer
let data = buffer_slice.get_mapped_range();
// Since contents are got in bytes, this converts these bytes back to u32
let result = bytemuck::cast_slice(&data).to_vec();

// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
drop(data);
staging_buffer.unmap(); // Unmaps buffer from memory
// If you are familiar with C++ these 2 lines can be thought of similarly to:
// delete myPointer;
// myPointer = NULL;
// It effectively frees the memory

// Returns data from buffer
Some(result)
} else {
panic!("failed to run compute on gpu!")
}
}

pub fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
env_logger::init();
pollster::block_on(run());
}
#[cfg(target_arch = "wasm32")]
{
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_log::init().expect("could not initialize logger");
wasm_bindgen_futures::spawn_local(run());
}
}
9 changes: 9 additions & 0 deletions examples/src/shader_f16/shader.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
enable f16;

@group(0) @binding(0)
FL33TW00D marked this conversation as resolved.
Show resolved Hide resolved
var<storage, read_write> values: array<vec4<f16>>; // this is used as both values and output for convenience

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
values[global_id.x] = fma(values[0], values[0], values[0]);
}
7 changes: 5 additions & 2 deletions naga/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ msl-out = []
## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`.
msl-out-if-target-apple = []

serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"]
spv-in = ["dep:petgraph", "dep:spirv"]
spv-out = ["dep:spirv"]
Expand Down Expand Up @@ -82,6 +82,9 @@ petgraph = { version = "0.6", optional = true }
pp-rs = { version = "0.2.1", optional = true }
hexf-parse = { version = "0.2.1", optional = true }
unicode-xid = { version = "0.2.6", optional = true }
# TODO: remove `[patch]` entry in workspace `Cargo.toml` for `half` after we upstream `arbitrary` support
half = { version = "2.4.1", features = ["arbitrary", "num-traits"] }
num-traits = "0.2"

[build-dependencies]
cfg_aliases.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2647,6 +2647,9 @@ impl<'a, W: Write> Writer<'a, W> {
// decimal part even it's zero which is needed for a valid glsl float constant
crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(_) => {
return Err(Error::Custom("GLSL has no 16-bit float type".into()));
}
// Unsigned integers need a `u` at the end
//
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// decimal part even it's zero
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => write!(self.out, "{value}")?,
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
Expand Down
23 changes: 22 additions & 1 deletion naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use crate::{
proc::{self, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
use half::f16;
use num_traits::real::Real;
#[cfg(test)]
use std::ptr;
use std::{
Expand Down Expand Up @@ -390,8 +392,12 @@ impl crate::Scalar {
match self {
Self {
kind: Sk::Float,
width: _,
width: 4,
} => "float",
Self {
kind: Sk::Float,
width: 2,
} => "half",
Self {
kind: Sk::Sint,
width: 4,
Expand Down Expand Up @@ -1385,6 +1391,21 @@ impl<W: Write> Writer<W> {
crate::Literal::F64(_) => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}
crate::Literal::F16(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
write!(self.out, "{sign}INFINITY")?;
} else if value.is_nan() {
write!(self.out, "NAN")?;
} else {
let suffix = if value.fract() == f16::from_f32(0.0) {
".0h"
} else {
"h"
};
write!(self.out, "{value}{suffix}")?;
}
}
crate::Literal::F32(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
Expand Down
Loading
Loading