From 04823cd706ebd0cccf792d5ddb0fbf529420f138 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 31 Jan 2024 13:05:08 +0800 Subject: [PATCH] [js/webgpu] Use DataType as uniform cpu type (#19281) This saves turning data type to string by tensorDataTypeEnumToString. --- web/lib/wasm/jsep/backend-webgpu.ts | 18 ++++++----- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 7 +++-- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 8 +++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 8 +++-- .../ops/3rd-party/matmul_packed_webgpu.ts | 7 +++-- web/lib/wasm/jsep/webgpu/ops/attention.ts | 30 +++++++++---------- web/lib/wasm/jsep/webgpu/ops/batch-norm.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/common.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/concat.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 13 ++++---- web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/einsum.ts | 7 +++-- web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 7 +++-- .../wasm/jsep/webgpu/ops/gather-elements.ts | 7 +++-- web/lib/wasm/jsep/webgpu/ops/gather.ts | 6 ++-- web/lib/wasm/jsep/webgpu/ops/gemm.ts | 6 ++-- web/lib/wasm/jsep/webgpu/ops/instance-norm.ts | 14 +++++---- web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 ++-- .../jsep/webgpu/ops/multi-head-attentiion.ts | 7 +++-- web/lib/wasm/jsep/webgpu/ops/pad.ts | 7 ++--- web/lib/wasm/jsep/webgpu/ops/pool.ts | 20 +++++++------ web/lib/wasm/jsep/webgpu/ops/range.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/reduce.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/resize.ts | 7 +++-- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 8 ++--- web/lib/wasm/jsep/webgpu/ops/slice.ts | 6 ++-- web/lib/wasm/jsep/webgpu/ops/softmax.ts | 3 +- web/lib/wasm/jsep/webgpu/ops/split.ts | 5 ++-- web/lib/wasm/jsep/webgpu/ops/tile.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/transpose.ts | 3 +- web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 2 +- web/lib/wasm/jsep/webgpu/ops/where.ts | 5 ++-- web/lib/wasm/jsep/webgpu/types.ts | 3 +- 37 files changed, 148 insertions(+), 108 deletions(-) diff --git a/web/lib/wasm/jsep/backend-webgpu.ts b/web/lib/wasm/jsep/backend-webgpu.ts index e1faecfc046e3..58efa795dba48 100644 --- a/web/lib/wasm/jsep/backend-webgpu.ts +++ b/web/lib/wasm/jsep/backend-webgpu.ts @@ -3,7 +3,7 @@ import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {tensorDataTypeEnumToString} from '../wasm-common'; +import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; import {configureLogger, LOG_DEBUG} from './log'; import {createView, TensorView} from './tensor-view'; @@ -453,10 +453,10 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const sizeOfElement = v.type === 'float16' ? 2 : 4; + const sizeOfElement = v.type === DataType.float16 ? 2 : 4; let sizeOfVecOrMat; let baseAlignment; - if (v.type === 'float16') { + if (v.type === DataType.float16) { baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { @@ -470,7 +470,7 @@ export class WebGpuBackend { // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). - const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement; }); @@ -483,15 +483,17 @@ export class WebGpuBackend { programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; - if (v.type === 'int32') { + if (v.type === DataType.int32) { new Int32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'uint32') { + } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'float16') { + } else if (v.type === DataType.float16) { // TODO: use Float16Array. new Uint16Array(arrayBuffer, offset, data.length).set(data); - } else { + } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { + throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); } }); diff --git a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index e5ca3204d4433..bc39bd94e3072 100644 --- a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.dilations} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index e50733559dbe9..d18f8586dd071 100644 --- a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo = ]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, - {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, + {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, + {type: DataType.int32, data: pads} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 380efc8bc577a..ba6776e9d8c94 100644 --- a/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,6 +17,7 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo = const outputChannelsPerGroup = wShape[1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, - {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, - {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides}, + {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, + {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, + {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) ]; if (hasBias) { diff --git a/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 00c1f86d67419..d9a8d59f731de 100644 --- a/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -447,8 +448,10 @@ export const createMatmulProgramInfo = const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner} + ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), diff --git a/web/lib/wasm/jsep/webgpu/ops/attention.ts b/web/lib/wasm/jsep/webgpu/ops/attention.ts index f07a21a343fa8..2cfe6356dd6e7 100644 --- a/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); - const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; - const programUniforms: ProgramUniform[] = - [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const programUniforms: ProgramUniform[] = [ + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.uint32, data: elementsPerWG} + ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -336,11 +337,10 @@ const computeAttentionProbs = y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; - const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, - {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, - {type: tensorDataType, data: alpha} + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, + {type: DataType.uint32, data: parameters.totalSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} ]; const inputs = [q, key]; @@ -430,9 +430,9 @@ const computeVxAttentionScore = z: params.batchSize * params.numHeads }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, - {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, - {type: 'uint32', data: params.vHiddenSize} + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, + {type: DataType.uint32, data: params.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { }; const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, - {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, - {type: 'uint32', data: parameters.hiddenSize}, - {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, + {type: DataType.uint32, data: parameters.hiddenSize}, + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 159b971636765..39b932375891b 100644 --- a/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: useShapesUniforms ? [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(yShape), ] : [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ], }), }; diff --git a/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 8e144a36dc1b0..51f0c76ed8824 100644 --- a/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -179,7 +179,7 @@ const createBinaryOpProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, ...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b.dims), ...createTensorShapeVariables(outputShape), diff --git a/web/lib/wasm/jsep/webgpu/ops/common.ts b/web/lib/wasm/jsep/webgpu/ops/common.ts index 1bedf31ee4e38..3de57d5ac7f7c 100644 --- a/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => - dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ? + [] : + [{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}]; /** * A helper function to get maximum vector size for specified data length diff --git a/web/lib/wasm/jsep/webgpu/ops/concat.ts b/web/lib/wasm/jsep/webgpu/ops/concat.ts index daa326b1a34e2..b06c9fb496d15 100644 --- a/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); inputDependencies.push('rank'); - programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } for (let i = 0; i < inputs.length; ++i) { programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); diff --git a/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index c0aaaa7ce134b..3c2c3cc4e046c 100644 --- a/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo = const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, - {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, + {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.uint32, data: outputChannelsPerGroup} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( @@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} + {type: DataType.uint32, data: outputSize}, + {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 2ff909c30e62e..fb17202cd042f 100644 --- a/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -54,7 +54,7 @@ const createCumsumProgramInfo = outputs: [{dims: inputShape, dataType: inputType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) ] diff --git a/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 9e1f58bbfa127..19a009c2eb79b 100644 --- a/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -272,8 +273,10 @@ const createEinsumProgramInfo = // filter is added to make sure that dimValue is never 0. const programUniformsInit: ProgramUniform[] = uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: 'uint32', data: outputSize}); + .map( + (symbol) => + ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: DataType.uint32, data: outputSize}); const programUniforms: ProgramUniform[] = inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); diff --git a/web/lib/wasm/jsep/webgpu/ops/expand.ts b/web/lib/wasm/jsep/webgpu/ops/expand.ts index dd18bd23a5912..f8fdb63160380 100644 --- a/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ]; return { diff --git a/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index e1dc9a5e0ab7d..60067c014613b 100644 --- a/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {MAX_CLIP, MIN_CLIP} from '../../util'; import {ProgramUniform} from '../types'; @@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { if (attributes.activation === 'Clip') { - programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + programUniform.push( + {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); } else if (attributes.activation === 'HardSigmoid') { - programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + programUniform.push( + {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); } }; diff --git a/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a945954adcaa4..a2d4e3d28f7c5 100644 --- a/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -46,8 +47,10 @@ const createGatherElementsProgramInfo = const output = outputVariable('output', inputOutputDataType, outputShape.length); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis} + ]; programUniforms.push(...createTensorShapeVariables(inputShape)); programUniforms.push(...createTensorShapeVariables(indicesShape)); programUniforms.push(...createTensorShapeVariables(outputShape)); diff --git a/web/lib/wasm/jsep/webgpu/ops/gather.ts b/web/lib/wasm/jsep/webgpu/ops/gather.ts index e2a62c6655c72..f2c71a9cd4188 100644 --- a/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -34,9 +34,9 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/web/lib/wasm/jsep/webgpu/ops/gemm.ts index a0d4021516bf7..76302e1af2e53 100644 --- a/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -45,8 +46,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt } const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, - {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, + {type: DataType.float, data: attributes.beta} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { diff --git a/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index a835c90bd5451..2096b898b5d40 100644 --- a/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -25,7 +25,7 @@ const createInstanceNormProgramInfo = const inputShape = [xShape[0], xShape[1], normPackedSize]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; + [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -132,8 +132,9 @@ const computeMean = const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const meanProgramUniforms: ProgramUniform[] = [ - {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, - {type: 'uint32', data: Math.floor(h * c / components)} + {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(h * c / components)} ]; const getMeanShaderSource = (shaderHelper: ShaderHelper) => { @@ -182,8 +183,9 @@ const computeMean = {inputs: [input], outputs: [-1]})[0]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, - {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(WG * c / components)} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -246,7 +248,7 @@ const createInstanceNormNHWCProgramInfo = const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); diff --git a/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3c9f6ce71bb67..3f73d9cb7c5bc 100644 --- a/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -49,8 +49,9 @@ const createLayerNormProgramInfo = const components = getMaxComponents(normSize); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, - {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, + {type: DataType.uint32, data: Math.floor(normSize / components)}, + {type: DataType.float, data: attributes.epsilon} ]; if (bias) { inputDependencies.push('type'); diff --git a/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 188b88b2510d8..b263451b99134 100644 --- a/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; @@ -29,8 +30,8 @@ export const createNaiveMatmulProgramInfo = const outputShapeInShader = [batchSize, M, N]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K} ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( diff --git a/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 6d22e3780efd9..5c5c849d99811 100644 --- a/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -238,8 +239,10 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, + {type: DataType.uint32, data: hiddenSize} + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); diff --git a/web/lib/wasm/jsep/webgpu/ops/pad.ts b/web/lib/wasm/jsep/webgpu/ops/pad.ts index c65b741e1105a..9f5e60773f080 100644 --- a/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -153,10 +153,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const inputDims = inputs[0].dims; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}]; if (attributes.mode === 0) { - const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; - programUniforms.push({type: tensorDataType, data: attributes.value}); + programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); diff --git a/web/lib/wasm/jsep/webgpu/ops/pool.ts b/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9e9b361c1af1c..70b8acc3146a0 100644 --- a/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -56,7 +57,8 @@ const getUniformAndPadInfo = ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: outputSize}, - programUniforms: [{type: 'uint32', data: reduceSize}] + programUniforms: [{type: DataType.uint32, data: reduceSize}] }), }; }; diff --git a/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8851ac546942..123eb38a1fb93 100644 --- a/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -101,7 +101,7 @@ export const createReduceProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ] }), diff --git a/web/lib/wasm/jsep/webgpu/ops/resize.ts b/web/lib/wasm/jsep/webgpu/ops/resize.ts index f68526acc0e63..edfd856aeb850 100644 --- a/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -641,9 +642,9 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, - {type: 'float32', data: scales}, - {type: 'float32', data: roi}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape), ] diff --git a/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 509a722f4b52a..7be9ceec6bc65 100644 --- a/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -88,10 +88,10 @@ const createSkipLayerNormProgramInfo = const components = getMaxComponents(hiddenSize); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, - {type: 'uint32', data: components}, - {type: 'uint32', data: hiddenSize}, - {type: 'float32', data: attributes.epsilon}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.uint32, data: components}, + {type: DataType.uint32, data: hiddenSize}, + {type: DataType.float, data: attributes.epsilon}, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniformsArray: UniformsArrayType = [ diff --git a/web/lib/wasm/jsep/webgpu/ops/slice.ts b/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5212c6475dce0..6baa634f69f82 100644 --- a/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice ]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, - {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, + {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 324dc3af1a710..6f8bfa08d7b62 100644 --- a/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,6 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut getRunData: () => ({ outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, - programUniforms: [{type: 'uint32', data: packedCols}] + programUniforms: [{type: DataType.uint32, data: packedCols}] }), getShaderSource, }; diff --git a/web/lib/wasm/jsep/webgpu/ops/split.ts b/web/lib/wasm/jsep/webgpu/ops/split.ts index b8582614fa214..0b703de2ffa1c 100644 --- a/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -72,7 +73,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; @@ -82,7 +83,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); + programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis}); programUniforms.push(...createTensorShapeVariables(inputShape)); outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/web/lib/wasm/jsep/webgpu/ops/tile.ts b/web/lib/wasm/jsep/webgpu/ops/tile.ts index 90a36a7bec2a9..b080767d2faac 100644 --- a/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -80,7 +80,7 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ], }), diff --git a/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ab9a9ac8dd1f0..920da04398832 100644 --- a/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -65,7 +66,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), ], diff --git a/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 76929efb32537..1accfac18b876 100644 --- a/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -53,7 +53,7 @@ const createElementwiseProgramInfo = dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ], }) }); diff --git a/web/lib/wasm/jsep/webgpu/ops/where.ts b/web/lib/wasm/jsep/webgpu/ops/where.ts index 2ef9637bcda5e..51e8f56c229bd 100644 --- a/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -98,8 +98,9 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), - ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC), + ...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB), + ...createTensorShapeVariables(outputShape) ], }), }; diff --git a/web/lib/wasm/jsep/webgpu/types.ts b/web/lib/wasm/jsep/webgpu/types.ts index a34b6190b7244..ba5b84fcfe067 100644 --- a/web/lib/wasm/jsep/webgpu/types.ts +++ b/web/lib/wasm/jsep/webgpu/types.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../wasm-common'; import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; @@ -26,7 +27,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float16'|'float32'|'uint32'; + type: DataType; data: number|readonly number[]; }