-
Notifications
You must be signed in to change notification settings - Fork 933
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement F16 support in shaders
Co-Authored-By: Erich Gubler <[email protected]>
- Loading branch information
1 parent
765dacf
commit e556c47
Showing
39 changed files
with
1,151 additions
and
61 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# shader-f16 | ||
|
||
Demonstrate the ability to perform compute in F16 using wgpu. | ||
|
||
## To Run | ||
|
||
``` | ||
RUST_LOG=shader_f16 cargo run --bin wgpu-examples shader_f16 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
use half::f16; | ||
use std::{borrow::Cow, str::FromStr}; | ||
use wgpu::util::DeviceExt; | ||
|
||
#[cfg_attr(test, allow(dead_code))] | ||
async fn run() { | ||
let numbers = if std::env::args().len() <= 2 { | ||
let default = vec![ | ||
f16::from_f32(27.), | ||
f16::from_f32(7.), | ||
f16::from_f32(5.), | ||
f16::from_f32(3.), | ||
]; | ||
println!("No numbers were provided, defaulting to {default:?}"); | ||
default | ||
} else { | ||
std::env::args() | ||
.skip(2) | ||
.map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!")) | ||
.collect() | ||
}; | ||
|
||
let steps = execute_gpu(&numbers).await.unwrap(); | ||
println!("Steps: [{:?}]", steps); | ||
#[cfg(target_arch = "wasm32")] | ||
log::info!("Steps: [{:?}]", steps); | ||
} | ||
|
||
#[cfg_attr(test, allow(dead_code))] | ||
async fn execute_gpu(numbers: &[f16]) -> Option<Vec<f16>> { | ||
// Instantiates instance of WebGPU | ||
let instance = wgpu::Instance::default(); | ||
|
||
// `request_adapter` instantiates the general connection to the GPU | ||
let adapter = instance | ||
.request_adapter(&wgpu::RequestAdapterOptions::default()) | ||
.await?; | ||
|
||
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters, | ||
// `features` being the available features. | ||
let (device, queue) = adapter | ||
.request_device( | ||
&wgpu::DeviceDescriptor { | ||
label: None, | ||
required_features: wgpu::Features::SHADER_F16, | ||
required_limits: wgpu::Limits::downlevel_defaults(), | ||
memory_hints: Default::default(), | ||
}, | ||
None, | ||
) | ||
.await | ||
.unwrap(); | ||
|
||
execute_gpu_inner(&device, &queue, numbers).await | ||
} | ||
|
||
async fn execute_gpu_inner( | ||
device: &wgpu::Device, | ||
queue: &wgpu::Queue, | ||
numbers: &[f16], | ||
) -> Option<Vec<f16>> { | ||
// Loads the shader from WGSL | ||
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { | ||
label: None, | ||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))), | ||
}); | ||
|
||
// Gets the size in bytes of the buffer. | ||
let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress; | ||
|
||
// Instantiates buffer without data. | ||
// `usage` of buffer specifies how it can be used: | ||
// `BufferUsages::MAP_READ` allows it to be read (outside the shader). | ||
// `BufferUsages::COPY_DST` allows it to be the destination of the copy. | ||
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { | ||
label: None, | ||
size, | ||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, | ||
mapped_at_creation: false, | ||
}); | ||
|
||
// Instantiates buffer with data (`numbers`). | ||
// Usage allowing the buffer to be: | ||
// A storage buffer (can be bound within a bind group and thus available to a shader). | ||
// The destination of a copy. | ||
// The source of a copy. | ||
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { | ||
label: Some("Storage Buffer"), | ||
contents: bytemuck::cast_slice(numbers), | ||
usage: wgpu::BufferUsages::STORAGE | ||
| wgpu::BufferUsages::COPY_DST | ||
| wgpu::BufferUsages::COPY_SRC, | ||
}); | ||
|
||
// A bind group defines how buffers are accessed by shaders. | ||
// It is to WebGPU what a descriptor set is to Vulkan. | ||
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`). | ||
|
||
// A pipeline specifies the operation of a shader | ||
|
||
// Instantiates the pipeline. | ||
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { | ||
label: None, | ||
layout: None, | ||
module: &cs_module, | ||
entry_point: None, | ||
compilation_options: Default::default(), | ||
cache: None, | ||
}); | ||
|
||
// Instantiates the bind group, once again specifying the binding of buffers. | ||
let bind_group_layout = compute_pipeline.get_bind_group_layout(0); | ||
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { | ||
label: None, | ||
layout: &bind_group_layout, | ||
entries: &[wgpu::BindGroupEntry { | ||
binding: 0, | ||
resource: storage_buffer.as_entire_binding(), | ||
}], | ||
}); | ||
|
||
// A command encoder executes one or many pipelines. | ||
// It is to WebGPU what a command buffer is to Vulkan. | ||
let mut encoder = | ||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); | ||
{ | ||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { | ||
label: None, | ||
timestamp_writes: None, | ||
}); | ||
cpass.set_pipeline(&compute_pipeline); | ||
cpass.set_bind_group(0, Some(&bind_group), &[]); | ||
cpass.insert_debug_marker("compute collatz iterations"); | ||
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed | ||
} | ||
// Sets adds copy operation to command encoder. | ||
// Will copy data from storage buffer on GPU to staging buffer on CPU. | ||
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size); | ||
|
||
// Submits command encoder for processing | ||
queue.submit(Some(encoder.finish())); | ||
|
||
// Note that we're not calling `.await` here. | ||
let buffer_slice = staging_buffer.slice(..); | ||
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished. | ||
let (sender, receiver) = flume::bounded(1); | ||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); | ||
|
||
// Poll the device in a blocking manner so that our future resolves. | ||
// In an actual application, `device.poll(...)` should | ||
// be called in an event loop or on another thread. | ||
device.poll(wgpu::Maintain::wait()).panic_on_timeout(); | ||
|
||
// Awaits until `buffer_future` can be read from | ||
if let Ok(Ok(())) = receiver.recv_async().await { | ||
// Gets contents of buffer | ||
let data = buffer_slice.get_mapped_range(); | ||
// Since contents are got in bytes, this converts these bytes back to u32 | ||
let result = bytemuck::cast_slice(&data).to_vec(); | ||
|
||
// With the current interface, we have to make sure all mapped views are | ||
// dropped before we unmap the buffer. | ||
drop(data); | ||
staging_buffer.unmap(); // Unmaps buffer from memory | ||
// If you are familiar with C++ these 2 lines can be thought of similarly to: | ||
// delete myPointer; | ||
// myPointer = NULL; | ||
// It effectively frees the memory | ||
|
||
// Returns data from buffer | ||
Some(result) | ||
} else { | ||
panic!("failed to run compute on gpu!") | ||
} | ||
} | ||
|
||
pub fn main() { | ||
#[cfg(not(target_arch = "wasm32"))] | ||
{ | ||
env_logger::init(); | ||
pollster::block_on(run()); | ||
} | ||
#[cfg(target_arch = "wasm32")] | ||
{ | ||
std::panic::set_hook(Box::new(console_error_panic_hook::hook)); | ||
console_log::init().expect("could not initialize logger"); | ||
wasm_bindgen_futures::spawn_local(run()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
enable f16; | ||
|
||
@group(0) @binding(0) | ||
var<storage, read_write> values: array<vec4<f16>>; // this is used as both values and output for convenience | ||
|
||
@compute @workgroup_size(1) | ||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
values[global_id.x] = fma(values[0], values[0], values[0]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.