diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 80543d434b..912713f911 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -6,45 +6,57 @@ use crate::{Handle, Span}; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; -enum ConcreteConstructorHandle { - PartialVector { - size: crate::VectorSize, - }, +/// A cooked form of `ast::ConstructorType` that uses Naga types whenever +/// possible. +enum Constructor { + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. PartialMatrix { columns: crate::VectorSize, rows: crate::VectorSize, }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. PartialArray, - Type(Handle), + + /// A known Naga type. + /// + /// When we match on this type, we need to see the `TypeInner` here, but at + /// the point that we build this value we'll still need mutable access to + /// the module later. To avoid borrowing from the module, the type parameter + /// `T` is `Handle` initially. Then we use `borrow_inner` to produce a + /// version holding a tuple `(Handle, &TypeInner)`. + Type(T), } -impl ConcreteConstructorHandle { - fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> { - match *self { - Self::PartialVector { size } => ConcreteConstructor::PartialVector { size }, - Self::PartialMatrix { columns, rows } => { - ConcreteConstructor::PartialMatrix { columns, rows } +impl Constructor> { + /// Return an equivalent `Constructor` value that includes borrowed + /// `TypeInner` values alongside any type handles. + /// + /// The returned form is more convenient to match on, since the patterns + /// can actually see what the handle refers to. + fn borrow_inner( + self, + module: &crate::Module, + ) -> Constructor<(Handle, &crate::TypeInner)> { + match self { + Constructor::PartialVector { size } => Constructor::PartialVector { size }, + Constructor::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } } - Self::PartialArray => ConcreteConstructor::PartialArray, - Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner), + Constructor::PartialArray => Constructor::PartialArray, + Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), } } } -enum ConcreteConstructor<'a> { - PartialVector { - size: crate::VectorSize, - }, - PartialMatrix { - columns: crate::VectorSize, - rows: crate::VectorSize, - }, - PartialArray, - Type(Handle, &'a crate::TypeInner), -} - -impl ConcreteConstructorHandle { - fn to_error_string(&self, ctx: &mut ExpressionContext) -> String { +impl Constructor<(Handle, &crate::TypeInner)> { + fn to_error_string(&self, ctx: &ExpressionContext) -> String { match *self { Self::PartialVector { size } => { format!("vec{}", size as u32,) @@ -53,7 +65,7 @@ impl ConcreteConstructorHandle { format!("mat{}x{}", columns as u32, rows as u32,) } Self::PartialArray => "array".to_string(), - Self::Type(ty) => ctx.format_type(ty), + Self::Type((handle, _inner)) => ctx.format_type(handle), } } } @@ -146,15 +158,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } }; - let constructor = constructor_h.borrow(ctx.module); + // Even though we computed `constructor` above, wait until now to borrow + // a reference to the `TypeInner`, so that the component-handling code + // above can have mutable access to the type arena. + let constructor = constructor_h.borrow_inner(ctx.module); let expr = match (components, constructor) { // Empty constructor (Components::None, dst_ty) => match dst_ty { - ConcreteConstructor::Type(ty, _) => { - return ctx.append_expression(crate::Expression::ZeroValue(ty), span) + Constructor::Type((result_ty, _)) => { + return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span) + } + Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. } + | Constructor::PartialArray => { + // We have no arguments from which to infer the result type, so + // partial constructors aren't acceptable here. + return Err(Error::TypeNotInferrable(ty_span)); } - _ => return Err(Error::TypeNotInferrable(ty_span)), }, // Scalar constructor & conversion (scalar -> scalar) @@ -164,7 +185,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner: &crate::TypeInner::Scalar { .. }, .. }, - ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }), + Constructor::Type((_, &crate::TypeInner::Scalar { kind, width })), ) => crate::Expression::As { expr: component, kind, @@ -178,14 +199,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Vector { size: dst_size, kind: dst_kind, width: dst_width, }, - ), + )), ) if dst_size == src_size => crate::Expression::As { expr: component, kind: dst_kind, @@ -199,7 +220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, .. }, - ConcreteConstructor::PartialVector { size: dst_size }, + Constructor::PartialVector { size: dst_size }, ) if dst_size == src_size => { // This is a trivial conversion: the sizes match, and a Partial // constructor doesn't specify a scalar type, so nothing can @@ -219,14 +240,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns: dst_columns, rows: dst_rows, width: dst_width, }, - ), + )), ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As { expr: component, kind: crate::ScalarKind::Float, @@ -245,7 +266,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, .. }, - ConcreteConstructor::PartialMatrix { + Constructor::PartialMatrix { columns: dst_columns, rows: dst_rows, }, @@ -263,7 +284,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner: &crate::TypeInner::Scalar { .. }, .. }, - ConcreteConstructor::PartialVector { size }, + Constructor::PartialVector { size }, ) => crate::Expression::Splat { size, value: component, @@ -281,14 +302,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Vector { size, kind: dst_kind, width: dst_width, }, - ), + )), ) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat { size, value: component, @@ -303,7 +324,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { | &crate::TypeInner::Vector { kind, width, .. }, .. }, - ConcreteConstructor::PartialVector { size }, + Constructor::PartialVector { size }, ) | ( Components::Many { @@ -312,7 +333,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, .. }, - ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }), + Constructor::Type((_, &crate::TypeInner::Vector { size, width, kind })), ) => { let inner = crate::TypeInner::Vector { size, kind, width }; let ty = ctx.ensure_type_exists(inner); @@ -326,7 +347,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { first_component_ty_inner: &crate::TypeInner::Scalar { width, .. }, .. }, - ConcreteConstructor::PartialMatrix { columns, rows }, + Constructor::PartialMatrix { columns, rows }, ) | ( Components::Many { @@ -334,14 +355,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { first_component_ty_inner: &crate::TypeInner::Scalar { .. }, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns, rows, width, }, - ), + )), ) => { let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector { width, @@ -377,7 +398,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { first_component_ty_inner: &crate::TypeInner::Vector { width, .. }, .. }, - ConcreteConstructor::PartialMatrix { columns, rows }, + Constructor::PartialMatrix { columns, rows }, ) | ( Components::Many { @@ -385,14 +406,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { first_component_ty_inner: &crate::TypeInner::Vector { .. }, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns, rows, width, }, - ), + )), ) => { let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, @@ -403,7 +424,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } // Array constructor - infer type - (components, ConcreteConstructor::PartialArray) => { + (components, Constructor::PartialArray) => { let components = components.into_components_vec(); let base = ctx.register_type(components[0])?; @@ -426,10 +447,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Array or Struct constructor ( components, - ConcreteConstructor::Type( + Constructor::Type(( ty, &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, - ), + )), ) => { let components = components.into_components_vec(); crate::Expression::Compose { ty, components } @@ -438,19 +459,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // ERRORS // Bad conversion (type cast) - (Components::One { span, ty_inner, .. }, _) => { + (Components::One { span, ty_inner, .. }, constructor) => { let from_type = ctx.format_typeinner(ty_inner); return Err(Error::BadTypeCast { span, from_type, - to_type: constructor_h.to_error_string(ctx), + to_type: constructor.to_error_string(ctx), }); } // Too many parameters for scalar constructor ( Components::Many { spans, .. }, - ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }), + Constructor::Type((_, &crate::TypeInner::Scalar { .. })), ) => { let span = spans[1].until(spans.last().unwrap()); return Err(Error::UnexpectedComponents(span)); @@ -459,12 +480,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Parameters are of the wrong type for vector or matrix constructor ( Components::Many { spans, .. }, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, - ) - | ConcreteConstructor::PartialVector { .. } - | ConcreteConstructor::PartialMatrix { .. }, + )) + | Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. }, ) => { return Err(Error::InvalidConstructorComponentType(spans[0], 0)); } @@ -477,17 +498,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(expr) } - /// Build a Naga IR [`Type`] for `constructor` if there is enough - /// information to do so. + /// Build a [`Constructor`] for a WGSL construction expression. /// - /// For `Partial` variants of [`ast::ConstructorType`], we don't know the - /// component type, so in that case we return the appropriate `Partial` - /// variant of [`ConcreteConstructorHandle`]. + /// If `constructor` conveys enough information to determine which Naga [`Type`] + /// we're actually building (i.e., it's not a partial constructor), then + /// ensure the `Type` exists in [`ctx.module`], and return + /// [`Constructor::Type`]. /// - /// But for the other `ConstructorType` variants, we have everything we need - /// to know to actually produce a Naga IR type. In this case we add to/find - /// in [`ctx.module`] a suitable Naga `Type` and return a - /// [`ConcreteConstructorHandle::Type`] value holding its handle. + /// Otherwise, return the [`Constructor`] partial variant corresponding to + /// `constructor`. /// /// [`Type`]: crate::Type /// [`ctx.module`]: ExpressionContext::module @@ -495,21 +514,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, constructor: &ast::ConstructorType<'source>, ctx: &mut ExpressionContext<'source, '_, 'out>, - ) -> Result> { - let c = match *constructor { + ) -> Result>, Error<'source>> { + let handle = match *constructor { ast::ConstructorType::Scalar { width, kind } => { let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind }); - ConcreteConstructorHandle::Type(ty) - } - ast::ConstructorType::PartialVector { size } => { - ConcreteConstructorHandle::PartialVector { size } + Constructor::Type(ty) } + ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, ast::ConstructorType::Vector { size, kind, width } => { let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::PartialMatrix { rows, columns } => { - ConcreteConstructorHandle::PartialMatrix { rows, columns } + ast::ConstructorType::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } } ast::ConstructorType::Matrix { rows, @@ -521,9 +538,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, width, }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray, + ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { let base = self.resolve_ast_type(base, &mut ctx.as_global())?; let size = self.array_size(size, &mut ctx.as_global())?; @@ -532,11 +549,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let stride = self.layouter[base].to_stride(); let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty), + ast::ConstructorType::Type(ty) => Constructor::Type(ty), }; - Ok(c) + Ok(handle) } }