Skip to content

Commit

Permalink
[js/webgpu] set query type in onRunStart (#19202)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
`env.webgpu.profiling` is a global flag. It may change before each
session.run. So the best place is to update it in `onRunStart` event.
After this, we can directly check `this.queryType`'s value. Without this
pr, we need to make sure that `getCommandEncoder()` is called before
checking `this.queryType`. Otherwise, it may happen that
`pendingKernels`'s length is not equal to `pendingDispatchNumber`'s
length. See the two ugly workarounds
[1)](e630dbf#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268)
and
[2)](e630dbf#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80)
if we don't introduce `onRunStart`. Or we need to call `setQueryType` in
each kernel run.
  • Loading branch information
qjia7 authored and fs-eire committed Mar 15, 2024
1 parent a24273e commit c44d497
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
4 changes: 4 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Called when InferenceSession.run started.
*/
jsepOnRunStart: () => void;
// #endregion
}

Expand Down
9 changes: 5 additions & 4 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ export class WebGpuBackend {

Object.defineProperty(this.env.webgpu, 'device', {value: this.device});

// init queryType, which is necessary for createKernel
// init queryType, which is necessary for InferenceSession.create
this.setQueryType();
}

Expand All @@ -223,8 +223,6 @@ export class WebGpuBackend {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();

// refresh queryType, as sometimes we only need to enable query for a specific run
this.setQueryType();
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
Expand Down Expand Up @@ -639,6 +637,7 @@ export class WebGpuBackend {
return createView(data.buffer, type);
};
}
// #endregion
writeTimestamp(index: number): void {
if (this.queryType !== 'inside-passes') {
return;
Expand All @@ -657,5 +656,7 @@ export class WebGpuBackend {
}
}
}
// #endregion
onRunStart(): void {
this.setQueryType();
}
}
2 changes: 1 addition & 1 deletion js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ export const run = async(
}
}

wasm.jsepOnRunStart?.();
let errorCode: number;

if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnRunStart'] = () => {
return backend['onRunStart']();
};
};

0 comments on commit c44d497

Please sign in to comment.