Skip to content

Commit

Permalink
[js/webgpu] support customop FastGelu (#19392)
Browse files Browse the repository at this point in the history
### Description
Support WebGPU custom operator FastGelu.
  • Loading branch information
fs-eire committed Mar 15, 2024
1 parent c3e7768 commit 6239c2c
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 8 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Do not modify directly.*
| Erf | ai.onnx(9-12,13+) | |
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
| FastGelu | com.microsoft(1+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
Expand Down Expand Up @@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Erf', [unaryOps.erf]],
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['FastGelu', [fastGelu]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
${shaderHelper.declareVariables(input, bias, output)}
${erfImpl(`vec4<${dataType}>`, dataType)}
${erfImpl(dataType)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Expand Down
69 changes: 69 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common';
import * as unary from './unary-op';

// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias.

const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => {
const dataType = inputTensors[0].dataType;
const outputSize = ShapeUtil.size(inputTensors[0].dims);
const biasLength = ShapeUtil.size(inputTensors[1].dims);
// can only use vec4 when bias length is multiple of 4
const useVec4 = biasLength % 4 === 0;
const getShaderSource = (shaderHelper: ShaderHelper): string => {
const x = inputVariable('x', dataType, [1], 4);
const bias = inputVariable('bias', dataType, [1], 4);
const y = outputVariable('y', dataType, [1], 4);

const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}];

const singleElementBias = (i: 0|1|2|3) => `
let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size;
let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`;
const biasGetExpression = useVec4 ?
`
let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` :
`${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)}
let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`;

return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)}
${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))}
${shaderHelper.mainStart(WORKGROUP_SIZE)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')}
let x = ${x.getByOffset('global_idx')};
${biasGetExpression}
let x_in = x + bias;
${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))}
}`;
};

return {
name: 'FastGeluWithBias',
shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']},
getShaderSource,
getRunData: (inputs) => ({
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
programUniforms:
[{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}],
dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)}
})
};
};

export const fastGelu = (context: ComputeContext): void => {
if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) {
unary.fastGelu(context);
} else {
context.compute(createFastGeluProgramInfo(context.inputs));
}
};
33 changes: 26 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string, varType = 'f32') => `
export const erfImpl = (varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
let absv = abs(v);
let x = 1.0 / (1.0 + r0 * absv);
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
}`;

export const erf = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => {
export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl(`vec4<${dataType}>`, dataType)));
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down Expand Up @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
};

export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;

export const tanh = (context: ComputeContext): void => {
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
};

export const fastGeluImpl = (varType = 'f32') => `
const fast_gelu_a: ${varType} = 0.5;
const fast_gelu_b: ${varType} = 0.7978845608028654;
const fast_gelu_c: ${varType} = 0.035677408136300125;
fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
return ${tanhExpression('v')};
}
`;

export const fastGeluExpression = (x: string) =>
`(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;

export const fastGelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`));
context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined,
context.inputs[0].dataType));
};

export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
Expand Down
211 changes: 211 additions & 0 deletions js/web/test/data/ops/fast-gelu.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
[
{
"name": "FastGelu test without bias",
"operator": "FastGelu",
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [0.841192],
"dims": [],
"type": "float32"
}
]
},
{
"name": "[2x4]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[3x5]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581,
1.0617, 1.17393, 1.28671, 1.39957
],
"dims": [3, 5],
"type": "float32"
}
]
}
]
},
{
"name": "FastGelu test with bias",
"operator": "FastGelu",
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "float32"
},
{
"data": [0.5],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1.39957],
"dims": [],
"type": "float32"
}
]
},
{
"name": "[2x4], [4]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[2x4], [3]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[3x5], [2]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
},
{
"data": [2, 3],
"dims": [2],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869,
4.39999, 3.49938
],
"dims": [3, 5],
"type": "float32"
}
]
},
{
"name": "[3x5], [7]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
},
{
"data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7],
"dims": [7],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989,
4.09996, 3.59959
],
"dims": [3, 5],
"type": "float32"
}
]
},
{
"name": "[4x4], [8]",
"inputs": [
{
"data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0],
"dims": [4, 4],
"type": "float32"
},
{
"data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957,
1.39957, 4.39999, 1.0617, -0.149419, 3.09737
],
"dims": [4, 4],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@
"equal.jsonc",
"exp.jsonc",
"expand.jsonc",
"fast-gelu.jsonc",
"floor.jsonc",
"gather-elements.jsonc",
"gemm.jsonc",
Expand Down
Loading

0 comments on commit 6239c2c

Please sign in to comment.