Skip to content

Commit

Permalink
Use isFlatPatchLayout to determine main header
Browse files Browse the repository at this point in the history
And address comments
  • Loading branch information
haoyunfeix committed Oct 12, 2022
1 parent a10518b commit c2c070d
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram {
}
${main()} {
let localIndexI = i32(localIndex);
let coords = getOutputCoords();
let batch = coords[0];
let xRCCorner = vec2<i32>(coords.zw) - uniforms.pad;
Expand All @@ -112,7 +111,7 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram {
}
// Load one tile of W into local memory.
var wIndex = localIndexI;
var wIndex = i32(localIndex);
${
filterSize < workGroupSize ?
`if (wIndex < ${filterSize})` :
Expand Down
115 changes: 60 additions & 55 deletions tfjs-backend-webgpu/src/depthwise_conv2d_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
activation: backend_util.Activation;
hasPreluActivation: boolean;
isChannelsLast: boolean;
size = true;

constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
Expand Down Expand Up @@ -65,70 +66,74 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
'getX(batch, d1, xR, xC);';

const userCode = `
${activationFnSnippet(this.activation, this.hasPreluActivation, false, 4)}
${main()} {
let coords = getOutputCoords();
let batch = coords[0];
let xRCCorner = vec2<i32>(coords.${
${
activationFnSnippet(
this.activation, this.hasPreluActivation, false, 4)}
${main('index')} {
if (index < uniforms.size) {
let coords = getOutputCoords();
let batch = coords[0];
let xRCCorner = vec2<i32>(coords.${
this.isChannelsLast ? 'yz' : 'zw'}) * uniforms.stride - uniforms.pad;
let d2 = coords[${this.isChannelsLast ? 3 : 1}];
let channelMul = uniforms.wShape[3];
let d1 = d2 / channelMul;
let q = d2 % channelMul;
let inputRowStart = xRCCorner.x;
let inputColStart = xRCCorner.y;
let inputRowEnd = inputRowStart + uniforms.filterHeight *
uniforms.dilation[0];
let inputColEnd = inputColStart + uniforms.filterWidth *
uniforms.dilation[1];
// Convolve x(?, ?, d1)|x(d1, ?, ?) with w(:, :, d1, q) to get
// y(yR, yC, d2)|y(d2, yR, yC). ? = to be determined. : = across all
// values in that axis. x(?, ?, d1) and y(yR, yC, d2) is for NHWC.
// x(d1, ?, ?) and y(d2, yR, yC) is for NCHW.
var value = 0.0;
// Extract if checking out of for loop for performance.
if (inputRowStart >= 0 && inputColStart >= 0 &&
inputRowEnd < uniforms.inDims[0] &&
inputColEnd < uniforms.inDims[1]) {
for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) {
let xR = inputRowStart + wR * uniforms.dilation[0];
for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) {
let xC = inputColStart + wC * uniforms.dilation[1];
let xVal = ${getXSnippet};
let wVal = getW(wR, wC, d1, q);
value = value + xVal * wVal;
}
}
} else {
for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) {
let xR = inputRowStart + wR * uniforms.dilation[0];
if (xR < 0 || xR >= uniforms.inDims[0]) {
continue;
let d2 = coords[${this.isChannelsLast ? 3 : 1}];
let channelMul = uniforms.wShape[3];
let d1 = d2 / channelMul;
let q = d2 % channelMul;
let inputRowStart = xRCCorner.x;
let inputColStart = xRCCorner.y;
let inputRowEnd = inputRowStart + uniforms.filterHeight *
uniforms.dilation[0];
let inputColEnd = inputColStart + uniforms.filterWidth *
uniforms.dilation[1];
// Convolve x(?, ?, d1)|x(d1, ?, ?) with w(:, :, d1, q) to get
// y(yR, yC, d2)|y(d2, yR, yC). ? = to be determined. : = across all
// values in that axis. x(?, ?, d1) and y(yR, yC, d2) is for NHWC.
// x(d1, ?, ?) and y(d2, yR, yC) is for NCHW.
var value = 0.0;
// Extract if checking out of for loop for performance.
if (inputRowStart >= 0 && inputColStart >= 0 &&
inputRowEnd < uniforms.inDims[0] &&
inputColEnd < uniforms.inDims[1]) {
for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) {
let xR = inputRowStart + wR * uniforms.dilation[0];
for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) {
let xC = inputColStart + wC * uniforms.dilation[1];
let xVal = ${getXSnippet};
let wVal = getW(wR, wC, d1, q);
value = value + xVal * wVal;
}
}
} else {
for (var wR = 0; wR < uniforms.filterHeight; wR = wR + 1) {
let xR = inputRowStart + wR * uniforms.dilation[0];
for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) {
let xC = inputColStart + wC * uniforms.dilation[1];
if (xC < 0 || xC >= uniforms.inDims[1]) {
if (xR < 0 || xR >= uniforms.inDims[0]) {
continue;
}
let xVal = ${getXSnippet};
let wVal = getW(wR, wC, d1, q);
value = value + xVal * wVal;
for (var wC = 0; wC < uniforms.filterWidth; wC = wC + 1) {
let xC = inputColStart + wC * uniforms.dilation[1];
if (xC < 0 || xC >= uniforms.inDims[1]) {
continue;
}
let xVal = ${getXSnippet};
let wVal = getW(wR, wC, d1, q);
value = value + xVal * wVal;
}
}
}
${biasActivationSnippet(this.addBias, this.activation)}
if (coordsInBounds4D(coords, uniforms.outShape)) {
setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value);
}
${biasActivationSnippet(this.addBias, this.activation)}
if (coordsInBounds4D(coords, uniforms.outShape)) {
setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value);
}
}
`;
Expand Down
1 change: 0 additions & 1 deletion tfjs-backend-webgpu/src/from_pixels_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ export class FromPixelsProgram implements WebGPUProgram {
variableNames: string[] = [];
workGroupSize: [number, number, number] =
[256, 1, 1]; // The empirical value.
size = true;

constructor(outputShape: number[], numChannels: number, importVideo = false) {
this.outputShape = outputShape;
Expand Down
8 changes: 4 additions & 4 deletions tfjs-backend-webgpu/src/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ export function makeMatMulPackedVec4Source(
let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * RowPerThread'};
let globalCol = i32(globalId.x);
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
let globalRowStart = i32(workGroupId.y) * ${tileAOuter};
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let numTiles = ${
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` :
Expand Down Expand Up @@ -288,8 +288,8 @@ export function makeMatMulPackedSource(
`
let localRow = i32(localId.y);
let localCol = i32(localId.x);
let globalRowStart = i32(workGroupId.y) * ${tileAOuter};
let globalColStart = i32(workGroupId.x) * ${tileBOuter};
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let globalColStart = i32(workgroupId.x) * ${tileBOuter};
// Loop over shared dimension.
for (var t = 0; t < numTiles; t = t + 1) {
Expand Down Expand Up @@ -349,7 +349,7 @@ export function makeMatMulPackedSource(
let globalRow = i32(globalId.y) * RowPerThread;
let globalCol = i32(globalId.x) * ColPerThread;
let globalRowStart = i32(workGroupId.y) * ${tileAOuter};
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let tileRowA = i32(localId.y) * ${rowPerThreadA};
let tileColA = i32(localId.x) * ${colPerThreadA};
Expand Down
6 changes: 3 additions & 3 deletions tfjs-backend-webgpu/src/scatter_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ export class ScatterProgram implements WebGPUProgram {
sliceDimGreaterThanOne: boolean;
atomic = true;
type: DataType;
size = true;

constructor(
flattenXShape: number[], sliceDim: number, indicesRank: number,
Expand All @@ -50,7 +49,8 @@ export class ScatterProgram implements WebGPUProgram {
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
const stridesType = getCoordsDataType(strides.length);
this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`;
this.uniforms =
`sliceDim : i32, strides: ${stridesType}, updatesSize: i32,`;
this.updatesRank = updatesRank;
this.indicesRank = indicesRank;
}
Expand Down Expand Up @@ -123,7 +123,7 @@ export class ScatterProgram implements WebGPUProgram {
${getUpdatesCoordsFromFlatIndex}
${main('index')} {
if (index < uniforms.sizeUpdate) {
if (index < uniforms.updatesSize) {
let coords = getUpdatesCoordsFromFlatIndex(index);
var flattenedIndex = 0;
for (var j = 0; j < uniforms.sliceDim; j = j + 1) {
Expand Down
8 changes: 4 additions & 4 deletions tfjs-backend-webgpu/src/transpose_shared_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ export class TransposeSharedProgram implements WebGPUProgram {
var<workgroup> tile : array<array<f32, ${this.workGroupSize[0] + 1}>, ${
this.workGroupSize[0]}>;
${main()} {
var x = i32(workGroupId.x) * TILE_DIM + i32(localId.x);
var y = i32(workGroupId.y) * TILE_DIM + i32(localId.y);
var x = i32(workgroupId.x) * TILE_DIM + i32(localId.x);
var y = i32(workgroupId.y) * TILE_DIM + i32(localId.y);
let width = uniforms.outShape[0];
let height = uniforms.outShape[1];
if (x < width && y < height) {
tile[localId.y][localId.x] = A[y * width + x];
}
workgroupBarrier();
x = i32(workGroupId.y) * TILE_DIM + i32(localId.x);
y = i32(workGroupId.x) * TILE_DIM + i32(localId.y);
x = i32(workgroupId.y) * TILE_DIM + i32(localId.x);
y = i32(workgroupId.x) * TILE_DIM + i32(localId.y);
if (x < height && y < width) {
setOutputAtIndex((y * height + x), tile[localId.x]
[localId.y]);
Expand Down
39 changes: 24 additions & 15 deletions tfjs-backend-webgpu/src/webgpu_program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,21 @@ export function getMainHeaderString(...params: string[]): string {
return snippet;
}

export function getStartHeaderString(isUseIndex: boolean): string {
export function getStartHeaderString(useGlobalIndex: boolean): string {
let snippet: string;
snippet = `
${getWorkGroupSizeString()}
fn _start(@builtin(local_invocation_id) LocalId : vec3<u32>,
@builtin(global_invocation_id) GlobalId : vec3<u32>,
@builtin(local_invocation_index) LocalIndex: u32,
@builtin(workgroup_id) WorkGroupId : vec3<u32>,
@builtin(workgroup_id) WorkgroupId : vec3<u32>,
@builtin(num_workgroups) NumWorkgroups : vec3<u32>) {
localId = LocalId;
localIndex = LocalIndex;
globalId = GlobalId;
numWorkgroups = NumWorkgroups;
workGroupId = WorkGroupId;
${isUseIndex ? `main(getGlobalIndex());` : `main();`};
workgroupId = WorkgroupId;
${useGlobalIndex ? `main(getGlobalIndex());` : `main();`};
}
`;
return snippet;
Expand All @@ -163,20 +163,17 @@ function makeShader(
var<private> localIndex: u32;
var<private> globalId: vec3<u32>;
var<private> numWorkgroups: vec3<u32>;
var<private> workGroupId: vec3<u32>;
var<private> workgroupId: vec3<u32>;
// Only used when the y/z dimension of workgroup size is 1.
fn getGlobalIndex() -> i32 {
${
isFlatDispatch(program) ?
` return i32(globalId.x);` :
` let localInvocationIndex = localId.z * workGroupSizeX * workGroupSizeY +
localId.y * workGroupSizeX + localId.x;
return i32((workGroupId.z * numWorkgroups.x * numWorkgroups.y +
workGroupId.y * numWorkgroups.x + workGroupId.x) *
(workGroupSizeX * workGroupSizeY * workGroupSizeZ) +
localInvocationIndex);
` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y +
workgroupId.y * numWorkgroups.x + workgroupId.x) *
(workGroupSizeX * workGroupSizeY * workGroupSizeZ) +
localIndex);
`}
}
`);
Expand All @@ -193,12 +190,13 @@ function makeShader(
mapToWgslTypes(outputData.dtype, program.isVec4)}>;
@group(0) @binding(2) var<uniform> uniforms: Uniform;
`);
const useGlobalIndex = isFlatDispatchLayout(program);
return [
commonSnippet,
prefixSnippets.join('\n'),
getCoordsFromIndexSnippet(outputData.shape),
program.getUserCode(),
getStartHeaderString(program.size),
getStartHeaderString(useGlobalIndex),
].join('\n');
}

Expand Down Expand Up @@ -278,9 +276,9 @@ function makeShader(
program.dispatchLayout.x.length === outputData.shape.length))
.join('\n');
sources.push(inputSnippet);

sources.push(program.getUserCode());
sources.push(getStartHeaderString(program.size));
const useGlobalIndex = isFlatDispatchLayout(program);
sources.push(getStartHeaderString(useGlobalIndex));
const source = sources.join('\n');
return source;
}
Expand Down Expand Up @@ -847,3 +845,14 @@ function insertAlignment(uniformShader: string) {
});
return uniformShader;
}
function isFlatDispatchLayout(program: WebGPUProgram): boolean {
if (program.dispatchLayout.hasOwnProperty('y') &&
program.dispatchLayout.y.length !== 0) {
return false;
}
if (program.dispatchLayout.hasOwnProperty('z') &&
program.dispatchLayout.z.length !== 0) {
return false;
}
return true;
}

0 comments on commit c2c070d

Please sign in to comment.