forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pack.ts
138 lines (119 loc) · 3.82 KB
/
pack.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Tensor} from '../../../tensor';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types';
import {getCoordsDataType} from '../utils';
import {getChannels} from './packing-utils';
const packProgramMetadata = {
name: 'pack',
inputNames: ['A'],
inputTypes: [TextureType.unpackedReversed]
};
const createPackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
const glsl = getGlsl(handler.session.backend.glContext.version);
const inputShape = input.dims;
const inputRank = inputShape.length;
// createTextureLayoutFromShape won't change output rank. Need to verify by running tests
const outputRank = input.dims.length;
const coordsDataType = getCoordsDataType(outputRank);
const channels = getChannels('rc', outputRank);
const setup = getSetup(outputRank, channels, inputShape[inputShape.length - 2], inputShape[inputShape.length - 1]);
let reversedInputWH;
if (inputRank === 0) {
reversedInputWH = [1, 1];
} else if (inputRank === 1) {
reversedInputWH = [inputShape[0], 1];
} else {
reversedInputWH = [inputShape[outputRank - 1], inputShape[outputRank - 2]];
}
const outOfBoundsCondition = getOutOfBoundsCondition(outputRank, reversedInputWH, channels);
const output = getOutput(inputShape, channels);
const shaderSource = `
void main() {
${coordsDataType} rc = getOutputCoords();
if(${outOfBoundsCondition}) {
${glsl.output} = vec4(0);
} else {
${setup}
${glsl.output} = vec4(${output});
}
}
`;
return {
...packProgramMetadata,
hasMain: true,
output: {dims: input.dims, type: input.type, textureType: TextureType.packed},
shaderSource
};
};
export const createPackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader =>
({...packProgramMetadata, get: () => createPackProgramInfo(handler, input)});
/**
* check output coordinate location and return false if it is outside input's width/height boundary
*/
function getOutOfBoundsCondition(rank: number, shape: readonly number[], dims: string[]): string {
if (rank === 0) {
return 'false';
}
if (rank === 1) {
return `rc > ${shape[0]}`;
}
let cond = '';
for (let i = rank - 2; i < rank; i++) {
cond += `${dims[i]} >= ${shape[i - rank + 2]}`;
if (i < rank - 1) {
cond += '||';
}
}
return cond;
}
/**
* code snippet to sample input texture with output coordiantes
*/
function getOutput(shape: readonly number[], dims: string[]): string {
const rank = shape.length;
if (rank === 0) {
return 'getA(), 0, 0, 0';
}
if (rank === 1) {
return `getA(rc),
rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
0, 0`;
}
const coord00 = 'r, c';
const coord01 = 'r, cp1';
const coord10 = 'rp1, c';
const coord11 = 'rp1, cp1';
let D = '';
if (rank > 2) {
for (let i = 0; i < rank - 2; ++i) {
D = D + `${dims[i]},`;
}
}
return `getA(${D}${coord00}),
rEdge ? 0. : getA(${D}${coord10}),
cEdge ? 0. : getA(${D}${coord01}),
rEdge || cEdge ? 0. : getA(${D}${coord11})`;
}
/**
* code snippet to setup 4 coordinates and edge conditions
*/
function getSetup(rank: number, dims: string[], rows: number, cols: number): string {
if (rank === 0 || rank === 1) {
return '';
}
// rank >= 2 for width+height pack.
else {
const setup = `
int r = ${dims[rank - 2]};
int c = ${dims[rank - 1]};
int rp1 = ${dims[rank - 2]} + 1;
int cp1 = ${dims[rank - 1]} + 1;
bool rEdge = rp1 >= ${cols};
bool cEdge = cp1 >= ${rows};
`;
return setup;
}
}