Skip to content
This repository has been archived by the owner on Oct 22, 2024. It is now read-only.

Support latest spec #238

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@webmachinelearning/webnn-polyfill",
"version": "0.1.9",
"version": "0.1.10",
"description": "WebNN API polyfill",
"main": "dist/webnn-polyfill.js",
"jsdelivr": "dist/webnn-polyfill.js",
Expand Down
10 changes: 1 addition & 9 deletions src/ml.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import {MLContext, MLContextOptions} from './nn/context';
import * as utils from './nn/utils';

/**
* [spec](https://webmachinelearning.github.io/webnn/#ml)
* [spec](https://webmachinelearning.github.io/webnn/#api-ml)
*/
export class ML {
/** @ignore */
Expand All @@ -20,11 +19,4 @@ export class ML {
resolve(context);
});
}

createContextSync(options: MLContextOptions = {}): MLContext {
utils.assert(
typeof window === 'undefined' && typeof importScripts === 'function',
'createContextSync() should only be allowed in dedicated worker.');
return new MLContext(options);
}
}
17 changes: 2 additions & 15 deletions src/nn/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export interface MLContextOptions {
}

/**
* [API spec](https://webmachinelearning.github.io/webnn/#mlcontext)
* [API spec](https://webmachinelearning.github.io/webnn/#api-mlcontext)
*/
export class MLContext {
private options_: MLContextOptions;
Expand Down Expand Up @@ -79,7 +79,7 @@ export class MLContext {
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlcontext-compute)
* [spec](https://webmachinelearning.github.io/webnn/#api-mlcontext-compute)
*/
async compute(
graph: MLGraph,
Expand All @@ -89,19 +89,6 @@ export class MLContext {
return result;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlcontext-computesync)
*/
computeSync(
graph: MLGraph,
inputs: MLNamedArrayBufferViews,
outputs: MLNamedArrayBufferViews): void {
utils.assert(
typeof window === 'undefined' && typeof importScripts === 'function',
'computeSync() should only be allowed in dedicated worker.');
graph.computeSync(inputs, outputs);
}

/** @internal */
// Expose tf interfance for setting backend.
get tf(): unknown {
Expand Down
46 changes: 4 additions & 42 deletions src/nn/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ export class MLGraph {
utils.assert(
utils.isTypedArray(resource),
'Only resource of ArrayBufferView type is supported.');
utils.validateTypedArray(resource, inputOperand.desc.type, dimensions);
utils.validateTypedArray(
resource, inputOperand.desc.dataType, dimensions);
}
}

Expand Down Expand Up @@ -178,7 +179,7 @@ export class MLGraph {
for (const inputName of this.inputs_.keys()) {
const inputOperand = this.inputs_.get(inputName);
const typedArrayConstructor =
utils.getTypedArray(inputOperand.desc.type);
utils.getTypedArray(inputOperand.desc.dataType);
const inputBuffer = new typedArrayConstructor(
utils.sizeFromDimensions(inputOperand.desc.dimensions));
inputs[inputName] = inputBuffer;
Expand Down Expand Up @@ -208,30 +209,13 @@ export class MLGraph {
const tensor = outputTensors[outputName] as tf.Tensor;
const desc = utils.createOperandDescriptorFromTensor(tensor);
const resource = outputs[outputName] ;
utils.validateTypedArray(resource, desc.type, desc.dimensions);
utils.validateTypedArray(resource, desc.dataType, desc.dimensions);
resource.set(await tensor.data());
tf.dispose(tensor);
}
return {inputs, outputs};
}

/** @internal */
computeSync(
inputs: MLNamedArrayBufferViews,
outputs: MLNamedArrayBufferViews): void {
const outputTensors: tf.TensorContainerObject =
this.computeOutputTensors(inputs, outputs);
// Setup the outputs.
for (const outputName of Object.keys(outputTensors)) {
const tensor = outputTensors[outputName] as tf.Tensor;
const desc = utils.createOperandDescriptorFromTensor(tensor);
const resource = outputs[outputName] ;
utils.validateTypedArray(resource, desc.type, desc.dimensions);
resource.set(tensor.dataSync());
tf.dispose(tensor);
}
}

/** @ignore */
constructor(outputs?: MLNamedOperands) {
utils.assert(outputs !== undefined, 'Invalid argument');
Expand All @@ -252,14 +236,6 @@ export class MLGraph {
return graph;
}

/** @internal */
static buildAndCompileSync(outputs?: MLNamedOperands): MLGraph {
const graph = new MLGraph(outputs);
graph.build();
graph.compileSync();
return graph;
}

private build(): void {
const visitedOps: Set<Operation> = new Set();
for (const output of this.outputs_.values()) {
Expand Down Expand Up @@ -306,11 +282,6 @@ export class MLGraph {
await this.computeOnce();
}

private compileSync(): void {
this.allocateConstants();
this.computeOnceSync();
}

private allocateConstants(): void {
for (const constant of this.constants_) {
this.constantTensors_.set(
Expand All @@ -327,15 +298,6 @@ export class MLGraph {
}
}

private computeOnceSync(): void {
const outputTensors = this.computeOutputTensors();
for (const outputName of Object.keys(outputTensors)) {
const tensor = outputTensors[outputName] as tf.Tensor;
tensor.dataSync();
tf.dispose(tensor);
}
}

/** @ignore */
// For memory leak testing.
dispose(): void {
Expand Down
Loading
Loading