-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Update shader to support non module-level scoping function #6918
Conversation
Can we add a generator function |
let localRow = i32(localId.y); | ||
let tileRow = ${isVectorA ? '0' : 'localRow * RowPerThread'}; | ||
let tileCol = i32(localId.x); | ||
|
||
let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * RowPerThread'}; | ||
let globalCol = i32(globalId.x); | ||
let batch = ${splitK ? '0' : 'i32(globalId.z)'}; | ||
let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; | ||
let globalRowStart = i32(workGroupId.y) * ${tileAOuter}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workgroup should be a single word?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
@@ -49,7 +50,7 @@ export class ScatterProgram implements WebGPUProgram { | |||
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ | |||
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; | |||
const stridesType = getCoordsDataType(strides.length); | |||
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`; | |||
this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sizeUpdate -> updateSize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use updatesSize
since the buffer name is updates
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
program.size
seems not a good name. But we can do the renaming in a separate PR.
@@ -116,11 +112,12 @@ export class DepthwiseConv2DNCHWSharedProgram implements WebGPUProgram { | |||
} | |||
|
|||
// Load one tile of W into local memory. | |||
var wIndex = localIndex; | |||
var wIndex = localIndexI; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about var wIndex = i32(localIndex);
? Then you can remove L89?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, updated.
@@ -49,7 +50,7 @@ export class ScatterProgram implements WebGPUProgram { | |||
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ | |||
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; | |||
const stridesType = getCoordsDataType(strides.length); | |||
this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`; | |||
this.uniforms = `sliceDim : i32, strides: ${stridesType}, sizeUpdate: i32,`; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use updatesSize
since the buffer name is updates
.
@@ -144,6 +124,26 @@ export function getMainHeaderString(...params: string[]): string { | |||
return snippet; | |||
} | |||
|
|||
export function getStartHeaderString(isUseIndex: boolean): string { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename: isUseIndex
-> useGlobalIndex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
var<private> globalId: vec3<u32>; | ||
var<private> numWorkgroups: vec3<u32>; | ||
var<private> workGroupId: vec3<u32>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workGroupId
-> workgroupId
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -198,6 +198,7 @@ function makeShader( | |||
prefixSnippets.join('\n'), | |||
getCoordsFromIndexSnippet(outputData.shape), | |||
program.getUserCode(), | |||
getStartHeaderString(program.size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just use getStartHeaderString(true)
so that you don't need to change from_pixels_webgpu.ts
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A second thought, I prefer we use whether it's a flat dispatch layout to be as the parameter of getStartHeaderString
, which is easier to understand. And same for L283.
const isFlatDispatchLayout = (program.dispatchLayout.y === null && program.dispatchLayout.z === null) || (program.dispatchLayout.y.length === 0 && program.dispatchLayout.z.length ===0)
getStartHeaderString(isFlatDispatchLayout )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a function to do this, PTAL.
return i32((workGroupID.z * numWorkgroups.x * numWorkgroups.y + | ||
workGroupID.y * numWorkgroups.x + workGroupID.x) * | ||
return i32((workGroupId.z * numWorkgroups.x * numWorkgroups.y + | ||
workGroupId.y * numWorkgroups.x + workGroupId.x) * | ||
(workGroupSizeX * workGroupSizeY * workGroupSizeZ) + | ||
localInvocationIndex); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The localInvocationIndex
can also be replaced by localIndex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that may also reduce computing cost.
FIXES tensorflow#6842 To support shader translation library which does not implement module scoping like naga
use main() to generate user function and getStartHenderString() to make entry point function
And address comments
ca06402
to
c2c070d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. We need another PR to change all the occurrences of workgroup to a single word.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one nit.
} | ||
} | ||
${biasActivationSnippet(this.addBias, this.activation)} | ||
if (coordsInBounds4D(coords, uniforms.outShape)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you already used if (index < uniforms.size)
, if (coordsInBounds4D(coords, uniforms.outShape))
is not needed anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated.
FIXES #6842
Declare user defined function before entry point function to support shader translation library which does not implement module scoping yet, like naga(gfx-rs/naga#2075)
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is