From c2c070d4bd4839c936be4901b5aade528417b989 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Wed, 12 Oct 2022 17:16:25 +0800 Subject: [PATCH] Use isFlatPatchLayout to determine main header And address comments --- .../depthwise_conv2d_nchw_shared_webgpu.ts | 3 +- .../src/depthwise_conv2d_webgpu.ts | 115 +++++++++--------- tfjs-backend-webgpu/src/from_pixels_webgpu.ts | 1 - .../src/matmul_packed_webgpu.ts | 8 +- tfjs-backend-webgpu/src/scatter_webgpu.ts | 6 +- .../src/transpose_shared_webgpu.ts | 8 +- tfjs-backend-webgpu/src/webgpu_program.ts | 39 +++--- 7 files changed, 96 insertions(+), 84 deletions(-) diff --git a/tfjs-backend-webgpu/src/depthwise_conv2d_nchw_shared_webgpu.ts b/tfjs-backend-webgpu/src/depthwise_conv2d_nchw_shared_webgpu.ts index 2efa5f75127..2db5d08ad5b 100644 --- a/tfjs-backend-webgpu/src/depthwise_conv2d_nchw_shared_webgpu.ts +++ b/tfjs-backend-webgpu/src/depthwise_conv2d_nchw_shared_webgpu.ts @@ -86,7 +86,6 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram { } ${main()} { - let localIndexI = i32(localIndex); let coords = getOutputCoords(); let batch = coords[0]; let xRCCorner = vec2(coords.zw) - uniforms.pad; @@ -112,7 +111,7 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram { } // Load one tile of W into local memory. - var wIndex = localIndexI; + var wIndex = i32(localIndex); ${ filterSize < workGroupSize ? `if (wIndex < ${filterSize})` : diff --git a/tfjs-backend-webgpu/src/depthwise_conv2d_webgpu.ts b/tfjs-backend-webgpu/src/depthwise_conv2d_webgpu.ts index 915ecd7de88..8d79c8e75b8 100644 --- a/tfjs-backend-webgpu/src/depthwise_conv2d_webgpu.ts +++ b/tfjs-backend-webgpu/src/depthwise_conv2d_webgpu.ts @@ -36,6 +36,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { activation: backend_util.Activation; hasPreluActivation: boolean; isChannelsLast: boolean; + size = true; constructor( convInfo: backend_util.Conv2DInfo, addBias = false, @@ -65,70 +66,74 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { 'getX(batch, d1, xR, xC);'; const userCode = ` - ${activationFnSnippet(this.activation, this.hasPreluActivation, false, 4)} - - ${main()} { - let coords = getOutputCoords(); - let batch = coords[0]; - let xRCCorner = vec2(coords.${ + ${ + activationFnSnippet( + this.activation, this.hasPreluActivation, false, 4)} + + ${main('index')} { + if (index < uniforms.size) { + let coords = getOutputCoords(); + let batch = coords[0]; + let xRCCorner = vec2(coords.${ this.isChannelsLast ? 'yz' : 'zw'}) * uniforms.stride - uniforms.pad; - let d2 = coords[${this.isChannelsLast ? 3 : 1}]; - let channelMul = uniforms.wShape[3]; - let d1 = d2 / channelMul; - let q = d2 % channelMul; - - let inputRowStart = xRCCorner.x; - let inputColStart = xRCCorner.y; - let inputRowEnd = inputRowStart + uniforms.filterHeight * - uniforms.dilation[0]; - let inputColEnd = inputColStart + uniforms.filterWidth * - uniforms.dilation[1]; - - // Convolve x(?, ?, d1)|x(d1, ?, ?) with w(:, :, d1, q) to get - // y(yR, yC, d2)|y(d2, yR, yC). ? = to be determined. : = across all - // values in that axis. x(?, ?, d1) and y(yR, yC, d2) is for NHWC. - // x(d1, ?, ?) and y(d2, yR, yC) is for NCHW. - var value = 0.0; - - // Extract if checking out of for loop for performance. - if (inputRowStart >= 0 && inputColStart >= 0 && - inputRowEnd < uniforms.inDims[0] && - inputColEnd < uniforms.inDims[1]) { - for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) { - let xR = inputRowStart + wR * uniforms.dilation[0]; - - for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) { - let xC = inputColStart + wC * uniforms.dilation[1]; - - let xVal = ${getXSnippet}; - let wVal = getW(wR, wC, d1, q); - value = value + xVal * wVal; - } - } - } else { - for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) { - let xR = inputRowStart + wR * uniforms.dilation[0]; - - if (xR < 0 || xR >= uniforms.inDims[0]) { - continue; + let d2 = coords[${this.isChannelsLast ? 3 : 1}]; + let channelMul = uniforms.wShape[3]; + let d1 = d2 / channelMul; + let q = d2 % channelMul; + + let inputRowStart = xRCCorner.x; + let inputColStart = xRCCorner.y; + let inputRowEnd = inputRowStart + uniforms.filterHeight * + uniforms.dilation[0]; + let inputColEnd = inputColStart + uniforms.filterWidth * + uniforms.dilation[1]; + + // Convolve x(?, ?, d1)|x(d1, ?, ?) with w(:, :, d1, q) to get + // y(yR, yC, d2)|y(d2, yR, yC). ? = to be determined. : = across all + // values in that axis. x(?, ?, d1) and y(yR, yC, d2) is for NHWC. + // x(d1, ?, ?) and y(d2, yR, yC) is for NCHW. + var value = 0.0; + + // Extract if checking out of for loop for performance. + if (inputRowStart >= 0 && inputColStart >= 0 && + inputRowEnd < uniforms.inDims[0] && + inputColEnd < uniforms.inDims[1]) { + for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) { + let xR = inputRowStart + wR * uniforms.dilation[0]; + + for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) { + let xC = inputColStart + wC * uniforms.dilation[1]; + + let xVal = ${getXSnippet}; + let wVal = getW(wR, wC, d1, q); + value = value + xVal * wVal; + } } + } else { + for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) { + let xR = inputRowStart + wR * uniforms.dilation[0]; - for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) { - let xC = inputColStart + wC * uniforms.dilation[1]; - - if (xC < 0 || xC >= uniforms.inDims[1]) { + if (xR < 0 || xR >= uniforms.inDims[0]) { continue; } - let xVal = ${getXSnippet}; - let wVal = getW(wR, wC, d1, q); - value = value + xVal * wVal; + for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) { + let xC = inputColStart + wC * uniforms.dilation[1]; + + if (xC < 0 || xC >= uniforms.inDims[1]) { + continue; + } + + let xVal = ${getXSnippet}; + let wVal = getW(wR, wC, d1, q); + value = value + xVal * wVal; + } } } + ${biasActivationSnippet(this.addBias, this.activation)} + if (coordsInBounds4D(coords, uniforms.outShape)) { + setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); } - ${biasActivationSnippet(this.addBias, this.activation)} - if (coordsInBounds4D(coords, uniforms.outShape)) { - setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); } } `; diff --git a/tfjs-backend-webgpu/src/from_pixels_webgpu.ts b/tfjs-backend-webgpu/src/from_pixels_webgpu.ts index 51d27f48e17..de63617db09 100644 --- a/tfjs-backend-webgpu/src/from_pixels_webgpu.ts +++ b/tfjs-backend-webgpu/src/from_pixels_webgpu.ts @@ -28,7 +28,6 @@ export class FromPixelsProgram implements WebGPUProgram { variableNames: string[] = []; workGroupSize: [number, number, number] = [256, 1, 1]; // The empirical value. - size = true; constructor(outputShape: number[], numChannels: number, importVideo = false) { this.outputShape = outputShape; diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index 231ddc1094c..b21008c0dbe 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -188,7 +188,7 @@ export function makeMatMulPackedVec4Source( let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * RowPerThread'}; let globalCol = i32(globalId.x); let batch = ${splitK ? '0' : 'i32(globalId.z)'}; - let globalRowStart = i32(workGroupId.y) * ${tileAOuter}; + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; let numTiles = ${ splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : @@ -288,8 +288,8 @@ export function makeMatMulPackedSource( ` let localRow = i32(localId.y); let localCol = i32(localId.x); - let globalRowStart = i32(workGroupId.y) * ${tileAOuter}; - let globalColStart = i32(workGroupId.x) * ${tileBOuter}; + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. for (var t = 0; t < numTiles; t = t + 1) { @@ -349,7 +349,7 @@ export function makeMatMulPackedSource( let globalRow = i32(globalId.y) * RowPerThread; let globalCol = i32(globalId.x) * ColPerThread; - let globalRowStart = i32(workGroupId.y) * ${tileAOuter}; + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; diff --git a/tfjs-backend-webgpu/src/scatter_webgpu.ts b/tfjs-backend-webgpu/src/scatter_webgpu.ts index 66f36722680..eb32619a6c1 100644 --- a/tfjs-backend-webgpu/src/scatter_webgpu.ts +++ b/tfjs-backend-webgpu/src/scatter_webgpu.ts @@ -33,7 +33,6 @@ export class ScatterProgram implements WebGPUProgram { sliceDimGreaterThanOne: boolean; atomic = true; type: DataType; - size = true; constructor( flattenXShape: number[], sliceDim: number, indicesRank: number, @@ -50,7 +49,8 @@ export class ScatterProgram implements WebGPUProgram { this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; const stridesType = getCoordsDataType(strides.length); - this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`; + this.uniforms = + `sliceDim : i32, strides: ${stridesType}, updatesSize: i32,`; this.updatesRank = updatesRank; this.indicesRank = indicesRank; } @@ -123,7 +123,7 @@ export class ScatterProgram implements WebGPUProgram { ${getUpdatesCoordsFromFlatIndex} ${main('index')} { - if (index < uniforms.sizeUpdate) { + if (index < uniforms.updatesSize) { let coords = getUpdatesCoordsFromFlatIndex(index); var flattenedIndex = 0; for (var j = 0; j < uniforms.sliceDim; j = j + 1) { diff --git a/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts b/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts index 43cec049e3d..5431212d5a8 100644 --- a/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts +++ b/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts @@ -46,8 +46,8 @@ export class TransposeSharedProgram implements WebGPUProgram { var tile : array, ${ this.workGroupSize[0]}>; ${main()} { - var x = i32(workGroupId.x) * TILE_DIM + i32(localId.x); - var y = i32(workGroupId.y) * TILE_DIM + i32(localId.y); + var x = i32(workgroupId.x) * TILE_DIM + i32(localId.x); + var y = i32(workgroupId.y) * TILE_DIM + i32(localId.y); let width = uniforms.outShape[0]; let height = uniforms.outShape[1]; if (x < width && y < height) { @@ -55,8 +55,8 @@ export class TransposeSharedProgram implements WebGPUProgram { } workgroupBarrier(); - x = i32(workGroupId.y) * TILE_DIM + i32(localId.x); - y = i32(workGroupId.x) * TILE_DIM + i32(localId.y); + x = i32(workgroupId.y) * TILE_DIM + i32(localId.x); + y = i32(workgroupId.x) * TILE_DIM + i32(localId.y); if (x < height && y < width) { setOutputAtIndex((y * height + x), tile[localId.x] [localId.y]); diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index fad1eceeab0..39f182d2581 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -124,21 +124,21 @@ export function getMainHeaderString(...params: string[]): string { return snippet; } -export function getStartHeaderString(isUseIndex: boolean): string { +export function getStartHeaderString(useGlobalIndex: boolean): string { let snippet: string; snippet = ` ${getWorkGroupSizeString()} fn _start(@builtin(local_invocation_id) LocalId : vec3, @builtin(global_invocation_id) GlobalId : vec3, @builtin(local_invocation_index) LocalIndex: u32, - @builtin(workgroup_id) WorkGroupId : vec3, + @builtin(workgroup_id) WorkgroupId : vec3, @builtin(num_workgroups) NumWorkgroups : vec3) { localId = LocalId; localIndex = LocalIndex; globalId = GlobalId; numWorkgroups = NumWorkgroups; - workGroupId = WorkGroupId; - ${isUseIndex ? `main(getGlobalIndex());` : `main();`}; + workgroupId = WorkgroupId; + ${useGlobalIndex ? `main(getGlobalIndex());` : `main();`}; } `; return snippet; @@ -163,20 +163,17 @@ function makeShader( var localIndex: u32; var globalId: vec3; var numWorkgroups: vec3; - var workGroupId: vec3; + var workgroupId: vec3; // Only used when the y/z dimension of workgroup size is 1. fn getGlobalIndex() -> i32 { ${ isFlatDispatch(program) ? ` return i32(globalId.x);` : - ` let localInvocationIndex = localId.z * workGroupSizeX * workGroupSizeY + - localId.y * workGroupSizeX + localId.x; - - return i32((workGroupId.z * numWorkgroups.x * numWorkgroups.y + - workGroupId.y * numWorkgroups.x + workGroupId.x) * - (workGroupSizeX * workGroupSizeY * workGroupSizeZ) + - localInvocationIndex); + ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y + + workgroupId.y * numWorkgroups.x + workgroupId.x) * + (workGroupSizeX * workGroupSizeY * workGroupSizeZ) + + localIndex); `} } `); @@ -193,12 +190,13 @@ function makeShader( mapToWgslTypes(outputData.dtype, program.isVec4)}>; @group(0) @binding(2) var uniforms: Uniform; `); + const useGlobalIndex = isFlatDispatchLayout(program); return [ commonSnippet, prefixSnippets.join('\n'), getCoordsFromIndexSnippet(outputData.shape), program.getUserCode(), - getStartHeaderString(program.size), + getStartHeaderString(useGlobalIndex), ].join('\n'); } @@ -278,9 +276,9 @@ function makeShader( program.dispatchLayout.x.length === outputData.shape.length)) .join('\n'); sources.push(inputSnippet); - sources.push(program.getUserCode()); - sources.push(getStartHeaderString(program.size)); + const useGlobalIndex = isFlatDispatchLayout(program); + sources.push(getStartHeaderString(useGlobalIndex)); const source = sources.join('\n'); return source; } @@ -847,3 +845,14 @@ function insertAlignment(uniformShader: string) { }); return uniformShader; } +function isFlatDispatchLayout(program: WebGPUProgram): boolean { + if (program.dispatchLayout.hasOwnProperty('y') && + program.dispatchLayout.y.length !== 0) { + return false; + } + if (program.dispatchLayout.hasOwnProperty('z') && + program.dispatchLayout.z.length !== 0) { + return false; + } + return true; +}