Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shaders] Expose a binding index set per target language #362

Merged
merged 2 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions crates/shaders/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,58 @@ fn write_shaders(
wg_bufs
)?;
if cfg!(feature = "wgsl") {
writeln!(buf, " wgsl: Cow::Borrowed({:?}),", info.source)?;
}
if cfg!(feature = "msl") {
let indices = info
.bindings
.iter()
.map(|binding| binding.location.1)
.collect::<Vec<_>>();
writeln!(buf, " wgsl: WgslSource {{")?;
writeln!(
buf,
" msl: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
" code: Cow::Borrowed({:?}),",
info.source
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
}
if cfg!(feature = "msl") {
write_msl(buf, info)?;
}
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}

#[cfg(not(feature = "msl"))]
fn write_msl(_: &mut String, _: &ShaderInfo) -> Result<(), std::fmt::Error> {
Ok(())
}

#[cfg(feature = "msl")]
fn write_msl(buf: &mut String, info: &ShaderInfo) -> Result<(), std::fmt::Error> {
let mut index_iter = compile::msl::BindingIndexIterator::default();
let indices = info
.bindings
.iter()
.map(|binding| index_iter.next(binding.ty))
.collect::<Vec<_>>();
writeln!(buf, " msl: MslSource {{")?;
writeln!(
buf,
" code: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/shaders/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use {
pub mod permutations;
pub mod preprocess;

#[cfg(feature = "msl")]
pub mod msl;

use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
Expand Down
64 changes: 44 additions & 20 deletions crates/shaders/src/compile/msl.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

use naga::back::msl;
use naga::back::msl as naga_msl;
use {
super::{BindType, ShaderInfo},
crate::types::msl::BindingIndex,
};

use super::{BindType, ShaderInfo};

pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::EntryPointResourceMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
pub fn translate(shader: &ShaderInfo) -> Result<String, naga_msl::Error> {
let mut map = naga_msl::EntryPointResourceMap::default();
let mut idx_iter = BindingIndexIterator::default();
let mut binding_map = naga_msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
let mut target = naga_msl::BindTarget::default();
match idx_iter.next(resource.ty) {
BindingIndex::Buffer(idx) => {
target.buffer = Some(idx);
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
BindingIndex::Texture(idx) => {
target.texture = Some(idx);
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.insert(
"main".to_string(),
msl::EntryPointResources {
naga_msl::EntryPointResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
},
);
let options = msl::Options {
let options = naga_msl::Options {
lang_version: (2, 0),
per_entry_point_map: map,
inline_samplers: vec![],
Expand All @@ -46,11 +45,36 @@ pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: false,
};
let (source, _) = msl::write_string(
let (source, _) = naga_msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
&naga_msl::PipelineOptions::default(),
)?;
Ok(source)
}

#[derive(Default)]
pub struct BindingIndexIterator {
buffer_idx: u8,
tex_idx: u8,
}

impl BindingIndexIterator {
pub fn next(&mut self, ty: BindType) -> BindingIndex {
match ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
let idx = self.buffer_idx;
self.buffer_idx += 1;
assert!(self.buffer_idx > 0);
BindingIndex::Buffer(idx)
}
BindType::Image | BindType::ImageRead => {
let idx = self.tex_idx;
self.tex_idx += 1;
assert!(self.tex_idx > 0);
BindingIndex::Texture(idx)
}
}
}
}
74 changes: 72 additions & 2 deletions crates/shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ pub mod compile;

pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};

#[cfg(feature = "msl")]
pub use types::msl;

use std::borrow::Cow;

#[derive(Clone, Debug)]
Expand All @@ -18,10 +21,77 @@ pub struct ComputeShader<'a> {
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,

#[cfg(feature = "wgsl")]
pub wgsl: Cow<'a, str>,
pub wgsl: WgslSource<'a>,

#[cfg(feature = "msl")]
pub msl: Cow<'a, str>,
pub msl: MslSource<'a>,
}

#[cfg(feature = "wgsl")]
#[derive(Clone, Debug)]
pub struct WgslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In WGSL, each index directly corresponds to the value of the corresponding
/// `@binding(..)` declaration in the shader source. The bind group index (i.e.
/// value of `@group(..)`) is always 0.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// wgsl: WgslSource {
/// code: ...,
/// binding_indices: [1, 2],
/// },
pub binding_indices: Cow<'a, [u8]>,
}

#[cfg(feature = "msl")]
#[derive(Clone, Debug)]
pub struct MslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In MSL, each index is scoped to the index range of the corresponding resource type.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// msl: MslSource {
/// code: ...,
/// // In MSL these would be declared as `[[buffer(0)]]` and `[[texture(0)]]`.
/// binding_indices: [msl::BindingIndex::Buffer(0), msl::BindingIndex::Texture(0)],
/// },
pub binding_indices: Cow<'a, [msl::BindingIndex]>,
}

pub trait PipelineHost {
Expand Down
20 changes: 20 additions & 0 deletions crates/shaders/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,23 @@ pub struct WorkgroupBufferInfo {
/// The order in which the workgroup variable is declared in the shader module.
pub index: u32,
}

#[cfg(feature = "msl")]
pub mod msl {
use std::fmt;

#[derive(Clone)]
pub enum BindingIndex {
Buffer(u8),
Texture(u8),
}

impl fmt::Debug for BindingIndex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Buffer(i) => write!(f, "msl::BindingIndex::Buffer({})", i),
Self::Texture(i) => write!(f, "msl::BindingIndex::Texture({})", i),
}
}
}
}
10 changes: 3 additions & 7 deletions shader/clip_reduce.wgsl
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense

#import config
#import bbox
#import clip

@group(0) @binding(0)
var<uniform> config: Config;

@group(0) @binding(1)
var<storage> clip_inp: array<ClipInp>;

@group(0) @binding(2)
@group(0) @binding(1)
var<storage> path_bboxes: array<PathBbox>;

@group(0) @binding(3)
@group(0) @binding(2)
var<storage, read_write> reduced: array<Bic>;

@group(0) @binding(4)
@group(0) @binding(3)
var<storage, read_write> clip_out: array<ClipEl>;

let WG_SIZE = 256u;
Expand Down
13 changes: 5 additions & 8 deletions shader/fine.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ struct Tile {
var<uniform> config: Config;

@group(0) @binding(1)
var<storage> tiles: array<Tile>;

@group(0) @binding(2)
var<storage> segments: array<Segment>;

#ifdef full
Expand All @@ -28,19 +25,19 @@ var<storage> segments: array<Segment>;

let GRADIENT_WIDTH = 512;

@group(0) @binding(2)
var<storage> ptcl: array<u32>;

@group(0) @binding(3)
var output: texture_storage_2d<rgba8unorm, write>;
var<storage> info: array<u32>;

@group(0) @binding(4)
var<storage> ptcl: array<u32>;
var output: texture_storage_2d<rgba8unorm, write>;

@group(0) @binding(5)
var gradients: texture_2d<f32>;

@group(0) @binding(6)
var<storage> info: array<u32>;

@group(0) @binding(7)
var image_atlas: texture_2d<f32>;

fn read_fill(cmd_ix: u32) -> CmdFill {
Expand Down
11 changes: 4 additions & 7 deletions shader/path_coarse_full.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ var<uniform> config: Config;
var<storage> scene: array<u32>;

@group(0) @binding(2)
var<storage> tag_monoids: array<TagMonoid>;

@group(0) @binding(3)
var<storage> cubics: array<Cubic>;

@group(0) @binding(4)
@group(0) @binding(3)
var<storage> paths: array<Path>;

// We don't get this from import as it's the atomic version
Expand All @@ -30,13 +27,13 @@ struct AtomicTile {
segments: atomic<u32>,
}

@group(0) @binding(5)
@group(0) @binding(4)
var<storage, read_write> bump: BumpAllocators;

@group(0) @binding(6)
@group(0) @binding(5)
var<storage, read_write> tiles: array<AtomicTile>;

@group(0) @binding(7)
@group(0) @binding(6)
var<storage, read_write> segments: array<Segment>;

struct SubdivResult {
Expand Down
Loading