diff --git a/Cargo.toml b/Cargo.toml index 0511538..03ccae6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "naga_oil" -version = "0.8.2" +version = "0.9.0" edition = "2021" license = "MIT OR Apache-2.0" description = "a crate for combining and manipulating shaders using naga IR" @@ -15,7 +15,7 @@ prune = [] override_any = [] [dependencies] -naga = { version = "0.12", features = ["wgsl-in", "wgsl-out", "glsl-in", "glsl-out", "clone", "span"] } +naga = { version = "0.13", features = ["wgsl-in", "wgsl-out", "glsl-in", "glsl-out", "clone", "span"] } tracing = "0.1" regex = "1.5" regex-syntax = "0.6" @@ -29,6 +29,6 @@ once_cell = "1.17.0" indexmap = "1.9.3" [dev-dependencies] -wgpu = { version = "0.16", features=["naga"] } +wgpu = { version = "0.17", features=["naga"] } futures-lite = "1" tracing-subscriber = { version = "0.3", features = ["std", "fmt"] } diff --git a/src/compose/mod.rs b/src/compose/mod.rs index d5d8201..caa27e6 100644 --- a/src/compose/mod.rs +++ b/src/compose/mod.rs @@ -143,6 +143,7 @@ pub mod comment_strip_iter; pub mod error; pub mod preprocess; mod test; +pub mod util; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)] pub enum ShaderLanguage { diff --git a/src/compose/test.rs b/src/compose/test.rs index 686ad05..ff006e5 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -1083,6 +1083,45 @@ mod test { output_eq!(wgsl, "tests/expected/conditional_import_b.txt"); } + #[test] + fn use_shared_global() { + let mut composer = Composer::default(); + + composer + .add_composable_module(ComposableModuleDescriptor { + source: include_str!("tests/use_shared_global/mod.wgsl"), + file_path: "tests/use_shared_global/mod.wgsl", + ..Default::default() + }) + .unwrap(); + let module = composer + .make_naga_module(NagaModuleDescriptor { + source: include_str!("tests/use_shared_global/top.wgsl"), + file_path: "tests/use_shared_global/top.wgsl", + ..Default::default() + }) + .unwrap(); + + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::default(), + ) + .validate(&module) + .unwrap(); + let wgsl = naga::back::wgsl::write_string( + &module, + &info, + naga::back::wgsl::WriterFlags::EXPLICIT_TYPES, + ) + .unwrap(); + + // let mut f = std::fs::File::create("use_shared_global.txt").unwrap(); + // f.write_all(wgsl.as_bytes()).unwrap(); + // drop(f); + + output_eq!(wgsl, "tests/expected/use_shared_global.txt"); + } + #[cfg(feature = "test_shader")] #[test] fn effective_defs() { diff --git a/src/compose/tests/expected/dup_import.txt b/src/compose/tests/expected/dup_import.txt index f20ad86..855c630 100644 --- a/src/compose/tests/expected/dup_import.txt +++ b/src/compose/tests/expected/dup_import.txt @@ -1,4 +1,4 @@ -const _naga_oil_mod_MNXW443UOM_memberPI: f32 = 3.0999999046325684; +const _naga_oil_mod_MNXW443UOM_memberPI: f32 = 3.1; fn _naga_oil_mod_ME_memberf() -> f32 { return (_naga_oil_mod_MNXW443UOM_memberPI * 1.0); diff --git a/src/compose/tests/expected/import_in_decl.txt b/src/compose/tests/expected/import_in_decl.txt index b7cdc68..c1817c3 100644 --- a/src/compose/tests/expected/import_in_decl.txt +++ b/src/compose/tests/expected/import_in_decl.txt @@ -1,8 +1,7 @@ const _naga_oil_mod_MNXW443UOM_memberX: u32 = 1u; - const _naga_oil_mod_MJUW4ZA_membery: u32 = 2u; -var _naga_oil_mod_MJUW4ZA_memberarr: array; +var _naga_oil_mod_MJUW4ZA_memberarr: array; fn main() -> f32 { let _e2: u32 = _naga_oil_mod_MJUW4ZA_memberarr[0]; diff --git a/src/compose/tests/expected/item_import_test.txt b/src/compose/tests/expected/item_import_test.txt index f7a23ce..00d9170 100644 --- a/src/compose/tests/expected/item_import_test.txt +++ b/src/compose/tests/expected/item_import_test.txt @@ -2,7 +2,6 @@ - let _e1: u32 = _naga_oil_mod_MNXW443UOM_memberdouble(_naga_oil_mod_MNXW443UOM_memberX); let _e1: u32 = _naga_oil_mod_MNXW443UOM_memberdouble(_naga_oil_mod_MNXW443UOM_memberY); return (in * 2u); diff --git a/src/compose/tests/expected/use_shared_global.txt b/src/compose/tests/expected/use_shared_global.txt new file mode 100644 index 0000000..54d95c9 --- /dev/null +++ b/src/compose/tests/expected/use_shared_global.txt @@ -0,0 +1,15 @@ +var _naga_oil_mod_NVXWI_membera: f32 = 0.0; + +fn add() { + let _e2: f32 = _naga_oil_mod_NVXWI_membera; + _naga_oil_mod_NVXWI_membera = (_e2 + 1.0); + return; +} + +fn main() -> f32 { + add(); + add(); + let _e1: f32 = _naga_oil_mod_NVXWI_membera; + return _e1; +} + diff --git a/src/compose/tests/use_shared_global/mod.wgsl b/src/compose/tests/use_shared_global/mod.wgsl new file mode 100644 index 0000000..9c5ea97 --- /dev/null +++ b/src/compose/tests/use_shared_global/mod.wgsl @@ -0,0 +1,3 @@ +#define_import_path mod + +var a: f32 = 0.0; \ No newline at end of file diff --git a/src/compose/tests/use_shared_global/top.wgsl b/src/compose/tests/use_shared_global/top.wgsl new file mode 100644 index 0000000..4e07d33 --- /dev/null +++ b/src/compose/tests/use_shared_global/top.wgsl @@ -0,0 +1,11 @@ +#import mod + +fn add() { + mod::a += 1.0; +} + +fn main() -> f32 { + add(); + add(); + return mod::a; +} \ No newline at end of file diff --git a/src/compose/util.rs b/src/compose/util.rs new file mode 100644 index 0000000..c2646d1 --- /dev/null +++ b/src/compose/util.rs @@ -0,0 +1,267 @@ +use naga::Expression; + +// Expression does not implement PartialEq except for internal testing (cfg_attr(test)), so we must use our own version. +// This implementation is tweaked from the output of `cargo expand` +#[inline] +pub fn expression_eq(lhs: &Expression, rhs: &Expression) -> bool { + let __lhs_tag = std::mem::discriminant(lhs); + let __arg1_tag = std::mem::discriminant(rhs); + __lhs_tag == __arg1_tag + && match (lhs, rhs) { + (Expression::Literal(__lhs_0), Expression::Literal(__arg1_0)) => *__lhs_0 == *__arg1_0, + (Expression::Constant(__lhs_0), Expression::Constant(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + (Expression::ZeroValue(__lhs_0), Expression::ZeroValue(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + ( + Expression::Compose { + ty: __lhs_0, + components: __lhs_1, + }, + Expression::Compose { + ty: __arg1_0, + components: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Access { + base: __lhs_0, + index: __lhs_1, + }, + Expression::Access { + base: __arg1_0, + index: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::AccessIndex { + base: __lhs_0, + index: __lhs_1, + }, + Expression::AccessIndex { + base: __arg1_0, + index: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Splat { + size: __lhs_0, + value: __lhs_1, + }, + Expression::Splat { + size: __arg1_0, + value: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Swizzle { + size: __lhs_0, + vector: __lhs_1, + pattern: __lhs_2, + }, + Expression::Swizzle { + size: __arg1_0, + vector: __arg1_1, + pattern: __arg1_2, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1 && *__lhs_2 == *__arg1_2, + (Expression::FunctionArgument(__lhs_0), Expression::FunctionArgument(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + (Expression::GlobalVariable(__lhs_0), Expression::GlobalVariable(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + (Expression::LocalVariable(__lhs_0), Expression::LocalVariable(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + (Expression::Load { pointer: __lhs_0 }, Expression::Load { pointer: __arg1_0 }) => { + *__lhs_0 == *__arg1_0 + } + ( + Expression::ImageSample { + image: __lhs_0, + sampler: __lhs_1, + gather: __lhs_2, + coordinate: __lhs_3, + array_index: __lhs_4, + offset: __lhs_5, + level: __lhs_6, + depth_ref: __lhs_7, + }, + Expression::ImageSample { + image: __arg1_0, + sampler: __arg1_1, + gather: __arg1_2, + coordinate: __arg1_3, + array_index: __arg1_4, + offset: __arg1_5, + level: __arg1_6, + depth_ref: __arg1_7, + }, + ) => { + *__lhs_0 == *__arg1_0 + && *__lhs_1 == *__arg1_1 + && *__lhs_2 == *__arg1_2 + && *__lhs_3 == *__arg1_3 + && *__lhs_4 == *__arg1_4 + && *__lhs_5 == *__arg1_5 + && *__lhs_6 == *__arg1_6 + && *__lhs_7 == *__arg1_7 + } + ( + Expression::ImageLoad { + image: __lhs_0, + coordinate: __lhs_1, + array_index: __lhs_2, + sample: __lhs_3, + level: __lhs_4, + }, + Expression::ImageLoad { + image: __arg1_0, + coordinate: __arg1_1, + array_index: __arg1_2, + sample: __arg1_3, + level: __arg1_4, + }, + ) => { + *__lhs_0 == *__arg1_0 + && *__lhs_1 == *__arg1_1 + && *__lhs_2 == *__arg1_2 + && *__lhs_3 == *__arg1_3 + && *__lhs_4 == *__arg1_4 + } + ( + Expression::ImageQuery { + image: __lhs_0, + query: __lhs_1, + }, + Expression::ImageQuery { + image: __arg1_0, + query: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Unary { + op: __lhs_0, + expr: __lhs_1, + }, + Expression::Unary { + op: __arg1_0, + expr: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Binary { + op: __lhs_0, + left: __lhs_1, + right: __lhs_2, + }, + Expression::Binary { + op: __arg1_0, + left: __arg1_1, + right: __arg1_2, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1 && *__lhs_2 == *__arg1_2, + ( + Expression::Select { + condition: __lhs_0, + accept: __lhs_1, + reject: __lhs_2, + }, + Expression::Select { + condition: __arg1_0, + accept: __arg1_1, + reject: __arg1_2, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1 && *__lhs_2 == *__arg1_2, + ( + Expression::Derivative { + axis: __lhs_0, + ctrl: __lhs_1, + expr: __lhs_2, + }, + Expression::Derivative { + axis: __arg1_0, + ctrl: __arg1_1, + expr: __arg1_2, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1 && *__lhs_2 == *__arg1_2, + ( + Expression::Relational { + fun: __lhs_0, + argument: __lhs_1, + }, + Expression::Relational { + fun: __arg1_0, + argument: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::Math { + fun: __lhs_0, + arg: __lhs_1, + arg1: __lhs_2, + arg2: __lhs_3, + arg3: __lhs_4, + }, + Expression::Math { + fun: __arg1_0, + arg: __arg1_1, + arg1: __arg1_2, + arg2: __arg1_3, + arg3: __arg1_4, + }, + ) => { + *__lhs_0 == *__arg1_0 + && *__lhs_1 == *__arg1_1 + && *__lhs_2 == *__arg1_2 + && *__lhs_3 == *__arg1_3 + && *__lhs_4 == *__arg1_4 + } + ( + Expression::As { + expr: __lhs_0, + kind: __lhs_1, + convert: __lhs_2, + }, + Expression::As { + expr: __arg1_0, + kind: __arg1_1, + convert: __arg1_2, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1 && *__lhs_2 == *__arg1_2, + (Expression::CallResult(__lhs_0), Expression::CallResult(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + ( + Expression::AtomicResult { + ty: __lhs_0, + comparison: __lhs_1, + }, + Expression::AtomicResult { + ty: __arg1_0, + comparison: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + ( + Expression::WorkGroupUniformLoadResult { ty: __lhs_0 }, + Expression::WorkGroupUniformLoadResult { ty: __arg1_0 }, + ) => *__lhs_0 == *__arg1_0, + (Expression::ArrayLength(__lhs_0), Expression::ArrayLength(__arg1_0)) => { + *__lhs_0 == *__arg1_0 + } + ( + Expression::RayQueryGetIntersection { + query: __lhs_0, + committed: __lhs_1, + }, + Expression::RayQueryGetIntersection { + query: __arg1_0, + committed: __arg1_1, + }, + ) => *__lhs_0 == *__arg1_0 && *__lhs_1 == *__arg1_1, + _ => unreachable!(), + } +} diff --git a/src/derive.rs b/src/derive.rs index 12a1dae..3ae887a 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,10 +1,12 @@ use indexmap::IndexMap; use naga::{ - Arena, ArraySize, Block, Constant, ConstantInner, EntryPoint, Expression, Function, - FunctionArgument, FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, - SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, + Arena, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, FunctionResult, + GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, Statement, + StructMember, SwitchCase, Type, TypeInner, UniqueArena, }; -use std::collections::HashMap; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; + +use crate::compose::util::expression_eq; #[derive(Debug, Default)] pub struct DerivedModule<'a> { @@ -13,11 +15,13 @@ pub struct DerivedModule<'a> { type_map: HashMap, Handle>, const_map: HashMap, Handle>, + const_expression_map: Rc, Handle>>>, global_map: HashMap, Handle>, function_map: HashMap>, types: UniqueArena, constants: Arena, + const_expressions: Rc>>, globals: Arena, functions: Arena, } @@ -36,6 +40,7 @@ impl<'a> DerivedModule<'a> { self.type_map.clear(); self.const_map.clear(); self.global_map.clear(); + self.const_expression_map.borrow_mut().clear(); } pub fn map_span(&self, span: Span) -> Span { @@ -102,27 +107,15 @@ impl<'a> DerivedModule<'a> { span: *span, } } - TypeInner::Array { base, size, stride } => { - let size = match size { - ArraySize::Constant(c) => ArraySize::Constant(self.import_const(c)), - ArraySize::Dynamic => ArraySize::Dynamic, - }; - TypeInner::Array { - base: self.import_type(base), - size, - stride: *stride, - } - } - TypeInner::BindingArray { base, size } => { - let size = match size { - ArraySize::Constant(c) => ArraySize::Constant(self.import_const(c)), - ArraySize::Dynamic => ArraySize::Dynamic, - }; - TypeInner::BindingArray { - base: self.import_type(base), - size, - } - } + TypeInner::Array { base, size, stride } => TypeInner::Array { + base: self.import_type(base), + size: *size, + stride: *stride, + }, + TypeInner::BindingArray { base, size } => TypeInner::BindingArray { + base: self.import_type(base), + size: *size, + }, }, }; let span = self.shader.as_ref().unwrap().types.get_span(*h_type); @@ -145,17 +138,9 @@ impl<'a> DerivedModule<'a> { let new_const = Constant { name: c.name.clone(), - specialization: c.specialization, - inner: match &c.inner { - ConstantInner::Scalar { .. } => c.inner.clone(), - ConstantInner::Composite { ty, components } => { - let components = components.iter().map(|c| self.import_const(c)).collect(); - ConstantInner::Composite { - ty: self.import_type(ty), - components, - } - } - }, + r#override: c.r#override.clone(), + ty: self.import_type(&c.ty), + init: self.import_const_expression(c.init), }; let span = self.shader.as_ref().unwrap().constants.get_span(*h_const); @@ -183,7 +168,7 @@ impl<'a> DerivedModule<'a> { space: gv.space, binding: gv.binding.clone(), ty: self.import_type(&gv.ty), - init: gv.init.map(|c| self.import_const(&c)), + init: gv.init.map(|c| self.import_const_expression(c)), }; let span = self @@ -199,22 +184,34 @@ impl<'a> DerivedModule<'a> { new_h }) } + // remap a const expression from source context into our derived context + pub fn import_const_expression(&mut self, h_cexpr: Handle) -> Handle { + self.import_expression( + h_cexpr, + &self.shader.as_ref().unwrap().const_expressions, + self.const_expression_map.clone(), + self.const_expressions.clone(), + false, + true, + ) + } // remap a block fn import_block( &mut self, block: &Block, old_expressions: &Arena, - already_imported: &mut HashMap, Handle>, - new_expressions: &mut Arena, + already_imported: Rc, Handle>>>, + new_expressions: Rc>>, ) -> Block { macro_rules! map_expr { ($e:expr) => { self.import_expression( *$e, old_expressions, - already_imported, - new_expressions, + already_imported.clone(), + new_expressions.clone(), + false, false, ) }; @@ -228,7 +225,12 @@ impl<'a> DerivedModule<'a> { macro_rules! map_block { ($b:expr) => { - self.import_block($b, old_expressions, already_imported, new_expressions) + self.import_block( + $b, + old_expressions, + already_imported.clone(), + new_expressions.clone(), + ) }; } @@ -286,18 +288,19 @@ impl<'a> DerivedModule<'a> { self.import_expression( expr, old_expressions, - already_imported, - new_expressions, + already_imported.clone(), + new_expressions.clone(), true, + false, ); } - let old_length = new_expressions.len(); + let old_length = new_expressions.borrow().len(); // iterate again to add expressions that should be part of the emit statement for expr in exprs.clone() { map_expr!(&expr); } - Statement::Emit(new_expressions.range_from(old_length)) + Statement::Emit(new_expressions.borrow().range_from(old_length)) } Statement::Store { pointer, value } => Statement::Store { pointer: map_expr!(pointer), @@ -325,6 +328,12 @@ impl<'a> DerivedModule<'a> { value: map_expr!(value), result: map_expr!(result), }, + Statement::WorkGroupUniformLoad { pointer, result } => { + Statement::WorkGroupUniformLoad { + pointer: map_expr!(pointer), + result: map_expr!(result), + } + } Statement::Return { value } => Statement::Return { value: map_expr_opt!(value), }, @@ -369,11 +378,12 @@ impl<'a> DerivedModule<'a> { &mut self, h_expr: Handle, old_expressions: &Arena, - already_imported: &mut HashMap, Handle>, - new_expressions: &mut Arena, + already_imported: Rc, Handle>>>, + new_expressions: Rc>>, non_emitting_only: bool, // only brings items that should NOT be emitted into scope + unique: bool, // ensure expressions are unique with custom comparison ) -> Handle { - if let Some(h_new) = already_imported.get(&h_expr) { + if let Some(h_new) = already_imported.borrow().get(&h_expr) { return *h_new; } @@ -382,9 +392,10 @@ impl<'a> DerivedModule<'a> { self.import_expression( *$e, old_expressions, - already_imported, - new_expressions, + already_imported.clone(), + new_expressions.clone(), non_emitting_only, + unique, ) }; } @@ -395,9 +406,10 @@ impl<'a> DerivedModule<'a> { self.import_expression( *expr, old_expressions, - already_imported, - new_expressions, + already_imported.clone(), + new_expressions.clone(), non_emitting_only, + unique, ) }) }; @@ -406,6 +418,14 @@ impl<'a> DerivedModule<'a> { let mut is_external = false; let expr = old_expressions.try_get(h_expr).unwrap(); let expr = match expr { + Expression::Literal(_) => { + is_external = true; + expr.clone() + } + Expression::ZeroValue(zv) => { + is_external = true; + Expression::ZeroValue(self.import_type(zv)) + } Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)), Expression::Constant(c) => { is_external = true; @@ -434,7 +454,7 @@ impl<'a> DerivedModule<'a> { gather: *gather, coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), - offset: offset.map(|c| self.import_const(&c)), + offset: offset.map(|c| self.import_const_expression(c)), level: match level { SampleLevel::Auto | SampleLevel::Zero => *level, SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)), @@ -549,6 +569,7 @@ impl<'a> DerivedModule<'a> { } Expression::AtomicResult { .. } => expr.clone(), + Expression::WorkGroupUniformLoadResult { .. } => expr.clone(), Expression::RayQueryProceedResult => expr.clone(), Expression::RayQueryGetIntersection { query, committed } => { Expression::RayQueryGetIntersection { @@ -560,9 +581,19 @@ impl<'a> DerivedModule<'a> { if !non_emitting_only || is_external { let span = old_expressions.get_span(h_expr); - let h_new = new_expressions.append(expr, self.map_span(span)); + let h_new = if unique { + new_expressions.borrow_mut().fetch_if_or_append( + expr, + self.map_span(span), + expression_eq, + ) + } else { + new_expressions + .borrow_mut() + .append(expr, self.map_span(span)) + }; - already_imported.insert(h_expr, h_new); + already_imported.borrow_mut().insert(h_expr, h_new); h_new } else { h_expr @@ -591,27 +622,32 @@ impl<'a> DerivedModule<'a> { let new_local = LocalVariable { name: l.name.clone(), ty: self.import_type(&l.ty), - init: l.init.map(|c| self.import_const(&c)), + init: l.init.map(|c| self.import_const_expression(c)), }; let span = func.local_variables.get_span(h_l); let new_h = local_variables.append(new_local, self.map_span(span)); assert_eq!(h_l, new_h); } - let mut expressions = Arena::new(); - let mut expr_map = HashMap::new(); + let expressions = Rc::new(RefCell::new(Arena::new())); + let expr_map = Rc::new(RefCell::new(HashMap::new())); let body = self.import_block( &func.body, &func.expressions, - &mut expr_map, - &mut expressions, + expr_map.clone(), + expressions.clone(), ); let named_expressions = func .named_expressions .iter() - .flat_map(|(h_expr, name)| expr_map.get(h_expr).map(|new_h| (*new_h, name.clone()))) + .flat_map(|(h_expr, name)| { + expr_map + .borrow() + .get(h_expr) + .map(|new_h| (*new_h, name.clone())) + }) .collect::>>(); Function { @@ -619,7 +655,7 @@ impl<'a> DerivedModule<'a> { arguments, result, local_variables, - expressions, + expressions: Rc::try_unwrap(expressions).unwrap().into_inner(), named_expressions, body, } @@ -688,6 +724,9 @@ impl<'a> From> for naga::Module { types: derived.types, constants: derived.constants, global_variables: derived.globals, + const_expressions: Rc::try_unwrap(derived.const_expressions) + .unwrap() + .into_inner(), functions: derived.functions, special_types: Default::default(), entry_points: Default::default(), diff --git a/src/prune/mod.rs b/src/prune/mod.rs index 732762d..7566cac 100644 --- a/src/prune/mod.rs +++ b/src/prune/mod.rs @@ -93,6 +93,9 @@ impl FunctionReq { expr_map: &HashMap, Handle>, ) -> Expression { match expr { + Expression::Literal(_) => expr.clone(), + Expression::ZeroValue(_) => expr.clone(), + Expression::WorkGroupUniformLoadResult { ty: _ty } => expr.clone(), Expression::Access { base, index } => Expression::Access { base: expr_map[base], index: expr_map[index], @@ -1150,6 +1153,9 @@ impl<'a> Pruner<'a> { ); match expr { + Expression::Literal(_) => (), + Expression::ZeroValue(_) => (), + Expression::WorkGroupUniformLoadResult { .. } => (), Expression::AccessIndex { base, index } => self.add_expression( function, func_req, @@ -1236,10 +1242,6 @@ impl<'a> Pruner<'a> { context.locals.insert(*lv, part.clone()); } } - let lv = function.local_variables.try_get(*lv).unwrap(); - if let Some(init) = lv.init { - self.constants.insert(init); - } } Expression::Load { pointer } => { self.add_expression(function, func_req, context, *pointer, part); @@ -1250,7 +1252,7 @@ impl<'a> Pruner<'a> { gather: _gather, coordinate, array_index, - offset, + offset: _offset, level, depth_ref, } => { @@ -1259,7 +1261,6 @@ impl<'a> Pruner<'a> { self.add_expression(function, func_req, context, *coordinate, &PartReq::All); array_index .map(|e| self.add_expression(function, func_req, context, e, &PartReq::All)); - offset.map(|c| self.constants.insert(c)); match level { naga::SampleLevel::Auto | naga::SampleLevel::Zero => (), naga::SampleLevel::Exact(e) | naga::SampleLevel::Bias(e) => { @@ -1654,6 +1655,14 @@ impl<'a> Pruner<'a> { let required = self.store_required(context, &var_ref); RayQuery(required.is_some()) } + Statement::WorkGroupUniformLoad { pointer, result } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + self.add_expression(function, func_req, context, *pointer, &PartReq::All); + } + RayQuery(required) + } } } @@ -1780,6 +1789,11 @@ impl<'a> Pruner<'a> { let mut derived = DerivedModule::default(); derived.set_shader_source(self.module, 0); + // just copy all the constants for now, so we can copy const handles as well + for (h_cexpr, _) in self.module.const_expressions.iter() { + derived.import_const_expression(h_cexpr); + } + for (h_f, f) in self.module.functions.iter() { if let Some(req) = self.functions.get(&h_f) { if req.body_required.is_required() { diff --git a/src/redirect.rs b/src/redirect.rs index 32e6c5f..eca25bd 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -62,6 +62,7 @@ impl Redirector { | Statement::Break | Statement::Continue | Statement::Return { .. } + | Statement::WorkGroupUniformLoad { .. } | Statement::Kill | Statement::Barrier(_) | Statement::Store { .. }