diff --git a/packages/dev/core/src/Compute/computeShader.ts b/packages/dev/core/src/Compute/computeShader.ts index b1c4f3c5a5f..ee1b71eb18a 100644 --- a/packages/dev/core/src/Compute/computeShader.ts +++ b/packages/dev/core/src/Compute/computeShader.ts @@ -230,13 +230,13 @@ export class ComputeShader { * @param name Binding name of the buffer * @param buffer Buffer to bind */ - public setUniformBuffer(name: string, buffer: UniformBuffer): void { + public setUniformBuffer(name: string, buffer: UniformBuffer | DataBuffer): void { const current = this._bindings[name]; this._contextIsDirty ||= !current || current.object !== buffer; this._bindings[name] = { - type: ComputeBindingType.UniformBuffer, + type: ComputeShader._BufferIsDataBuffer(buffer) ? ComputeBindingType.DataBuffer : ComputeBindingType.UniformBuffer, object: buffer, indexInGroupEntries: current?.indexInGroupEntries, }; @@ -247,13 +247,13 @@ export class ComputeShader { * @param name Binding name of the buffer * @param buffer Buffer to bind */ - public setStorageBuffer(name: string, buffer: StorageBuffer): void { + public setStorageBuffer(name: string, buffer: StorageBuffer | DataBuffer): void { const current = this._bindings[name]; this._contextIsDirty ||= !current || current.object !== buffer; this._bindings[name] = { - type: ComputeBindingType.StorageBuffer, + type: ComputeShader._BufferIsDataBuffer(buffer) ? ComputeBindingType.DataBuffer : ComputeBindingType.StorageBuffer, object: buffer, indexInGroupEntries: current?.indexInGroupEntries, }; @@ -498,6 +498,10 @@ export class ComputeShader { return compute; } + + protected static _BufferIsDataBuffer(buffer: UniformBuffer | StorageBuffer | DataBuffer): buffer is DataBuffer { + return (buffer as DataBuffer).underlyingResource !== undefined; + } } RegisterClass("BABYLON.ComputeShader", ComputeShader); diff --git a/packages/dev/core/src/Engines/Extensions/engine.computeShader.ts b/packages/dev/core/src/Engines/Extensions/engine.computeShader.ts index b5e703af3d7..9f1aed59ac9 100644 --- a/packages/dev/core/src/Engines/Extensions/engine.computeShader.ts +++ b/packages/dev/core/src/Engines/Extensions/engine.computeShader.ts @@ -27,6 +27,7 @@ export enum ComputeBindingType { TextureWithoutSampler = 4, Sampler = 5, ExternalTexture = 6, + DataBuffer = 7, } /** @internal */ diff --git a/packages/dev/core/src/Engines/WebGPU/webgpuComputeContext.ts b/packages/dev/core/src/Engines/WebGPU/webgpuComputeContext.ts index 8549ecb9be6..2c1698ab7ff 100644 --- a/packages/dev/core/src/Engines/WebGPU/webgpuComputeContext.ts +++ b/packages/dev/core/src/Engines/WebGPU/webgpuComputeContext.ts @@ -1,3 +1,4 @@ +import type { DataBuffer } from "core/Buffers/dataBuffer"; import type { StorageBuffer } from "../../Buffers/storageBuffer"; import type { IComputeContext } from "../../Compute/IComputeContext"; import type { BaseTexture } from "../../Materials/Textures/baseTexture"; @@ -116,9 +117,14 @@ export class WebGPUComputeContext implements IComputeContext { } case ComputeBindingType.UniformBuffer: - case ComputeBindingType.StorageBuffer: { - const buffer = type === ComputeBindingType.UniformBuffer ? (object as UniformBuffer) : (object as StorageBuffer); - const dataBuffer = buffer.getBuffer()!; + case ComputeBindingType.StorageBuffer: + case ComputeBindingType.DataBuffer: { + const dataBuffer = + type === ComputeBindingType.DataBuffer + ? (object as DataBuffer) + : type === ComputeBindingType.UniformBuffer + ? (object as UniformBuffer).getBuffer()! + : (object as StorageBuffer).getBuffer()!; const webgpuBuffer = dataBuffer.underlyingResource as GPUBuffer; if (indexInGroupEntries !== undefined && bindGroupEntriesExist) { (entries[indexInGroupEntries].resource as GPUBufferBinding).buffer = webgpuBuffer;