Skip to content

Commit

Permalink
[naga wgsl] Impl const_assert (#6198)
Browse files Browse the repository at this point in the history
Signed-off-by: sagudev <[email protected]>
  • Loading branch information
sagudev authored Sep 2, 2024
1 parent ace2e20 commit 4e9a2a5
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 17 deletions.
18 changes: 18 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ pub(crate) enum Error<'a> {
limit: u8,
},
PipelineConstantIDValue(Span),
NotBool(Span),
ConstAssertFailed(Span),
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -815,6 +817,22 @@ impl<'a> Error<'a> {
)],
notes: vec![],
},
Error::NotBool(span) => ParseError {
message: "must be a const-expression that resolves to a bool".to_string(),
labels: vec![(
span,
"must resolve to bool".into(),
)],
notes: vec![],
},
Error::ConstAssertFailed(span) => ParseError {
message: "const_assert failure".to_string(),
labels: vec![(
span,
"evaluates to false".into(),
)],
notes: vec![],
},
}
}
}
Expand Down
43 changes: 26 additions & 17 deletions naga/src/front/wgsl/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ impl<'a> Index<'a> {
// While doing so, reject conflicting definitions.
let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
for (handle, decl) in tu.decls.iter() {
let ident = decl_ident(decl);
let name = ident.name;
if let Some(old) = globals.insert(name, handle) {
return Err(Error::Redefinition {
previous: decl_ident(&tu.decls[old]).span,
current: ident.span,
});
if let Some(ident) = decl_ident(decl) {
let name = ident.name;
if let Some(old) = globals.insert(name, handle) {
return Err(Error::Redefinition {
previous: decl_ident(&tu.decls[old])
.expect("decl should have ident for redefinition")
.span,
current: ident.span,
});
}
}
}

Expand Down Expand Up @@ -130,7 +133,7 @@ impl<'a> DependencySolver<'a, '_> {
return if dep_id == id {
// A declaration refers to itself directly.
Err(Error::RecursiveDeclaration {
ident: decl_ident(decl).span,
ident: decl_ident(decl).expect("decl should have ident").span,
usage: dep.usage,
})
} else {
Expand All @@ -146,14 +149,19 @@ impl<'a> DependencySolver<'a, '_> {
.unwrap_or(0);

Err(Error::CyclicDeclaration {
ident: decl_ident(&self.module.decls[dep_id]).span,
ident: decl_ident(&self.module.decls[dep_id])
.expect("decl should have ident")
.span,
path: self.path[start_at..]
.iter()
.map(|curr_dep| {
let curr_id = curr_dep.decl;
let curr_decl = &self.module.decls[curr_id];

(decl_ident(curr_decl).span, curr_dep.usage)
(
decl_ident(curr_decl).expect("decl should have ident").span,
curr_dep.usage,
)
})
.collect(),
})
Expand Down Expand Up @@ -182,13 +190,14 @@ impl<'a> DependencySolver<'a, '_> {
}
}

const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> Option<ast::Ident<'a>> {
match decl.kind {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
ast::GlobalDeclKind::Fn(ref f) => Some(f.name),
ast::GlobalDeclKind::Var(ref v) => Some(v.name),
ast::GlobalDeclKind::Const(ref c) => Some(c.name),
ast::GlobalDeclKind::Override(ref o) => Some(o.name),
ast::GlobalDeclKind::Struct(ref s) => Some(s.name),
ast::GlobalDeclKind::Type(ref t) => Some(t.name),
ast::GlobalDeclKind::ConstAssert(_) => None,
}
}
36 changes: 36 additions & 0 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
}
ast::GlobalDeclKind::ConstAssert(condition) => {
let condition = self.expression(condition, &mut ctx.as_const())?;

let span = ctx.module.global_expressions.get_span(condition);
match ctx
.module
.to_ctx()
.eval_expr_to_bool_from(condition, &ctx.module.global_expressions)
{
Some(true) => Ok(()),
Some(false) => Err(Error::ConstAssertFailed(span)),
_ => Err(Error::NotBool(span)),
}?;
}
}
}

Expand Down Expand Up @@ -1742,6 +1756,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
value,
}
}
ast::StatementKind::ConstAssert(condition) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let condition =
self.expression(condition, &mut ctx.as_const(block, &mut emitter))?;

let span = ctx.function.expressions.get_span(condition);
match ctx
.module
.to_ctx()
.eval_expr_to_bool_from(condition, &ctx.function.expressions)
{
Some(true) => Ok(()),
Some(false) => Err(Error::ConstAssertFailed(span)),
_ => Err(Error::NotBool(span)),
}?;

block.extend(emitter.finish(&ctx.function.expressions));

return Ok(());
}
ast::StatementKind::Ignore(expr) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/wgsl/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub enum GlobalDeclKind<'a> {
Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
ConstAssert(Handle<Expression<'a>>),
}

#[derive(Debug)]
Expand Down Expand Up @@ -284,6 +285,7 @@ pub enum StatementKind<'a> {
Increment(Handle<Expression<'a>>),
Decrement(Handle<Expression<'a>>),
Ignore(Handle<Expression<'a>>),
ConstAssert(Handle<Expression<'a>>),
}

#[derive(Debug)]
Expand Down
26 changes: 26 additions & 0 deletions naga/src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,20 @@ impl Parser {
lexer.expect(Token::Separator(';'))?;
ast::StatementKind::Kill
}
// https://www.w3.org/TR/WGSL/#const-assert-statement
"const_assert" => {
let _ = lexer.next();
// parentheses are optional
let paren = lexer.skip(Token::Paren('('));

let condition = self.general_expression(lexer, ctx)?;

if paren {
lexer.expect(Token::Paren(')'))?;
}
lexer.expect(Token::Separator(';'))?;
ast::StatementKind::ConstAssert(condition)
}
// assignment or a function call
_ => {
self.function_call_or_assignment_statement(lexer, ctx, block)?;
Expand Down Expand Up @@ -2419,6 +2433,18 @@ impl Parser {
..function
}))
}
(Token::Word("const_assert"), _) => {
// parentheses are optional
let paren = lexer.skip(Token::Paren('('));

let condition = self.general_expression(lexer, &mut ctx)?;

if paren {
lexer.expect(Token::Paren(')'))?;
}
lexer.expect(Token::Separator(';'))?;
Some(ast::GlobalDeclKind::ConstAssert(condition))
}
(Token::End, _) => return Ok(()),
other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
};
Expand Down
13 changes: 13 additions & 0 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,19 @@ impl GlobalCtx<'_> {
}
}

/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`.
#[allow(dead_code)]
pub(super) fn eval_expr_to_bool_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Option<bool> {
match self.eval_expr_to_literal_from(handle, arena) {
Some(crate::Literal::Bool(value)) => Some(value),
_ => None,
}
}

#[allow(dead_code)]
pub(crate) fn eval_expr_to_literal(
&self,
Expand Down
11 changes: 11 additions & 0 deletions naga/tests/in/const_assert.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Sourced from https://www.w3.org/TR/WGSL/#const-assert-statement
const x = 1;
const y = 2;
const_assert x < y; // valid at module-scope.
const_assert(y != 0); // parentheses are optional.

fn foo() {
const z = x + y - 2;
const_assert z > 0; // valid in functions.
const_assert(z > 0);
}
54 changes: 54 additions & 0 deletions naga/tests/out/ir/const_assert.compact.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
functions: [
(
name: Some("foo"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(1)),
],
named_expressions: {
0: "z",
},
body: [
Return(
value: None,
),
],
),
],
entry_points: [],
)
54 changes: 54 additions & 0 deletions naga/tests/out/ir/const_assert.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
functions: [
(
name: Some("foo"),
arguments: [],
result: None,
local_variables: [],
expressions: [
Literal(I32(1)),
],
named_expressions: {
0: "z",
},
body: [
Return(
value: None,
),
],
),
],
entry_points: [],
)
7 changes: 7 additions & 0 deletions naga/tests/out/wgsl/const_assert.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const x: i32 = 1i;
const y: i32 = 2i;

fn foo() {
return;
}

1 change: 1 addition & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ fn convert_wgsl() {
"const-exprs",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("const_assert", Targets::WGSL | Targets::IR),
("separate-entry-points", Targets::SPIRV | Targets::GLSL),
(
"struct-layout",
Expand Down
Loading

0 comments on commit 4e9a2a5

Please sign in to comment.