Skip to content

Commit

Permalink
Meshlet fill cluster buffers rewritten (#15955)
Browse files Browse the repository at this point in the history
# Objective
- Make the meshlet fill cluster buffers pass slightly faster
- Address #15920 for meshlets
- Added PreviousGlobalTransform as a required meshlet component to avoid
extra archetype moves, slightly alleviating
#14681 for meshlets
- Enforce that MeshletPlugin::cluster_buffer_slots is not greater than
2^25 (glitches will occur otherwise). Technically this field controls
post-lod/culling cluster count, and the issue is on pre-lod/culling
cluster count, but it's still valid now, and in the future this will be
more true.

Needs to be merged after #15846
and #15886

## Solution

- Old pass dispatched a thread per cluster, and did a binary search over
the instances to find which instance the cluster belongs to, and what
meshlet index within the instance it is.
- New pass dispatches a workgroup per instance, and has the workgroup
loop over all meshlets in the instance in order to write out the cluster
data.
- Use a push constant instead of arrayLength to fix the linked bug
- Remap 1d->2d dispatch for software raster only if actually needed to
save on spawning excess workgroups

## Testing

- Did you test these changes? If so, how?
- Ran the meshlet example, and an example with 1041 instances of 32217
meshlets per instance. Profiled the second scene with nsight, went from
0.55ms -> 0.40ms. Small savings. We're pretty much VRAM bandwidth bound
at this point.
- How can other people (reviewers) test your changes? Is there anything
specific they need to know?
  - Run the meshlet example

## Changelog (non-meshlets)
- PreviousGlobalTransform now implements the Default trait
  • Loading branch information
JMS55 authored and mockersf committed Oct 24, 2024
1 parent 119e37d commit 33365c5
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 83 deletions.
6 changes: 3 additions & 3 deletions crates/bevy_pbr/src/meshlet/cull_clusters.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
meshlet_software_raster_indirect_args,
meshlet_hardware_raster_indirect_args,
meshlet_raster_clusters,
meshlet_raster_cluster_rightmost_slot,
constants,
MeshletBoundingSphere,
}
#import bevy_render::maths::affine3_to_square
Expand All @@ -32,7 +32,7 @@ fn cull_clusters(
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= arrayLength(&meshlet_cluster_meshlet_ids) { return; }
if cluster_id >= constants.scene_cluster_count { return; }

#ifdef MESHLET_SECOND_CULLING_PASS
if !cluster_is_second_pass_candidate(cluster_id) { return; }
Expand Down Expand Up @@ -138,7 +138,7 @@ fn cull_clusters(
} else {
// Append this cluster to the list for hardware rasterization
buffer_slot = atomicAdd(&meshlet_hardware_raster_indirect_args.instance_count, 1u);
buffer_slot = meshlet_raster_cluster_rightmost_slot - buffer_slot;
buffer_slot = constants.meshlet_raster_cluster_rightmost_slot - buffer_slot;
}
meshlet_raster_clusters[buffer_slot] = cluster_id;
}
Expand Down
56 changes: 31 additions & 25 deletions crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl
Original file line number Diff line number Diff line change
@@ -1,44 +1,50 @@
#import bevy_pbr::meshlet_bindings::{
cluster_count,
meshlet_instance_meshlet_counts_prefix_sum,
scene_instance_count,
meshlet_global_cluster_count,
meshlet_instance_meshlet_counts,
meshlet_instance_meshlet_slice_starts,
meshlet_cluster_instance_ids,
meshlet_cluster_meshlet_ids,
}

/// Writes out instance_id and meshlet_id to the global buffers for each cluster in the scene.

var<workgroup> cluster_slice_start_workgroup: u32;

@compute
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread
@workgroup_size(1024, 1, 1) // 1024 threads per workgroup, 1 instance per workgroup
fn fill_cluster_buffers(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(local_invocation_index) local_invocation_index: u32,
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= cluster_count { return; } // TODO: Could be an arrayLength?

// Binary search to find the instance this cluster belongs to
var left = 0u;
var right = arrayLength(&meshlet_instance_meshlet_counts_prefix_sum) - 1u;
while left <= right {
let mid = (left + right) / 2u;
if meshlet_instance_meshlet_counts_prefix_sum[mid] <= cluster_id {
left = mid + 1u;
} else {
right = mid - 1u;
}
// Calculate the instance ID for this workgroup
var instance_id = workgroup_id.x + (workgroup_id.y * num_workgroups.x);
if instance_id >= scene_instance_count { return; }

let instance_meshlet_count = meshlet_instance_meshlet_counts[instance_id];
let instance_meshlet_slice_start = meshlet_instance_meshlet_slice_starts[instance_id];

// Reserve cluster slots for the instance and broadcast to the workgroup
if local_invocation_index == 0u {
cluster_slice_start_workgroup = atomicAdd(&meshlet_global_cluster_count, instance_meshlet_count);
}
let instance_id = right;
let cluster_slice_start = workgroupUniformLoad(&cluster_slice_start_workgroup);

// Find the meshlet ID for this cluster within the instance's MeshletMesh
let meshlet_id_local = cluster_id - meshlet_instance_meshlet_counts_prefix_sum[instance_id];
// Loop enough times to write out all the meshlets for the instance given that each thread writes 1 meshlet in each iteration
for (var clusters_written = 0u; clusters_written < instance_meshlet_count; clusters_written += 1024u) {
// Calculate meshlet ID within this instance's MeshletMesh to process for this thread
let meshlet_id_local = clusters_written + local_invocation_index;
if meshlet_id_local >= instance_meshlet_count { return; }

// Find the overall meshlet ID in the global meshlet buffer
let meshlet_id = meshlet_id_local + meshlet_instance_meshlet_slice_starts[instance_id];
// Find the overall cluster ID in the global cluster buffer
let cluster_id = cluster_slice_start + meshlet_id_local;

// Write results to buffers
meshlet_cluster_instance_ids[cluster_id] = instance_id;
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
// Find the overall meshlet ID in the global meshlet buffer
let meshlet_id = instance_meshlet_slice_start + meshlet_id_local;

// Write results to buffers
meshlet_cluster_instance_ids[cluster_id] = instance_id;
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
}
}
48 changes: 27 additions & 21 deletions crates/bevy_pbr/src/meshlet/instance_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,46 @@ use bevy_ecs::{
query::Has,
system::{Local, Query, Res, ResMut, Resource, SystemState},
};
use bevy_render::sync_world::MainEntity;
use bevy_render::{render_resource::StorageBuffer, view::RenderLayers, MainWorld};
use bevy_render::{
render_resource::StorageBuffer, sync_world::MainEntity, view::RenderLayers, MainWorld,
};
use bevy_transform::components::GlobalTransform;
use bevy_utils::{HashMap, HashSet};
use core::ops::{DerefMut, Range};

/// Manages data for each entity with a [`MeshletMesh`].
#[derive(Resource)]
pub struct InstanceManager {
/// Amount of clusters in the scene (sum of all meshlet counts across all instances)
/// Amount of instances in the scene.
pub scene_instance_count: u32,
/// Amount of clusters in the scene.
pub scene_cluster_count: u32,

/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`]
/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`].
pub instances: Vec<(MainEntity, RenderLayers, bool)>,
/// Per-instance [`MeshUniform`]
/// Per-instance [`MeshUniform`].
pub instance_uniforms: StorageBuffer<Vec<MeshUniform>>,
/// Per-instance material ID
/// Per-instance material ID.
pub instance_material_ids: StorageBuffer<Vec<u32>>,
/// Prefix-sum of meshlet counts per instance
pub instance_meshlet_counts_prefix_sum: StorageBuffer<Vec<u32>>,
/// Per-instance index to the start of the instance's slice of the meshlets buffer
/// Per-instance count of meshlets in the instance's [`MeshletMesh`].
pub instance_meshlet_counts: StorageBuffer<Vec<u32>>,
/// Per-instance index to the start of the instance's slice of the meshlets buffer.
pub instance_meshlet_slice_starts: StorageBuffer<Vec<u32>>,
/// Per-view per-instance visibility bit. Used for [`RenderLayers`] and [`NotShadowCaster`] support.
pub view_instance_visibility: EntityHashMap<StorageBuffer<Vec<u32>>>,

/// Next material ID available for a [`Material`]
/// Next material ID available for a [`Material`].
next_material_id: u32,
/// Map of [`Material`] to material ID
/// Map of [`Material`] to material ID.
material_id_lookup: HashMap<UntypedAssetId, u32>,
/// Set of material IDs used in the scene
/// Set of material IDs used in the scene.
material_ids_present_in_scene: HashSet<u32>,
}

impl InstanceManager {
pub fn new() -> Self {
Self {
scene_instance_count: 0,
scene_cluster_count: 0,

instances: Vec::new(),
Expand All @@ -59,9 +63,9 @@ impl InstanceManager {
buffer.set_label(Some("meshlet_instance_material_ids"));
buffer
},
instance_meshlet_counts_prefix_sum: {
instance_meshlet_counts: {
let mut buffer = StorageBuffer::default();
buffer.set_label(Some("meshlet_instance_meshlet_counts_prefix_sum"));
buffer.set_label(Some("meshlet_instance_meshlet_counts"));
buffer
},
instance_meshlet_slice_starts: {
Expand All @@ -80,7 +84,7 @@ impl InstanceManager {
#[allow(clippy::too_many_arguments)]
pub fn add_instance(
&mut self,
instance: Entity,
instance: MainEntity,
meshlets_slice: Range<u32>,
transform: &GlobalTransform,
previous_transform: Option<&PreviousGlobalTransform>,
Expand Down Expand Up @@ -108,20 +112,21 @@ impl InstanceManager {

// Append instance data
self.instances.push((
instance.into(),
instance,
render_layers.cloned().unwrap_or(RenderLayers::default()),
not_shadow_caster,
));
self.instance_uniforms.get_mut().push(mesh_uniform);
self.instance_material_ids.get_mut().push(0);
self.instance_meshlet_counts_prefix_sum
self.instance_meshlet_counts
.get_mut()
.push(self.scene_cluster_count);
.push(meshlets_slice.len() as u32);
self.instance_meshlet_slice_starts
.get_mut()
.push(meshlets_slice.start);

self.scene_cluster_count += meshlets_slice.end - meshlets_slice.start;
self.scene_instance_count += 1;
self.scene_cluster_count += meshlets_slice.len() as u32;
}

/// Get the material ID for a [`crate::Material`].
Expand All @@ -140,12 +145,13 @@ impl InstanceManager {
}

pub fn reset(&mut self, entities: &Entities) {
self.scene_instance_count = 0;
self.scene_cluster_count = 0;

self.instances.clear();
self.instance_uniforms.get_mut().clear();
self.instance_material_ids.get_mut().clear();
self.instance_meshlet_counts_prefix_sum.get_mut().clear();
self.instance_meshlet_counts.get_mut().clear();
self.instance_meshlet_slice_starts.get_mut().clear();
self.view_instance_visibility
.retain(|view_entity, _| entities.contains(*view_entity));
Expand Down Expand Up @@ -227,7 +233,7 @@ pub fn extract_meshlet_mesh_entities(

// Add the instance's data to the instance manager
instance_manager.add_instance(
instance,
instance.into(),
meshlets_slice,
transform,
previous_transform,
Expand Down
16 changes: 9 additions & 7 deletions crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,27 @@ struct DrawIndirectArgs {
const CENTIMETERS_PER_METER = 100.0;

#ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS
var<push_constant> cluster_count: u32;
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts_prefix_sum: array<u32>; // Per entity instance
var<push_constant> scene_instance_count: u32;
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts: array<u32>; // Per entity instance
@group(0) @binding(1) var<storage, read> meshlet_instance_meshlet_slice_starts: array<u32>; // Per entity instance
@group(0) @binding(2) var<storage, read_write> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(3) var<storage, read_write> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(4) var<storage, read_write> meshlet_global_cluster_count: atomic<u32>; // Single object shared between all workgroups
#endif

#ifdef MESHLET_CULLING_PASS
var<push_constant> meshlet_raster_cluster_rightmost_slot: u32;
struct Constants { scene_cluster_count: u32, meshlet_raster_cluster_rightmost_slot: u32 }
var<push_constant> constants: Constants;
@group(0) @binding(0) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(1) var<storage, read> meshlet_bounding_spheres: array<MeshletBoundingSpheres>; // Per meshlet
@group(0) @binding(2) var<storage, read> meshlet_simplification_errors: array<u32>; // Per meshlet
@group(0) @binding(3) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(4) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(5) var<storage, read> meshlet_view_instance_visibility: array<u32>; // 1 bit per entity instance, packed as a bitmask
@group(0) @binding(6) var<storage, read_write> meshlet_second_pass_candidates: array<atomic<u32>>; // 1 bit per cluster , packed as a bitmask
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
@group(0) @binding(10) var depth_pyramid: texture_2d<f32>; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass
@group(0) @binding(11) var<uniform> view: View;
@group(0) @binding(12) var<uniform> previous_view: PreviousViewUniforms;
Expand All @@ -95,7 +97,7 @@ fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool {
@group(0) @binding(3) var<storage, read> meshlet_vertex_positions: array<u32>; // Many per meshlet
@group(0) @binding(4) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(5) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
@group(0) @binding(7) var<storage, read> meshlet_software_raster_cluster_count: u32;
#ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS_OUTPUT
@group(0) @binding(8) var<storage, read_write> meshlet_visibility_buffer: array<atomic<u64>>; // Per pixel
Expand Down
11 changes: 9 additions & 2 deletions crates/bevy_pbr/src/meshlet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use self::{
},
visibility_buffer_raster_node::MeshletVisibilityBufferRasterPassNode,
};
use crate::{graph::NodePbr, Material, MeshMaterial3d};
use crate::{graph::NodePbr, Material, MeshMaterial3d, PreviousGlobalTransform};
use bevy_app::{App, Plugin, PostUpdate};
use bevy_asset::{load_internal_asset, AssetApp, AssetId, Handle};
use bevy_core_pipeline::{
Expand Down Expand Up @@ -129,6 +129,8 @@ pub struct MeshletPlugin {
/// If this number is too low, you'll see rendering artifacts like missing or blinking meshes.
///
/// Each cluster slot costs 4 bytes of VRAM.
///
/// Must not be greater than 2^25.
pub cluster_buffer_slots: u32,
}

Expand All @@ -147,6 +149,11 @@ impl Plugin for MeshletPlugin {
#[cfg(target_endian = "big")]
compile_error!("MeshletPlugin is only supported on little-endian processors.");

if self.cluster_buffer_slots > 2_u32.pow(25) {
error!("MeshletPlugin::cluster_buffer_slots must not be greater than 2^25.");
std::process::exit(1);
}

load_internal_asset!(
app,
MESHLET_BINDINGS_SHADER_HANDLE,
Expand Down Expand Up @@ -293,7 +300,7 @@ impl Plugin for MeshletPlugin {
/// The meshlet mesh equivalent of [`bevy_render::mesh::Mesh3d`].
#[derive(Component, Clone, Debug, Default, Deref, DerefMut, Reflect, PartialEq, Eq, From)]
#[reflect(Component, Default)]
#[require(Transform, Visibility)]
#[require(Transform, PreviousGlobalTransform, Visibility)]
pub struct MeshletMesh3d(pub Handle<MeshletMesh>);

impl From<MeshletMesh3d> for AssetId<MeshletMesh> {
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_pbr/src/meshlet/pipelines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl FromWorld for MeshletPipelines {
layout: vec![cull_layout.clone()],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
range: 0..8,
}],
shader: MESHLET_CULLING_SHADER_HANDLE,
shader_defs: vec![
Expand All @@ -99,7 +99,7 @@ impl FromWorld for MeshletPipelines {
layout: vec![cull_layout],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
range: 0..8,
}],
shader: MESHLET_CULLING_SHADER_HANDLE,
shader_defs: vec![
Expand Down Expand Up @@ -441,7 +441,10 @@ impl FromWorld for MeshletPipelines {
pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("meshlet_remap_1d_to_2d_dispatch_pipeline".into()),
layout: vec![layout],
push_constant_ranges: vec![],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
}],
shader: MESHLET_REMAP_1D_TO_2D_DISPATCH_SHADER_HANDLE,
shader_defs: vec![],
entry_point: "remap_dispatch".into(),
Expand Down
9 changes: 6 additions & 3 deletions crates/bevy_pbr/src/meshlet/remap_1d_to_2d_dispatch.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ struct DispatchIndirectArgs {

@group(0) @binding(0) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs;
@group(0) @binding(1) var<storage, read_write> meshlet_software_raster_cluster_count: u32;
var<push_constant> max_compute_workgroups_per_dimension: u32;

@compute
@workgroup_size(1, 1, 1)
fn remap_dispatch() {
meshlet_software_raster_cluster_count = meshlet_software_raster_indirect_args.x;

let n = u32(ceil(sqrt(f32(meshlet_software_raster_indirect_args.x))));
meshlet_software_raster_indirect_args.x = n;
meshlet_software_raster_indirect_args.y = n;
if meshlet_software_raster_cluster_count > max_compute_workgroups_per_dimension {
let n = u32(ceil(sqrt(f32(meshlet_software_raster_cluster_count))));
meshlet_software_raster_indirect_args.x = n;
meshlet_software_raster_indirect_args.y = n;
}
}
Loading

0 comments on commit 33365c5

Please sign in to comment.