Skip to content

Commit

Permalink
[core] Support explicit padding in tf.conv. (#3427)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
annxingyuan authored Jun 12, 2020
1 parent 5db28e7 commit 9b50644
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 22 deletions.
8 changes: 5 additions & 3 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
// tslint:disable: variable-name
// Unfortunately just enabling PascalCase per file (tslint:enable:
// allow-pascal-case) doesn't work.
import {ExplicitPadding} from '../src/ops/conv_util';

import {NamedTensorInfoMap, TensorInfo} from './kernel_registry';
import {DataType, PixelData} from './types';

Expand Down Expand Up @@ -104,7 +106,7 @@ export const Conv2D = 'Conv2D';
export type Conv2DInputs = Pick<NamedTensorInfoMap, 'x'|'filter'>;
export interface Conv2DAttrs {
strides: [number, number]|number;
pad: 'valid'|'same'|number;
pad: 'valid'|'same'|number|ExplicitPadding;
dataFormat: 'NHWC'|'NCHW';
dilations: [number, number]|number;
dimRoundingMode?: 'floor'|'round'|'ceil';
Expand All @@ -114,7 +116,7 @@ export const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
export type Conv2DBackpropFilterInputs = Pick<NamedTensorInfoMap, 'x'|'dy'>;
export interface Conv2DBackpropFilterAttrs {
strides: [number, number]|number;
pad: 'valid'|'same'|number;
pad: 'valid'|'same'|number|ExplicitPadding;
dataFormat: 'NHWC'|'NCHW';
dimRoundingMode?: 'floor'|'round'|'ceil';
}
Expand All @@ -123,7 +125,7 @@ export const Conv2DBackpropInput = 'Conv2DBackpropInput';
export type Conv2DBackpropInputInputs = Pick<NamedTensorInfoMap, 'dy'|'filter'>;
export interface Conv2DBackpropInputAttrs {
strides: [number, number]|number;
pad: 'valid'|'same'|number;
pad: 'valid'|'same'|number|ExplicitPadding;
dataFormat: 'NHWC'|'NCHW';
dimRoundingMode?: 'floor'|'round'|'ceil';
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/conv1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ import {op} from './operation';
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
function conv1d_<T extends Tensor2D|Tensor3D>(
x: T|TensorLike, filter: Tensor3D|TensorLike, stride: number,
pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
pad: 'valid'|'same'|number|conv_util.ExplicitPadding,
dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
const $x = convertToTensor(x, 'x', 'conv1d');
const $filter = convertToTensor(filter, 'filter', 'conv1d');
Expand Down
20 changes: 20 additions & 0 deletions tfjs-core/src/ops/conv1d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ import {expectArraysClose} from '../test_util';
import {Rank} from '../types';

describeWithFlags('conv1d', ALL_ENVS, () => {
it('conv1d input=2x2x1,d2=1,f=1,s=1,d=1,p=explicit', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad =
[[0, 0], [0, 0], [0, 0], [0, 0]] as tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

const result = tf.conv1d(x, w, stride, pad, dataFormat, dilation);

expect(result.shape).toEqual([2, 2, 1]);
expectArraysClose(await result.data(), [3, 6, 9, 12]);
});

it('conv1d input=2x2x1,d2=1,f=1,s=1,d=1,p=same', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ import {op} from './operation';
/** @doc {heading: 'Operations', subheading: 'Convolution'} */
function conv2d_<T extends Tensor3D|Tensor4D>(
x: T|TensorLike, filter: Tensor4D|TensorLike,
strides: [number, number]|number, pad: 'valid'|'same'|number,
strides: [number, number]|number,
pad: 'valid'|'same'|number|conv_util.ExplicitPadding,
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
dilations: [number, number]|number = [1, 1],
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/conv2d_backprop_filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ import {op} from './operation';
*/
function conv2DBackpropFilter_<T extends Tensor3D|Tensor4D>(
x: T, dy: T, filterShape: [number, number, number, number],
strides: [number, number]|number, pad: 'valid'|'same'|number,
strides: [number, number]|number,
pad: 'valid'|'same'|number|conv_util.ExplicitPadding,
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor4D {
let x4D = x as Tensor4D;
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/conv2d_backprop_input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ import {op} from './operation';
function conv2DBackpropInput_<T extends Tensor3D|Tensor4D>(
xShape: [number, number, number, number]|[number, number, number], dy: T,
filter: Tensor4D, strides: [number, number]|number,
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
pad: 'valid'|'same'|number|conv_util.ExplicitPadding,
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
util.assert(
xShape.length === dy.rank,
Expand Down
42 changes: 42 additions & 0 deletions tfjs-core/src/ops/conv2d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ describeWithFlags('conv2d', ALL_ENVS, () => {
expectArraysClose(resultData, [133, 66, 200, 102, 108, 58, 56, 58]);
});

it('x=[4,2,1] f=[4,2,1,1] s=1 d=1 p=explicit', async () => {
const inputDepth = 1;
const outputDepth = 1;
const pad =
[[0, 0], [1, 2], [0, 1], [0, 0]] as tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NHWC';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [4, 2, inputDepth]);
const w =
tf.tensor4d([3, 1, 5, 0, 2, 7, 8, 9], [4, 2, inputDepth, outputDepth]);

const result = tf.conv2d(x, w, stride, pad, dataFormat, dilation);

const resultData = await result.data();
expect(result.shape).toEqual([4, 2, 1]);
expectArraysClose(resultData, [133, 66, 200, 102, 108, 58, 56, 58]);
});

it('x=[2,2,1] f=[2,2,1,1] s=1 d=1 p=same', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
Expand Down Expand Up @@ -208,6 +228,28 @@ describeWithFlags('conv2d', ALL_ENVS, () => {
expectArraysClose(resultData, [20, 26, 13, 12]);
});

it('x=[1,2,2] f=[2,2,1,1] s=1 d=1 p=explicit NCHW', async () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [inputDepth, 2, 2];
const outputDepth = 1;
const fSize = 2;
const pad =
[[0, 0], [0, 0], [0, 1], [0, 1]] as tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NCHW';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w =
tf.tensor4d([3, 1, 5, 0], [fSize, fSize, inputDepth, outputDepth]);

const result = tf.conv2d(x, w, stride, pad, dataFormat, dilation);

const resultData = await result.data();
expect(result.shape).toEqual([1, 2, 2]);
expectArraysClose(resultData, [20, 26, 13, 12]);
});

it('x=[2,2,2] f=[2,2,2,1] s=1 d=1 p=same NCHW', async () => {
const inputDepth = 2;
const inputShape: [number, number, number] = [inputDepth, 2, 2];
Expand Down
40 changes: 32 additions & 8 deletions tfjs-core/src/ops/conv_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@

import * as util from '../util';

type PadType = 'SAME'|'VALID'|'NUMBER';
type PadType = 'SAME'|'VALID'|'NUMBER'|'EXPLICIT';

// For NHWC should be in the following form:
// [[0, 0], [pad_top,pad_bottom], [pad_left, pad_right], [0, 0]]
// For NCHW should be in the following form:
// [[0, 0], [0, 0], [pad_top,pad_bottom], [pad_left, pad_right]]
// Reference: https://www.tensorflow.org/api_docs/python/tf/nn/conv2d
export type ExplicitPadding =
[[number, number], [number, number], [number, number], [number, number]];

export type PadInfo = {
top: number,
Expand Down Expand Up @@ -126,8 +134,8 @@ export function computeConv2DInfo(
inShape: [number, number, number, number],
filterShape: [number, number, number, number],
strides: number|[number, number], dilations: number|[number, number],
pad: 'same'|'valid'|number, roundingMode?: 'floor'|'round'|'ceil',
depthwise = false,
pad: 'same'|'valid'|number|ExplicitPadding,
roundingMode?: 'floor'|'round'|'ceil', depthwise = false,
dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv2DInfo {
let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
if (dataFormat === 'channelsLast') {
Expand All @@ -148,7 +156,7 @@ export function computeConv2DInfo(
getEffectiveFilterSize(filterWidth, dilationWidth);
const {padInfo, outHeight, outWidth} = getPadAndOutInfo(
pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight,
effectiveFilterWidth, roundingMode);
effectiveFilterWidth, roundingMode, dataFormat);

const outChannels = depthwise ? filterChannels * inChannels : filterChannels;

Expand Down Expand Up @@ -399,10 +407,12 @@ function getEffectiveFilterSize(filterSize: number, dilation: number) {
}

function getPadAndOutInfo(
pad: 'same'|'valid'|number, inHeight: number, inWidth: number,
strideHeight: number, strideWidth: number, filterHeight: number,
filterWidth: number, roundingMode?: 'floor'|'round'|'ceil'):
{padInfo: PadInfo, outHeight: number, outWidth: number} {
pad: 'same'|'valid'|number|ExplicitPadding, inHeight: number,
inWidth: number, strideHeight: number, strideWidth: number,
filterHeight: number, filterWidth: number,
roundingMode: 'floor'|'round'|'ceil',
dataFormat: 'channelsFirst'|
'channelsLast'): {padInfo: PadInfo, outHeight: number, outWidth: number} {
let padInfo: PadInfo;
let outHeight: number;
let outWidth: number;
Expand Down Expand Up @@ -430,6 +440,20 @@ function getPadAndOutInfo(
padInfo = {top: 0, bottom: 0, left: 0, right: 0, type: 'VALID'};
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else if (typeof pad === 'object') {
const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
'VALID' :
'EXPLICIT';
padInfo = {top, bottom, left, right, type: padType};
outHeight = conditionalRound(
(inHeight - filterHeight + top + bottom) / strideHeight + 1,
roundingMode);
outWidth = conditionalRound(
(inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/ops/fused_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function fusedConv2d_<T extends Tensor3D|Tensor4D>({
x: T|TensorLike,
filter: Tensor4D|TensorLike,
strides: [number, number]|number,
pad: 'valid'|'same'|number,
pad: 'valid'|'same'|number|conv_util.ExplicitPadding,
dataFormat?: 'NHWC'|'NCHW',
dilations?: [number, number]|number,
dimRoundingMode?: 'floor'|'round'|'ceil',
Expand Down
21 changes: 21 additions & 0 deletions tfjs-core/src/ops/fused_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,27 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('basic with explicit padding', async () => {
const inputDepth = 1;
const outputDepth = 1;
const pad =
[[0, 0], [1, 2], [0, 1], [0, 0]] as tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NHWC';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [4, 2, inputDepth]);
const w =
tf.tensor4d([3, 1, 5, 0, 2, 7, 8, 9], [4, 2, inputDepth, outputDepth]);

const result = tf.fused.conv2d(
{x, filter: w, strides: stride, pad, dataFormat, dilations: dilation});

const resultData = await result.data();
expect(result.shape).toEqual([4, 2, 1]);
expectArraysClose(resultData, [133, 66, 200, 102, 108, 58, 56, 58]);
});

it('basic with elu', async () => {
const inputDepth = 2;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
Expand Down
11 changes: 6 additions & 5 deletions tfjs-core/src/public/chained_ops/conv1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@
* =============================================================================
*/
import {conv1d} from '../../ops/conv1d';
import {ExplicitPadding} from '../../ops/conv_util';
import {Tensor, Tensor2D, Tensor3D} from '../../tensor';
import {Rank, TensorLike3D} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
conv1d<T extends Tensor2D|Tensor3D>(
filter: Tensor3D|TensorLike3D, stride: number,
pad: 'valid'|'same'|number, dataFormat?: 'NWC'|'NCW', dilation?: number,
dimRoundingMode?: 'floor'|'round'|'ceil'): T;
pad: 'valid'|'same'|number|ExplicitPadding, dataFormat?: 'NWC'|'NCW',
dilation?: number, dimRoundingMode?: 'floor'|'round'|'ceil'): T;
}
}

Tensor.prototype.conv1d = function<T extends Tensor2D|Tensor3D>(
filter: Tensor3D|TensorLike3D, stride: number, pad: 'valid'|'same'|number,
dataFormat?: 'NWC'|'NCW', dilation?: number,
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
filter: Tensor3D|TensorLike3D, stride: number,
pad: 'valid'|'same'|number|ExplicitPadding, dataFormat?: 'NWC'|'NCW',
dilation?: number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {
this.throwIfDisposed();
return conv1d(
this, filter, stride, pad, dataFormat, dilation,
Expand Down
15 changes: 14 additions & 1 deletion tfjs-node/src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ export class NodeJSKernelBackend extends KernelBackend {

conv2d(x: Tensor4D, filter: Tensor4D, convInfo: backend_util.Conv2DInfo):
Tensor4D {
if (convInfo.padInfo.type !== 'VALID' && convInfo.padInfo.type !== 'SAME') {
if (convInfo.padInfo.type !== 'VALID' && convInfo.padInfo.type !== 'SAME' &&
convInfo.padInfo.type !== 'EXPLICIT') {
throw new Error(
`TF Backend supports only 'valid' and 'same' padding ` +
`while padding was ${convInfo.padInfo.type}`);
Expand All @@ -867,6 +868,18 @@ export class NodeJSKernelBackend extends KernelBackend {
{name: 'use_cudnn_on_gpu', type: this.binding.TF_ATTR_BOOL, value: true},
{name: 'dilations', type: this.binding.TF_ATTR_INT, value: dilations},
];
if (padding === 'EXPLICIT') {
const padValue = [
convInfo.padInfo.top, convInfo.padInfo.bottom, convInfo.padInfo.left,
convInfo.padInfo.right
];
opAttrs.push({
name: 'explicit_paddings',
type: this.binding.TF_ATTR_INT,
value: dataFormat === 'NHWC' ? [0, 0, ...padValue, 0, 0] :
[0, 0, 0, 0, ...padValue]
});
}
return this.executeSingleOutput('Conv2D', opAttrs, [x, filter]) as Tensor4D;
}

Expand Down
1 change: 1 addition & 0 deletions tfjs-node/src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const IGNORE_LIST: string[] = [
'diag test-tensorflow {} bool',
// See https://github.com/tensorflow/tfjs/issues/1891
'conv2d test-tensorflow {} x=[2,1,2,2] f=[1,1,1,1] s=1 d=1 p=0 NCHW',
'conv2d test-tensorflow {} x=[1,2,2] f=[2,2,1,1] s=1 d=1 p=explicit NCHW',
'conv2d test-tensorflow {} x=[1,2,2] f=[2,2,1,1] s=1 d=1 p=same NCHW',
'conv2d test-tensorflow {} x=[2,2,2] f=[2,2,2,1] s=1 d=1 p=same NCHW',
'conv2d test-tensorflow {} x=[2,1,2,2] f=[2,2,1,1] s=1 d=1 p=same NCHW',
Expand Down

0 comments on commit 9b50644

Please sign in to comment.