Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Update shader to support non module-level scoping function #6918

Merged
merged 5 commits into from
Oct 13, 2022

Conversation

haoyunfeix
Copy link
Contributor

@haoyunfeix haoyunfeix commented Oct 9, 2022

FIXES #6842
Declare user defined function before entry point function to support shader translation library which does not implement module scoping yet, like naga(gfx-rs/naga#2075)

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@hujiajie
Copy link
Contributor

hujiajie commented Oct 9, 2022

Can we add a generator function foobar() in webgpu_program.ts, call the first step of foobar() in getMainHeaderString() and get a string like fn main(), then after getUserCode() returns, call the second step of foobar() to get the matching definition of _start()?

@haoyunfeix
Copy link
Contributor Author

@qjia7 @xhcao @axinging @hujiajie @gyagp PTAL

let localRow = i32(localId.y);
let tileRow = ${isVectorA ? '0' : 'localRow * RowPerThread'};
let tileCol = i32(localId.x);

let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * RowPerThread'};
let globalCol = i32(globalId.x);
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let globalRowStart = i32(workGroupId.y) * ${tileAOuter};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workgroup should be a single word?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -49,7 +50,7 @@ export class ScatterProgram implements WebGPUProgram {
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
const stridesType = getCoordsDataType(strides.length);
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`;
this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sizeUpdate -> updateSize?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use updatesSize since the buffer name is updates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

program.size seems not a good name. But we can do the renaming in a separate PR.

@@ -116,11 +112,12 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram {
}

// Load one tile of W into local memory.
var wIndex = localIndex;
var wIndex = localIndexI;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about var wIndex = i32(localIndex);? Then you can remove L89?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated.

@@ -49,7 +50,7 @@ export class ScatterProgram implements WebGPUProgram {
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
const stridesType = getCoordsDataType(strides.length);
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`;
this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use updatesSize since the buffer name is updates.

@@ -144,6 +124,26 @@ export function getMainHeaderString(...params: string[]): string {
return snippet;
}

export function getStartHeaderString(isUseIndex: boolean): string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename: isUseIndex -> useGlobalIndex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

var<private> globalId: vec3<u32>;
var<private> numWorkgroups: vec3<u32>;
var<private> workGroupId: vec3<u32>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workGroupId -> workgroupId

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -198,6 +198,7 @@ function makeShader(
prefixSnippets.join('\n'),
getCoordsFromIndexSnippet(outputData.shape),
program.getUserCode(),
getStartHeaderString(program.size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just use getStartHeaderString(true) so that you don't need to change from_pixels_webgpu.ts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second thought, I prefer we use whether it's a flat dispatch layout to be as the parameter of getStartHeaderString, which is easier to understand. And same for L283.

const isFlatDispatchLayout = (program.dispatchLayout.y === null && program.dispatchLayout.z === null) || (program.dispatchLayout.y.length === 0 && program.dispatchLayout.z.length ===0)
getStartHeaderString(isFlatDispatchLayout )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a function to do this, PTAL.

return i32((workGroupID.z * numWorkgroups.x * numWorkgroups.y +
workGroupID.y * numWorkgroups.x + workGroupID.x) *
return i32((workGroupId.z * numWorkgroups.x * numWorkgroups.y +
workGroupId.y * numWorkgroups.x + workGroupId.x) *
(workGroupSizeX * workGroupSizeY * workGroupSizeZ) +
localInvocationIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The localInvocationIndex can also be replaced by localIndex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that may also reduce computing cost.

FIXES tensorflow#6842
To support shader translation library which does not implement module
scoping like naga
use main() to generate user function and getStartHenderString() to make
entry point function
Copy link
Contributor

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We need another PR to change all the occurrences of workgroup to a single word.

Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit.

}
}
${biasActivationSnippet(this.addBias, this.activation)}
if (coordsInBounds4D(coords, uniforms.outShape)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you already used if (index < uniforms.size), if (coordsInBounds4D(coords, uniforms.outShape)) is not needed anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LayersModel#predict() results in all zeros when using WebGPU backend in Deno
4 participants