Skip to content

Commit

Permalink
[naga wgsl-in] Automatic conversions for local var initializers.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy authored and teoxoy committed Dec 6, 2023
1 parent 1676ee0 commit f470103
Show file tree
Hide file tree
Showing 6 changed files with 518 additions and 62 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
- Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673).
- Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707).

- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)).
- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743), [#4755](https://github.com/gfx-rs/wgpu/pull/4755)).

Abstract types make numeric literals easier to use, by
automatically converting literals and other constant expressions
Expand All @@ -121,9 +121,10 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
Even though the literals are abstract integers, Naga recognizes
that it is safe and necessary to convert them to `f32` values in
order to build the vector. You can also use abstract values as
initializers for global constants, like this:
initializers for global constants and global and local variables,
like this:

const unit_x: vec2<f32> = vec2(1, 0);
var unit_x: vec2<f32> = vec2(1, 0);

The literals `1` and `0` are abstract integers, and the expression
`vec2(1, 0)` is an abstract vector. However, Naga recognizes that
Expand Down
70 changes: 37 additions & 33 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1162,45 +1162,49 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(());
}
ast::LocalDecl::Var(ref v) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let initializer = match v.init {
Some(init) => Some(
self.expression(init, &mut ctx.as_expression(block, &mut emitter))?,
),
None => None,
};

let explicit_ty =
v.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
.transpose()?;

let ty = match (explicit_ty, initializer) {
(Some(explicit), Some(initializer)) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = resolve_inner!(ctx, initializer);
if !ctx.module.types[explicit]
.inner
.equivalent(initializer_ty, &ctx.module.types)
{
let gctx = &ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let mut ectx = ctx.as_expression(block, &mut emitter);

let ty;
let initializer;
match (v.init, explicit_ty) {
(Some(init), Some(explicit_ty)) => {
let init = self.expression_for_abstract(init, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
let init = ectx
.try_automatic_conversions(init, &ty_res, v.name.span)
.map_err(|error| match error {
Error::AutoConversion {
dest_span: _,
dest_type,
source_span: _,
source_type,
} => Error::InitializationTypeMismatch {
name: v.name.span,
expected: explicit.to_wgsl(gctx),
got: initializer_ty.to_wgsl(gctx),
});
}
explicit
expected: dest_type,
got: source_type,
},
other => other,
})?;
ty = explicit_ty;
initializer = Some(init);
}
(Some(explicit), None) => explicit,
(None, Some(initializer)) => ctx
.as_expression(block, &mut emitter)
.register_type(initializer)?,
(None, None) => {
return Err(Error::MissingType(v.name.span));
(Some(init), None) => {
let concretized = self.expression(init, &mut ectx)?;
ty = ectx.register_type(concretized)?;
initializer = Some(concretized);
}
};
(None, Some(explicit_ty)) => {
ty = explicit_ty;
initializer = None;
}
(None, None) => return Err(Error::MissingType(v.name.span)),
}

let (const_initializer, initializer) = {
match initializer {
Expand Down
76 changes: 76 additions & 0 deletions naga/tests/in/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,79 @@ var<private> xafpaiai: array<i32, 2> = array(1, 2);
var<private> xafpaiaf: array<f32, 2> = array(1, 2.0);
var<private> xafpafai: array<f32, 2> = array(1.0, 2);
var<private> xafpafaf: array<f32, 2> = array(1.0, 2.0);

fn all_constant_arguments() {
var xvipaiai: vec2<i32> = vec2(42, 43);
var xvupaiai: vec2<u32> = vec2(44, 45);
var xvfpaiai: vec2<f32> = vec2(46, 47);

var xvupuai: vec2<u32> = vec2(42u, 43);
var xvupaiu: vec2<u32> = vec2(42, 43u);

var xvuuai: vec2<u32> = vec2<u32>(42u, 43);
var xvuaiu: vec2<u32> = vec2<u32>(42, 43u);

var xmfpaiaiaiai: mat2x2<f32> = mat2x2(1, 2, 3, 4);
var xmfpafaiaiai: mat2x2<f32> = mat2x2(1.0, 2, 3, 4);
var xmfpaiafaiai: mat2x2<f32> = mat2x2(1, 2.0, 3, 4);
var xmfpaiaiafai: mat2x2<f32> = mat2x2(1, 2, 3.0, 4);
var xmfpaiaiaiaf: mat2x2<f32> = mat2x2(1, 2, 3, 4.0);

var xmfp_faiaiai: mat2x2<f32> = mat2x2(1.0f, 2, 3, 4);
var xmfpai_faiai: mat2x2<f32> = mat2x2(1, 2.0f, 3, 4);
var xmfpaiai_fai: mat2x2<f32> = mat2x2(1, 2, 3.0f, 4);
var xmfpaiaiai_f: mat2x2<f32> = mat2x2(1, 2, 3, 4.0f);

var xvispai: vec2<i32> = vec2(1);
var xvfspaf: vec2<f32> = vec2(1.0);
var xvis_ai: vec2<i32> = vec2<i32>(1);
var xvus_ai: vec2<u32> = vec2<u32>(1);
var xvfs_ai: vec2<f32> = vec2<f32>(1);
var xvfs_af: vec2<f32> = vec2<f32>(1.0);

var xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xaf_faf: array<f32, 2> = array<f32, 2>(1.0f, 2.0);
var xafaf_f: array<f32, 2> = array<f32, 2>(1.0, 2.0f);
var xafaiai: array<f32, 2> = array<f32, 2>(1, 2);
var xai_iai: array<i32, 2> = array<i32, 2>(1i, 2);
var xaiai_i: array<i32, 2> = array<i32, 2>(1, 2i);

// Ideally these would infer the var type from the initializer,
// but we don't support that yet.
var xaipaiai: array<i32, 2> = array(1, 2);
var xafpaiai: array<f32, 2> = array(1, 2);
var xafpaiaf: array<f32, 2> = array(1, 2.0);
var xafpafai: array<f32, 2> = array(1.0, 2);
var xafpafaf: array<f32, 2> = array(1.0, 2.0);
}

fn mixed_constant_and_runtime_arguments() {
var u: u32;
var i: i32;
var f: f32;

var xvupuai: vec2<u32> = vec2(u, 43);
var xvupaiu: vec2<u32> = vec2(42, u);

var xvuuai: vec2<u32> = vec2<u32>(u, 43);
var xvuaiu: vec2<u32> = vec2<u32>(42, u);

var xmfp_faiaiai: mat2x2<f32> = mat2x2(f, 2, 3, 4);
var xmfpai_faiai: mat2x2<f32> = mat2x2(1, f, 3, 4);
var xmfpaiai_fai: mat2x2<f32> = mat2x2(1, 2, f, 4);
var xmfpaiaiai_f: mat2x2<f32> = mat2x2(1, 2, 3, f);

var xaf_faf: array<f32, 2> = array<f32, 2>(f, 2.0);
var xafaf_f: array<f32, 2> = array<f32, 2>(1.0, f);
var xaf_fai: array<f32, 2> = array<f32, 2>(f, 2);
var xafai_f: array<f32, 2> = array<f32, 2>(1, f);
var xai_iai: array<i32, 2> = array<i32, 2>(i, 2);
var xaiai_i: array<i32, 2> = array<i32, 2>(1, i);

var xafp_faf: array<f32, 2> = array(f, 2.0);
var xafpaf_f: array<f32, 2> = array(1.0, f);
var xafp_fai: array<f32, 2> = array(f, 2);
var xafpai_f: array<f32, 2> = array(1, f);
var xaip_iai: array<i32, 2> = array(i, 2);
var xaipai_i: array<i32, 2> = array(1, i);
}
105 changes: 105 additions & 0 deletions naga/tests/out/msl/abstract-types-var.msl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,108 @@ struct type_5 {
struct type_7 {
int inner[2];
};

void all_constant_arguments(
) {
metal::int2 xvipaiai = metal::int2(42, 43);
metal::uint2 xvupaiai = metal::uint2(44u, 45u);
metal::float2 xvfpaiai = metal::float2(46.0, 47.0);
metal::uint2 xvupuai = metal::uint2(42u, 43u);
metal::uint2 xvupaiu = metal::uint2(42u, 43u);
metal::uint2 xvuuai = metal::uint2(42u, 43u);
metal::uint2 xvuaiu = metal::uint2(42u, 43u);
metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfp_faiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpai_faiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiai_fai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiai_f = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::int2 xvispai = metal::int2(1);
metal::float2 xvfspaf = metal::float2(1.0);
metal::int2 xvis_ai = metal::int2(1);
metal::uint2 xvus_ai = metal::uint2(1u);
metal::float2 xvfs_ai = metal::float2(1.0);
metal::float2 xvfs_af = metal::float2(1.0);
type_5 xafafaf = type_5 {1.0, 2.0};
type_5 xaf_faf = type_5 {1.0, 2.0};
type_5 xafaf_f = type_5 {1.0, 2.0};
type_5 xafaiai = type_5 {1.0, 2.0};
type_7 xai_iai = type_7 {1, 2};
type_7 xaiai_i = type_7 {1, 2};
type_7 xaipaiai = type_7 {1, 2};
type_5 xafpaiai = type_5 {1.0, 2.0};
type_5 xafpaiaf = type_5 {1.0, 2.0};
type_5 xafpafai = type_5 {1.0, 2.0};
type_5 xafpafaf = type_5 {1.0, 2.0};
}

void mixed_constant_and_runtime_arguments(
) {
uint u = {};
int i = {};
float f = {};
metal::uint2 xvupuai_1 = {};
metal::uint2 xvupaiu_1 = {};
metal::uint2 xvuuai_1 = {};
metal::uint2 xvuaiu_1 = {};
metal::float2x2 xmfp_faiaiai_1 = {};
metal::float2x2 xmfpai_faiai_1 = {};
metal::float2x2 xmfpaiai_fai_1 = {};
metal::float2x2 xmfpaiaiai_f_1 = {};
type_5 xaf_faf_1 = {};
type_5 xafaf_f_1 = {};
type_5 xaf_fai = {};
type_5 xafai_f = {};
type_7 xai_iai_1 = {};
type_7 xaiai_i_1 = {};
type_5 xafp_faf = {};
type_5 xafpaf_f = {};
type_5 xafp_fai = {};
type_5 xafpai_f = {};
type_7 xaip_iai = {};
type_7 xaipai_i = {};
uint _e3 = u;
xvupuai_1 = metal::uint2(_e3, 43u);
uint _e7 = u;
xvupaiu_1 = metal::uint2(42u, _e7);
uint _e11 = u;
xvuuai_1 = metal::uint2(_e11, 43u);
uint _e15 = u;
xvuaiu_1 = metal::uint2(42u, _e15);
float _e19 = f;
xmfp_faiaiai_1 = metal::float2x2(metal::float2(_e19, 2.0), metal::float2(3.0, 4.0));
float _e27 = f;
xmfpai_faiai_1 = metal::float2x2(metal::float2(1.0, _e27), metal::float2(3.0, 4.0));
float _e35 = f;
xmfpaiai_fai_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(_e35, 4.0));
float _e43 = f;
xmfpaiaiai_f_1 = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, _e43));
float _e51 = f;
xaf_faf_1 = type_5 {_e51, 2.0};
float _e55 = f;
xafaf_f_1 = type_5 {1.0, _e55};
float _e59 = f;
xaf_fai = type_5 {_e59, 2.0};
float _e63 = f;
xafai_f = type_5 {1.0, _e63};
int _e67 = i;
xai_iai_1 = type_7 {_e67, 2};
int _e71 = i;
xaiai_i_1 = type_7 {1, _e71};
float _e75 = f;
xafp_faf = type_5 {_e75, 2.0};
float _e79 = f;
xafpaf_f = type_5 {1.0, _e79};
float _e83 = f;
xafp_fai = type_5 {_e83, 2.0};
float _e87 = f;
xafpai_f = type_5 {1.0, _e87};
int _e91 = i;
xaip_iai = type_7 {_e91, 2};
int _e95 = i;
xaipai_i = type_7 {1, _e95};
return;
}
Loading

0 comments on commit f470103

Please sign in to comment.