From e89c045c5f636087f9c81f6a0f13878918463a09 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 19 Apr 2023 18:28:10 +0200 Subject: [PATCH 1/5] [glsl-in] refactor: use `Context::add_expression` in all places --- src/front/glsl/builtins.rs | 14 +-- src/front/glsl/context.rs | 132 ++++++++++++++------------ src/front/glsl/functions.rs | 24 ++--- src/front/glsl/parser/declarations.rs | 7 +- 4 files changed, 93 insertions(+), 84 deletions(-) diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index fa5d5d9b9e..55ff8fbc9b 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -2057,7 +2057,7 @@ impl MacroCall { body, ), MacroCall::Mod(size) => { - ctx.implicit_splat(frontend, &mut args[1], meta, size)?; + ctx.implicit_splat(frontend, &mut args[1], meta, size, body)?; // x - y * floor(x / y) @@ -2101,7 +2101,7 @@ impl MacroCall { ) } MacroCall::Splatted(fun, size, i) => { - ctx.implicit_splat(frontend, &mut args[i], meta, size)?; + ctx.implicit_splat(frontend, &mut args[i], meta, size, body)?; ctx.add_expression( Expression::Math { @@ -2125,8 +2125,8 @@ impl MacroCall { body, ), MacroCall::Clamp(size) => { - ctx.implicit_splat(frontend, &mut args[1], meta, size)?; - ctx.implicit_splat(frontend, &mut args[2], meta, size)?; + ctx.implicit_splat(frontend, &mut args[1], meta, size, body)?; + ctx.implicit_splat(frontend, &mut args[2], meta, size, body)?; ctx.add_expression( Expression::Math { @@ -2164,8 +2164,8 @@ impl MacroCall { return Ok(None); } MacroCall::SmoothStep { splatted } => { - ctx.implicit_splat(frontend, &mut args[0], meta, splatted)?; - ctx.implicit_splat(frontend, &mut args[1], meta, splatted)?; + ctx.implicit_splat(frontend, &mut args[0], meta, splatted, body)?; + ctx.implicit_splat(frontend, &mut args[1], meta, splatted, body)?; ctx.add_expression( Expression::Math { @@ -2196,7 +2196,7 @@ fn texture_call( let mut array_index = comps.array_index; if let Some(ref mut array_index_expr) = array_index { - ctx.conversion(array_index_expr, meta, Sk::Sint, 4)?; + ctx.conversion(array_index_expr, meta, Sk::Sint, 4, body)?; } Ok(ctx.add_expression( diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index e16dc125e7..755f2eba9c 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -114,28 +114,19 @@ impl Context { }: GlobalLookup, body: &mut Block, ) { - self.emit_end(body); let (expr, load, constant) = match kind { GlobalLookupKind::Variable(v) => { let span = frontend.module.global_variables.get_span(v); - let res = ( - self.expressions.append(Expression::GlobalVariable(v), span), + ( + self.add_expression(Expression::GlobalVariable(v), span, body), frontend.module.global_variables[v].space != AddressSpace::Handle, None, - ); - self.emit_start(); - - res + ) } GlobalLookupKind::BlockSelect(handle, index) => { let span = frontend.module.global_variables.get_span(handle); - let base = self - .expressions - .append(Expression::GlobalVariable(handle), span); - self.emit_start(); - let expr = self - .expressions - .append(Expression::AccessIndex { base, index }, span); + let base = self.add_expression(Expression::GlobalVariable(handle), span, body); + let expr = self.add_expression(Expression::AccessIndex { base, index }, span, body); ( expr, @@ -162,13 +153,11 @@ impl Context { } GlobalLookupKind::Constant(v, ty) => { let span = frontend.module.constants.get_span(v); - let res = ( - self.expressions.append(Expression::Constant(v), span), + ( + self.add_expression(Expression::Constant(v), span, body), false, Some((v, ty)), - ); - self.emit_start(); - res + ) } }; @@ -568,9 +557,10 @@ impl Context { right_meta, ScalarKind::Uint, 4, + body, )?, _ => self.binary_implicit_conversion( - frontend, &mut left, left_meta, &mut right, right_meta, + frontend, &mut left, left_meta, &mut right, right_meta, body, )?, } @@ -632,20 +622,21 @@ impl Context { ); // Divide the vectors - let column = self.expressions.append( + let column = self.add_expression( Expression::Binary { op, left: left_vector, right: right_vector, }, meta, + body, ); components.push(column) } // Rebuild the matrix from the divided vectors - self.expressions.append( + self.add_expression( Expression::Compose { ty: frontend.module.types.insert( Type { @@ -661,6 +652,7 @@ impl Context { components, }, meta, + body, ) } BinaryOperator::Equal | BinaryOperator::NotEqual => { @@ -699,13 +691,14 @@ impl Context { body, ); - let argument = self.expressions.append( + let argument = self.add_expression( Expression::Binary { op, left: left_vector, right: right_vector, }, meta, + body, ); // The result of comparing two vectors is a boolean vector @@ -750,9 +743,11 @@ impl Context { false => (BinaryOperator::NotEqual, RelationalFunction::Any), }; - let argument = self - .expressions - .append(Expression::Binary { op, left, right }, meta); + let argument = self.add_expression( + Expression::Binary { op, left, right }, + meta, + body, + ); self.add_expression( Expression::Relational { fun, argument }, @@ -872,20 +867,21 @@ impl Context { // Apply the operation to the splatted vector and // the column vector - let column = self.expressions.append( + let column = self.add_expression( Expression::Binary { op, left: scalar_vector, right: matrix_column, }, meta, + body, ); components.push(column) } // Rebuild the matrix from the operation result vectors - self.expressions.append( + self.add_expression( Expression::Compose { ty: frontend.module.types.insert( Type { @@ -901,6 +897,7 @@ impl Context { components, }, meta, + body, ) } _ => self.add_expression( @@ -963,20 +960,21 @@ impl Context { // Apply the operation to the splatted vector and // the column vector - let column = self.expressions.append( + let column = self.add_expression( Expression::Binary { op, left: matrix_column, right: scalar_vector, }, meta, + body, ); components.push(column) } // Rebuild the matrix from the operation result vectors - self.expressions.append( + self.add_expression( Expression::Compose { ty: frontend.module.types.insert( Type { @@ -992,6 +990,7 @@ impl Context { components, }, meta, + body, ) } _ => self.add_expression( @@ -1137,27 +1136,31 @@ impl Context { ) { match accept_power.cmp(&reject_power) { std::cmp::Ordering::Less => { - self.conversion(&mut accept, accept_meta, reject_kind, reject_width)?; + self.conversion( + &mut accept, + accept_meta, + reject_kind, + reject_width, + &mut accept_body, + )?; // The expression belongs to the `true` branch so we need to flush to // the respective body - self.emit_end(&mut accept_body); + self.emit_restart(&mut accept_body); } - // Technically there's nothing to flush but later we will need to - // add some expressions that must not be emitted so instead - // of flushing, starting and flushing again, just make sure - // everything is flushed. - std::cmp::Ordering::Equal => self.emit_end(body), + std::cmp::Ordering::Equal => {} std::cmp::Ordering::Greater => { - self.conversion(&mut reject, reject_meta, accept_kind, accept_width)?; + self.conversion( + &mut reject, + reject_meta, + accept_kind, + accept_width, + &mut reject_body, + )?; // The expression belongs to the `false` branch so we need to flush to // the respective body - self.emit_end(&mut reject_body); + self.emit_restart(&mut reject_body); } } - } else { - // Technically there's nothing to flush but later we will need to - // add some expressions that must not be emitted. - self.emit_end(body) } // We need to get the type of the resulting expression to create the local, @@ -1175,11 +1178,7 @@ impl Context { meta, ); - // Note: `Expression::LocalVariable` must not be emited so it's important - // that at this point the emitter is flushed but not started. - let local_expr = self - .expressions - .append(Expression::LocalVariable(local), meta); + let local_expr = self.add_expression(Expression::LocalVariable(local), meta, body); // Add to each body the store to the result variable accept_body.push( @@ -1208,16 +1207,14 @@ impl Context { meta, ); - // Restart the emitter - self.emit_start(); - // Note: `Expression::Load` must be emited before it's used so make // sure the emitter is active here. - self.expressions.append( + self.add_expression( Expression::Load { pointer: local_expr, }, meta, + body, ) } HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { @@ -1232,7 +1229,7 @@ impl Context { }; if let Some((kind, width)) = scalar_components(ty) { - self.implicit_conversion(frontend, &mut value, value_meta, kind, width)?; + self.implicit_conversion(frontend, &mut value, value_meta, kind, width, body)?; } self.lower_store(pointer, value, meta, body); @@ -1365,6 +1362,7 @@ impl Context { meta, ScalarKind::Sint, 4, + body, )?; array_length } @@ -1375,7 +1373,13 @@ impl Context { meta, body, ); - self.conversion(&mut array_length, meta, ScalarKind::Sint, 4)?; + self.conversion( + &mut array_length, + meta, + ScalarKind::Sint, + 4, + body, + )?; array_length } } @@ -1439,14 +1443,16 @@ impl Context { meta: Span, kind: ScalarKind, width: crate::Bytes, + body: &mut Block, ) -> Result<()> { - *expr = self.expressions.append( + *expr = self.add_expression( Expression::As { expr: *expr, kind, convert: Some(width), }, meta, + body, ); Ok(()) @@ -1459,13 +1465,14 @@ impl Context { meta: Span, kind: ScalarKind, width: crate::Bytes, + body: &mut Block, ) -> Result<()> { if let (Some(tgt_power), Some(expr_power)) = ( type_power(kind, width), self.expr_power(frontend, *expr, meta)?, ) { if tgt_power > expr_power { - self.conversion(expr, meta, kind, width)?; + self.conversion(expr, meta, kind, width, body)?; } } @@ -1479,12 +1486,13 @@ impl Context { meta: Span, kind: ScalarKind, width: crate::Bytes, + body: &mut Block, ) -> Result<()> { if let Some((expr_scalar_kind, expr_width)) = self.expr_scalar_components(frontend, *expr, meta)? { if expr_scalar_kind != kind || expr_width != width { - self.conversion(expr, meta, kind, width)?; + self.conversion(expr, meta, kind, width, body)?; } } @@ -1498,6 +1506,7 @@ impl Context { left_meta: Span, right: &mut Handle, right_meta: Span, + body: &mut Block, ) -> Result<()> { let left_components = self.expr_scalar_components(frontend, *left, left_meta)?; let right_components = self.expr_scalar_components(frontend, *right, right_meta)?; @@ -1512,11 +1521,11 @@ impl Context { ) { match left_power.cmp(&right_power) { std::cmp::Ordering::Less => { - self.conversion(left, left_meta, right_kind, right_width)?; + self.conversion(left, left_meta, right_kind, right_width, body)?; } std::cmp::Ordering::Equal => {} std::cmp::Ordering::Greater => { - self.conversion(right, right_meta, left_kind, left_width)?; + self.conversion(right, right_meta, left_kind, left_width, body)?; } } } @@ -1530,13 +1539,12 @@ impl Context { expr: &mut Handle, meta: Span, vector_size: Option, + body: &mut Block, ) -> Result<()> { let expr_type = frontend.resolve_type(self, *expr, meta)?; if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { - *expr = self - .expressions - .append(Expression::Splat { size, value: *expr }, meta) + *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta, body) } Ok(()) diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 7dee023743..f85c765050 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -79,8 +79,8 @@ impl Frontend { let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta, body); let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta, body); - ctx.implicit_splat(self, &mut reject, meta, vector_size)?; - ctx.implicit_splat(self, &mut accept, meta, vector_size)?; + ctx.implicit_splat(self, &mut reject, meta, vector_size, body)?; + ctx.implicit_splat(self, &mut accept, meta, vector_size, body)?; let h = ctx.add_expression( Expression::Select { @@ -99,7 +99,7 @@ impl Frontend { Ok(match self.module.types[ty].inner { TypeInner::Vector { size, kind, width } if vector_size.is_none() => { - ctx.forced_conversion(self, &mut value, expr_meta, kind, width)?; + ctx.forced_conversion(self, &mut value, expr_meta, kind, width, body)?; if let TypeInner::Scalar { .. } = *self.resolve_type(ctx, value, expr_meta)? { ctx.add_expression(Expression::Splat { size, value }, meta, body) @@ -186,7 +186,7 @@ impl Frontend { .get(0) .and_then(|member| scalar_components(&self.module.types[member.ty].inner)); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?; + ctx.implicit_conversion(self, &mut value, expr_meta, kind, width, body)?; } ctx.add_expression( @@ -202,7 +202,7 @@ impl Frontend { TypeInner::Array { base, .. } => { let scalar_components = scalar_components(&self.module.types[base].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?; + ctx.implicit_conversion(self, &mut value, expr_meta, kind, width, body)?; } ctx.add_expression( @@ -242,7 +242,7 @@ impl Frontend { // `Expression::As` doesn't support matrix width // casts so we need to do some extra work for casts - ctx.forced_conversion(self, &mut value, expr_meta, ScalarKind::Float, width)?; + ctx.forced_conversion(self, &mut value, expr_meta, ScalarKind::Float, width, body)?; match *self.resolve_type(ctx, value, expr_meta)? { TypeInner::Scalar { .. } => { // If a matrix is constructed with a single scalar value, then that @@ -395,7 +395,7 @@ impl Frontend { let mut components = Vec::with_capacity(size as usize); for (mut arg, expr_meta) in args.iter().copied() { - ctx.forced_conversion(self, &mut arg, expr_meta, kind, width)?; + ctx.forced_conversion(self, &mut arg, expr_meta, kind, width, body)?; if components.len() >= size as usize { break; @@ -461,7 +461,7 @@ impl Frontend { let mut flattened = Vec::with_capacity(columns as usize * rows as usize); for (mut arg, meta) in args.iter().copied() { - ctx.forced_conversion(self, &mut arg, meta, ScalarKind::Float, width)?; + ctx.forced_conversion(self, &mut arg, meta, ScalarKind::Float, width, body)?; match *self.resolve_type(ctx, arg, meta)? { TypeInner::Vector { size, .. } => { @@ -510,7 +510,7 @@ impl Frontend { for (mut arg, meta) in args.iter().copied() { let scalar_components = scalar_components(&self.module.types[base].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut arg, meta, kind, width)?; + ctx.implicit_conversion(self, &mut arg, meta, kind, width, body)?; } components.push(arg) @@ -520,7 +520,7 @@ impl Frontend { for ((mut arg, meta), member) in args.iter().copied().zip(members.iter()) { let scalar_components = scalar_components(&self.module.types[member.ty].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut arg, meta, kind, width)?; + ctx.implicit_conversion(self, &mut arg, meta, kind, width, body)?; } components.push(arg) @@ -845,7 +845,7 @@ impl Frontend { // Apply implicit conversions as needed if let Some((kind, width)) = scalar_comps { - ctx.implicit_conversion(self, &mut handle, meta, kind, width)?; + ctx.implicit_conversion(self, &mut handle, meta, kind, width, body)?; } arguments.push(handle) @@ -883,7 +883,7 @@ impl Frontend { ); if let Some((kind, width)) = proxy_write.convert { - ctx.conversion(&mut value, meta, kind, width)?; + ctx.conversion(&mut value, meta, kind, width, body)?; } ctx.emit_restart(body); diff --git a/src/front/glsl/parser/declarations.rs b/src/front/glsl/parser/declarations.rs index e99dda54f2..42b2abb5c0 100644 --- a/src/front/glsl/parser/declarations.rs +++ b/src/front/glsl/parser/declarations.rs @@ -161,7 +161,7 @@ impl<'source> ParsingContext<'source> { let scalar_components = scalar_components(&frontend.module.types[ty].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(frontend, &mut init, init_meta, kind, width)?; + ctx.implicit_conversion(frontend, &mut init, init_meta, kind, width, body)?; } Ok((init, init_meta)) @@ -233,8 +233,9 @@ impl<'source> ParsingContext<'source> { let scalar_components = scalar_components(&frontend.module.types[ty].inner); if let Some((kind, width)) = scalar_components { - ctx.ctx - .implicit_conversion(frontend, &mut expr, init_meta, kind, width)?; + ctx.ctx.implicit_conversion( + frontend, &mut expr, init_meta, kind, width, ctx.body, + )?; } meta.subsume(init_meta); From c15b45aeb030b12f84214b86e6720258eef60f00 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 27 Apr 2023 15:52:07 +0200 Subject: [PATCH 2/5] [glsl-in] move module and active body to Context --- src/front/glsl/builtins.rs | 184 +++----- src/front/glsl/context.rs | 622 +++++++++++--------------- src/front/glsl/functions.rs | 474 +++++++++----------- src/front/glsl/mod.rs | 26 +- src/front/glsl/parser.rs | 60 ++- src/front/glsl/parser/declarations.rs | 80 ++-- src/front/glsl/parser/expressions.rs | 63 ++- src/front/glsl/parser/functions.rs | 433 +++++++++--------- src/front/glsl/parser/types.rs | 54 ++- src/front/glsl/types.rs | 49 +- src/front/glsl/variables.rs | 76 ++-- 11 files changed, 946 insertions(+), 1175 deletions(-) diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index 55ff8fbc9b..811be65cce 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -7,7 +7,7 @@ use super::{ Error, ErrorKind, Frontend, Result, }; use crate::{ - BinaryOperator, Block, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle, + BinaryOperator, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle, ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module, RelationalFunction, SampleLevel, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator, VectorSize, }; @@ -1684,7 +1684,6 @@ impl MacroCall { &self, frontend: &mut Frontend, ctx: &mut Context, - body: &mut Block, args: &mut [Handle], meta: Span, ) -> Result>> { @@ -1694,14 +1693,8 @@ impl MacroCall { args[0] } MacroCall::SamplerShadow => { - sampled_to_depth( - &mut frontend.module, - ctx, - args[0], - meta, - &mut frontend.errors, - ); - frontend.invalidate_expression(ctx, args[0], meta)?; + sampled_to_depth(ctx, args[0], meta, &mut frontend.errors); + ctx.invalidate_expression(args[0], meta)?; ctx.samplers.insert(args[0], args[1]); args[0] } @@ -1714,7 +1707,7 @@ impl MacroCall { let mut coords = args[1]; if proj { - let size = match *frontend.resolve_type(ctx, coords, meta)? { + let size = match *ctx.resolve_type(coords, meta)? { TypeInner::Vector { size, .. } => size, _ => unreachable!(), }; @@ -1724,8 +1717,7 @@ impl MacroCall { index: size as u32 - 1, }, Span::default(), - body, - ); + )?; let left = if let VectorSize::Bi = size { ctx.add_expression( Expression::AccessIndex { @@ -1733,8 +1725,7 @@ impl MacroCall { index: 0, }, Span::default(), - body, - ) + )? } else { let size = match size { VectorSize::Tri => VectorSize::Bi, @@ -1743,9 +1734,8 @@ impl MacroCall { right = ctx.add_expression( Expression::Splat { size, value: right }, Span::default(), - body, - ); - ctx.vector_resize(size, coords, Span::default(), body) + )?; + ctx.vector_resize(size, coords, Span::default())? }; coords = ctx.add_expression( Expression::Binary { @@ -1754,13 +1744,11 @@ impl MacroCall { right, }, Span::default(), - body, - ); + )?; } let extra = args.get(2).copied(); - let comps = - frontend.coordinate_components(ctx, args[0], coords, extra, meta, body)?; + let comps = frontend.coordinate_components(ctx, args[0], coords, extra, meta)?; let mut num_args = 2; @@ -1807,7 +1795,7 @@ impl MacroCall { true => { let offset_arg = args[num_args]; num_args += 1; - match frontend.solve_constant(ctx, offset_arg, meta) { + match ctx.solve_constant(offset_arg, meta) { Ok(v) => Some(v), Err(e) => { frontend.errors.push(e); @@ -1826,7 +1814,7 @@ impl MacroCall { .map_or(SampleLevel::Auto, SampleLevel::Bias); } - texture_call(ctx, args[0], level, comps, texture_offset, body, meta)? + texture_call(ctx, args[0], level, comps, texture_offset, meta)? } MacroCall::TextureSize { arrayed } => { @@ -1838,20 +1826,18 @@ impl MacroCall { }, }, Span::default(), - body, - ); + )?; if arrayed { let mut components = Vec::with_capacity(4); - let size = match *frontend.resolve_type(ctx, expr, meta)? { + let size = match *ctx.resolve_type(expr, meta)? { TypeInner::Vector { size: ori_size, .. } => { for index in 0..(ori_size as u32) { components.push(ctx.add_expression( Expression::AccessIndex { base: expr, index }, Span::default(), - body, - )) + )?) } match ori_size { @@ -1871,10 +1857,9 @@ impl MacroCall { query: ImageQuery::NumLayers, }, Span::default(), - body, - )); + )?); - let ty = frontend.module.types.insert( + let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { @@ -1886,7 +1871,7 @@ impl MacroCall { Span::default(), ); - expr = ctx.add_expression(Expression::Compose { components, ty }, meta, body) + expr = ctx.add_expression(Expression::Compose { components, ty }, meta)? } ctx.add_expression( @@ -1896,12 +1881,10 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ) + )? } MacroCall::ImageLoad { multi } => { - let comps = - frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?; + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; let (sample, level) = match (multi, args.get(2)) { (_, None) => (None, None), (true, Some(&arg)) => (Some(arg), None), @@ -1916,14 +1899,12 @@ impl MacroCall { level, }, Span::default(), - body, - ) + )? } MacroCall::ImageStore => { - let comps = - frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?; - ctx.emit_restart(body); - body.push( + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; + ctx.emit_restart(); + ctx.body.push( crate::Statement::ImageStore { image: args[0], coordinate: comps.coordinate, @@ -1943,8 +1924,7 @@ impl MacroCall { arg3: args.get(3).copied(), }, Span::default(), - body, - ), + )?, mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => { let fun = match mc { MacroCall::FindLsbUint => MathFunction::FindLsb, @@ -1960,8 +1940,7 @@ impl MacroCall { arg3: None, }, Span::default(), - body, - ); + )?; ctx.add_expression( Expression::As { expr: res, @@ -1969,8 +1948,7 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ) + )? } MacroCall::BitfieldInsert => { let conv_arg_2 = ctx.add_expression( @@ -1980,8 +1958,7 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ); + )?; let conv_arg_3 = ctx.add_expression( Expression::As { expr: args[3], @@ -1989,8 +1966,7 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ); + )?; ctx.add_expression( Expression::Math { fun: MathFunction::InsertBits, @@ -2000,8 +1976,7 @@ impl MacroCall { arg3: Some(conv_arg_3), }, Span::default(), - body, - ) + )? } MacroCall::BitfieldExtract => { let conv_arg_1 = ctx.add_expression( @@ -2011,8 +1986,7 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ); + )?; let conv_arg_2 = ctx.add_expression( Expression::As { expr: args[2], @@ -2020,8 +1994,7 @@ impl MacroCall { convert: Some(4), }, Span::default(), - body, - ); + )?; ctx.add_expression( Expression::Math { fun: MathFunction::ExtractBits, @@ -2031,8 +2004,7 @@ impl MacroCall { arg3: None, }, Span::default(), - body, - ) + )? } MacroCall::Relational(fun) => ctx.add_expression( Expression::Relational { @@ -2040,13 +2012,10 @@ impl MacroCall { argument: args[0], }, Span::default(), - body, - ), - MacroCall::Unary(op) => ctx.add_expression( - Expression::Unary { op, expr: args[0] }, - Span::default(), - body, - ), + )?, + MacroCall::Unary(op) => { + ctx.add_expression(Expression::Unary { op, expr: args[0] }, Span::default())? + } MacroCall::Binary(op) => ctx.add_expression( Expression::Binary { op, @@ -2054,10 +2023,9 @@ impl MacroCall { right: args[1], }, Span::default(), - body, - ), + )?, MacroCall::Mod(size) => { - ctx.implicit_splat(frontend, &mut args[1], meta, size, body)?; + ctx.implicit_splat(&mut args[1], meta, size)?; // x - y * floor(x / y) @@ -2068,8 +2036,7 @@ impl MacroCall { right: args[1], }, Span::default(), - body, - ); + )?; let floor = ctx.add_expression( Expression::Math { fun: MathFunction::Floor, @@ -2079,8 +2046,7 @@ impl MacroCall { arg3: None, }, Span::default(), - body, - ); + )?; let mult = ctx.add_expression( Expression::Binary { op: BinaryOperator::Multiply, @@ -2088,8 +2054,7 @@ impl MacroCall { right: args[1], }, Span::default(), - body, - ); + )?; ctx.add_expression( Expression::Binary { op: BinaryOperator::Subtract, @@ -2097,11 +2062,10 @@ impl MacroCall { right: mult, }, Span::default(), - body, - ) + )? } MacroCall::Splatted(fun, size, i) => { - ctx.implicit_splat(frontend, &mut args[i], meta, size, body)?; + ctx.implicit_splat(&mut args[i], meta, size)?; ctx.add_expression( Expression::Math { @@ -2112,8 +2076,7 @@ impl MacroCall { arg3: args.get(3).copied(), }, Span::default(), - body, - ) + )? } MacroCall::MixBoolean => ctx.add_expression( Expression::Select { @@ -2122,11 +2085,10 @@ impl MacroCall { reject: args[0], }, Span::default(), - body, - ), + )?, MacroCall::Clamp(size) => { - ctx.implicit_splat(frontend, &mut args[1], meta, size, body)?; - ctx.implicit_splat(frontend, &mut args[2], meta, size, body)?; + ctx.implicit_splat(&mut args[1], meta, size)?; + ctx.implicit_splat(&mut args[2], meta, size)?; ctx.add_expression( Expression::Math { @@ -2137,8 +2099,7 @@ impl MacroCall { arg3: args.get(3).copied(), }, Span::default(), - body, - ) + )? } MacroCall::BitCast(kind) => ctx.add_expression( Expression::As { @@ -2147,8 +2108,7 @@ impl MacroCall { convert: None, }, Span::default(), - body, - ), + )?, MacroCall::Derivate(axis, ctrl) => ctx.add_expression( Expression::Derivative { axis, @@ -2156,16 +2116,16 @@ impl MacroCall { expr: args[0], }, Span::default(), - body, - ), + )?, MacroCall::Barrier => { - ctx.emit_restart(body); - body.push(crate::Statement::Barrier(crate::Barrier::all()), meta); + ctx.emit_restart(); + ctx.body + .push(crate::Statement::Barrier(crate::Barrier::all()), meta); return Ok(None); } MacroCall::SmoothStep { splatted } => { - ctx.implicit_splat(frontend, &mut args[0], meta, splatted, body)?; - ctx.implicit_splat(frontend, &mut args[1], meta, splatted, body)?; + ctx.implicit_splat(&mut args[0], meta, splatted)?; + ctx.implicit_splat(&mut args[1], meta, splatted)?; ctx.add_expression( Expression::Math { @@ -2176,8 +2136,7 @@ impl MacroCall { arg3: None, }, Span::default(), - body, - ) + )? } })) } @@ -2189,14 +2148,13 @@ fn texture_call( level: SampleLevel, comps: CoordComponents, offset: Option>, - body: &mut Block, meta: Span, ) -> Result> { if let Some(sampler) = ctx.samplers.get(&image).copied() { let mut array_index = comps.array_index; if let Some(ref mut array_index_expr) = array_index { - ctx.conversion(array_index_expr, meta, Sk::Sint, 4, body)?; + ctx.conversion(array_index_expr, meta, Sk::Sint, 4)?; } Ok(ctx.add_expression( @@ -2211,8 +2169,7 @@ fn texture_call( depth_ref: comps.depth_ref, }, meta, - body, - )) + )?) } else { Err(Error { kind: ErrorKind::SemanticError("Bad call".into()), @@ -2241,13 +2198,12 @@ impl Frontend { coord: Handle, extra: Option>, meta: Span, - body: &mut Block, ) -> Result { if let TypeInner::Image { dim, arrayed, class, - } = *self.resolve_type(ctx, image, meta)? + } = *ctx.resolve_type(image, meta)? { let image_size = match dim { Dim::D1 => None, @@ -2255,7 +2211,7 @@ impl Frontend { Dim::D3 => Some(VectorSize::Tri), Dim::Cube => Some(VectorSize::Tri), }; - let coord_size = match *self.resolve_type(ctx, coord, meta)? { + let coord_size = match *ctx.resolve_type(coord, meta)? { TypeInner::Vector { size, .. } => Some(size), _ => None, }; @@ -2267,7 +2223,7 @@ impl Frontend { let coordinate = match (image_size, coord_size) { (Some(size), Some(coord_s)) if size != coord_s => { - ctx.vector_resize(size, coord, Span::default(), body) + ctx.vector_resize(size, coord, Span::default())? } (None, Some(_)) => ctx.add_expression( Expression::AccessIndex { @@ -2275,8 +2231,7 @@ impl Frontend { index: 0, }, Span::default(), - body, - ), + )?, _ => coord, }; @@ -2289,8 +2244,7 @@ impl Frontend { Some(ctx.add_expression( Expression::AccessIndex { base: coord, index }, Span::default(), - body, - )) + )?) } else { None }; @@ -2306,8 +2260,7 @@ impl Frontend { Some(ctx.add_expression( Expression::AccessIndex { base: coord, index }, Span::default(), - body, - )) + )?) } } false => None, @@ -2338,7 +2291,6 @@ impl Frontend { /// Helper function to cast a expression holding a sampled image to a /// depth image. pub fn sampled_to_depth( - module: &mut Module, ctx: &mut Context, image: Handle, meta: Span, @@ -2346,7 +2298,7 @@ pub fn sampled_to_depth( ) { // Get the a mutable type handle of the underlying image storage let ty = match ctx[image] { - Expression::GlobalVariable(handle) => &mut module.global_variables.get_mut(handle).ty, + Expression::GlobalVariable(handle) => &mut ctx.module.global_variables.get_mut(handle).ty, Expression::FunctionArgument(i) => { // Mark the function argument as carrying a depth texture ctx.parameters_info[i as usize].depth = true; @@ -2362,7 +2314,7 @@ pub fn sampled_to_depth( } }; - match module.types[*ty].inner { + match ctx.module.types[*ty].inner { // Update the image class to depth in case it already isn't TypeInner::Image { class, @@ -2370,7 +2322,7 @@ pub fn sampled_to_depth( arrayed, } => match class { ImageClass::Sampled { multi, .. } => { - *ty = module.types.insert( + *ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Image { diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index 755f2eba9c..407619c3d2 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -44,7 +44,7 @@ impl ExprPos { } #[derive(Debug)] -pub struct Context { +pub struct Context<'a> { pub expressions: Arena, pub locals: Arena, @@ -74,10 +74,12 @@ pub struct Context { pub typifier: Typifier, emitter: Emitter, stmt_ctx: Option, + pub body: Block, + pub module: &'a mut crate::Module, } -impl Context { - pub fn new(frontend: &Frontend, body: &mut Block) -> Self { +impl<'a> Context<'a> { + pub fn new(frontend: &Frontend, module: &'a mut crate::Module) -> Result { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), @@ -92,53 +94,84 @@ impl Context { typifier: Typifier::new(), emitter: Emitter::default(), stmt_ctx: Some(StmtContext::new()), + body: Block::new(), + module, }; this.emit_start(); for &(ref name, lookup) in frontend.global_variables.iter() { - this.add_global(frontend, name, lookup, body) + this.add_global(name, lookup)? } - this + Ok(this) + } + + pub fn new_body(&mut self, cb: F) -> Result + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.new_body_with_ret(cb).map(|(b, _)| b) + } + + pub fn new_body_with_ret(&mut self, cb: F) -> Result<(Block, R)> + where + F: FnOnce(&mut Self) -> Result, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, Block::new()); + let res = cb(self); + self.emit_restart(); + let new_body = std::mem::replace(&mut self.body, old_body); + res.map(|r| (new_body, r)) + } + + pub fn with_body(&mut self, body: Block, cb: F) -> Result + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, body); + let res = cb(self); + self.emit_restart(); + let body = std::mem::replace(&mut self.body, old_body); + res.map(|_| body) } pub fn add_global( &mut self, - frontend: &Frontend, name: &str, GlobalLookup { kind, entry_arg, mutable, }: GlobalLookup, - body: &mut Block, - ) { + ) -> Result<()> { let (expr, load, constant) = match kind { GlobalLookupKind::Variable(v) => { - let span = frontend.module.global_variables.get_span(v); + let span = self.module.global_variables.get_span(v); ( - self.add_expression(Expression::GlobalVariable(v), span, body), - frontend.module.global_variables[v].space != AddressSpace::Handle, + self.add_expression(Expression::GlobalVariable(v), span)?, + self.module.global_variables[v].space != AddressSpace::Handle, None, ) } GlobalLookupKind::BlockSelect(handle, index) => { - let span = frontend.module.global_variables.get_span(handle); - let base = self.add_expression(Expression::GlobalVariable(handle), span, body); - let expr = self.add_expression(Expression::AccessIndex { base, index }, span, body); + let span = self.module.global_variables.get_span(handle); + let base = self.add_expression(Expression::GlobalVariable(handle), span)?; + let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?; ( expr, { - let ty = frontend.module.global_variables[handle].ty; + let ty = self.module.global_variables[handle].ty; - match frontend.module.types[ty].inner { + match self.module.types[ty].inner { TypeInner::Struct { ref members, .. } => { if let TypeInner::Array { size: crate::ArraySize::Dynamic, .. - } = frontend.module.types[members[index as usize].ty].inner + } = self.module.types[members[index as usize].ty].inner { false } else { @@ -152,9 +185,9 @@ impl Context { ) } GlobalLookupKind::Constant(v, ty) => { - let span = frontend.module.constants.get_span(v); + let span = self.module.constants.get_span(v); ( - self.add_expression(Expression::Constant(v), span, body), + self.add_expression(Expression::Constant(v), span)?, false, Some((v, ty)), ) @@ -170,6 +203,8 @@ impl Context { }; self.symbol_table.add(name.into(), var); + + Ok(()) } /// Starts the expression emitter @@ -182,7 +217,7 @@ impl Context { self.emitter.start(&self.expressions) } - /// Emits all the expressions captured by the emitter to the passed `body` + /// Emits all the expressions captured by the emitter to the current body /// /// # Panics /// @@ -190,36 +225,31 @@ impl Context { /// - If called twice in a row without calling [`emit_start`]. /// /// [`emit_start`]: Self::emit_start - pub fn emit_end(&mut self, body: &mut Block) { - body.extend(self.emitter.finish(&self.expressions)) + pub fn emit_end(&mut self) { + self.body.extend(self.emitter.finish(&self.expressions)) } - /// Emits all the expressions captured by the emitter to the passed `body` + /// Emits all the expressions captured by the emitter to the current body /// and starts the emitter again /// /// # Panics /// /// - If called before calling [`emit_start`][Self::emit_start]. - pub fn emit_restart(&mut self, body: &mut Block) { - self.emit_end(body); + pub fn emit_restart(&mut self) { + self.emit_end(); self.emit_start() } - pub fn add_expression( - &mut self, - expr: Expression, - meta: Span, - body: &mut Block, - ) -> Handle { + pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { let needs_pre_emit = expr.needs_pre_emit(); if needs_pre_emit { - self.emit_end(body); + self.emit_end(); } let handle = self.expressions.append(expr, meta); if needs_pre_emit { self.emit_start(); } - handle + Ok(handle) } /// Add variable to current scope @@ -246,12 +276,10 @@ impl Context { /// Add function argument to current scope pub fn add_function_arg( &mut self, - frontend: &mut Frontend, - body: &mut Block, name_meta: Option<(String, Span)>, ty: Handle, qualifier: ParameterQualifier, - ) { + ) -> Result<()> { let index = self.arguments.len(); let mut arg = FunctionArgument { name: name_meta.as_ref().map(|&(ref name, _)| name.clone()), @@ -260,14 +288,14 @@ impl Context { }; self.parameters.push(ty); - let opaque = match frontend.module.types[ty].inner { + let opaque = match self.module.types[ty].inner { TypeInner::Image { .. } | TypeInner::Sampler { .. } => true, _ => false, }; if qualifier.is_lhs() { - let span = frontend.module.types.get_span(arg.ty); - arg.ty = frontend.module.types.insert( + let span = self.module.types.get_span(arg.ty); + arg.ty = self.module.types.insert( Type { name: None, inner: TypeInner::Pointer { @@ -287,7 +315,7 @@ impl Context { }); if let Some((name, meta)) = name_meta { - let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta, body); + let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?; let mutable = qualifier != ParameterQualifier::Const && !opaque; let load = qualifier.is_lhs(); @@ -300,11 +328,11 @@ impl Context { }, meta, ); - let local_expr = self.add_expression(Expression::LocalVariable(handle), meta, body); + let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?; - self.emit_restart(body); + self.emit_restart(); - body.push( + self.body.push( Statement::Store { pointer: local_expr, value: expr, @@ -331,6 +359,8 @@ impl Context { self.symbol_table.add(name, var); } + + Ok(()) } /// Returns a [`StmtContext`](StmtContext) to be used in parsing and lowering @@ -353,9 +383,8 @@ impl Context { frontend: &mut Frontend, expr: Handle, pos: ExprPos, - body: &mut Block, ) -> Result<(Option>, Span)> { - let res = self.lower_inner(&stmt, frontend, expr, pos, body); + let res = self.lower_inner(&stmt, frontend, expr, pos); stmt.hir_exprs.clear(); self.stmt_ctx = Some(stmt); @@ -374,9 +403,8 @@ impl Context { frontend: &mut Frontend, expr: Handle, pos: ExprPos, - body: &mut Block, ) -> Result<(Handle, Span)> { - let res = self.lower_expect_inner(&stmt, frontend, expr, pos, body); + let res = self.lower_expect_inner(&stmt, frontend, expr, pos); stmt.hir_exprs.clear(); self.stmt_ctx = Some(stmt); @@ -395,9 +423,8 @@ impl Context { frontend: &mut Frontend, expr: Handle, pos: ExprPos, - body: &mut Block, ) -> Result<(Handle, Span)> { - let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos, body)?; + let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?; let expr = match maybe_expr { Some(e) => e, @@ -417,8 +444,7 @@ impl Context { pointer: Handle, value: Handle, meta: Span, - body: &mut Block, - ) { + ) -> Result<()> { if let Expression::Swizzle { size, mut vector, @@ -445,20 +471,18 @@ impl Context { index: pattern[index].index(), }, meta, - body, - ); + )?; let src = self.add_expression( Expression::AccessIndex { base: value, index: index as u32, }, meta, - body, - ); + )?; - self.emit_restart(body); + self.emit_restart(); - body.push( + self.body.push( Statement::Store { pointer: dst, value: src, @@ -467,10 +491,12 @@ impl Context { ); } } else { - self.emit_restart(body); + self.emit_restart(); - body.push(Statement::Store { pointer, value }, meta); + self.body.push(Statement::Store { pointer, value }, meta); } + + Ok(()) } /// Internal implementation of [`lower`](Self::lower) @@ -480,7 +506,6 @@ impl Context { frontend: &mut Frontend, expr: Handle, pos: ExprPos, - body: &mut Block, ) -> Result<(Option>, Span)> { let HirExpr { ref kind, meta } = stmt.hir_exprs[expr]; @@ -489,12 +514,12 @@ impl Context { let handle = match *kind { HirExprKind::Access { base, index } => { let (index, index_meta) = - self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs, body)?; + self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?; let maybe_constant_index = match pos { // Don't try to generate `AccessIndex` if in a LHS position, since it // wouldn't produce a pointer. ExprPos::Lhs => None, - _ => frontend.solve_constant(self, index, index_meta).ok(), + _ => self.solve_constant(index, index_meta).ok(), }; let base = self @@ -503,7 +528,6 @@ impl Context { frontend, base, pos.maybe_access_base(maybe_constant_index.is_some()), - body, )? .0; @@ -512,22 +536,20 @@ impl Context { Some(self.add_expression( Expression::AccessIndex { base, - index: - frontend.module.to_ctx().eval_expr_to_u32(const_expr).ok()?, + index: self.module.to_ctx().eval_expr_to_u32(const_expr).ok()?, }, meta, - body, )) }) .unwrap_or_else(|| { - self.add_expression(Expression::Access { base, index }, meta, body) - }); + self.add_expression(Expression::Access { base, index }, meta) + })?; if ExprPos::Rhs == pos { - let resolved = frontend.resolve_type(self, pointer, meta)?; + let resolved = self.resolve_type(pointer, meta)?; if resolved.pointer_space().is_some() { return Ok(( - Some(self.add_expression(Expression::Load { pointer }, meta, body)), + Some(self.add_expression(Expression::Load { pointer }, meta)?), meta, )); } @@ -536,39 +558,32 @@ impl Context { pointer } HirExprKind::Select { base, ref field } => { - let base = self.lower_expect_inner(stmt, frontend, base, pos, body)?.0; + let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0; - frontend.field_selection(self, pos, body, base, field, meta)? + frontend.field_selection(self, pos, base, field, meta)? } HirExprKind::Literal(literal) if pos != ExprPos::Lhs => { - self.add_expression(Expression::Literal(literal), meta, body) + self.add_expression(Expression::Literal(literal), meta)? } HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => { let (mut left, left_meta) = - self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs, body)?; + self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?; let (mut right, right_meta) = - self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs, body)?; + self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?; match op { - BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self - .implicit_conversion( - frontend, - &mut right, - right_meta, - ScalarKind::Uint, - 4, - body, - )?, - _ => self.binary_implicit_conversion( - frontend, &mut left, left_meta, &mut right, right_meta, body, - )?, + BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { + self.implicit_conversion(&mut right, right_meta, ScalarKind::Uint, 4)? + } + _ => self + .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?, } - frontend.typifier_grow(self, left, left_meta)?; - frontend.typifier_grow(self, right, right_meta)?; + self.typifier_grow(left, left_meta)?; + self.typifier_grow(right, right_meta)?; - let left_inner = self.typifier.get(left, &frontend.module.types); - let right_inner = self.typifier.get(right, &frontend.module.types); + let left_inner = self.typifier.get(left, &self.module.types); + let right_inner = self.typifier.get(right, &self.module.types); match (left_inner, right_inner) { ( @@ -613,13 +628,11 @@ impl Context { let left_vector = self.add_expression( Expression::AccessIndex { base: left, index }, meta, - body, - ); + )?; let right_vector = self.add_expression( Expression::AccessIndex { base: right, index }, meta, - body, - ); + )?; // Divide the vectors let column = self.add_expression( @@ -629,31 +642,25 @@ impl Context { right: right_vector, }, meta, - body, - ); + )?; components.push(column) } - // Rebuild the matrix from the divided vectors - self.add_expression( - Expression::Compose { - ty: frontend.module.types.insert( - Type { - name: None, - inner: TypeInner::Matrix { - columns: left_columns, - rows: left_rows, - width: left_width, - }, - }, - Span::default(), - ), - components, + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + width: left_width, + }, }, - meta, - body, - ) + Span::default(), + ); + + // Rebuild the matrix from the divided vectors + self.add_expression(Expression::Compose { ty, components }, meta)? } BinaryOperator::Equal | BinaryOperator::NotEqual => { // Naga IR doesn't support matrix comparisons so we need to @@ -683,13 +690,11 @@ impl Context { let left_vector = self.add_expression( Expression::AccessIndex { base: left, index }, meta, - body, - ); + )?; let right_vector = self.add_expression( Expression::AccessIndex { base: right, index }, meta, - body, - ); + )?; let argument = self.add_expression( Expression::Binary { @@ -698,8 +703,7 @@ impl Context { right: right_vector, }, meta, - body, - ); + )?; // The result of comparing two vectors is a boolean vector // so use a relational function like all to get a single @@ -707,8 +711,7 @@ impl Context { let compare = self.add_expression( Expression::Relational { fun, argument }, meta, - body, - ); + )?; // Fold the result root = Some(match root { @@ -719,19 +722,16 @@ impl Context { right, }, meta, - body, - ), + )?, None => compare, }); } root.unwrap() } - _ => self.add_expression( - Expression::Binary { left, op, right }, - meta, - body, - ), + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } } } (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { @@ -743,21 +743,12 @@ impl Context { false => (BinaryOperator::NotEqual, RelationalFunction::Any), }; - let argument = self.add_expression( - Expression::Binary { op, left, right }, - meta, - body, - ); + let argument = + self.add_expression(Expression::Binary { op, left, right }, meta)?; - self.add_expression( - Expression::Relational { fun, argument }, - meta, - body, - ) - } - _ => { - self.add_expression(Expression::Binary { left, op, right }, meta, body) + self.add_expression(Expression::Relational { fun, argument }, meta)? } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op { BinaryOperator::Add @@ -768,11 +759,8 @@ impl Context { | BinaryOperator::InclusiveOr | BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { - let scalar_vector = self.add_expression( - Expression::Splat { size, value: right }, - meta, - body, - ); + let scalar_vector = self + .add_expression(Expression::Splat { size, value: right }, meta)?; self.add_expression( Expression::Binary { @@ -781,12 +769,9 @@ impl Context { right: scalar_vector, }, meta, - body, - ) - } - _ => { - self.add_expression(Expression::Binary { left, op, right }, meta, body) + )? } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op { BinaryOperator::Add @@ -795,11 +780,8 @@ impl Context { | BinaryOperator::And | BinaryOperator::ExclusiveOr | BinaryOperator::InclusiveOr => { - let scalar_vector = self.add_expression( - Expression::Splat { size, value: left }, - meta, - body, - ); + let scalar_vector = + self.add_expression(Expression::Splat { size, value: left }, meta)?; self.add_expression( Expression::Binary { @@ -808,12 +790,9 @@ impl Context { right, }, meta, - body, - ) - } - _ => { - self.add_expression(Expression::Binary { left, op, right }, meta, body) + )? } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, }, ( &TypeInner::Scalar { @@ -852,8 +831,7 @@ impl Context { value: left, }, meta, - body, - ); + )?; let mut components = Vec::with_capacity(columns as usize); @@ -862,8 +840,7 @@ impl Context { let matrix_column = self.add_expression( Expression::AccessIndex { base: right, index }, meta, - body, - ); + )?; // Apply the operation to the splatted vector and // the column vector @@ -874,37 +851,29 @@ impl Context { right: matrix_column, }, meta, - body, - ); + )?; components.push(column) } - // Rebuild the matrix from the operation result vectors - self.add_expression( - Expression::Compose { - ty: frontend.module.types.insert( - Type { - name: None, - inner: TypeInner::Matrix { - columns, - rows, - width: left_width, - }, - }, - Span::default(), - ), - components, + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width: left_width, + }, }, - meta, - body, - ) + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? } - _ => self.add_expression( - Expression::Binary { left, op, right }, - meta, - body, - ), } } ( @@ -945,8 +914,7 @@ impl Context { value: right, }, meta, - body, - ); + )?; let mut components = Vec::with_capacity(columns as usize); @@ -955,8 +923,7 @@ impl Context { let matrix_column = self.add_expression( Expression::AccessIndex { base: left, index }, meta, - body, - ); + )?; // Apply the operation to the splatted vector and // the column vector @@ -967,48 +934,40 @@ impl Context { right: scalar_vector, }, meta, - body, - ); + )?; components.push(column) } - // Rebuild the matrix from the operation result vectors - self.add_expression( - Expression::Compose { - ty: frontend.module.types.insert( - Type { - name: None, - inner: TypeInner::Matrix { - columns, - rows, - width: left_width, - }, - }, - Span::default(), - ), - components, + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width: left_width, + }, }, - meta, - body, - ) + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? } - _ => self.add_expression( - Expression::Binary { left, op, right }, - meta, - body, - ), } } - _ => self.add_expression(Expression::Binary { left, op, right }, meta, body), + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, } } HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => { let expr = self - .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs, body)? + .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)? .0; - self.add_expression(Expression::Unary { op, expr }, meta, body) + self.add_expression(Expression::Unary { op, expr }, meta)? } HirExprKind::Variable(ref var) => match pos { ExprPos::Lhs => { @@ -1034,7 +993,7 @@ impl Context { LocalVariable { name: None, ty, - init: Some(frontend.module.const_expressions.append( + init: Some(self.module.const_expressions.append( Expression::Constant(constant), Span::default(), )), @@ -1042,11 +1001,7 @@ impl Context { Span::default(), ); - self.add_expression( - Expression::LocalVariable(local), - Span::default(), - body, - ) + self.add_expression(Expression::LocalVariable(local), Span::default())? } else { var.expr } @@ -1055,7 +1010,7 @@ impl Context { } } _ if var.load => { - self.add_expression(Expression::Load { pointer: var.expr }, meta, body) + self.add_expression(Expression::Load { pointer: var.expr }, meta)? } ExprPos::Rhs => var.expr, }, @@ -1063,7 +1018,6 @@ impl Context { let maybe_expr = frontend.function_or_constructor_call( self, stmt, - body, call.kind.clone(), &call.args, meta, @@ -1096,31 +1050,20 @@ impl Context { // Lower the condition first to the current bodyy let condition = self - .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs, body)? + .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)? .0; - // Emit all expressions since we will be adding statements to - // other bodies next - self.emit_restart(body); + let (mut accept_body, (mut accept, accept_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `true` branch + ctx.lower_expect_inner(stmt, frontend, accept, pos) + })?; - // Create the bodies for the two cases - let mut accept_body = Block::new(); - let mut reject_body = Block::new(); - - // Lower the `true` branch - let (mut accept, accept_meta) = - self.lower_expect_inner(stmt, frontend, accept, pos, &mut accept_body)?; - - // Flush the body of the `true` branch, to start emitting on the - // `false` branch - self.emit_restart(&mut accept_body); - - // Lower the `false` branch - let (mut reject, reject_meta) = - self.lower_expect_inner(stmt, frontend, reject, pos, &mut reject_body)?; - - // Flush the body of the `false` branch - self.emit_restart(&mut reject_body); + let (mut reject_body, (mut reject, reject_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `false` branch + ctx.lower_expect_inner(stmt, frontend, reject, pos) + })?; // We need to do some custom implicit conversions since the two target expressions // are in different bodies @@ -1129,36 +1072,34 @@ impl Context { Some((reject_power, reject_width, reject_kind)), ) = ( // Get the components of both branches and calculate the type power - self.expr_scalar_components(frontend, accept, accept_meta)? + self.expr_scalar_components(accept, accept_meta)? .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), - self.expr_scalar_components(frontend, reject, reject_meta)? + self.expr_scalar_components(reject, reject_meta)? .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))), ) { match accept_power.cmp(&reject_power) { std::cmp::Ordering::Less => { - self.conversion( - &mut accept, - accept_meta, - reject_kind, - reject_width, - &mut accept_body, - )?; - // The expression belongs to the `true` branch so we need to flush to - // the respective body - self.emit_restart(&mut accept_body); + accept_body = self.with_body(accept_body, |ctx| { + ctx.conversion( + &mut accept, + accept_meta, + reject_kind, + reject_width, + )?; + Ok(()) + })?; } std::cmp::Ordering::Equal => {} std::cmp::Ordering::Greater => { - self.conversion( - &mut reject, - reject_meta, - accept_kind, - accept_width, - &mut reject_body, - )?; - // The expression belongs to the `false` branch so we need to flush to - // the respective body - self.emit_restart(&mut reject_body); + reject_body = self.with_body(reject_body, |ctx| { + ctx.conversion( + &mut reject, + reject_meta, + accept_kind, + accept_width, + )?; + Ok(()) + })?; } } } @@ -1166,7 +1107,7 @@ impl Context { // We need to get the type of the resulting expression to create the local, // this must be done after implicit conversions to ensure both branches have // the same type. - let ty = frontend.resolve_type_handle(self, accept, accept_meta)?; + let ty = self.resolve_type_handle(accept, accept_meta)?; // Add the local that will hold the result of our conditional let local = self.locals.append( @@ -1178,9 +1119,9 @@ impl Context { meta, ); - let local_expr = self.add_expression(Expression::LocalVariable(local), meta, body); + let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?; - // Add to each body the store to the result variable + // Add to each the store to the result variable accept_body.push( Statement::Store { pointer: local_expr, @@ -1198,7 +1139,7 @@ impl Context { // Finally add the `If` to the main body with the `condition` we lowered // earlier and the branches we prepared. - body.push( + self.body.push( Statement::If { condition, accept: accept_body, @@ -1214,38 +1155,36 @@ impl Context { pointer: local_expr, }, meta, - body, - ) + )? } HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { let (pointer, ptr_meta) = - self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs, body)?; + self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?; let (mut value, value_meta) = - self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs, body)?; + self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?; - let ty = match *frontend.resolve_type(self, pointer, ptr_meta)? { - TypeInner::Pointer { base, .. } => &frontend.module.types[base].inner, + let ty = match *self.resolve_type(pointer, ptr_meta)? { + TypeInner::Pointer { base, .. } => &self.module.types[base].inner, ref ty => ty, }; if let Some((kind, width)) = scalar_components(ty) { - self.implicit_conversion(frontend, &mut value, value_meta, kind, width, body)?; + self.implicit_conversion(&mut value, value_meta, kind, width)?; } - self.lower_store(pointer, value, meta, body); + self.lower_store(pointer, value, meta)?; value } HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { - let (pointer, _) = - self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs, body)?; + let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?; let left = if let Expression::Swizzle { .. } = self.expressions[pointer] { pointer } else { - self.add_expression(Expression::Load { pointer }, meta, body) + self.add_expression(Expression::Load { pointer }, meta)? }; - let res = match *frontend.resolve_type(self, left, meta)? { + let res = match *self.resolve_type(left, meta)? { TypeInner::Scalar { kind, width } => { let ty = TypeInner::Scalar { kind, width }; Literal::one(kind, width).map(|i| (ty, i, None, None)) @@ -1282,18 +1221,17 @@ impl Context { } }; - let mut right = self.add_expression(Expression::Literal(literal), meta, body); + let mut right = self.add_expression(Expression::Literal(literal), meta)?; // Glsl allows pre/postfixes operations on vectors and matrices, so if the // target is either of them change the right side of the addition to be splatted // to the same size as the target, furthermore if the target is a matrix // use a composed matrix using the splatted value. if let Some(size) = rows { - right = - self.add_expression(Expression::Splat { size, value: right }, meta, body); + right = self.add_expression(Expression::Splat { size, value: right }, meta)?; if let Some(cols) = columns { - let ty = frontend.module.types.insert( + let ty = self.module.types.insert( Type { name: None, inner: ty_inner, @@ -1307,14 +1245,13 @@ impl Context { components: std::iter::repeat(right).take(cols as usize).collect(), }, meta, - body, - ); + )?; } } - let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); + let value = self.add_expression(Expression::Binary { op, left, right }, meta)?; - self.lower_store(pointer, value, meta, body); + self.lower_store(pointer, value, meta)?; if postfix { left @@ -1329,7 +1266,7 @@ impl Context { } if ExprPos::Lhs != pos => { let args = args .iter() - .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs, body)) + .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs)) .collect::>>()?; match name.as_ref() { "length" => { @@ -1341,10 +1278,8 @@ impl Context { meta, }); } - let lowered_array = self - .lower_expect_inner(stmt, frontend, object, pos, body)? - .0; - let array_type = frontend.resolve_type(self, lowered_array, meta)?; + let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0; + let array_type = self.resolve_type(lowered_array, meta)?; match *array_type { TypeInner::Array { @@ -1354,32 +1289,20 @@ impl Context { let mut array_length = self.add_expression( Expression::Literal(Literal::U32(size.get())), meta, - body, - ); + )?; self.forced_conversion( - frontend, &mut array_length, meta, ScalarKind::Sint, 4, - body, )?; array_length } // let the error be handled in type checking if it's not a dynamic array _ => { - let mut array_length = self.add_expression( - Expression::ArrayLength(lowered_array), - meta, - body, - ); - self.conversion( - &mut array_length, - meta, - ScalarKind::Sint, - 4, - body, - )?; + let mut array_length = self + .add_expression(Expression::ArrayLength(lowered_array), meta)?; + self.conversion(&mut array_length, meta, ScalarKind::Sint, 4)?; array_length } } @@ -1418,22 +1341,16 @@ impl Context { pub fn expr_scalar_components( &mut self, - frontend: &Frontend, expr: Handle, meta: Span, ) -> Result> { - let ty = frontend.resolve_type(self, expr, meta)?; + let ty = self.resolve_type(expr, meta)?; Ok(scalar_components(ty)) } - pub fn expr_power( - &mut self, - frontend: &Frontend, - expr: Handle, - meta: Span, - ) -> Result> { + pub fn expr_power(&mut self, expr: Handle, meta: Span) -> Result> { Ok(self - .expr_scalar_components(frontend, expr, meta)? + .expr_scalar_components(expr, meta)? .and_then(|(kind, width)| type_power(kind, width))) } @@ -1443,7 +1360,6 @@ impl Context { meta: Span, kind: ScalarKind, width: crate::Bytes, - body: &mut Block, ) -> Result<()> { *expr = self.add_expression( Expression::As { @@ -1452,27 +1368,23 @@ impl Context { convert: Some(width), }, meta, - body, - ); + )?; Ok(()) } pub fn implicit_conversion( &mut self, - frontend: &Frontend, expr: &mut Handle, meta: Span, kind: ScalarKind, width: crate::Bytes, - body: &mut Block, ) -> Result<()> { - if let (Some(tgt_power), Some(expr_power)) = ( - type_power(kind, width), - self.expr_power(frontend, *expr, meta)?, - ) { + if let (Some(tgt_power), Some(expr_power)) = + (type_power(kind, width), self.expr_power(*expr, meta)?) + { if tgt_power > expr_power { - self.conversion(expr, meta, kind, width, body)?; + self.conversion(expr, meta, kind, width)?; } } @@ -1481,18 +1393,14 @@ impl Context { pub fn forced_conversion( &mut self, - frontend: &Frontend, expr: &mut Handle, meta: Span, kind: ScalarKind, width: crate::Bytes, - body: &mut Block, ) -> Result<()> { - if let Some((expr_scalar_kind, expr_width)) = - self.expr_scalar_components(frontend, *expr, meta)? - { + if let Some((expr_scalar_kind, expr_width)) = self.expr_scalar_components(*expr, meta)? { if expr_scalar_kind != kind || expr_width != width { - self.conversion(expr, meta, kind, width, body)?; + self.conversion(expr, meta, kind, width)?; } } @@ -1501,15 +1409,13 @@ impl Context { pub fn binary_implicit_conversion( &mut self, - frontend: &Frontend, left: &mut Handle, left_meta: Span, right: &mut Handle, right_meta: Span, - body: &mut Block, ) -> Result<()> { - let left_components = self.expr_scalar_components(frontend, *left, left_meta)?; - let right_components = self.expr_scalar_components(frontend, *right, right_meta)?; + let left_components = self.expr_scalar_components(*left, left_meta)?; + let right_components = self.expr_scalar_components(*right, right_meta)?; if let ( Some((left_power, left_width, left_kind)), @@ -1521,11 +1427,11 @@ impl Context { ) { match left_power.cmp(&right_power) { std::cmp::Ordering::Less => { - self.conversion(left, left_meta, right_kind, right_width, body)?; + self.conversion(left, left_meta, right_kind, right_width)?; } std::cmp::Ordering::Equal => {} std::cmp::Ordering::Greater => { - self.conversion(right, right_meta, left_kind, left_width, body)?; + self.conversion(right, right_meta, left_kind, left_width)?; } } } @@ -1535,16 +1441,14 @@ impl Context { pub fn implicit_splat( &mut self, - frontend: &Frontend, expr: &mut Handle, meta: Span, vector_size: Option, - body: &mut Block, ) -> Result<()> { - let expr_type = frontend.resolve_type(self, *expr, meta)?; + let expr_type = self.resolve_type(*expr, meta)?; if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { - *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta, body) + *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)? } Ok(()) @@ -1555,8 +1459,7 @@ impl Context { size: VectorSize, vector: Handle, meta: Span, - body: &mut Block, - ) -> Handle { + ) -> Result> { self.add_expression( Expression::Swizzle { size, @@ -1564,12 +1467,11 @@ impl Context { pattern: crate::SwizzleComponent::XYZW, }, meta, - body, ) } } -impl Index> for Context { +impl Index> for Context<'_> { type Output = Expression; fn index(&self, index: Handle) -> &Self::Output { diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index f85c765050..0614943874 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -7,9 +7,9 @@ use super::{ Frontend, Result, }; use crate::{ - front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Arena, Block, - EntryPoint, Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, - LocalVariable, ScalarKind, Span, Statement, StructMember, Type, TypeInner, + front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Block, EntryPoint, + Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, LocalVariable, + ScalarKind, Span, Statement, StructMember, Type, TypeInner, }; use std::iter; @@ -28,27 +28,25 @@ impl Frontend { &mut self, ctx: &mut Context, stmt: &StmtContext, - body: &mut Block, fc: FunctionCallKind, raw_args: &[Handle], meta: Span, ) -> Result>> { let args: Vec<_> = raw_args .iter() - .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs, body)) + .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs)) .collect::>()?; match fc { FunctionCallKind::TypeConstructor(ty) => { if args.len() == 1 { - self.constructor_single(ctx, body, ty, args[0], meta) - .map(Some) + self.constructor_single(ctx, ty, args[0], meta).map(Some) } else { - self.constructor_many(ctx, body, ty, args, meta).map(Some) + self.constructor_many(ctx, ty, args, meta).map(Some) } } FunctionCallKind::Function(name) => { - self.function_call(ctx, stmt, body, name, args, raw_args, meta) + self.function_call(ctx, stmt, name, args, raw_args, meta) } } } @@ -56,31 +54,29 @@ impl Frontend { fn constructor_single( &mut self, ctx: &mut Context, - body: &mut Block, ty: Handle, (mut value, expr_meta): (Handle, Span), meta: Span, ) -> Result> { - let expr_type = self.resolve_type(ctx, value, expr_meta)?; + let expr_type = ctx.resolve_type(value, expr_meta)?; let vector_size = match *expr_type { TypeInner::Vector { size, .. } => Some(size), _ => None, }; + let expr_is_bool = expr_type.scalar_kind() == Some(ScalarKind::Bool); + // Special case: if casting from a bool, we need to use Select and not As. - match self.module.types[ty].inner.scalar_kind() { - Some(result_scalar_kind) - if expr_type.scalar_kind() == Some(ScalarKind::Bool) - && result_scalar_kind != ScalarKind::Bool => - { + match ctx.module.types[ty].inner.scalar_kind() { + Some(result_scalar_kind) if expr_is_bool && result_scalar_kind != ScalarKind::Bool => { let l0 = Literal::zero(result_scalar_kind, 4).unwrap(); let l1 = Literal::one(result_scalar_kind, 4).unwrap(); - let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta, body); - let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta, body); + let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta)?; + let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta)?; - ctx.implicit_splat(self, &mut reject, meta, vector_size, body)?; - ctx.implicit_splat(self, &mut accept, meta, vector_size, body)?; + ctx.implicit_splat(&mut reject, meta, vector_size)?; + ctx.implicit_splat(&mut accept, meta, vector_size)?; let h = ctx.add_expression( Expression::Select { @@ -89,24 +85,22 @@ impl Frontend { condition: value, }, expr_meta, - body, - ); + )?; return Ok(h); } _ => {} } - Ok(match self.module.types[ty].inner { + Ok(match ctx.module.types[ty].inner { TypeInner::Vector { size, kind, width } if vector_size.is_none() => { - ctx.forced_conversion(self, &mut value, expr_meta, kind, width, body)?; + ctx.forced_conversion(&mut value, expr_meta, kind, width)?; - if let TypeInner::Scalar { .. } = *self.resolve_type(ctx, value, expr_meta)? { - ctx.add_expression(Expression::Splat { size, value }, meta, body) + if let TypeInner::Scalar { .. } = *ctx.resolve_type(value, expr_meta)? { + ctx.add_expression(Expression::Splat { size, value }, meta)? } else { self.vector_constructor( ctx, - body, ty, size, kind, @@ -119,7 +113,7 @@ impl Frontend { TypeInner::Scalar { kind, width } => { let mut expr = value; if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } = - *self.resolve_type(ctx, value, expr_meta)? + *ctx.resolve_type(value, expr_meta)? { expr = ctx.add_expression( Expression::AccessIndex { @@ -127,19 +121,17 @@ impl Frontend { index: 0, }, meta, - body, - ); + )?; } - if let TypeInner::Matrix { .. } = *self.resolve_type(ctx, value, expr_meta)? { + if let TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? { expr = ctx.add_expression( Expression::AccessIndex { base: expr, index: 0, }, meta, - body, - ); + )?; } ctx.add_expression( @@ -149,12 +141,11 @@ impl Frontend { convert: Some(width), }, meta, - body, - ) + )? } TypeInner::Vector { size, kind, width } => { if vector_size.map_or(true, |s| s != size) { - value = ctx.vector_resize(size, value, expr_meta, body); + value = ctx.vector_resize(size, value, expr_meta)?; } ctx.add_expression( @@ -164,29 +155,19 @@ impl Frontend { convert: Some(width), }, meta, - body, - ) + )? } TypeInner::Matrix { columns, rows, width, - } => self.matrix_one_arg( - ctx, - body, - ty, - columns, - rows, - width, - (value, expr_meta), - meta, - )?, + } => self.matrix_one_arg(ctx, ty, columns, rows, width, (value, expr_meta), meta)?, TypeInner::Struct { ref members, .. } => { let scalar_components = members .get(0) - .and_then(|member| scalar_components(&self.module.types[member.ty].inner)); + .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner)); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut value, expr_meta, kind, width, body)?; + ctx.implicit_conversion(&mut value, expr_meta, kind, width)?; } ctx.add_expression( @@ -195,14 +176,13 @@ impl Frontend { components: vec![value], }, meta, - body, - ) + )? } TypeInner::Array { base, .. } => { - let scalar_components = scalar_components(&self.module.types[base].inner); + let scalar_components = scalar_components(&ctx.module.types[base].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut value, expr_meta, kind, width, body)?; + ctx.implicit_conversion(&mut value, expr_meta, kind, width)?; } ctx.add_expression( @@ -211,8 +191,7 @@ impl Frontend { components: vec![value], }, meta, - body, - ) + )? } _ => { self.errors.push(Error { @@ -229,7 +208,6 @@ impl Frontend { fn matrix_one_arg( &mut self, ctx: &mut Context, - body: &mut Block, ty: Handle, columns: crate::VectorSize, rows: crate::VectorSize, @@ -242,13 +220,13 @@ impl Frontend { // `Expression::As` doesn't support matrix width // casts so we need to do some extra work for casts - ctx.forced_conversion(self, &mut value, expr_meta, ScalarKind::Float, width, body)?; - match *self.resolve_type(ctx, value, expr_meta)? { + ctx.forced_conversion(&mut value, expr_meta, ScalarKind::Float, width)?; + match *ctx.resolve_type(value, expr_meta)? { TypeInner::Scalar { .. } => { // If a matrix is constructed with a single scalar value, then that // value is used to initialize all the values along the diagonal of // the matrix; the rest are given zeros. - let vector_ty = self.module.types.insert( + let vector_ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { @@ -261,7 +239,7 @@ impl Frontend { ); let zero_literal = Literal::zero(ScalarKind::Float, width).unwrap(); - let zero = ctx.add_expression(Expression::Literal(zero_literal), meta, body); + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; for i in 0..columns as u32 { components.push( @@ -276,8 +254,7 @@ impl Frontend { .collect(), }, meta, - body, - ), + )?, ) } } @@ -294,10 +271,10 @@ impl Frontend { let zero_literal = Literal::zero(ScalarKind::Float, width).unwrap(); let one_literal = Literal::one(ScalarKind::Float, width).unwrap(); - let zero = ctx.add_expression(Expression::Literal(zero_literal), meta, body); - let one = ctx.add_expression(Expression::Literal(one_literal), meta, body); + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; + let one = ctx.add_expression(Expression::Literal(one_literal), meta)?; - let vector_ty = self.module.types.insert( + let vector_ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { @@ -319,8 +296,7 @@ impl Frontend { index: i, }, meta, - body, - ); + )?; components.push(match ori_rows.cmp(&rows) { Ordering::Less => { @@ -333,15 +309,14 @@ impl Frontend { index: r, }, meta, - body, ) } else if r == i { - one + Ok(one) } else { - zero + Ok(zero) } }) - .collect(); + .collect::>()?; ctx.add_expression( Expression::Compose { @@ -349,11 +324,10 @@ impl Frontend { components, }, meta, - body, - ) + )? } Ordering::Equal => vector, - Ordering::Greater => ctx.vector_resize(rows, vector, meta, body), + Ordering::Greater => ctx.vector_resize(rows, vector, meta)?, }) } else { let compose_expr = Expression::Compose { @@ -366,7 +340,7 @@ impl Frontend { .collect(), }; - let vec = ctx.add_expression(compose_expr, meta, body); + let vec = ctx.add_expression(compose_expr, meta)?; components.push(vec) } @@ -377,14 +351,13 @@ impl Frontend { } } - Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body)) + ctx.add_expression(Expression::Compose { ty, components }, meta) } #[allow(clippy::too_many_arguments)] fn vector_constructor( &mut self, ctx: &mut Context, - body: &mut Block, ty: Handle, size: crate::VectorSize, kind: ScalarKind, @@ -395,13 +368,13 @@ impl Frontend { let mut components = Vec::with_capacity(size as usize); for (mut arg, expr_meta) in args.iter().copied() { - ctx.forced_conversion(self, &mut arg, expr_meta, kind, width, body)?; + ctx.forced_conversion(&mut arg, expr_meta, kind, width)?; if components.len() >= size as usize { break; } - match *self.resolve_type(ctx, arg, expr_meta)? { + match *ctx.resolve_type(arg, expr_meta)? { TypeInner::Scalar { .. } => components.push(arg), TypeInner::Matrix { rows, columns, .. } => { components.reserve(rows as usize * columns as usize); @@ -412,14 +385,12 @@ impl Frontend { index: c, }, expr_meta, - body, - ); + )?; for r in 0..(rows as u32) { components.push(ctx.add_expression( Expression::AccessIndex { base, index: r }, expr_meta, - body, - )) + )?) } } } @@ -429,8 +400,7 @@ impl Frontend { components.push(ctx.add_expression( Expression::AccessIndex { base: arg, index }, expr_meta, - body, - )) + )?) } } _ => components.push(arg), @@ -439,20 +409,19 @@ impl Frontend { components.truncate(size as usize); - Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body)) + ctx.add_expression(Expression::Compose { ty, components }, meta) } fn constructor_many( &mut self, ctx: &mut Context, - body: &mut Block, ty: Handle, args: Vec<(Handle, Span)>, meta: Span, ) -> Result> { let mut components = Vec::with_capacity(args.len()); - match self.module.types[ty].inner { + let struct_member_data = match ctx.module.types[ty].inner { TypeInner::Matrix { columns, rows, @@ -461,9 +430,9 @@ impl Frontend { let mut flattened = Vec::with_capacity(columns as usize * rows as usize); for (mut arg, meta) in args.iter().copied() { - ctx.forced_conversion(self, &mut arg, meta, ScalarKind::Float, width, body)?; + ctx.forced_conversion(&mut arg, meta, ScalarKind::Float, width)?; - match *self.resolve_type(ctx, arg, meta)? { + match *ctx.resolve_type(arg, meta)? { TypeInner::Vector { size, .. } => { for i in 0..(size as u32) { flattened.push(ctx.add_expression( @@ -472,15 +441,14 @@ impl Frontend { index: i, }, meta, - body, - )) + )?) } } _ => flattened.push(arg), } } - let ty = self.module.types.insert( + let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { @@ -499,42 +467,51 @@ impl Frontend { components: Vec::from(chunk), }, meta, - body, - )) + )?) } + None } TypeInner::Vector { size, kind, width } => { - return self.vector_constructor(ctx, body, ty, size, kind, width, &args, meta) + return self.vector_constructor(ctx, ty, size, kind, width, &args, meta) } TypeInner::Array { base, .. } => { for (mut arg, meta) in args.iter().copied() { - let scalar_components = scalar_components(&self.module.types[base].inner); + let scalar_components = scalar_components(&ctx.module.types[base].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut arg, meta, kind, width, body)?; - } - - components.push(arg) - } - } - TypeInner::Struct { ref members, .. } => { - for ((mut arg, meta), member) in args.iter().copied().zip(members.iter()) { - let scalar_components = scalar_components(&self.module.types[member.ty].inner); - if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(self, &mut arg, meta, kind, width, body)?; + ctx.implicit_conversion(&mut arg, meta, kind, width)?; } components.push(arg) } + None } + TypeInner::Struct { ref members, .. } => Some( + members + .iter() + .map(|member| scalar_components(&ctx.module.types[member.ty].inner)) + .collect::>(), + ), _ => { return Err(Error { kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()), meta, }) } + }; + + if let Some(struct_member_data) = struct_member_data { + for ((mut arg, meta), scalar_components) in + args.iter().copied().zip(struct_member_data.iter().copied()) + { + if let Some((kind, width)) = scalar_components { + ctx.implicit_conversion(&mut arg, meta, kind, width)?; + } + + components.push(arg) + } } - Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body)) + ctx.add_expression(Expression::Compose { ty, components }, meta) } #[allow(clippy::too_many_arguments)] @@ -542,7 +519,6 @@ impl Frontend { &mut self, ctx: &mut Context, stmt: &StmtContext, - body: &mut Block, name: String, args: Vec<(Handle, Span)>, raw_args: &[Handle], @@ -551,13 +527,13 @@ impl Frontend { // Grow the typifier to be able to index it later without needing // to hold the context mutably for &(expr, span) in args.iter() { - self.typifier_grow(ctx, expr, span)?; + ctx.typifier_grow(expr, span)?; } // Check if the passed arguments require any special variations let mut variations = builtin_required_variations( args.iter() - .map(|&(expr, _)| ctx.typifier.get(expr, &self.module.types)), + .map(|&(expr, _)| ctx.typifier.get(expr, &ctx.module.types)), ); // Initiate the declaration if it wasn't previously initialized and inject builtins @@ -565,7 +541,7 @@ impl Frontend { variations |= BuiltinVariations::STANDARD; Default::default() }); - inject_builtin(declaration, &mut self.module, &name, variations); + inject_builtin(declaration, ctx.module, &name, variations); // Borrow again but without mutability, at this point a declaration is guaranteed let declaration = self.lookup_function.get(&name).unwrap(); @@ -609,18 +585,14 @@ impl Frontend { // If the image is used in the overload as a depth texture convert it // before comparing, otherwise exact matches wouldn't be reported if parameter_info.depth { - sampled_to_depth( - &mut self.module, - ctx, - call_argument.0, - call_argument.1, - &mut self.errors, - ); - self.invalidate_expression(ctx, call_argument.0, call_argument.1)? + sampled_to_depth(ctx, call_argument.0, call_argument.1, &mut self.errors); + ctx.invalidate_expression(call_argument.0, call_argument.1)? } - let overload_param_ty = &self.module.types[*overload_parameter].inner; - let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?; + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[*overload_parameter].inner; + let call_arg_ty = ctx.typifier.get(call_argument.0, &ctx.module.types); log::trace!( "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}", @@ -823,12 +795,11 @@ impl Frontend { .zip(¶meters) { let (mut handle, meta) = - ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos(), body)?; + ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos())?; if parameter_info.qualifier.is_lhs() { self.process_lhs_argument( ctx, - body, meta, *parameter, parameter_info, @@ -841,11 +812,11 @@ impl Frontend { continue; } - let scalar_comps = scalar_components(&self.module.types[*parameter].inner); + let scalar_comps = scalar_components(&ctx.module.types[*parameter].inner); // Apply implicit conversions as needed if let Some((kind, width)) = scalar_comps { - ctx.implicit_conversion(self, &mut handle, meta, kind, width, body)?; + ctx.implicit_conversion(&mut handle, meta, kind, width)?; } arguments.push(handle) @@ -853,15 +824,15 @@ impl Frontend { match kind { FunctionKind::Call(function) => { - ctx.emit_end(body); + ctx.emit_end(); let result = if !is_void { - Some(ctx.add_expression(Expression::CallResult(function), meta, body)) + Some(ctx.add_expression(Expression::CallResult(function), meta)?) } else { None }; - body.push( + ctx.body.push( crate::Statement::Call { function, arguments, @@ -879,16 +850,15 @@ impl Frontend { pointer: proxy_write.value, }, meta, - body, - ); + )?; if let Some((kind, width)) = proxy_write.convert { - ctx.conversion(&mut value, meta, kind, width, body)?; + ctx.conversion(&mut value, meta, kind, width)?; } - ctx.emit_restart(body); + ctx.emit_restart(); - body.push( + ctx.body.push( Statement::Store { pointer: proxy_write.target, value, @@ -899,9 +869,7 @@ impl Frontend { Ok(result) } - FunctionKind::Macro(builtin) => { - builtin.call(self, ctx, body, arguments.as_mut_slice(), meta) - } + FunctionKind::Macro(builtin) => builtin.call(self, ctx, arguments.as_mut_slice(), meta), } } @@ -911,7 +879,6 @@ impl Frontend { fn process_lhs_argument( &mut self, ctx: &mut Context, - body: &mut Block, meta: Span, parameter_ty: Handle, parameter_info: &ParameterInfo, @@ -920,7 +887,7 @@ impl Frontend { proxy_writes: &mut Vec, arguments: &mut Vec>, ) -> Result<()> { - let original_ty = self.resolve_type(ctx, original, meta)?; + let original_ty = ctx.resolve_type(original, meta)?; let original_pointer_space = original_ty.pointer_space(); // The type of a possible spill variable needed for a proxy write @@ -928,7 +895,7 @@ impl Frontend { // If the argument is to be passed as a pointer but the type of the // expression returns a vector it must mean that it was for example // swizzled and it must be spilled into a local before calling - TypeInner::Vector { size, kind, width } => Some(self.module.types.insert( + TypeInner::Vector { size, kind, width } => Some(ctx.module.types.insert( Type { name: None, inner: TypeInner::Vector { size, kind, width }, @@ -951,7 +918,7 @@ impl Frontend { }; Some( - self.module + ctx.module .types .insert(Type { name: None, inner }, Span::default()), ) @@ -962,17 +929,15 @@ impl Frontend { // Since the original expression might be a pointer and we want a value // for the proxy writes, we might need to load the pointer. let value = if original_pointer_space.is_some() { - ctx.add_expression( - Expression::Load { pointer: original }, - Span::default(), - body, - ) + ctx.add_expression(Expression::Load { pointer: original }, Span::default())? } else { original }; - let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?; - let overload_param_ty = &self.module.types[parameter_ty].inner; + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[parameter_ty].inner; + let call_arg_ty = ctx.typifier.get(call_argument.0, &ctx.module.types); let needs_conversion = call_arg_ty != overload_param_ty; let arg_scalar_comps = scalar_components(call_arg_ty); @@ -995,12 +960,12 @@ impl Frontend { Span::default(), ); let spill_expr = - ctx.add_expression(Expression::LocalVariable(spill_var), Span::default(), body); + ctx.add_expression(Expression::LocalVariable(spill_var), Span::default())?; // If the argument is also copied in we must store the value of the // original variable to the spill variable. if let ParameterQualifier::InOut = parameter_info.qualifier { - body.push( + ctx.body.push( Statement::Store { pointer: spill_expr, value, @@ -1037,8 +1002,7 @@ impl Frontend { index: *component as u32, }, Span::default(), - body, - ); + )?; let spill_component = ctx.add_expression( Expression::AccessIndex { @@ -1046,8 +1010,7 @@ impl Frontend { index: i as u32, }, Span::default(), - body, - ); + )?; proxy_writes.push(ProxyWrite { target: original, @@ -1071,32 +1034,28 @@ impl Frontend { pub(crate) fn add_function( &mut self, - ctx: Context, + mut ctx: Context, name: String, result: Option, - mut body: Block, meta: Span, ) { - ensure_block_returns(&mut body); + ensure_block_returns(&mut ctx.body); let void = result.is_none(); - let &mut Frontend { - ref mut lookup_function, - ref mut module, - .. - } = self; - // Check if the passed arguments require any special variations - let mut variations = - builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner)); + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); // Initiate the declaration if it wasn't previously initialized and inject builtins - let declaration = lookup_function.entry(name.clone()).or_insert_with(|| { + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { variations |= BuiltinVariations::STANDARD; Default::default() }); - inject_builtin(declaration, module, &name, variations); + inject_builtin(declaration, ctx.module, &name, variations); let Context { expressions, @@ -1104,6 +1063,8 @@ impl Frontend { arguments, parameters, parameters_info, + body, + module, .. } = ctx; @@ -1141,7 +1102,7 @@ impl Frontend { decl.defined = true; decl.parameters_info = parameters_info; match decl.kind { - FunctionKind::Call(handle) => *self.module.functions.get_mut(handle) = function, + FunctionKind::Call(handle) => *module.functions.get_mut(handle) = function, FunctionKind::Macro(_) => { let handle = module.functions.append(function, meta); decl.kind = FunctionKind::Call(handle) @@ -1170,27 +1131,25 @@ impl Frontend { ) { let void = result.is_none(); - let &mut Frontend { - ref mut lookup_function, - ref mut module, - .. - } = self; - // Check if the passed arguments require any special variations - let mut variations = - builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner)); + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); // Initiate the declaration if it wasn't previously initialized and inject builtins - let declaration = lookup_function.entry(name.clone()).or_insert_with(|| { + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { variations |= BuiltinVariations::STANDARD; Default::default() }); - inject_builtin(declaration, module, &name, variations); + inject_builtin(declaration, ctx.module, &name, variations); let Context { arguments, parameters, parameters_info, + module, .. } = ctx; @@ -1238,27 +1197,26 @@ impl Frontend { /// recursively /// /// The passed arguments to the callback are: + /// - The ctx /// - The name /// - The pointer expression to the global storage /// - The handle to the type of the entry point argument /// - The binding of the entry point argument - /// - The expression arena fn arg_type_walker( - &self, + ctx: &mut Context, name: Option, binding: crate::Binding, pointer: Handle, ty: Handle, - expressions: &mut Arena, f: &mut impl FnMut( + &mut Context, Option, Handle, Handle, crate::Binding, - &mut Arena, ), - ) { - match self.module.types[ty].inner { + ) -> Result<()> { + match ctx.module.types[ty].inner { // TODO: Better error reporting // right now we just don't walk the array if the size isn't known at // compile time and let validation catch it @@ -1269,11 +1227,11 @@ impl Frontend { } => { let mut location = match binding { crate::Binding::Location { location, .. } => location, - crate::Binding::BuiltIn(_) => return, + crate::Binding::BuiltIn(_) => return Ok(()), }; let interpolation = - self.module.types[base] + ctx.module.types[base] .inner .scalar_kind() .map(|kind| match kind { @@ -1282,13 +1240,13 @@ impl Frontend { }); for index in 0..size.get() { - let member_pointer = expressions.append( + let member_pointer = ctx.add_expression( Expression::AccessIndex { base: pointer, index, }, crate::Span::default(), - ); + )?; let binding = crate::Binding::Location { location, @@ -1298,35 +1256,28 @@ impl Frontend { }; location += 1; - self.arg_type_walker( - name.clone(), - binding, - member_pointer, - base, - expressions, - f, - ) + Self::arg_type_walker(ctx, name.clone(), binding, member_pointer, base, f)? } } TypeInner::Struct { ref members, .. } => { let mut location = match binding { crate::Binding::Location { location, .. } => location, - crate::Binding::BuiltIn(_) => return, + crate::Binding::BuiltIn(_) => return Ok(()), }; - for (i, member) in members.iter().enumerate() { - let member_pointer = expressions.append( + for (i, member) in members.clone().into_iter().enumerate() { + let member_pointer = ctx.add_expression( Expression::AccessIndex { base: pointer, index: i as u32, }, crate::Span::default(), - ); + )?; - let binding = match member.binding.clone() { + let binding = match member.binding { Some(binding) => binding, None => { - let interpolation = self.module.types[member.ty] + let interpolation = ctx.module.types[member.ty] .inner .scalar_kind() .map(|kind| match kind { @@ -1344,51 +1295,51 @@ impl Frontend { } }; - self.arg_type_walker( - member.name.clone(), - binding, - member_pointer, - member.ty, - expressions, - f, - ) + Self::arg_type_walker(ctx, member.name, binding, member_pointer, member.ty, f)? } } - _ => f(name, pointer, ty, binding, expressions), + _ => f(ctx, name, pointer, ty, binding), } + + Ok(()) } pub(crate) fn add_entry_point( &mut self, function: Handle, - global_init_body: Block, - mut expressions: Arena, - ) { + mut ctx: Context, + ) -> Result<()> { let mut arguments = Vec::new(); - let mut body = Block::with_capacity( + + let body = Block::with_capacity( // global init body - global_init_body.len() + - // prologue and epilogue - self.entry_args.len() * 2 - // Call, Emit for composing struct and return - + 3, + ctx.body.len() + + // prologue and epilogue + self.entry_args.len() * 2 + // Call, Emit for composing struct and return + + 3, ); + let global_init_body = std::mem::replace(&mut ctx.body, body); + for arg in self.entry_args.iter() { if arg.storage != StorageQualifier::Input { continue; } - let pointer = - expressions.append(Expression::GlobalVariable(arg.handle), Default::default()); + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); - self.arg_type_walker( + let ty = ctx.module.global_variables[arg.handle].ty; + + Self::arg_type_walker( + &mut ctx, arg.name.clone(), arg.binding.clone(), pointer, - self.module.global_variables[arg.handle].ty, - &mut expressions, - &mut |name, pointer, ty, binding, expressions| { + ty, + &mut |ctx, name, pointer, ty, binding| { let idx = arguments.len() as u32; arguments.push(FunctionArgument { @@ -1397,16 +1348,18 @@ impl Frontend { binding: Some(binding), }); - let value = - expressions.append(Expression::FunctionArgument(idx), Default::default()); - body.push(Statement::Store { pointer, value }, Default::default()); + let value = ctx + .expressions + .append(Expression::FunctionArgument(idx), Default::default()); + ctx.body + .push(Statement::Store { pointer, value }, Default::default()); }, - ) + )? } - body.extend_block(global_init_body); + ctx.body.extend_block(global_init_body); - body.push( + ctx.body.push( Statement::Call { function, arguments: Vec::new(), @@ -1424,16 +1377,19 @@ impl Frontend { continue; } - let pointer = - expressions.append(Expression::GlobalVariable(arg.handle), Default::default()); + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); + + let ty = ctx.module.global_variables[arg.handle].ty; - self.arg_type_walker( + Self::arg_type_walker( + &mut ctx, arg.name.clone(), arg.binding.clone(), pointer, - self.module.global_variables[arg.handle].ty, - &mut expressions, - &mut |name, pointer, ty, binding, expressions| { + ty, + &mut |ctx, name, pointer, ty, binding| { members.push(StructMember { name, ty, @@ -1441,21 +1397,23 @@ impl Frontend { offset: span, }); - span += self.module.types[ty].inner.size(self.module.to_ctx()); + span += ctx.module.types[ty].inner.size(ctx.module.to_ctx()); - let len = expressions.len(); - let load = expressions.append(Expression::Load { pointer }, Default::default()); - body.push( - Statement::Emit(expressions.range_from(len)), + let len = ctx.expressions.len(); + let load = ctx + .expressions + .append(Expression::Load { pointer }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), Default::default(), ); components.push(load) }, - ) + )? } let (ty, value) = if !components.is_empty() { - let ty = self.module.types.insert( + let ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Struct { members, span }, @@ -1463,11 +1421,12 @@ impl Frontend { Default::default(), ); - let len = expressions.len(); - let res = - expressions.append(Expression::Compose { ty, components }, Default::default()); - body.push( - Statement::Emit(expressions.range_from(len)), + let len = ctx.expressions.len(); + let res = ctx + .expressions + .append(Expression::Compose { ty, components }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), Default::default(), ); @@ -1476,9 +1435,14 @@ impl Frontend { (None, None) }; - body.push(Statement::Return { value }, Default::default()); + ctx.body + .push(Statement::Return { value }, Default::default()); - self.module.entry_points.push(EntryPoint { + let Context { + body, expressions, .. + } = ctx; + + ctx.module.entry_points.push(EntryPoint { name: "main".to_string(), stage: self.meta.stage, early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) @@ -1492,6 +1456,8 @@ impl Frontend { ..Default::default() }, }); + + Ok(()) } } diff --git a/src/front/glsl/mod.rs b/src/front/glsl/mod.rs index a1aa4622d3..f8f554bf2d 100644 --- a/src/front/glsl/mod.rs +++ b/src/front/glsl/mod.rs @@ -175,8 +175,6 @@ pub struct Frontend { layouter: Layouter, errors: Vec, - - module: Module, } impl Frontend { @@ -188,10 +186,6 @@ impl Frontend { self.global_variables.clear(); self.entry_args.clear(); self.layouter.clear(); - - // This is necessary because if the last parsing errored out, the module - // wouldn't have been taken - self.module = Module::default(); } /// Parses a shader either outputting a shader [`Module`](Module) or a list @@ -208,14 +202,18 @@ impl Frontend { let lexer = lex::Lexer::new(source, &options.defines); let mut ctx = ParsingContext::new(lexer); - if let Err(e) = ctx.parse(self) { - self.errors.push(e); - } - - if self.errors.is_empty() { - Ok(std::mem::take(&mut self.module)) - } else { - Err(std::mem::take(&mut self.errors)) + match ctx.parse(self) { + Ok(module) => { + if self.errors.is_empty() { + Ok(module) + } else { + Err(std::mem::take(&mut self.errors)) + } + } + Err(e) => { + self.errors.push(e); + Err(std::mem::take(&mut self.errors)) + } } } diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 7db30ea80f..1a3b5a9086 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -9,7 +9,7 @@ use super::{ variables::{GlobalOrConstant, VarDeclaration}, Frontend, Result, }; -use crate::{arena::Handle, proc::U32EvalError, Block, Expression, Span, Type}; +use crate::{arena::Handle, proc::U32EvalError, Expression, Module, Span, Type}; use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue}; use std::iter::Peekable; @@ -162,13 +162,14 @@ impl<'source> ParsingContext<'source> { }) } - pub fn parse(&mut self, frontend: &mut Frontend) -> Result<()> { + pub fn parse(&mut self, frontend: &mut Frontend) -> Result { + let mut module = Module::default(); + // Body and expression arena for global initialization - let mut body = Block::new(); - let mut ctx = Context::new(frontend, &mut body); + let mut ctx = Context::new(frontend, &mut module)?; while self.peek(frontend).is_some() { - self.parse_external_declaration(frontend, &mut ctx, &mut body)?; + self.parse_external_declaration(frontend, &mut ctx)?; } // Add an `EntryPoint` to `parser.module` for `main`, if a @@ -177,8 +178,8 @@ impl<'source> ParsingContext<'source> { for decl in declaration.overloads.iter() { if let FunctionKind::Call(handle) = decl.kind { if decl.defined && decl.parameters.is_empty() { - frontend.add_entry_point(handle, body, ctx.expressions); - return Ok(()); + frontend.add_entry_point(handle, ctx)?; + return Ok(module); } } } @@ -190,10 +191,14 @@ impl<'source> ParsingContext<'source> { }) } - fn parse_uint_constant(&mut self, frontend: &mut Frontend) -> Result<(u32, Span)> { - let (const_expr, meta) = self.parse_constant_expression(frontend)?; + fn parse_uint_constant( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(u32, Span)> { + let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; - let res = frontend.module.to_ctx().eval_expr_to_u32(const_expr); + let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); let int = match res { Ok(value) => Ok(value), @@ -213,16 +218,15 @@ impl<'source> ParsingContext<'source> { fn parse_constant_expression( &mut self, frontend: &mut Frontend, + module: &mut Module, ) -> Result<(Handle, Span)> { - let mut block = Block::new(); - - let mut ctx = Context::new(frontend, &mut block); + let mut ctx = Context::new(frontend, module)?; let mut stmt_ctx = ctx.stmt_ctx(); - let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, &mut block, None)?; - let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs, &mut block)?; + let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; + let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?; - Ok((frontend.solve_constant(&ctx, root, meta)?, meta)) + Ok((ctx.solve_constant(root, meta)?, meta)) } } @@ -387,16 +391,14 @@ impl Frontend { } } -pub struct DeclarationContext<'ctx, 'qualifiers> { +pub struct DeclarationContext<'ctx, 'qualifiers, 'a> { qualifiers: TypeQualifiers<'qualifiers>, /// Indicates a global declaration external: bool, - - ctx: &'ctx mut Context, - body: &'ctx mut Block, + ctx: &'ctx mut Context<'a>, } -impl<'ctx, 'qualifiers> DeclarationContext<'ctx, 'qualifiers> { +impl<'ctx, 'qualifiers, 'a> DeclarationContext<'ctx, 'qualifiers, 'a> { fn add_var( &mut self, frontend: &mut Frontend, @@ -415,24 +417,14 @@ impl<'ctx, 'qualifiers> DeclarationContext<'ctx, 'qualifiers> { match self.external { true => { - let global = frontend.add_global_var(self.ctx, self.body, decl)?; + let global = frontend.add_global_var(self.ctx, decl)?; let expr = match global { GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle), GlobalOrConstant::Constant(handle) => Expression::Constant(handle), }; - Ok(self.ctx.add_expression(expr, meta, self.body)) + Ok(self.ctx.add_expression(expr, meta)?) } - false => frontend.add_local_var(self.ctx, self.body, decl), + false => frontend.add_local_var(self.ctx, decl), } } - - /// Emits all the expressions captured by the emitter and starts the emitter again - /// - /// Alias to [`emit_restart`] with the declaration body - /// - /// [`emit_restart`]: Context::emit_restart - #[inline] - fn flush_expressions(&mut self) { - self.ctx.emit_restart(self.body); - } } diff --git a/src/front/glsl/parser/declarations.rs b/src/front/glsl/parser/declarations.rs index 42b2abb5c0..2abd67672b 100644 --- a/src/front/glsl/parser/declarations.rs +++ b/src/front/glsl/parser/declarations.rs @@ -13,8 +13,8 @@ use crate::{ Error, ErrorKind, Frontend, Span, }, proc::Alignment, - AddressSpace, Block, Expression, FunctionResult, Handle, ScalarKind, Statement, StructMember, - Type, TypeInner, + AddressSpace, Expression, FunctionResult, Handle, ScalarKind, Statement, StructMember, Type, + TypeInner, }; use super::{DeclarationContext, ParsingContext, Result}; @@ -73,10 +73,9 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, global_ctx: &mut Context, - global_body: &mut Block, ) -> Result<()> { if self - .parse_declaration(frontend, global_ctx, global_body, true)? + .parse_declaration(frontend, global_ctx, true)? .is_none() { let token = self.bump(frontend)?; @@ -103,7 +102,6 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ty: Handle, ctx: &mut Context, - body: &mut Block, ) -> Result<(Handle, Span)> { // initializer: // assignment_expression @@ -118,10 +116,9 @@ impl<'source> ParsingContext<'source> { let mut components = Vec::new(); loop { // The type expected to be parsed inside the initializer list - let new_ty = - element_or_member_type(ty, components.len(), &mut frontend.module.types); + let new_ty = element_or_member_type(ty, components.len(), &mut ctx.module.types); - components.push(self.parse_initializer(frontend, new_ty, ctx, body)?.0); + components.push(self.parse_initializer(frontend, new_ty, ctx)?.0); let token = self.bump(frontend)?; match token.value { @@ -150,18 +147,17 @@ impl<'source> ParsingContext<'source> { } Ok(( - ctx.add_expression(Expression::Compose { ty, components }, meta, body), + ctx.add_expression(Expression::Compose { ty, components }, meta)?, meta, )) } else { let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_assignment(frontend, ctx, &mut stmt, body)?; - let (mut init, init_meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; + let expr = self.parse_assignment(frontend, ctx, &mut stmt)?; + let (mut init, init_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; - let scalar_components = scalar_components(&frontend.module.types[ty].inner); + let scalar_components = scalar_components(&ctx.module.types[ty].inner); if let Some((kind, width)) = scalar_components { - ctx.implicit_conversion(frontend, &mut init, init_meta, kind, width, body)?; + ctx.implicit_conversion(&mut init, init_meta, kind, width)?; } Ok((init, init_meta)) @@ -223,19 +219,17 @@ impl<'source> ParsingContext<'source> { // parse an array specifier if it exists // NOTE: unlike other parse methods this one doesn't expect an array specifier and // returns Ok(None) rather than an error if there is not one - self.parse_array_specifier(frontend, &mut meta, &mut ty)?; + self.parse_array_specifier(frontend, ctx.ctx, &mut meta, &mut ty)?; let init = self .bump_if(frontend, TokenValue::Assign) .map::, _>(|_| { - let (mut expr, init_meta) = - self.parse_initializer(frontend, ty, ctx.ctx, ctx.body)?; + let (mut expr, init_meta) = self.parse_initializer(frontend, ty, ctx.ctx)?; - let scalar_components = scalar_components(&frontend.module.types[ty].inner); + let scalar_components = scalar_components(&ctx.ctx.module.types[ty].inner); if let Some((kind, width)) = scalar_components { - ctx.ctx.implicit_conversion( - frontend, &mut expr, init_meta, kind, width, ctx.body, - )?; + ctx.ctx + .implicit_conversion(&mut expr, init_meta, kind, width)?; } meta.subsume(init_meta); @@ -247,7 +241,7 @@ impl<'source> ParsingContext<'source> { let is_const = ctx.qualifiers.storage.0 == StorageQualifier::Const; let maybe_const_expr = if ctx.external { if let Some((root, meta)) = init { - match frontend.solve_constant(ctx.ctx, root, meta) { + match ctx.ctx.solve_constant(root, meta) { Ok(res) => Some(res), // If the declaration is external (global scope) and is constant qualified // then the initializer must be a constant expression @@ -264,8 +258,8 @@ impl<'source> ParsingContext<'source> { let pointer = ctx.add_var(frontend, ty, name, maybe_const_expr, meta)?; if let Some((value, _)) = init.filter(|_| maybe_const_expr.is_none()) { - ctx.flush_expressions(); - ctx.body.push(Statement::Store { pointer, value }, meta); + ctx.ctx.emit_restart(); + ctx.ctx.body.push(Statement::Store { pointer, value }, meta); } let token = self.bump(frontend)?; @@ -292,7 +286,6 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, ctx: &mut Context, - body: &mut Block, external: bool, ) -> Result> { //declaration: @@ -308,12 +301,12 @@ impl<'source> ParsingContext<'source> { // type_qualifier IDENTIFIER identifier_list SEMICOLON if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) { - let mut qualifiers = self.parse_type_qualifiers(frontend)?; + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; if self.peek_type_name(frontend) { // This branch handles variables and function prototypes and if // external is true also function definitions - let (ty, mut meta) = self.parse_type(frontend)?; + let (ty, mut meta) = self.parse_type(frontend, ctx)?; let token = self.bump(frontend)?; let token_fallthrough = match token.value { @@ -323,11 +316,10 @@ impl<'source> ParsingContext<'source> { self.bump(frontend)?; let result = ty.map(|ty| FunctionResult { ty, binding: None }); - let mut body = Block::new(); - let mut context = Context::new(frontend, &mut body); + let mut context = Context::new(frontend, ctx.module)?; - self.parse_function_args(frontend, &mut context, &mut body)?; + self.parse_function_args(frontend, &mut context)?; let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; meta.subsume(end_meta); @@ -350,11 +342,10 @@ impl<'source> ParsingContext<'source> { token.meta, frontend, &mut context, - &mut body, &mut None, )?; - frontend.add_function(context, name, result, body, meta); + frontend.add_function(context, name, result, meta); Ok(Some(meta)) } @@ -395,7 +386,6 @@ impl<'source> ParsingContext<'source> { qualifiers, external, ctx, - body, }; self.backtrack(token_fallthrough)?; @@ -420,7 +410,6 @@ impl<'source> ParsingContext<'source> { self.parse_block_declaration( frontend, ctx, - body, &mut qualifiers, ty_name, token.meta, @@ -428,7 +417,7 @@ impl<'source> ParsingContext<'source> { .map(Some) } else { if qualifiers.invariant.take().is_some() { - frontend.make_variable_invariant(ctx, body, &ty_name, token.meta); + frontend.make_variable_invariant(ctx, &ty_name, token.meta)?; qualifiers.unused_errors(&mut frontend.errors); self.expect(frontend, TokenValue::Semicolon)?; @@ -501,9 +490,9 @@ impl<'source> ParsingContext<'source> { } }; - let (ty, meta) = self.parse_type_non_void(frontend)?; + let (ty, meta) = self.parse_type_non_void(frontend, ctx)?; - match frontend.module.types[ty].inner { + match ctx.module.types[ty].inner { TypeInner::Scalar { kind: ScalarKind::Float | ScalarKind::Sint, .. @@ -529,7 +518,6 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, ctx: &mut Context, - body: &mut Block, qualifiers: &mut TypeQualifiers, ty_name: String, mut meta: Span, @@ -549,10 +537,10 @@ impl<'source> ParsingContext<'source> { }; let mut members = Vec::new(); - let span = self.parse_struct_declaration_list(frontend, &mut members, layout)?; + let span = self.parse_struct_declaration_list(frontend, ctx, &mut members, layout)?; self.expect(frontend, TokenValue::RightBrace)?; - let mut ty = frontend.module.types.insert( + let mut ty = ctx.module.types.insert( Type { name: Some(ty_name), inner: TypeInner::Struct { @@ -567,7 +555,7 @@ impl<'source> ParsingContext<'source> { let name = match token.value { TokenValue::Semicolon => None, TokenValue::Identifier(name) => { - self.parse_array_specifier(frontend, &mut meta, &mut ty)?; + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; self.expect(frontend, TokenValue::Semicolon)?; @@ -586,7 +574,6 @@ impl<'source> ParsingContext<'source> { let global = frontend.add_global_var( ctx, - body, VarDeclaration { qualifiers, ty, @@ -608,7 +595,7 @@ impl<'source> ParsingContext<'source> { entry_arg: None, mutable: true, }; - ctx.add_global(frontend, &k, lookup, body); + ctx.add_global(&k, lookup)?; frontend.global_variables.push((k, lookup)); } @@ -620,6 +607,7 @@ impl<'source> ParsingContext<'source> { pub fn parse_struct_declaration_list( &mut self, frontend: &mut Frontend, + ctx: &mut Context, members: &mut Vec, layout: StructLayout, ) -> Result { @@ -629,12 +617,12 @@ impl<'source> ParsingContext<'source> { loop { // TODO: type_qualifier - let (base_ty, mut meta) = self.parse_type_non_void(frontend)?; + let (base_ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; loop { let (name, name_meta) = self.expect_ident(frontend)?; let mut ty = base_ty; - self.parse_array_specifier(frontend, &mut meta, &mut ty)?; + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; meta.subsume(name_meta); @@ -642,7 +630,7 @@ impl<'source> ParsingContext<'source> { ty, meta, layout, - &mut frontend.module.types, + &mut ctx.module.types, &mut frontend.errors, ); diff --git a/src/front/glsl/parser/expressions.rs b/src/front/glsl/parser/expressions.rs index 7e47b2eea7..32e0959c26 100644 --- a/src/front/glsl/parser/expressions.rs +++ b/src/front/glsl/parser/expressions.rs @@ -9,7 +9,7 @@ use crate::{ token::{Token, TokenValue}, Error, Frontend, Result, Span, }, - ArraySize, BinaryOperator, Block, Handle, Literal, Type, TypeInner, UnaryOperator, + ArraySize, BinaryOperator, Handle, Literal, Type, TypeInner, UnaryOperator, }; impl<'source> ParsingContext<'source> { @@ -18,7 +18,6 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, ) -> Result> { let mut token = self.bump(frontend)?; @@ -47,7 +46,7 @@ impl<'source> ParsingContext<'source> { } TokenValue::BoolConstant(value) => Literal::Bool(value), TokenValue::LeftParen => { - let expr = self.parse_expression(frontend, ctx, stmt, body)?; + let expr = self.parse_expression(frontend, ctx, stmt)?; let meta = self.expect(frontend, TokenValue::RightParen)?.meta; token.meta.subsume(meta); @@ -84,7 +83,6 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, meta: &mut Span, ) -> Result>> { let mut args = Vec::new(); @@ -92,7 +90,7 @@ impl<'source> ParsingContext<'source> { meta.subsume(token.meta); } else { loop { - args.push(self.parse_assignment(frontend, ctx, stmt, body)?); + args.push(self.parse_assignment(frontend, ctx, stmt)?); let token = self.bump(frontend)?; match token.value { @@ -122,21 +120,20 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, ) -> Result> { let mut base = if self.peek_type_name(frontend) { - let (mut handle, mut meta) = self.parse_type_non_void(frontend)?; + let (mut handle, mut meta) = self.parse_type_non_void(frontend, ctx)?; self.expect(frontend, TokenValue::LeftParen)?; - let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?; + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; if let TypeInner::Array { size: ArraySize::Dynamic, stride, base, - } = frontend.module.types[handle].inner + } = ctx.module.types[handle].inner { - let span = frontend.module.types.get_span(handle); + let span = ctx.module.types.get_span(handle); let size = u32::try_from(args.len()) .ok() @@ -148,7 +145,7 @@ impl<'source> ParsingContext<'source> { meta, })?; - handle = frontend.module.types.insert( + handle = ctx.module.types.insert( Type { name: None, inner: TypeInner::Array { @@ -175,7 +172,7 @@ impl<'source> ParsingContext<'source> { let (name, mut meta) = self.expect_ident(frontend)?; let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() { - let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?; + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; let kind = match frontend.lookup_type.get(&name) { Some(ty) => FunctionCallKind::TypeConstructor(*ty), @@ -187,7 +184,7 @@ impl<'source> ParsingContext<'source> { meta, } } else { - let var = match frontend.lookup_variable(ctx, body, &name, meta) { + let var = match frontend.lookup_variable(ctx, &name, meta)? { Some(var) => var, None => { return Err(Error { @@ -205,7 +202,7 @@ impl<'source> ParsingContext<'source> { stmt.hir_exprs.append(expr, Default::default()) } else { - self.parse_primary(frontend, ctx, stmt, body)? + self.parse_primary(frontend, ctx, stmt)? }; while let TokenValue::LeftBracket @@ -217,7 +214,7 @@ impl<'source> ParsingContext<'source> { match value { TokenValue::LeftBracket => { - let index = self.parse_expression(frontend, ctx, stmt, body)?; + let index = self.parse_expression(frontend, ctx, stmt)?; let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta; meta.subsume(end_meta); @@ -233,8 +230,7 @@ impl<'source> ParsingContext<'source> { let (field, end_meta) = self.expect_ident(frontend)?; if self.bump_if(frontend, TokenValue::LeftParen).is_some() { - let args = - self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?; + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; base = stmt.hir_exprs.append( HirExpr { @@ -287,13 +283,12 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, ) -> Result> { Ok(match self.expect_peek(frontend)?.value { TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => { let Token { value, mut meta } = self.bump(frontend)?; - let expr = self.parse_unary(frontend, ctx, stmt, body)?; + let expr = self.parse_unary(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[expr].meta; let kind = match value { @@ -315,7 +310,7 @@ impl<'source> ParsingContext<'source> { TokenValue::Increment | TokenValue::Decrement => { let Token { value, meta } = self.bump(frontend)?; - let expr = self.parse_unary(frontend, ctx, stmt, body)?; + let expr = self.parse_unary(frontend, ctx, stmt)?; stmt.hir_exprs.append( HirExpr { @@ -332,7 +327,7 @@ impl<'source> ParsingContext<'source> { Default::default(), ) } - _ => self.parse_postfix(frontend, ctx, stmt, body)?, + _ => self.parse_postfix(frontend, ctx, stmt)?, }) } @@ -341,13 +336,12 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, passthrough: Option>, min_bp: u8, ) -> Result> { let mut left = passthrough .ok_or(ErrorKind::EndOfFile /* Dummy error */) - .or_else(|_| self.parse_unary(frontend, ctx, stmt, body))?; + .or_else(|_| self.parse_unary(frontend, ctx, stmt))?; let mut meta = stmt.hir_exprs[left].meta; while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) { @@ -357,7 +351,7 @@ impl<'source> ParsingContext<'source> { let Token { value, .. } = self.bump(frontend)?; - let right = self.parse_binary(frontend, ctx, stmt, body, None, r_bp)?; + let right = self.parse_binary(frontend, ctx, stmt, None, r_bp)?; let end_meta = stmt.hir_exprs[right].meta; meta.subsume(end_meta); @@ -403,16 +397,15 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, passthrough: Option>, ) -> Result> { - let mut condition = self.parse_binary(frontend, ctx, stmt, body, passthrough, 0)?; + let mut condition = self.parse_binary(frontend, ctx, stmt, passthrough, 0)?; let mut meta = stmt.hir_exprs[condition].meta; if self.bump_if(frontend, TokenValue::Question).is_some() { - let accept = self.parse_expression(frontend, ctx, stmt, body)?; + let accept = self.parse_expression(frontend, ctx, stmt)?; self.expect(frontend, TokenValue::Colon)?; - let reject = self.parse_assignment(frontend, ctx, stmt, body)?; + let reject = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[reject].meta; meta.subsume(end_meta); @@ -437,15 +430,14 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, ) -> Result> { - let tgt = self.parse_unary(frontend, ctx, stmt, body)?; + let tgt = self.parse_unary(frontend, ctx, stmt)?; let mut meta = stmt.hir_exprs[tgt].meta; Ok(match self.expect_peek(frontend)?.value { TokenValue::Assign => { self.bump(frontend)?; - let value = self.parse_assignment(frontend, ctx, stmt, body)?; + let value = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[value].meta; meta.subsume(end_meta); @@ -468,7 +460,7 @@ impl<'source> ParsingContext<'source> { | TokenValue::RightShiftAssign | TokenValue::XorAssign => { let token = self.bump(frontend)?; - let right = self.parse_assignment(frontend, ctx, stmt, body)?; + let right = self.parse_assignment(frontend, ctx, stmt)?; let end_meta = stmt.hir_exprs[right].meta; meta.subsume(end_meta); @@ -504,7 +496,7 @@ impl<'source> ParsingContext<'source> { Default::default(), ) } - _ => self.parse_conditional(frontend, ctx, stmt, body, Some(tgt))?, + _ => self.parse_conditional(frontend, ctx, stmt, Some(tgt))?, }) } @@ -513,13 +505,12 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, stmt: &mut StmtContext, - body: &mut Block, ) -> Result> { - let mut expr = self.parse_assignment(frontend, ctx, stmt, body)?; + let mut expr = self.parse_assignment(frontend, ctx, stmt)?; while let TokenValue::Comma = self.expect_peek(frontend)?.value { self.bump(frontend)?; - expr = self.parse_assignment(frontend, ctx, stmt, body)?; + expr = self.parse_assignment(frontend, ctx, stmt)?; } Ok(expr) diff --git a/src/front/glsl/parser/functions.rs b/src/front/glsl/parser/functions.rs index c05d060ea1..358bc52e76 100644 --- a/src/front/glsl/parser/functions.rs +++ b/src/front/glsl/parser/functions.rs @@ -40,12 +40,11 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, ctx: &mut Context, - body: &mut Block, terminator: &mut Option, ) -> Result> { // Type qualifiers always identify a declaration statement if self.peek_type_qualifier(frontend) { - return self.parse_declaration(frontend, ctx, body, false); + return self.parse_declaration(frontend, ctx, false); } // Type names can identify either declaration statements or type constructors @@ -61,7 +60,7 @@ impl<'source> ParsingContext<'source> { self.backtrack(token)?; if declaration { - return self.parse_declaration(frontend, ctx, body, false); + return self.parse_declaration(frontend, ctx, false); } } @@ -79,14 +78,14 @@ impl<'source> ParsingContext<'source> { let meta_rest = match *value { TokenValue::Continue => { let meta = self.bump(frontend)?.meta; - body.push(Statement::Continue, meta); - terminator.get_or_insert(body.len()); + ctx.body.push(Statement::Continue, meta); + terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } TokenValue::Break => { let meta = self.bump(frontend)?.meta; - body.push(Statement::Break, meta); - terminator.get_or_insert(body.len()); + ctx.body.push(Statement::Break, meta); + terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } TokenValue::Return => { @@ -96,25 +95,25 @@ impl<'source> ParsingContext<'source> { _ => { // TODO: Implicit conversions let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; self.expect(frontend, TokenValue::Semicolon)?; let (handle, meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; (Some(handle), meta) } }; - ctx.emit_restart(body); + ctx.emit_restart(); - body.push(Statement::Return { value }, meta); - terminator.get_or_insert(body.len()); + ctx.body.push(Statement::Return { value }, meta); + terminator.get_or_insert(ctx.body.len()); meta } TokenValue::Discard => { let meta = self.bump(frontend)?.meta; - body.push(Statement::Kill, meta); - terminator.get_or_insert(body.len()); + ctx.body.push(Statement::Kill, meta); + terminator.get_or_insert(ctx.body.len()); self.expect(frontend, TokenValue::Semicolon)?.meta } @@ -124,33 +123,31 @@ impl<'source> ParsingContext<'source> { self.expect(frontend, TokenValue::LeftParen)?; let condition = { let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; let (handle, more_meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; meta.subsume(more_meta); handle }; self.expect(frontend, TokenValue::RightParen)?; - ctx.emit_restart(body); - - let mut accept = Block::new(); - if let Some(more_meta) = - self.parse_statement(frontend, ctx, &mut accept, &mut None)? - { - meta.subsume(more_meta) - } - - let mut reject = Block::new(); - if self.bump_if(frontend, TokenValue::Else).is_some() { - if let Some(more_meta) = - self.parse_statement(frontend, ctx, &mut reject, &mut None)? - { + let accept = ctx.new_body(|ctx| { + if let Some(more_meta) = self.parse_statement(frontend, ctx, &mut None)? { meta.subsume(more_meta); } - } + Ok(()) + })?; - body.push( + let reject = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Else).is_some() { + if let Some(more_meta) = self.parse_statement(frontend, ctx, &mut None)? { + meta.subsume(more_meta); + } + } + Ok(()) + })?; + + ctx.body.push( Statement::If { condition, accept, @@ -169,17 +166,16 @@ impl<'source> ParsingContext<'source> { let (selector, uint) = { let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; - let (root, meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; - let uint = frontend.resolve_type(ctx, root, meta)?.scalar_kind() + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + let (root, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + let uint = ctx.resolve_type(root, meta)?.scalar_kind() == Some(crate::ScalarKind::Uint); (root, uint) }; self.expect(frontend, TokenValue::RightParen)?; - ctx.emit_restart(body); + ctx.emit_restart(); let mut cases = Vec::new(); // Track if any default case is present in the switch statement. @@ -192,12 +188,12 @@ impl<'source> ParsingContext<'source> { self.bump(frontend)?; let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; let (root, meta) = - ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?; - let const_expr = frontend.solve_constant(ctx, root, meta)?; + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + let const_expr = ctx.solve_constant(root, meta)?; - match frontend.module.const_expressions[const_expr] { + match ctx.module.const_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { true => crate::SwitchValue::U32(value as u32), false => crate::SwitchValue::I32(value), @@ -244,35 +240,32 @@ impl<'source> ParsingContext<'source> { self.expect(frontend, TokenValue::Colon)?; - let mut body = Block::new(); + let mut fall_through = true; - let mut case_terminator = None; - loop { - match self.expect_peek(frontend)?.value { - TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => { - break - } - _ => { - self.parse_statement( - frontend, - ctx, - &mut body, - &mut case_terminator, - )?; + let body = ctx.new_body(|ctx| { + let mut case_terminator = None; + loop { + match self.expect_peek(frontend)?.value { + TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => { + break + } + _ => { + self.parse_statement(frontend, ctx, &mut case_terminator)?; + } } } - } - let mut fall_through = true; + if let Some(mut idx) = case_terminator { + if let Statement::Break = ctx.body[idx - 1] { + fall_through = false; + idx -= 1; + } - if let Some(mut idx) = case_terminator { - if let Statement::Break = body[idx - 1] { - fall_through = false; - idx -= 1; + ctx.body.cull(idx..) } - body.cull(idx..) - } + Ok(()) + })?; cases.push(SwitchCase { value, @@ -317,51 +310,48 @@ impl<'source> ParsingContext<'source> { }) } - body.push(Statement::Switch { selector, cases }, meta); + ctx.body.push(Statement::Switch { selector, cases }, meta); meta } TokenValue::While => { let mut meta = self.bump(frontend)?.meta; - let mut loop_body = Block::new(); - - let mut stmt = ctx.stmt_ctx(); - self.expect(frontend, TokenValue::LeftParen)?; - let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?; - meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); + let loop_body = ctx.new_body(|ctx| { + let mut stmt = ctx.stmt_ctx(); + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); - let (expr, expr_meta) = - ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?; - let condition = ctx.add_expression( - Expression::Unary { - op: UnaryOperator::Not, - expr, - }, - expr_meta, - &mut loop_body, - ); + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::Not, + expr, + }, + expr_meta, + )?; - ctx.emit_restart(&mut loop_body); + ctx.emit_restart(); - loop_body.push( - Statement::If { - condition, - accept: new_break(), - reject: Block::new(), - }, - crate::Span::default(), - ); + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); - meta.subsume(expr_meta); + meta.subsume(expr_meta); - if let Some(body_meta) = - self.parse_statement(frontend, ctx, &mut loop_body, &mut None)? - { - meta.subsume(body_meta); - } + if let Some(body_meta) = self.parse_statement(frontend, ctx, &mut None)? { + meta.subsume(body_meta); + } + Ok(()) + })?; - body.push( + ctx.body.push( Statement::Loop { body: loop_body, continuing: Block::new(), @@ -375,47 +365,46 @@ impl<'source> ParsingContext<'source> { TokenValue::Do => { let mut meta = self.bump(frontend)?.meta; - let mut loop_body = Block::new(); - - let mut terminator = None; - self.parse_statement(frontend, ctx, &mut loop_body, &mut terminator)?; + let loop_body = ctx.new_body(|ctx| { + let mut terminator = None; + self.parse_statement(frontend, ctx, &mut terminator)?; - let mut stmt = ctx.stmt_ctx(); + let mut stmt = ctx.stmt_ctx(); - self.expect(frontend, TokenValue::While)?; - self.expect(frontend, TokenValue::LeftParen)?; - let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?; - let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; + self.expect(frontend, TokenValue::While)?; + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; - meta.subsume(end_meta); + meta.subsume(end_meta); - let (expr, expr_meta) = - ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?; - let condition = ctx.add_expression( - Expression::Unary { - op: UnaryOperator::Not, - expr, - }, - expr_meta, - &mut loop_body, - ); + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::Not, + expr, + }, + expr_meta, + )?; - ctx.emit_restart(&mut loop_body); + ctx.emit_restart(); - loop_body.push( - Statement::If { - condition, - accept: new_break(), - reject: Block::new(), - }, - crate::Span::default(), - ); + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); - if let Some(idx) = terminator { - loop_body.cull(idx..) - } + if let Some(idx) = terminator { + ctx.body.cull(idx..) + } + Ok(()) + })?; - body.push( + ctx.body.push( Statement::Loop { body: loop_body, continuing: Block::new(), @@ -434,96 +423,98 @@ impl<'source> ParsingContext<'source> { if self.bump_if(frontend, TokenValue::Semicolon).is_none() { if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { - self.parse_declaration(frontend, ctx, body, false)?; + self.parse_declaration(frontend, ctx, false)?; } else { let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; - ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?; + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; self.expect(frontend, TokenValue::Semicolon)?; } } - let (mut block, mut continuing) = (Block::new(), Block::new()); - - if self.bump_if(frontend, TokenValue::Semicolon).is_none() { - let (expr, expr_meta) = if self.peek_type_name(frontend) - || self.peek_type_qualifier(frontend) - { - let mut qualifiers = self.parse_type_qualifiers(frontend)?; - let (ty, mut meta) = self.parse_type_non_void(frontend)?; - let name = self.expect_ident(frontend)?.0; - - self.expect(frontend, TokenValue::Assign)?; - - let (value, end_meta) = - self.parse_initializer(frontend, ty, ctx, &mut block)?; - meta.subsume(end_meta); - - let decl = VarDeclaration { - qualifiers: &mut qualifiers, - ty, - name: Some(name), - init: None, - meta, - }; - - let pointer = frontend.add_local_var(ctx, &mut block, decl)?; - - ctx.emit_restart(&mut block); + let loop_body = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Semicolon).is_none() { + let (expr, expr_meta) = if self.peek_type_name(frontend) + || self.peek_type_qualifier(frontend) + { + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; + let (ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; + let name = self.expect_ident(frontend)?.0; + + self.expect(frontend, TokenValue::Assign)?; + + let (value, end_meta) = self.parse_initializer(frontend, ty, ctx)?; + meta.subsume(end_meta); + + let decl = VarDeclaration { + qualifiers: &mut qualifiers, + ty, + name: Some(name), + init: None, + meta, + }; - block.push(Statement::Store { pointer, value }, meta); + let pointer = frontend.add_local_var(ctx, decl)?; - (value, end_meta) - } else { - let mut stmt = ctx.stmt_ctx(); - let root = self.parse_expression(frontend, ctx, &mut stmt, &mut block)?; - ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut block)? - }; + ctx.emit_restart(); - let condition = ctx.add_expression( - Expression::Unary { - op: UnaryOperator::Not, - expr, - }, - expr_meta, - &mut block, - ); + ctx.body.push(Statement::Store { pointer, value }, meta); - ctx.emit_restart(&mut block); + (value, end_meta) + } else { + let mut stmt = ctx.stmt_ctx(); + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)? + }; - block.push( - Statement::If { - condition, - accept: new_break(), - reject: Block::new(), - }, - crate::Span::default(), - ); + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::Not, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); - self.expect(frontend, TokenValue::Semicolon)?; - } + self.expect(frontend, TokenValue::Semicolon)?; + } + Ok(()) + })?; - match self.expect_peek(frontend)?.value { - TokenValue::RightParen => {} - _ => { - let mut stmt = ctx.stmt_ctx(); - let rest = - self.parse_expression(frontend, ctx, &mut stmt, &mut continuing)?; - ctx.lower(stmt, frontend, rest, ExprPos::Rhs, &mut continuing)?; + let continuing = ctx.new_body(|ctx| { + match self.expect_peek(frontend)?.value { + TokenValue::RightParen => {} + _ => { + let mut stmt = ctx.stmt_ctx(); + let rest = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, rest, ExprPos::Rhs)?; + } } - } + Ok(()) + })?; meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); - if let Some(stmt_meta) = - self.parse_statement(frontend, ctx, &mut block, &mut None)? - { - meta.subsume(stmt_meta); - } + let loop_body = ctx.with_body(loop_body, |ctx| { + if let Some(stmt_meta) = self.parse_statement(frontend, ctx, &mut None)? { + meta.subsume(stmt_meta); + } + Ok(()) + })?; - body.push( + ctx.body.push( Statement::Loop { - body: block, + body: loop_body, continuing, break_if: None, }, @@ -535,22 +526,20 @@ impl<'source> ParsingContext<'source> { meta } TokenValue::LeftBrace => { - let meta = self.bump(frontend)?.meta; - - let mut block = Block::new(); + let mut meta = self.bump(frontend)?.meta; let mut block_terminator = None; - let meta = self.parse_compound_statement( - meta, - frontend, - ctx, - &mut block, - &mut block_terminator, - )?; - body.push(Statement::Block(block), meta); + let block = ctx.new_body(|ctx| { + let block_meta = + self.parse_compound_statement(meta, frontend, ctx, &mut block_terminator)?; + meta.subsume(block_meta); + Ok(()) + })?; + + ctx.body.push(Statement::Block(block), meta); if block_terminator.is_some() { - terminator.get_or_insert(body.len()); + terminator.get_or_insert(ctx.body.len()); } meta @@ -561,8 +550,8 @@ impl<'source> ParsingContext<'source> { // tokens. Unknown or invalid tokens will be caught there and // turned into an error. let mut stmt = ctx.stmt_ctx(); - let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?; - ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?; + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; self.expect(frontend, TokenValue::Semicolon)?.meta } }; @@ -576,7 +565,6 @@ impl<'source> ParsingContext<'source> { mut meta: Span, frontend: &mut Frontend, ctx: &mut Context, - body: &mut Block, terminator: &mut Option, ) -> Result { ctx.symbol_table.push_scope(); @@ -590,7 +578,7 @@ impl<'source> ParsingContext<'source> { break; } - let stmt = self.parse_statement(frontend, ctx, body, terminator)?; + let stmt = self.parse_statement(frontend, ctx, terminator)?; if let Some(stmt_meta) = stmt { meta.subsume(stmt_meta); @@ -598,7 +586,7 @@ impl<'source> ParsingContext<'source> { } if let Some(idx) = *terminator { - body.cull(idx..) + ctx.body.cull(idx..) } ctx.symbol_table.pop_scope(); @@ -609,8 +597,7 @@ impl<'source> ParsingContext<'source> { pub fn parse_function_args( &mut self, frontend: &mut Frontend, - context: &mut Context, - body: &mut Block, + ctx: &mut Context, ) -> Result<()> { if self.bump_if(frontend, TokenValue::Void).is_some() { return Ok(()); @@ -619,19 +606,19 @@ impl<'source> ParsingContext<'source> { loop { if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) { let qualifier = self.parse_parameter_qualifier(frontend); - let mut ty = self.parse_type_non_void(frontend)?.0; + let mut ty = self.parse_type_non_void(frontend, ctx)?.0; match self.expect_peek(frontend)?.value { TokenValue::Comma => { self.bump(frontend)?; - context.add_function_arg(frontend, body, None, ty, qualifier); + ctx.add_function_arg(None, ty, qualifier)?; continue; } TokenValue::Identifier(_) => { let mut name = self.expect_ident(frontend)?; - self.parse_array_specifier(frontend, &mut name.1, &mut ty)?; + self.parse_array_specifier(frontend, ctx, &mut name.1, &mut ty)?; - context.add_function_arg(frontend, body, Some(name), ty, qualifier); + ctx.add_function_arg(Some(name), ty, qualifier)?; if self.bump_if(frontend, TokenValue::Comma).is_some() { continue; diff --git a/src/front/glsl/parser/types.rs b/src/front/glsl/parser/types.rs index 4761e8bc53..d2d24b5dab 100644 --- a/src/front/glsl/parser/types.rs +++ b/src/front/glsl/parser/types.rs @@ -3,6 +3,7 @@ use std::num::NonZeroU32; use crate::{ front::glsl::{ ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers}, + context::Context, error::ExpectedToken, parser::ParsingContext, token::{Token, TokenValue}, @@ -17,10 +18,11 @@ impl<'source> ParsingContext<'source> { pub fn parse_array_specifier( &mut self, frontend: &mut Frontend, + ctx: &mut Context, span: &mut Span, ty: &mut Handle, ) -> Result<()> { - while self.parse_array_specifier_single(frontend, span, ty)? {} + while self.parse_array_specifier_single(frontend, ctx, span, ty)? {} Ok(()) } @@ -28,6 +30,7 @@ impl<'source> ParsingContext<'source> { fn parse_array_specifier_single( &mut self, frontend: &mut Frontend, + ctx: &mut Context, span: &mut Span, ty: &mut Handle, ) -> Result { @@ -38,7 +41,7 @@ impl<'source> ParsingContext<'source> { span.subsume(meta); ArraySize::Dynamic } else { - let (value, constant_span) = self.parse_uint_constant(frontend)?; + let (value, constant_span) = self.parse_uint_constant(frontend, ctx)?; let size = NonZeroU32::new(value).ok_or(Error { kind: ErrorKind::SemanticError("Array size must be greater than zero".into()), meta: constant_span, @@ -48,9 +51,9 @@ impl<'source> ParsingContext<'source> { ArraySize::Constant(size) }; - frontend.layouter.update(frontend.module.to_ctx()).unwrap(); + frontend.layouter.update(ctx.module.to_ctx()).unwrap(); let stride = frontend.layouter[*ty].to_stride(); - *ty = frontend.module.types.insert( + *ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Array { @@ -68,11 +71,15 @@ impl<'source> ParsingContext<'source> { } } - pub fn parse_type(&mut self, frontend: &mut Frontend) -> Result<(Option>, Span)> { + pub fn parse_type( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Option>, Span)> { let token = self.bump(frontend)?; let mut handle = match token.value { TokenValue::Void => return Ok((None, token.meta)), - TokenValue::TypeName(ty) => frontend.module.types.insert(ty, token.meta), + TokenValue::TypeName(ty) => ctx.module.types.insert(ty, token.meta), TokenValue::Struct => { let mut meta = token.meta; let ty_name = self.expect_ident(frontend)?.0; @@ -80,12 +87,13 @@ impl<'source> ParsingContext<'source> { let mut members = Vec::new(); let span = self.parse_struct_declaration_list( frontend, + ctx, &mut members, StructLayout::Std140, )?; let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta; meta.subsume(end_meta); - let ty = frontend.module.types.insert( + let ty = ctx.module.types.insert( Type { name: Some(ty_name.clone()), inner: TypeInner::Struct { members, span }, @@ -120,12 +128,16 @@ impl<'source> ParsingContext<'source> { }; let mut span = token.meta; - self.parse_array_specifier(frontend, &mut span, &mut handle)?; + self.parse_array_specifier(frontend, ctx, &mut span, &mut handle)?; Ok((Some(handle), span)) } - pub fn parse_type_non_void(&mut self, frontend: &mut Frontend) -> Result<(Handle, Span)> { - let (maybe_ty, meta) = self.parse_type(frontend)?; + pub fn parse_type_non_void( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Handle, Span)> { + let (maybe_ty, meta) = self.parse_type(frontend, ctx)?; let ty = maybe_ty.ok_or_else(|| Error { kind: ErrorKind::SemanticError("Type can't be void".into()), meta, @@ -156,6 +168,7 @@ impl<'source> ParsingContext<'source> { pub fn parse_type_qualifiers<'a>( &mut self, frontend: &mut Frontend, + ctx: &mut Context, ) -> Result> { let mut qualifiers = TypeQualifiers::default(); @@ -164,7 +177,7 @@ impl<'source> ParsingContext<'source> { // Handle layout qualifiers outside the match since this can push multiple values if token.value == TokenValue::Layout { - self.parse_layout_qualifier_id_list(frontend, &mut qualifiers)?; + self.parse_layout_qualifier_id_list(frontend, ctx, &mut qualifiers)?; continue; } @@ -287,11 +300,12 @@ impl<'source> ParsingContext<'source> { pub fn parse_layout_qualifier_id_list( &mut self, frontend: &mut Frontend, + ctx: &mut Context, qualifiers: &mut TypeQualifiers, ) -> Result<()> { self.expect(frontend, TokenValue::LeftParen)?; loop { - self.parse_layout_qualifier_id(frontend, &mut qualifiers.layout_qualifiers)?; + self.parse_layout_qualifier_id(frontend, ctx, &mut qualifiers.layout_qualifiers)?; if self.bump_if(frontend, TokenValue::Comma).is_some() { continue; @@ -308,6 +322,7 @@ impl<'source> ParsingContext<'source> { pub fn parse_layout_qualifier_id( &mut self, frontend: &mut Frontend, + ctx: &mut Context, qualifiers: &mut crate::FastHashMap, ) -> Result<()> { // layout_qualifier_id: @@ -332,13 +347,14 @@ impl<'source> ParsingContext<'source> { } else { let key = QualifierKey::String(name.into()); let value = if self.bump_if(frontend, TokenValue::Assign).is_some() { - let (value, end_meta) = match self.parse_uint_constant(frontend) { - Ok(v) => v, - Err(e) => { - frontend.errors.push(e); - (0, Span::default()) - } - }; + let (value, end_meta) = + match self.parse_uint_constant(frontend, ctx) { + Ok(v) => v, + Err(e) => { + frontend.errors.push(e); + (0, Span::default()) + } + }; token.meta.subsume(end_meta); QualifierValue::Uint(value) diff --git a/src/front/glsl/types.rs b/src/front/glsl/types.rs index 513bc22754..a91a0a9f28 100644 --- a/src/front/glsl/types.rs +++ b/src/front/glsl/types.rs @@ -1,6 +1,4 @@ -use super::{ - constants::ConstantSolver, context::Context, Error, ErrorKind, Frontend, Result, Span, -}; +use super::{constants::ConstantSolver, context::Context, Error, ErrorKind, Result, Span}; use crate::{ proc::ResolveContext, Bytes, Expression, Handle, ImageClass, ImageDimension, ScalarKind, Type, TypeInner, VectorSize, @@ -226,7 +224,7 @@ pub const fn type_power(kind: ScalarKind, width: Bytes) -> Option { }) } -impl Frontend { +impl Context<'_> { /// Resolves the types of the expressions until `expr` (inclusive) /// /// This needs to be done before the [`typifier`] can be queried for @@ -240,16 +238,11 @@ impl Frontend { /// /// [`typifier`]: Context::typifier /// [`resolve_type`]: Self::resolve_type - pub(crate) fn typifier_grow( - &self, - ctx: &mut Context, - expr: Handle, - meta: Span, - ) -> Result<()> { - let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); + pub(crate) fn typifier_grow(&mut self, expr: Handle, meta: Span) -> Result<()> { + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); - ctx.typifier - .grow(expr, &ctx.expressions, &resolve_ctx) + self.typifier + .grow(expr, &self.expressions, &resolve_ctx) .map_err(|error| Error { kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), meta, @@ -263,14 +256,13 @@ impl Frontend { /// /// [`typifier`]: Context::typifier /// [`typifier_grow`]: Self::typifier_grow - pub(crate) fn resolve_type<'b>( - &'b self, - ctx: &'b mut Context, + pub(crate) fn resolve_type( + &mut self, expr: Handle, meta: Span, - ) -> Result<&'b TypeInner> { - self.typifier_grow(ctx, expr, meta)?; - Ok(ctx.typifier.get(expr, &self.module.types)) + ) -> Result<&TypeInner> { + self.typifier_grow(expr, meta)?; + Ok(self.typifier.get(expr, &self.module.types)) } /// Gets the type handle for the result of the `expr` expression @@ -290,25 +282,23 @@ impl Frontend { /// [`resolve_type`]: Self::resolve_type pub(crate) fn resolve_type_handle( &mut self, - ctx: &mut Context, expr: Handle, meta: Span, ) -> Result> { - self.typifier_grow(ctx, expr, meta)?; - Ok(ctx.typifier.register_type(expr, &mut self.module.types)) + self.typifier_grow(expr, meta)?; + Ok(self.typifier.register_type(expr, &mut self.module.types)) } /// Invalidates the cached type resolution for `expr` forcing a recomputation - pub(crate) fn invalidate_expression<'b>( - &'b self, - ctx: &'b mut Context, + pub(crate) fn invalidate_expression( + &mut self, expr: Handle, meta: Span, ) -> Result<()> { - let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); - ctx.typifier - .invalidate(expr, &ctx.expressions, &resolve_ctx) + self.typifier + .invalidate(expr, &self.expressions, &resolve_ctx) .map_err(|error| Error { kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), meta, @@ -317,13 +307,12 @@ impl Frontend { pub(crate) fn solve_constant( &mut self, - ctx: &Context, root: Handle, meta: Span, ) -> Result> { let mut solver = ConstantSolver { types: &mut self.module.types, - expressions: &ctx.expressions, + expressions: &self.expressions, constants: &mut self.module.constants, const_expressions: &mut self.module.const_expressions, }; diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index ad0bd7cbd2..58cc31461e 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -5,9 +5,9 @@ use super::{ Frontend, Result, Span, }; use crate::{ - AddressSpace, Binding, Block, BuiltIn, Constant, Expression, GlobalVariable, Handle, - Interpolation, LocalVariable, ResourceBinding, ScalarKind, ShaderStage, SwizzleComponent, Type, - TypeInner, VectorSize, + AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation, + LocalVariable, ResourceBinding, ScalarKind, ShaderStage, SwizzleComponent, Type, TypeInner, + VectorSize, }; pub struct VarDeclaration<'a, 'key> { @@ -40,12 +40,11 @@ impl Frontend { fn add_builtin( &mut self, ctx: &mut Context, - body: &mut Block, name: &str, data: BuiltInData, meta: Span, - ) -> Option { - let ty = self.module.types.insert( + ) -> Result> { + let ty = ctx.module.types.insert( Type { name: None, inner: data.inner, @@ -53,7 +52,7 @@ impl Frontend { meta, ); - let handle = self.module.global_variables.append( + let handle = ctx.module.global_variables.append( GlobalVariable { name: Some(name.into()), space: AddressSpace::Private, @@ -81,7 +80,7 @@ impl Frontend { }, )); - let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta, body); + let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?; let var = VariableReference { expr, @@ -93,18 +92,17 @@ impl Frontend { ctx.symbol_table.add_root(name.into(), var.clone()); - Some(var) + Ok(Some(var)) } pub(crate) fn lookup_variable( &mut self, ctx: &mut Context, - body: &mut Block, name: &str, meta: Span, - ) -> Option { + ) -> Result> { if let Some(var) = ctx.symbol_table.lookup(name).cloned() { - return Some(var); + return Ok(Some(var)); } let data = match name { @@ -182,7 +180,7 @@ impl Frontend { storage: StorageQualifier::Output, }, "gl_ClipDistance" | "gl_CullDistance" => { - let base = self.module.types.insert( + let base = ctx.module.types.insert( Type { name: None, inner: TypeInner::Scalar { @@ -217,7 +215,7 @@ impl Frontend { "gl_VertexIndex" => BuiltIn::VertexIndex, "gl_SampleID" => BuiltIn::SampleIndex, "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex, - _ => return None, + _ => return Ok(None), }; BuiltInData { @@ -232,17 +230,16 @@ impl Frontend { } }; - self.add_builtin(ctx, body, name, data, meta) + self.add_builtin(ctx, name, data, meta) } pub(crate) fn make_variable_invariant( &mut self, ctx: &mut Context, - body: &mut Block, name: &str, meta: Span, - ) { - if let Some(var) = self.lookup_variable(ctx, body, name, meta) { + ) -> Result<()> { + if let Some(var) = self.lookup_variable(ctx, name, meta)? { if let Some(index) = var.entry_arg { if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) = self.entry_args[index].binding @@ -251,19 +248,19 @@ impl Frontend { } } } + Ok(()) } pub(crate) fn field_selection( &mut self, ctx: &mut Context, pos: ExprPos, - body: &mut Block, expression: Handle, name: &str, meta: Span, ) -> Result> { - let (ty, is_pointer) = match *self.resolve_type(ctx, expression, meta)? { - TypeInner::Pointer { base, .. } => (&self.module.types[base].inner, true), + let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? { + TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true), ref ty => (ty, false), }; match *ty { @@ -281,12 +278,11 @@ impl Frontend { index: index as u32, }, meta, - body, - ); + )?; Ok(match pos { ExprPos::Rhs if is_pointer => { - ctx.add_expression(Expression::Load { pointer }, meta, body) + ctx.add_expression(Expression::Load { pointer }, meta)? } _ => pointer, }) @@ -358,19 +354,17 @@ impl Frontend { pointer: expression, }, meta, - body, - ); + )?; } _ => {} }; - return Ok(ctx.add_expression( + return ctx.add_expression( Expression::AccessIndex { base: expression, index: pattern[0].index(), }, meta, - body, - )); + ); } 2 => VectorSize::Bi, 3 => VectorSize::Tri, @@ -396,8 +390,7 @@ impl Frontend { pointer: expression, }, meta, - body, - ); + )?; } Ok(ctx.add_expression( @@ -407,8 +400,7 @@ impl Frontend { pattern, }, meta, - body, - )) + )?) } else { Err(Error { kind: ErrorKind::SemanticError( @@ -430,7 +422,6 @@ impl Frontend { pub(crate) fn add_global_var( &mut self, ctx: &mut Context, - body: &mut Block, VarDeclaration { qualifiers, mut ty, @@ -449,7 +440,7 @@ impl Frontend { .uint_layout_qualifier("location", &mut self.errors) .unwrap_or(0); let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| { - let kind = self.module.types[ty].inner.scalar_kind()?; + let kind = ctx.module.types[ty].inner.scalar_kind()?; Some(match kind { ScalarKind::Float => Interpolation::Perspective, _ => Interpolation::Flat, @@ -457,7 +448,7 @@ impl Frontend { }); let sampling = qualifiers.sampling.take().map(|(s, _)| s); - let handle = self.module.global_variables.append( + let handle = ctx.module.global_variables.append( GlobalVariable { name: name.clone(), space: AddressSpace::Private, @@ -501,7 +492,7 @@ impl Frontend { ty, init, }; - let handle = self.module.constants.fetch_or_append(constant, meta); + let handle = ctx.module.constants.fetch_or_append(constant, meta); let lookup = GlobalLookup { kind: GlobalLookupKind::Constant(handle, ty), @@ -518,7 +509,7 @@ impl Frontend { *access = allowed_access; } } - AddressSpace::Uniform => match self.module.types[ty].inner { + AddressSpace::Uniform => match ctx.module.types[ty].inner { TypeInner::Image { class, dim, @@ -547,7 +538,7 @@ impl Frontend { _ => unreachable!(), } - ty = self.module.types.insert( + ty = ctx.module.types.insert( Type { name: None, inner: TypeInner::Image { @@ -593,7 +584,7 @@ impl Frontend { _ => None, }; - let handle = self.module.global_variables.append( + let handle = ctx.module.global_variables.append( GlobalVariable { name: name.clone(), space, @@ -615,7 +606,7 @@ impl Frontend { }; if let Some(name) = name { - ctx.add_global(self, &name, lookup, body); + ctx.add_global(&name, lookup)?; self.global_variables.push((name, lookup)); } @@ -628,7 +619,6 @@ impl Frontend { pub(crate) fn add_local_var( &mut self, ctx: &mut Context, - body: &mut Block, decl: VarDeclaration, ) -> Result> { let storage = decl.qualifiers.storage; @@ -652,7 +642,7 @@ impl Frontend { }, decl.meta, ); - let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta, body); + let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?; if let Some(name) = decl.name { let maybe_var = ctx.add_local_var(name.clone(), expr, mutable); From 6d782e9bc05ceb0c7afcc1a4f9227572e307de4f Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 2 Aug 2023 11:59:09 -0700 Subject: [PATCH 3/5] [glsl-in] Document Frontend::add_entry_point. --- src/front/glsl/functions.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 0614943874..c0a208394f 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1304,6 +1304,36 @@ impl Frontend { Ok(()) } + /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function. + /// + /// We compile the GLSL `main` function as an ordinary Naga [`Function`]. + /// This function synthesizes a Naga [`EntryPoint`] to call that. + /// + /// Each GLSL input and output variable (including builtins) becomes a Naga + /// [`GlobalVariable`]s in the [`Private`] address space, which `main` can + /// access in the usual way. + /// + /// The `EntryPoint` we synthesize here has an argument for each GLSL input + /// variable, and returns a struct with a member for each GLSL output + /// variable. The entry point contains code to: + /// + /// - copy its arguments into the Naga globals representing the GLSL input + /// variables, + /// + /// - call the Naga `Function` representing the GLSL `main` function, and then + /// + /// - build its return value from whatever values the GLSL `main` left in + /// the Naga globals representing GLSL `output` variables. + /// + /// Upon entry, [`ctx.body`] should contain code, accumulated by prior calls + /// to [`ParsingContext::parse_external_declaration`][pxd], to initialize + /// private global variables as needed. This code gets spliced into the + /// entry point before the call to `main`. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + /// [`Private`]: crate::AddressSpace::Private + /// [`ctx.body`]: Context::body + /// [pxd]: super::ParsingContext::parse_external_declaration pub(crate) fn add_entry_point( &mut self, function: Handle, From 3b3efd7053560a84d94ead0b31795bed9b43e52b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Fri, 1 Sep 2023 12:43:43 -0700 Subject: [PATCH 4/5] [glsl-in] Doc fix. --- src/front/glsl/parser/functions.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/front/glsl/parser/functions.rs b/src/front/glsl/parser/functions.rs index 358bc52e76..200afc78fd 100644 --- a/src/front/glsl/parser/functions.rs +++ b/src/front/glsl/parser/functions.rs @@ -195,6 +195,11 @@ impl<'source> ParsingContext<'source> { match ctx.module.const_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { + // This unchecked cast isn't good, but since + // we only reach this code when the selector + // is unsigned but the case label is signed, + // verification will reject the module + // anyway (which also matches GLSL's rules). true => crate::SwitchValue::U32(value as u32), false => crate::SwitchValue::I32(value), }, From 6429a67c03c1ad589a7b39e9f528a4f547984ed3 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 5 Sep 2023 13:18:40 -0700 Subject: [PATCH 5/5] [glsl-in] Move `arg_type_walker` method from Frontend to Context. --- src/front/glsl/functions.rs | 234 ++++++++++++++++++------------------ 1 file changed, 117 insertions(+), 117 deletions(-) diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index c0a208394f..fe96d2c4eb 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1191,119 +1191,6 @@ impl Frontend { }); } - /// Helper function for building the input/output interface of the entry point - /// - /// Calls `f` with the data of the entry point argument, flattening composite types - /// recursively - /// - /// The passed arguments to the callback are: - /// - The ctx - /// - The name - /// - The pointer expression to the global storage - /// - The handle to the type of the entry point argument - /// - The binding of the entry point argument - fn arg_type_walker( - ctx: &mut Context, - name: Option, - binding: crate::Binding, - pointer: Handle, - ty: Handle, - f: &mut impl FnMut( - &mut Context, - Option, - Handle, - Handle, - crate::Binding, - ), - ) -> Result<()> { - match ctx.module.types[ty].inner { - // TODO: Better error reporting - // right now we just don't walk the array if the size isn't known at - // compile time and let validation catch it - TypeInner::Array { - base, - size: crate::ArraySize::Constant(size), - .. - } => { - let mut location = match binding { - crate::Binding::Location { location, .. } => location, - crate::Binding::BuiltIn(_) => return Ok(()), - }; - - let interpolation = - ctx.module.types[base] - .inner - .scalar_kind() - .map(|kind| match kind { - ScalarKind::Float => crate::Interpolation::Perspective, - _ => crate::Interpolation::Flat, - }); - - for index in 0..size.get() { - let member_pointer = ctx.add_expression( - Expression::AccessIndex { - base: pointer, - index, - }, - crate::Span::default(), - )?; - - let binding = crate::Binding::Location { - location, - interpolation, - sampling: None, - second_blend_source: false, - }; - location += 1; - - Self::arg_type_walker(ctx, name.clone(), binding, member_pointer, base, f)? - } - } - TypeInner::Struct { ref members, .. } => { - let mut location = match binding { - crate::Binding::Location { location, .. } => location, - crate::Binding::BuiltIn(_) => return Ok(()), - }; - - for (i, member) in members.clone().into_iter().enumerate() { - let member_pointer = ctx.add_expression( - Expression::AccessIndex { - base: pointer, - index: i as u32, - }, - crate::Span::default(), - )?; - - let binding = match member.binding { - Some(binding) => binding, - None => { - let interpolation = ctx.module.types[member.ty] - .inner - .scalar_kind() - .map(|kind| match kind { - ScalarKind::Float => crate::Interpolation::Perspective, - _ => crate::Interpolation::Flat, - }); - let binding = crate::Binding::Location { - location, - interpolation, - sampling: None, - second_blend_source: false, - }; - location += 1; - binding - } - }; - - Self::arg_type_walker(ctx, member.name, binding, member_pointer, member.ty, f)? - } - } - _ => f(ctx, name, pointer, ty, binding), - } - - Ok(()) - } - /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function. /// /// We compile the GLSL `main` function as an ordinary Naga [`Function`]. @@ -1363,8 +1250,7 @@ impl Frontend { let ty = ctx.module.global_variables[arg.handle].ty; - Self::arg_type_walker( - &mut ctx, + ctx.arg_type_walker( arg.name.clone(), arg.binding.clone(), pointer, @@ -1413,8 +1299,7 @@ impl Frontend { let ty = ctx.module.global_variables[arg.handle].ty; - Self::arg_type_walker( - &mut ctx, + ctx.arg_type_walker( arg.name.clone(), arg.binding.clone(), pointer, @@ -1491,6 +1376,121 @@ impl Frontend { } } +impl Context<'_> { + /// Helper function for building the input/output interface of the entry point + /// + /// Calls `f` with the data of the entry point argument, flattening composite types + /// recursively + /// + /// The passed arguments to the callback are: + /// - The ctx + /// - The name + /// - The pointer expression to the global storage + /// - The handle to the type of the entry point argument + /// - The binding of the entry point argument + fn arg_type_walker( + &mut self, + name: Option, + binding: crate::Binding, + pointer: Handle, + ty: Handle, + f: &mut impl FnMut( + &mut Context, + Option, + Handle, + Handle, + crate::Binding, + ), + ) -> Result<()> { + match self.module.types[ty].inner { + // TODO: Better error reporting + // right now we just don't walk the array if the size isn't known at + // compile time and let validation catch it + TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + let interpolation = + self.module.types[base] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + + for index in 0..size.get() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index, + }, + crate::Span::default(), + )?; + + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + + self.arg_type_walker(name.clone(), binding, member_pointer, base, f)? + } + } + TypeInner::Struct { ref members, .. } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + for (i, member) in members.clone().into_iter().enumerate() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index: i as u32, + }, + crate::Span::default(), + )?; + + let binding = match member.binding { + Some(binding) => binding, + None => { + let interpolation = self.module.types[member.ty] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + binding + } + }; + + self.arg_type_walker(member.name, binding, member_pointer, member.ty, f)? + } + } + _ => f(self, name, pointer, ty, binding), + } + + Ok(()) + } +} + /// Helper enum containing the type of conversion need for a call #[derive(PartialEq, Eq, Clone, Copy, Debug)] enum Conversion {