diff --git a/crates/bevy_render/src/render_resource/shader.rs b/crates/bevy_render/src/render_resource/shader.rs index 6dd84961111e5..b0adfe8b39723 100644 --- a/crates/bevy_render/src/render_resource/shader.rs +++ b/crates/bevy_render/src/render_resource/shader.rs @@ -360,16 +360,16 @@ impl ShaderProcessor { } }; - let shader_defs = HashSet::::from_iter(shader_defs.iter().cloned()); + let shader_defs_unique = HashSet::::from_iter(shader_defs.iter().cloned()); let mut scopes = vec![true]; let mut final_string = String::new(); for line in shader_str.split('\n') { if let Some(cap) = self.ifdef_regex.captures(line) { let def = cap.get(1).unwrap(); - scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str())); + scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str())); } else if let Some(cap) = self.ifndef_regex.captures(line) { let def = cap.get(1).unwrap(); - scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str())); + scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str())); } else if self.else_regex.is_match(line) { let mut is_parent_scope_truthy = true; if scopes.len() > 1 { @@ -388,19 +388,32 @@ impl ShaderProcessor { .captures(line) { let import = ShaderImport::AssetPath(cap.get(1).unwrap().as_str().to_string()); - apply_import(import_handles, shaders, &import, shader, &mut final_string)?; + self.apply_import( + import_handles, + shaders, + &import, + shader, + shader_defs, + &mut final_string, + )?; } else if let Some(cap) = SHADER_IMPORT_PROCESSOR .import_custom_path_regex .captures(line) { let import = ShaderImport::Custom(cap.get(1).unwrap().as_str().to_string()); - apply_import(import_handles, shaders, &import, shader, &mut final_string)?; + self.apply_import( + import_handles, + shaders, + &import, + shader, + shader_defs, + &mut final_string, + )?; } else if *scopes.last().unwrap() { final_string.push_str(line); final_string.push('\n'); } } - final_string.pop(); if scopes.len() != 1 { @@ -417,45 +430,51 @@ impl ShaderProcessor { } } } -} -fn apply_import( - import_handles: &HashMap>, - shaders: &HashMap, Shader>, - import: &ShaderImport, - shader: &Shader, - final_string: &mut String, -) -> Result<(), ProcessShaderError> { - let imported_shader = import_handles - .get(import) - .and_then(|handle| shaders.get(handle)) - .ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?; - match &shader.source { - Source::Wgsl(_) => { - if let Source::Wgsl(import_source) = &imported_shader.source { - final_string.push_str(import_source); - } else { - return Err(ProcessShaderError::MismatchedImportFormat(import.clone())); + fn apply_import( + &self, + import_handles: &HashMap>, + shaders: &HashMap, Shader>, + import: &ShaderImport, + shader: &Shader, + shader_defs: &[String], + final_string: &mut String, + ) -> Result<(), ProcessShaderError> { + let imported_shader = import_handles + .get(import) + .and_then(|handle| shaders.get(handle)) + .ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?; + let imported_processed = + self.process(imported_shader, shader_defs, shaders, import_handles)?; + + match &shader.source { + Source::Wgsl(_) => { + if let ProcessedShader::Wgsl(import_source) = &imported_processed { + final_string.push_str(import_source); + } else { + return Err(ProcessShaderError::MismatchedImportFormat(import.clone())); + } } - } - Source::Glsl(_, _) => { - if let Source::Glsl(import_source, _) = &imported_shader.source { - final_string.push_str(import_source); - } else { - return Err(ProcessShaderError::MismatchedImportFormat(import.clone())); + Source::Glsl(_, _) => { + if let ProcessedShader::Glsl(import_source, _) = &imported_processed { + final_string.push_str(import_source); + } else { + return Err(ProcessShaderError::MismatchedImportFormat(import.clone())); + } + } + Source::SpirV(_) => { + return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports); } } - Source::SpirV(_) => { - return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports); - } - } - Ok(()) + Ok(()) + } } #[cfg(test)] mod tests { - use bevy_asset::Handle; + use bevy_asset::{Handle, HandleUntyped}; + use bevy_reflect::TypeUuid; use bevy_utils::HashMap; use naga::ShaderStage; @@ -1094,4 +1113,106 @@ fn vertex( .unwrap(); assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED); } + + #[test] + fn process_import_ifdef() { + #[rustfmt::skip] + const FOO: &str = r" +#ifdef IMPORT_MISSING +fn in_import_missing() { } +#endif +#ifdef IMPORT_PRESENT +fn in_import_present() { } +#endif +"; + #[rustfmt::skip] + const INPUT: &str = r" +#import FOO +#ifdef MAIN_MISSING +fn in_main_missing() { } +#endif +#ifdef MAIN_PRESENT +fn in_main_present() { } +#endif +"; + #[rustfmt::skip] + const EXPECTED: &str = r" + +fn in_import_present() { } +fn in_main_present() { } +"; + let processor = ShaderProcessor::default(); + let mut shaders = HashMap::default(); + let mut import_handles = HashMap::default(); + let foo_handle = Handle::::default(); + shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO)); + import_handles.insert( + ShaderImport::Custom("FOO".to_string()), + foo_handle.clone_weak(), + ); + let result = processor + .process( + &Shader::from_wgsl(INPUT), + &["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()], + &shaders, + &import_handles, + ) + .unwrap(); + assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED); + } + + #[test] + fn process_import_in_import() { + #[rustfmt::skip] + const BAR: &str = r" +#ifdef DEEP +fn inner_import() { } +#endif +"; + const FOO: &str = r" +#import BAR +fn import() { } +"; + #[rustfmt::skip] + const INPUT: &str = r" +#import FOO +fn in_main() { } +"; + #[rustfmt::skip] + const EXPECTED: &str = r" + + +fn inner_import() { } +fn import() { } +fn in_main() { } +"; + let processor = ShaderProcessor::default(); + let mut shaders = HashMap::default(); + let mut import_handles = HashMap::default(); + { + let bar_handle = Handle::::default(); + shaders.insert(bar_handle.clone_weak(), Shader::from_wgsl(BAR)); + import_handles.insert( + ShaderImport::Custom("BAR".to_string()), + bar_handle.clone_weak(), + ); + } + { + let foo_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1).typed(); + shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO)); + import_handles.insert( + ShaderImport::Custom("FOO".to_string()), + foo_handle.clone_weak(), + ); + } + let result = processor + .process( + &Shader::from_wgsl(INPUT), + &["DEEP".to_string()], + &shaders, + &import_handles, + ) + .unwrap(); + assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED); + } }