Skip to content

Commit

Permalink
Follow function call statements (#69)
Browse files Browse the repository at this point in the history
* Follow function call statements

* Update test code to test entire output
  • Loading branch information
stefnotch authored Nov 11, 2024
1 parent 76ebfc0 commit d783b7c
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 1 deletion.
39 changes: 38 additions & 1 deletion wgsl_to_wgpu/src/wgsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,50 @@ pub fn entry_stages(module: &naga::Module) -> wgpu::ShaderStages {
.collect()
}

fn update_stages_blocks(
module: &naga::Module,
block: &naga::Block,
global_stages: &mut BTreeMap<String, wgpu::ShaderStages>,
stage: wgpu::ShaderStages,
) {
for statement in block.iter() {
match statement {
naga::Statement::Block(block) => {
update_stages_blocks(module, block, global_stages, stage);
}
naga::Statement::If { accept, reject, .. } => {
update_stages_blocks(module, accept, global_stages, stage);
update_stages_blocks(module, reject, global_stages, stage);
}
naga::Statement::Switch { cases, .. } => {
for c in cases {
update_stages_blocks(module, &c.body, global_stages, stage);
}
}
naga::Statement::Loop {
body, continuing, ..
} => {
update_stages_blocks(module, body, global_stages, stage);
update_stages_blocks(module, continuing, global_stages, stage);
}
naga::Statement::Call { function, .. } => {
update_stages(module, &module.functions[*function], global_stages, stage);
}
_ => (),
}
}
}

fn update_stages(
module: &naga::Module,
function: &naga::Function,
global_stages: &mut BTreeMap<String, wgpu::ShaderStages>,
stage: wgpu::ShaderStages,
) {
// Search the function body to find function call statements
update_stages_blocks(module, &function.body, global_stages, stage);

// Search the function body to find used globals.
// TODO: This doesn't handle function calls properly?
for (_, e) in function.expressions.iter() {
match e {
naga::Expression::GlobalVariable(g) => {
Expand All @@ -55,6 +91,7 @@ fn update_stages(
}
}
naga::Expression::CallResult(f) => {
// Function call expressions
update_stages(module, &module.functions[*f], global_stages, stage);
}
_ => (),
Expand Down
17 changes: 17 additions & 0 deletions wgsl_to_wgpu/tests/create_shader_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,20 @@ fn vertex_entries() {

assert_eq!(include_str!("output/vertex_entries.rs"), actual);
}

#[test]
fn shader_stage_collection() {
// Check the visibility: wgpu::ShaderStages::COMPUTE
let actual = wgsl_to_wgpu::create_shader_module(
include_str!("wgsl/shader_stage_collection.wgsl"),
"shader.wgsl",
wgsl_to_wgpu::WriteOptions {
rustfmt: true,
derive_encase_host_shareable: true,
..Default::default()
},
)
.unwrap();

assert_eq!(include_str!("output/shader_stage_collection.rs"), actual);
}
115 changes: 115 additions & 0 deletions wgsl_to_wgpu/tests/output/shader_stage_collection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
pub mod bind_groups {
#[derive(Debug)]
pub struct BindGroup0(wgpu::BindGroup);
#[derive(Debug)]
pub struct BindGroupLayout0<'a> {
pub counter: wgpu::BufferBinding<'a>,
}
const LAYOUT_DESCRIPTOR0: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor {
label: Some("LayoutDescriptor0"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
};
impl BindGroup0 {
pub fn get_bind_group_layout(device: &wgpu::Device) -> wgpu::BindGroupLayout {
device.create_bind_group_layout(&LAYOUT_DESCRIPTOR0)
}
pub fn from_bindings(device: &wgpu::Device, bindings: BindGroupLayout0) -> Self {
let bind_group_layout = device.create_bind_group_layout(&LAYOUT_DESCRIPTOR0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer(bindings.counter),
}],
label: Some("BindGroup0"),
});
Self(bind_group)
}
pub fn set<P: SetBindGroup>(&self, pass: &mut P) {
pass.set_bind_group(0, &self.0, &[]);
}
}
#[derive(Debug, Copy, Clone)]
pub struct BindGroups<'a> {
pub bind_group0: &'a BindGroup0,
}
impl<'a> BindGroups<'a> {
pub fn set<P: SetBindGroup>(&self, pass: &mut P) {
self.bind_group0.set(pass);
}
}
pub trait SetBindGroup {
fn set_bind_group(
&mut self,
index: u32,
bind_group: &wgpu::BindGroup,
offsets: &[wgpu::DynamicOffset],
);
}
impl SetBindGroup for wgpu::ComputePass<'_> {
fn set_bind_group(
&mut self,
index: u32,
bind_group: &wgpu::BindGroup,
offsets: &[wgpu::DynamicOffset],
) {
self.set_bind_group(index, bind_group, offsets);
}
}
impl SetBindGroup for wgpu::RenderPass<'_> {
fn set_bind_group(
&mut self,
index: u32,
bind_group: &wgpu::BindGroup,
offsets: &[wgpu::DynamicOffset],
) {
self.set_bind_group(index, bind_group, offsets);
}
}
}
pub fn set_bind_groups<P: bind_groups::SetBindGroup>(
pass: &mut P,
bind_group0: &bind_groups::BindGroup0,
) {
bind_group0.set(pass);
}
pub mod compute {
pub const MAIN_WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];
pub fn create_main_pipeline(device: &wgpu::Device) -> wgpu::ComputePipeline {
let module = super::create_shader_module(device);
let layout = super::create_pipeline_layout(device);
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Compute Pipeline main"),
layout: Some(&layout),
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: Default::default(),
})
}
}
pub const ENTRY_MAIN: &str = "main";
pub const SOURCE: &str = include_str!("shader.wgsl");
pub fn create_shader_module(device: &wgpu::Device) -> wgpu::ShaderModule {
let source = std::borrow::Cow::Borrowed(SOURCE);
device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source),
})
}
pub fn create_pipeline_layout(device: &wgpu::Device) -> wgpu::PipelineLayout {
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_groups::BindGroup0::get_bind_group_layout(device)],
push_constant_ranges: &[],
})
}
11 changes: 11 additions & 0 deletions wgsl_to_wgpu/tests/wgsl/shader_stage_collection.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@group(0) @binding(0)
var<storage, read_write> counter : array<atomic<u32>, 1>;

fn add_one() {
atomicAdd(&counter[0], 1u);
}

@compute @workgroup_size(1, 1, 1)
fn main() {
add_one();
}

0 comments on commit d783b7c

Please sign in to comment.