Skip to content

Commit

Permalink
feat: implement F16 support in shaders
Browse files Browse the repository at this point in the history
Co-Authored-By: Erich Gubler <[email protected]>
  • Loading branch information
FL33TW00D and ErichDonGubler committed Oct 22, 2024
1 parent 765dacf commit e556c47
Show file tree
Hide file tree
Showing 39 changed files with 1,151 additions and 61 deletions.
17 changes: 15 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ ndk-sys = "0.5.0"
#gpu-alloc = { path = "../gpu-alloc/gpu-alloc" }

[patch.crates-io]
half = { git = "https://github.com/FL33TW00D/half-rs.git", branch = "feature/arbitrary" }
#glow = { path = "../glow" }
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }
Expand Down
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ encase = { workspace = true, features = ["glam"] }
flume.workspace = true
getrandom.workspace = true
glam.workspace = true
half = { version = "2.1.0", features = ["bytemuck"] }
ktx2.workspace = true
log.workspace = true
nanorand.workspace = true
Expand Down
1 change: 1 addition & 0 deletions examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod mipmap;
pub mod msaa_line;
pub mod render_to_texture;
pub mod repeated_compute;
pub mod shader_f16;
pub mod shadow;
pub mod skybox;
pub mod srgb_blend;
Expand Down
6 changes: 6 additions & 0 deletions examples/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ const EXAMPLES: &[ExampleDesc] = &[
webgl: false, // No RODS
webgpu: true,
},
ExampleDesc {
name: "shader-f16",
function: wgpu_examples::shader_f16::main,
webgl: false, // No RODS
webgpu: true,
},
];

fn get_example_name() -> Option<String> {
Expand Down
9 changes: 9 additions & 0 deletions examples/src/shader_f16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# shader-f16

Demonstrate the ability to perform compute in F16 using wgpu.

## To Run

```
RUST_LOG=shader_f16 cargo run --bin wgpu-examples shader_f16
```
189 changes: 189 additions & 0 deletions examples/src/shader_f16/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use half::f16;
use std::{borrow::Cow, str::FromStr};
use wgpu::util::DeviceExt;

#[cfg_attr(test, allow(dead_code))]
async fn run() {
let numbers = if std::env::args().len() <= 2 {
let default = vec![
f16::from_f32(27.),
f16::from_f32(7.),
f16::from_f32(5.),
f16::from_f32(3.),
];
println!("No numbers were provided, defaulting to {default:?}");
default
} else {
std::env::args()
.skip(2)
.map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!"))
.collect()
};

let steps = execute_gpu(&numbers).await.unwrap();
println!("Steps: [{:?}]", steps);
#[cfg(target_arch = "wasm32")]
log::info!("Steps: [{:?}]", steps);
}

#[cfg_attr(test, allow(dead_code))]
async fn execute_gpu(numbers: &[f16]) -> Option<Vec<f16>> {
// Instantiates instance of WebGPU
let instance = wgpu::Instance::default();

// `request_adapter` instantiates the general connection to the GPU
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await?;

// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
// `features` being the available features.
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::SHADER_F16,
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: Default::default(),
},
None,
)
.await
.unwrap();

execute_gpu_inner(&device, &queue, numbers).await
}

async fn execute_gpu_inner(
device: &wgpu::Device,
queue: &wgpu::Queue,
numbers: &[f16],
) -> Option<Vec<f16>> {
// Loads the shader from WGSL
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});

// Gets the size in bytes of the buffer.
let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress;

// Instantiates buffer without data.
// `usage` of buffer specifies how it can be used:
// `BufferUsages::MAP_READ` allows it to be read (outside the shader).
// `BufferUsages::COPY_DST` allows it to be the destination of the copy.
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});

// Instantiates buffer with data (`numbers`).
// Usage allowing the buffer to be:
// A storage buffer (can be bound within a bind group and thus available to a shader).
// The destination of a copy.
// The source of a copy.
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(numbers),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});

// A bind group defines how buffers are accessed by shaders.
// It is to WebGPU what a descriptor set is to Vulkan.
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`).

// A pipeline specifies the operation of a shader

// Instantiates the pipeline.
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &cs_module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});

// Instantiates the bind group, once again specifying the binding of buffers.
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
}],
});

// A command encoder executes one or many pipelines.
// It is to WebGPU what a command buffer is to Vulkan.
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, Some(&bind_group), &[]);
cpass.insert_debug_marker("compute collatz iterations");
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
}
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);

// Submits command encoder for processing
queue.submit(Some(encoder.finish()));

// Note that we're not calling `.await` here.
let buffer_slice = staging_buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());

// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
device.poll(wgpu::Maintain::wait()).panic_on_timeout();

// Awaits until `buffer_future` can be read from
if let Ok(Ok(())) = receiver.recv_async().await {
// Gets contents of buffer
let data = buffer_slice.get_mapped_range();
// Since contents are got in bytes, this converts these bytes back to u32
let result = bytemuck::cast_slice(&data).to_vec();

// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
drop(data);
staging_buffer.unmap(); // Unmaps buffer from memory
// If you are familiar with C++ these 2 lines can be thought of similarly to:
// delete myPointer;
// myPointer = NULL;
// It effectively frees the memory

// Returns data from buffer
Some(result)
} else {
panic!("failed to run compute on gpu!")
}
}

pub fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
env_logger::init();
pollster::block_on(run());
}
#[cfg(target_arch = "wasm32")]
{
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_log::init().expect("could not initialize logger");
wasm_bindgen_futures::spawn_local(run());
}
}
9 changes: 9 additions & 0 deletions examples/src/shader_f16/shader.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
enable f16;

@group(0) @binding(0)
var<storage, read_write> values: array<vec4<f16>>; // this is used as both values and output for convenience

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
values[global_id.x] = fma(values[0], values[0], values[0]);
}
7 changes: 5 additions & 2 deletions naga/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ msl-out = []
## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`.
msl-out-if-target-apple = []

serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"]
spv-in = ["dep:petgraph", "dep:spirv"]
spv-out = ["dep:spirv"]
Expand Down Expand Up @@ -82,6 +82,9 @@ petgraph = { version = "0.6", optional = true }
pp-rs = { version = "0.2.1", optional = true }
hexf-parse = { version = "0.2.1", optional = true }
unicode-xid = { version = "0.2.6", optional = true }
# TODO: remove `[patch]` entry in workspace `Cargo.toml` for `half` after we upstream `arbitrary` support
half = { version = "2.4.1", features = ["arbitrary", "num-traits"] }
num-traits = "0.2"

[build-dependencies]
cfg_aliases.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2647,6 +2647,9 @@ impl<'a, W: Write> Writer<'a, W> {
// decimal part even it's zero which is needed for a valid glsl float constant
crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(_) => {
return Err(Error::Custom("GLSL has no 16-bit float type".into()));
}
// Unsigned integers need a `u` at the end
//
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// decimal part even it's zero
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => write!(self.out, "{value}")?,
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
Expand Down
23 changes: 22 additions & 1 deletion naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use crate::{
proc::{self, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
use half::f16;
use num_traits::real::Real;
#[cfg(test)]
use std::ptr;
use std::{
Expand Down Expand Up @@ -390,8 +392,12 @@ impl crate::Scalar {
match self {
Self {
kind: Sk::Float,
width: _,
width: 4,
} => "float",
Self {
kind: Sk::Float,
width: 2,
} => "half",
Self {
kind: Sk::Sint,
width: 4,
Expand Down Expand Up @@ -1379,6 +1385,21 @@ impl<W: Write> Writer<W> {
crate::Literal::F64(_) => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}
crate::Literal::F16(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
write!(self.out, "{sign}INFINITY")?;
} else if value.is_nan() {
write!(self.out, "NAN")?;
} else {
let suffix = if value.fract() == f16::from_f32(0.0) {
".0h"
} else {
"h"
};
write!(self.out, "{value}{suffix}")?;
}
}
crate::Literal::F32(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
Expand Down
Loading

0 comments on commit e556c47

Please sign in to comment.