Skip to content

Commit

Permalink
[JS/Web] Enabled 1d spacial input to GlobalAveragePool (#17973)
Browse files Browse the repository at this point in the history
### Description
Enable one-dim special  input to GlobalAveragePoll input



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Currently only 2D input is supported.
  • Loading branch information
satyajandhyala authored Oct 23, 2023
1 parent 780ee18 commit f3cfe08
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Pool ops requires 1 input.');
}
if (inputs[0].dims.length !== 4) {
throw new Error('Pool ops supports 2-D inputs only for now.');
if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
throw new Error('Pool ops supports 1-D or 2-D inputs only for now.');
}
};

const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
const isChannelsLast = attributes.format === 'NHWC';
const inputShapeAsChannelFirst =
isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice();
const inputShapeAsChannelFirst = input.dims.slice();
if (isChannelsLast) {
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
}
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
const kernelShape = attributes.kernelShape.slice();
const strides = attributes.strides.slice();
Expand All @@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
} else {
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
}
return [
newAttributes,
isChannelsLast ?
[
outputShapeAsChannelFirst[0], outputShapeAsChannelFirst[2], outputShapeAsChannelFirst[3],
outputShapeAsChannelFirst[1]
] :
outputShapeAsChannelFirst
];
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
};

const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
Expand All @@ -76,22 +72,22 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
let codeHEnd = '';
if (pwStart + pwEnd !== 0) {
codeW = `
for (var i: u32 = 0u; i < ${kw}u; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
pad++;
continue;
}
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
for (var i: u32 = 0u; i < ${kw}u; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
pad++;
continue;
}
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
} else {
codeW = `
for (var i: u32 = 0u; i < ${kw}u; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
for (var i: u32 = 0u; i < ${kw}u; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
let x_val = x[${x.indicesToOffset('xIndices')}];
${op1}
}`;
}

if (attributes.kernelShape.length === 2) {
Expand Down

0 comments on commit f3cfe08

Please sign in to comment.