Skip to content
This repository has been archived by the owner on Oct 22, 2024. It is now read-only.

Support latest spec #238

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@webmachinelearning/webnn-polyfill",
"version": "0.1.9",
"version": "0.1.10",
"description": "WebNN API polyfill",
"main": "dist/webnn-polyfill.js",
"jsdelivr": "dist/webnn-polyfill.js",
Expand Down
10 changes: 1 addition & 9 deletions src/ml.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import {MLContext, MLContextOptions} from './nn/context';
import * as utils from './nn/utils';

/**
* [spec](https://webmachinelearning.github.io/webnn/#ml)
* [spec](https://webmachinelearning.github.io/webnn/#api-ml)
*/
export class ML {
/** @ignore */
Expand All @@ -20,11 +19,4 @@ export class ML {
resolve(context);
});
}

createContextSync(options: MLContextOptions = {}): MLContext {
utils.assert(
typeof window === 'undefined' && typeof importScripts === 'function',
'createContextSync() should only be allowed in dedicated worker.');
return new MLContext(options);
}
}
17 changes: 2 additions & 15 deletions src/nn/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export interface MLContextOptions {
}

/**
* [API spec](https://webmachinelearning.github.io/webnn/#mlcontext)
* [API spec](https://webmachinelearning.github.io/webnn/#api-mlcontext)
*/
export class MLContext {
private options_: MLContextOptions;
Expand Down Expand Up @@ -79,7 +79,7 @@ export class MLContext {
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlcontext-compute)
* [spec](https://webmachinelearning.github.io/webnn/#api-mlcontext-compute)
*/
async compute(
graph: MLGraph,
Expand All @@ -89,19 +89,6 @@ export class MLContext {
return result;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlcontext-computesync)
*/
computeSync(
graph: MLGraph,
inputs: MLNamedArrayBufferViews,
outputs: MLNamedArrayBufferViews): void {
utils.assert(
typeof window === 'undefined' && typeof importScripts === 'function',
'computeSync() should only be allowed in dedicated worker.');
graph.computeSync(inputs, outputs);
}

/** @internal */
// Expose tf interfance for setting backend.
get tf(): unknown {
Expand Down
46 changes: 4 additions & 42 deletions src/nn/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ export class MLGraph {
utils.assert(
utils.isTypedArray(resource),
'Only resource of ArrayBufferView type is supported.');
utils.validateTypedArray(resource, inputOperand.desc.type, dimensions);
utils.validateTypedArray(
resource, inputOperand.desc.dataType, dimensions);
}
}

Expand Down Expand Up @@ -178,7 +179,7 @@ export class MLGraph {
for (const inputName of this.inputs_.keys()) {
const inputOperand = this.inputs_.get(inputName);
const typedArrayConstructor =
utils.getTypedArray(inputOperand.desc.type);
utils.getTypedArray(inputOperand.desc.dataType);
const inputBuffer = new typedArrayConstructor(
utils.sizeFromDimensions(inputOperand.desc.dimensions));
inputs[inputName] = inputBuffer;
Expand Down Expand Up @@ -208,30 +209,13 @@ export class MLGraph {
const tensor = outputTensors[outputName] as tf.Tensor;
const desc = utils.createOperandDescriptorFromTensor(tensor);
const resource = outputs[outputName] ;
utils.validateTypedArray(resource, desc.type, desc.dimensions);
utils.validateTypedArray(resource, desc.dataType, desc.dimensions);
resource.set(await tensor.data());
tf.dispose(tensor);
}
return {inputs, outputs};
}

/** @internal */
computeSync(
inputs: MLNamedArrayBufferViews,
outputs: MLNamedArrayBufferViews): void {
const outputTensors: tf.TensorContainerObject =
this.computeOutputTensors(inputs, outputs);
// Setup the outputs.
for (const outputName of Object.keys(outputTensors)) {
const tensor = outputTensors[outputName] as tf.Tensor;
const desc = utils.createOperandDescriptorFromTensor(tensor);
const resource = outputs[outputName] ;
utils.validateTypedArray(resource, desc.type, desc.dimensions);
resource.set(tensor.dataSync());
tf.dispose(tensor);
}
}

/** @ignore */
constructor(outputs?: MLNamedOperands) {
utils.assert(outputs !== undefined, 'Invalid argument');
Expand All @@ -252,14 +236,6 @@ export class MLGraph {
return graph;
}

/** @internal */
static buildAndCompileSync(outputs?: MLNamedOperands): MLGraph {
const graph = new MLGraph(outputs);
graph.build();
graph.compileSync();
return graph;
}

private build(): void {
const visitedOps: Set<Operation> = new Set();
for (const output of this.outputs_.values()) {
Expand Down Expand Up @@ -306,11 +282,6 @@ export class MLGraph {
await this.computeOnce();
}

private compileSync(): void {
this.allocateConstants();
this.computeOnceSync();
}

private allocateConstants(): void {
for (const constant of this.constants_) {
this.constantTensors_.set(
Expand All @@ -327,15 +298,6 @@ export class MLGraph {
}
}

private computeOnceSync(): void {
const outputTensors = this.computeOutputTensors();
for (const outputName of Object.keys(outputTensors)) {
const tensor = outputTensors[outputName] as tf.Tensor;
tensor.dataSync();
tf.dispose(tensor);
}
}

/** @ignore */
// For memory leak testing.
dispose(): void {
Expand Down
Loading
Loading