Skip to content

Commit

Permalink
[shaders] Expose a binding index set per target language
Browse files Browse the repository at this point in the history
vello_shaders exposes a list of resource binding types for each
compute stage. This list contains only the bindings that are
actually reachable (obeying WebGPU WGSL validation rules). The
list also contains only the types and no index information is
provided; instead the elements are provided in their order of
declaration in the shader source and the client is expected to
infer the index from the declaration order.

This leads to some issues that are unique to each target language:

1. When targeting WGSL, this scheme is fragile if the shader source
   doesn't keep the index declarations contiguous. Any gaps in the
   indices will lead to incorrect results. This can happen due to a
   programmer error. This can also happen due to bindings that are
   unreachable by the entry point function, for example as a result
   of some refactoring.

2. When targeting MSL, the client needs to know the vello_shader's
   crate's internal binding re-assignment scheme and implement the
   same logic on their end to compute the indices. This stems from
   the fact that WGSL bindings *where indices are scoped to a
   bind group) and MSL bindings (where indices are scoped separately
   by resource type) get assigned from different ranges and
   numerically not the same.

To resolve this, the crate now exposes a list of indices alongside each
backend source.

Resolves #330
  • Loading branch information
armansito committed Sep 14, 2023
1 parent 0d5a926 commit 8816f27
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 27 deletions.
50 changes: 45 additions & 5 deletions crates/shaders/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,58 @@ fn write_shaders(
wg_bufs
)?;
if cfg!(feature = "wgsl") {
writeln!(buf, " wgsl: Cow::Borrowed({:?}),", info.source)?;
}
if cfg!(feature = "msl") {
let indices = info
.bindings
.iter()
.map(|binding| binding.location.1)
.collect::<Vec<_>>();
writeln!(buf, " wgsl: WgslSource {{")?;
writeln!(
buf,
" msl: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
" code: Cow::Borrowed({:?}),",
info.source
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
}
if cfg!(feature = "msl") {
write_msl(buf, info)?;
}
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}

#[cfg(not(feature = "msl"))]
fn write_msl(_: &mut String, _: &ShaderInfo) -> Result<(), std::fmt::Error> {
Ok(())
}

#[cfg(feature = "msl")]
fn write_msl(buf: &mut String, info: &ShaderInfo) -> Result<(), std::fmt::Error> {
let mut index_iter = compile::msl::BindingIndexIterator::new();
let indices = info
.bindings
.iter()
.map(|binding| index_iter.next(binding.ty))
.collect::<Vec<_>>();
writeln!(buf, " msl: MslSource {{")?;
writeln!(
buf,
" code: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/shaders/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use {
pub mod permutations;
pub mod preprocess;

#[cfg(feature = "msl")]
pub mod msl;

use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
Expand Down
70 changes: 50 additions & 20 deletions crates/shaders/src/compile/msl.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

use naga::back::msl;
use naga::back::msl as naga_msl;
use {
super::{BindType, ShaderInfo},
crate::types::msl::BindingIndex,
};

use super::{BindType, ShaderInfo};

pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::EntryPointResourceMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
pub fn translate(shader: &ShaderInfo) -> Result<String, naga_msl::Error> {
let mut map = naga_msl::EntryPointResourceMap::default();
let mut idx_iter = BindingIndexIterator::new();
let mut binding_map = naga_msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
let mut target = naga_msl::BindTarget::default();
match idx_iter.next(resource.ty) {
BindingIndex::Buffer(idx) => {
target.buffer = Some(idx);
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
BindingIndex::Texture(idx) => {
target.texture = Some(idx);
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.insert(
"main".to_string(),
msl::EntryPointResources {
naga_msl::EntryPointResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
},
);
let options = msl::Options {
let options = naga_msl::Options {
lang_version: (2, 0),
per_entry_point_map: map,
inline_samplers: vec![],
Expand All @@ -46,11 +45,42 @@ pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: false,
};
let (source, _) = msl::write_string(
let (source, _) = naga_msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
&naga_msl::PipelineOptions::default(),
)?;
Ok(source)
}

pub struct BindingIndexIterator {
buffer_idx: u8,
tex_idx: u8,
}

impl BindingIndexIterator {
pub fn new() -> Self {
Self {
buffer_idx: 0,
tex_idx: 0,
}
}

pub fn next(&mut self, ty: BindType) -> BindingIndex {
match ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
let idx = self.buffer_idx;
self.buffer_idx += 1;
assert!(self.buffer_idx > 0);
BindingIndex::Buffer(idx)
}
BindType::Image | BindType::ImageRead => {
let idx = self.tex_idx;
self.tex_idx += 1;
assert!(self.tex_idx > 0);
BindingIndex::Texture(idx)
}
}
}
}
74 changes: 72 additions & 2 deletions crates/shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ pub mod compile;

pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};

#[cfg(feature = "msl")]
pub use types::msl;

use std::borrow::Cow;

#[derive(Clone, Debug)]
Expand All @@ -18,10 +21,77 @@ pub struct ComputeShader<'a> {
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,

#[cfg(feature = "wgsl")]
pub wgsl: Cow<'a, str>,
pub wgsl: WgslSource<'a>,

#[cfg(feature = "msl")]
pub msl: Cow<'a, str>,
pub msl: MslSource<'a>,
}

#[cfg(feature = "wgsl")]
#[derive(Clone, Debug)]
pub struct WgslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In WGSL, each index directly corresponds to the value of the corresponding
/// `@binding(..)` declaration in the shader source. The bind group index (i.e.
/// value of `@group(..)`) is always 0.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// wgsl: WgslSource {
/// code: ...,
/// binding_indices: [1, 2],
/// },
pub binding_indices: Cow<'a, [u8]>,
}

#[cfg(feature = "msl")]
#[derive(Clone, Debug)]
pub struct MslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In MSL, each index is scoped to the index range of the corresponding resource type.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// msl: MslSource {
/// code: ...,
/// // In MSL these would be declared as `[[buffer(0)]]` and `[[texture(0)]]`.
/// binding_indices: [msl::BindingIndex::Buffer(0), msl::BindingIndex::Texture(0)],
/// },
pub binding_indices: Cow<'a, [msl::BindingIndex]>,
}

pub trait PipelineHost {
Expand Down
20 changes: 20 additions & 0 deletions crates/shaders/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,23 @@ pub struct WorkgroupBufferInfo {
/// The order in which the workgroup variable is declared in the shader module.
pub index: u32,
}

#[cfg(feature = "msl")]
pub mod msl {
use std::fmt;

#[derive(Clone)]
pub enum BindingIndex {
Buffer(u8),
Texture(u8),
}

impl fmt::Debug for BindingIndex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Buffer(i) => write!(f, "msl::BindingIndex::Buffer({})", i),
Self::Texture(i) => write!(f, "msl::BindingIndex::Texture({})", i),
}
}
}
}

0 comments on commit 8816f27

Please sign in to comment.