Skip to content

Commit

Permalink
[wgsl-in] Fail on repeated attributes (gfx-rs#2428)
Browse files Browse the repository at this point in the history
* [wgsl-in] Fail on repeated attributes

Fixes gfx-rs#2425.

* [wgsl-in] Use ParsedAttribute to keep track of parsed attributes
  • Loading branch information
fornwall authored Aug 13, 2023
1 parent 30afa5b commit 7a19f3a
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 41 deletions.
6 changes: 6 additions & 0 deletions src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ pub enum Error<'a> {
InvalidIdentifierUnderscore(Span),
ReservedIdentifierPrefix(Span),
UnknownAddressSpace(Span),
RepeatedAttribute(Span),
UnknownAttribute(Span),
UnknownBuiltin(Span),
UnknownAccess(Span),
Expand Down Expand Up @@ -430,6 +431,11 @@ impl<'a> Error<'a> {
labels: vec![(bad_span, "unknown address space".into())],
notes: vec![],
},
Error::RepeatedAttribute(bad_span) => ParseError {
message: format!("repeated attribute: '{}'", &source[bad_span]),
labels: vec![(bad_span, "repated attribute".into())],
notes: vec![],
},
Error::UnknownAttribute(bad_span) => ParseError {
message: format!("unknown attribute: '{}'", &source[bad_span]),
labels: vec![(bad_span, "unknown attribute".into())],
Expand Down
109 changes: 68 additions & 41 deletions src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,33 @@ enum Rule {
GeneralExpr,
}

struct ParsedAttribute<T> {
value: Option<T>,
}

impl<T> Default for ParsedAttribute<T> {
fn default() -> Self {
Self { value: None }
}
}

impl<T> ParsedAttribute<T> {
fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> {
if self.value.is_some() {
return Err(Error::RepeatedAttribute(name_span));
}
self.value = Some(value);
Ok(())
}
}

#[derive(Default)]
struct BindingParser {
location: Option<u32>,
built_in: Option<crate::BuiltIn>,
interpolation: Option<crate::Interpolation>,
sampling: Option<crate::Sampling>,
invariant: bool,
location: ParsedAttribute<u32>,
built_in: ParsedAttribute<crate::BuiltIn>,
interpolation: ParsedAttribute<crate::Interpolation>,
sampling: ParsedAttribute<crate::Sampling>,
invariant: ParsedAttribute<bool>,
}

impl BindingParser {
Expand All @@ -139,38 +159,44 @@ impl BindingParser {
match name {
"location" => {
lexer.expect(Token::Paren('('))?;
self.location = Some(Parser::non_negative_i32_literal(lexer)?);
self.location
.set(Parser::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
"builtin" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
self.built_in = Some(conv::map_built_in(raw, span)?);
self.built_in
.set(conv::map_built_in(raw, span)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
"interpolate" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
self.interpolation = Some(conv::map_interpolation(raw, span)?);
self.interpolation
.set(conv::map_interpolation(raw, span)?, name_span)?;
if lexer.skip(Token::Separator(',')) {
let (raw, span) = lexer.next_ident_with_span()?;
self.sampling = Some(conv::map_sampling(raw, span)?);
self.sampling
.set(conv::map_sampling(raw, span)?, name_span)?;
}
lexer.expect(Token::Paren(')'))?;
}
"invariant" => self.invariant = true,
"invariant" => {
self.invariant.set(true, name_span)?;
}
_ => return Err(Error::UnknownAttribute(name_span)),
}
Ok(())
}

const fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
match (
self.location,
self.built_in,
self.interpolation,
self.sampling,
self.invariant,
self.location.value,
self.built_in.value,
self.interpolation.value,
self.sampling.value,
self.invariant.value.unwrap_or_default(),
) {
(None, None, None, None, false) => Ok(None),
(Some(location), None, interpolation, sampling, false) => {
Expand Down Expand Up @@ -990,22 +1016,22 @@ impl Parser {
ExpectedToken::Token(Token::Separator(',')),
));
}
let (mut size, mut align) = (None, None);
let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default());
self.push_rule_span(Rule::Attribute, lexer);
let mut bind_parser = BindingParser::default();
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("size", _) => {
("size", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
size = Some((value, span));
size.set((value, span), name_span)?;
}
("align", _) => {
("align", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
align = Some((value, span));
align.set((value, span), name_span)?;
}
(word, word_span) => bind_parser.parse(lexer, word, word_span)?,
}
Expand All @@ -1023,8 +1049,8 @@ impl Parser {
name,
ty,
binding,
size,
align,
size: size.value,
align: align.value,
});
}

Expand Down Expand Up @@ -2131,32 +2157,33 @@ impl Parser {
) -> Result<(), Error<'a>> {
// read attributes
let mut binding = None;
let mut stage = None;
let mut stage = ParsedAttribute::default();
let mut workgroup_size = [0u32; 3];
let mut early_depth_test = None;
let (mut bind_index, mut bind_group) = (None, None);
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());

self.push_rule_span(Rule::Attribute, lexer);
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("binding", _) => {
("binding", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_index = Some(Self::non_negative_i32_literal(lexer)?);
bind_index.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("group", _) => {
("group", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_group = Some(Self::non_negative_i32_literal(lexer)?);
bind_group.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("vertex", _) => {
stage = Some(crate::ShaderStage::Vertex);
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
("fragment", _) => {
stage = Some(crate::ShaderStage::Fragment);
("fragment", name_span) => {
stage.set(crate::ShaderStage::Fragment, name_span)?;
}
("compute", _) => {
stage = Some(crate::ShaderStage::Compute);
("compute", name_span) => {
stage.set(crate::ShaderStage::Compute, name_span)?;
}
("workgroup_size", _) => {
lexer.expect(Token::Paren('('))?;
Expand All @@ -2175,7 +2202,7 @@ impl Parser {
}
}
}
("early_depth_test", _) => {
("early_depth_test", name_span) => {
let conservative = if lexer.skip(Token::Paren('(')) {
let (ident, ident_span) = lexer.next_ident_with_span()?;
let value = conv::map_conservative_depth(ident, ident_span)?;
Expand All @@ -2184,14 +2211,14 @@ impl Parser {
} else {
None
};
early_depth_test = Some(crate::EarlyDepthTest { conservative });
early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?;
}
(_, word_span) => return Err(Error::UnknownAttribute(word_span)),
}
}

let attrib_span = self.pop_rule_span(lexer);
match (bind_group, bind_index) {
match (bind_group.value, bind_index.value) {
(Some(group), Some(index)) => {
binding = Some(crate::ResourceBinding {
group,
Expand Down Expand Up @@ -2254,9 +2281,9 @@ impl Parser {
(Token::Word("fn"), _) => {
let function = self.function_decl(lexer, out, &mut dependencies)?;
Some(ast::GlobalDeclKind::Fn(ast::Function {
entry_point: stage.map(|stage| ast::EntryPoint {
entry_point: stage.value.map(|stage| ast::EntryPoint {
stage,
early_depth_test,
early_depth_test: early_depth_test.value,
workgroup_size,
}),
..function
Expand Down
39 changes: 39 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,42 @@ fn parse_texture_load_store_expecting_four_args() {
);
}
}

#[test]
fn parse_repeated_attributes() {
use crate::{
front::wgsl::{error::Error, Frontend},
Span,
};

let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
for (attribute, template) in [
("align(16)", template_struct),
("binding(0)", template_resource),
("builtin(position)", template_vs),
("compute", template_stage),
("fragment", template_stage),
("group(0)", template_resource),
("interpolate(flat)", template_vs),
("invariant", template_vs),
("location(0)", template_vs),
("size(16)", template_struct),
("vertex", template_stage),
("early_depth_test(less_equal)", template_resource),
] {
let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
let span_start = shader.rfind(attribute).unwrap() as u32;
let span_end = span_start + name_length;
let expected_span = Span::new(span_start, span_end);

let result = Frontend::new().inner(&shader);
assert!(matches!(
result.unwrap_err(),
Error::RepeatedAttribute(span) if span == expected_span
));
}
}

0 comments on commit 7a19f3a

Please sign in to comment.