Skip to content

Commit

Permalink
Modularize batchToSpaceND. (#3272)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
annxingyuan authored May 14, 2020
1 parent 602111a commit df06f39
Show file tree
Hide file tree
Showing 14 changed files with 403 additions and 284 deletions.
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/BatchToSpaceND_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {BatchToSpaceND, BatchToSpaceNDAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {spaceToBatchND} from '../ops/array_ops';
import {Tensor} from '../tensor';

export const batchToSpaceNDGradConfig: GradConfig = {
kernelName: BatchToSpaceND,
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const {blockShape, crops} = attrs as {} as BatchToSpaceNDAttrs;
return {x: () => spaceToBatchND(dy, blockShape, crops)};
}
};
7 changes: 7 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ export interface BatchMatMulAttrs {
transposeB: boolean;
}

export const BatchToSpaceND = 'BatchToSpaceND';
export type BatchToSpaceNDInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface BatchToSpaceNDAttrs {
blockShape: number[];
crops: number[][];
}

export type BinaryInputs = Pick<NamedTensorInfoMap, 'a'|'b'>;

export const BroadcastTo = 'BroadcastTo';
Expand Down
80 changes: 1 addition & 79 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,83 +159,6 @@ function stack_<T extends Tensor>(
return concat(expandedTensors, axis);
}

/**
* This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
* shape `blockShape + [batch]`, interleaves these blocks back into the grid
* defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
* the same rank as the input. The spatial dimensions of this intermediate
* result are then optionally cropped according to `crops` to produce the
* output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
* description.
*
* ```js
* const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
* const blockShape = [2, 2];
* const crops = [[0, 0], [0, 0]];
*
* x.batchToSpaceND(blockShape, crops).print();
* ```
*
* @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
* remainingShape`, where spatialShape has `M` dimensions.
* @param blockShape A 1-D array. Must have shape `[M]`, all values must
* be >= 1.
* @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
* `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
* dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
* that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
*
* This operation is equivalent to the following steps:
*
* 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
* blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
* x.shape[N-1]]`
*
* 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
* prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
* prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
* according to `crops` to produce the output of shape: `[batch /
* prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
* ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
* crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function batchToSpaceND_<T extends Tensor>(
x: T|TensorLike, blockShape: number[], crops: number[][]): T {
const $x = convertToTensor(x, 'x', 'batchToSpaceND');
const prod = blockShape.reduce((a, b) => a * b);

util.assert(
$x.rank >= 1 + blockShape.length,
() => `input rank is ${$x.rank} but should be > than blockShape.length ${
blockShape.length}`);

util.assert(
crops.length === blockShape.length,
() => `crops.length is ${
crops.length} but should be equal to blockShape.length ${
blockShape.length}`);

util.assert(
$x.shape[0] % prod === 0,
() => `input tensor batch is ${
$x.shape[0]} but is not divisible by the product of ` +
`the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);

const grad = (dy: T) => {
return {$x: () => dy.spaceToBatchND(blockShape, crops)};
};

return ENGINE.runKernelFunc(
backend => backend.batchToSpaceND($x, blockShape, crops), {$x}, grad);
}

/**
* This operation divides "spatial" dimensions `[1, ..., M]` of the input into
* a grid of blocks of shape `blockShape`, and interleaves these blocks with
Expand Down Expand Up @@ -314,7 +237,7 @@ function spaceToBatchND_<T extends Tensor>(
blockShape.toString()}`);

const grad = (dy: T) => {
return {$x: () => dy.batchToSpaceND(blockShape, paddings)};
return {$x: () => dy.batchToSpaceND(blockShape, paddings) as T};
};

return ENGINE.runKernelFunc(
Expand Down Expand Up @@ -620,7 +543,6 @@ export {
print // Not wrapped in op() since no need to increase stack trace.
};

export const batchToSpaceND = op({batchToSpaceND_});
export const cast = op({cast_});
export const cumsum = op({cumsum_});
export const depthToSpace = op({depthToSpace_});
Expand Down
194 changes: 0 additions & 194 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2852,200 +2852,6 @@ describeWithFlags('cumsum', ALL_ENVS, () => {
});
});

describeWithFlags('batchToSpaceND', ALL_ENVS, () => {
it('tensor4d, input shape=[4, 1, 1, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 1]);
expectArraysClose(await res.data(), [1, 2, 3, 4]);
});

it('tensor4d, input shape=[4, 1, 1, 3], blockShape=[2, 2]', async () => {
const t =
tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [4, 1, 1, 3]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 3]);
expectArraysClose(
await res.data(), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
});

it('tensor4d, input shape=[4, 2, 2, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16], [4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 4, 4, 1]);
expectArraysClose(
await res.data(),
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
});

it('tensor4d, input shape=[8, 1, 3, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d(
[
0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12,
0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16
],
[8, 1, 3, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [2, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([2, 2, 4, 1]);
expectArraysClose(
await res.data(),
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
});

it('tensor2d, blockShape [1]', async () => {
const t = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const blockShape = [2];
const crops = [[0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 4]);
expectArraysClose(await res.data(), [1, 3, 2, 4]);
});

it('tensor3d, blockSHape [1]', async () => {
const t = tf.tensor(
[
-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96,
44, -55, -64, -88, -94, 65, -32, -96, -73, -2, -77,
-14, 47, 33, 15, 70, 20, 75, 28, 84, -13
],
[8, 2, 2]);
const blockShape = [2];
const crops = [[0, 2]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([4, 2, 2]);
expectArraysClose(
await res.data(),
[-61, 37, 65, -32, 31, 62, -2, -77, 28, 54, 33, 15, -55, -64, 75, 28]);
});

it('tensor3d, blockShape [2]', async () => {
const t = tf.tensor(
[
-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96,
44, -55, -64, -88, -94, 65, -32, -96, -73, -2, -77,
-14, 47, 33, 15, 70, 20, 75, 28, 84, -13
],
[8, 2, 2]);
const blockShape = [2, 2];
const crops = [[2, 0], [2, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([2, 2, 2]);
expectArraysClose(await res.data(), [72, 44, -73, 20, -13, -94, 47, -13]);
});

it('throws when blockShape equal to input rank', () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2, 2, 2];
const crops = [[0, 0], [0, 0], [0, 0], [0, 0]];

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(
`input rank is ${t.rank} but should be > than blockShape.length ${
blockShape.length}`);
});

it('throws when crops row dimension not equal to blockshape', () => {
const t = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0]];

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(`crops.length is ${
crops.length} but should be equal to blockShape.length ${
blockShape.length}`);
});

it('throws when input tensor batch not divisible by prod(blockShape)', () => {
const t = tf.tensor4d([1, 2, 3, 4, 5], [5, 1, 1, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const prod = blockShape.reduce((a, b) => a * b);

expect(() => tf.batchToSpaceND(t, blockShape, crops))
.toThrowError(
`input tensor batch is ${t.shape[0]} but is not divisible by the ` +
`product of the elements of blockShape ${
blockShape.join(' * ')} === ${prod}`);
});

it('accepts a tensor-like object', async () => {
const t = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]];
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];

const res = tf.batchToSpaceND(t, blockShape, crops);
expect(res.shape).toEqual([1, 2, 2, 1]);
expectArraysClose(await res.data(), [1, 2, 3, 4]);
});

it('gradients, input shape=[4, 2, 2], block shape=[2]', async () => {
const t = tf.tensor(
[-61, 37, -68, 72, 31, 62, 0, -13, 28, 54, 96, 44, -55, -64, -88, -94],
[4, 2, 2]);
const blockShape = [2];
const crops = [[0, 2]];
const dy = tf.tensor([.01, .02, .03, .04, .05, .06, .07, .08], [2, 2, 2]);

const gradient =
tf.grad(t => tf.batchToSpaceND(t, blockShape, crops))(t, dy);
expect(gradient.shape).toEqual([4, 2, 2]);
expectArraysClose(await gradient.data(), [
0.01, 0.02, 0, 0, 0.05, 0.06, 0, 0, 0.03, 0.04, 0, 0, 0.07, 0.08, 0, 0
]);
});

it('gradients, input shape=[4, 2, 2, 1], block shape=[2, 2]', async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16], [4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const dy = tf.tensor(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 4, 1]);

const gradient =
tf.grad(t => tf.batchToSpaceND(t, blockShape, crops))(t, dy);
expect(gradient.shape).toEqual([4, 2, 2, 1]);
expectArraysClose(
await gradient.data(),
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16]);
});

it('gradient with clones, input=[4, 2, 2, 1], block shape=[2, 2]',
async () => {
const t = tf.tensor4d(
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16],
[4, 2, 2, 1]);
const blockShape = [2, 2];
const crops = [[0, 0], [0, 0]];
const dy = tf.tensor(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[1, 4, 4, 1]);

const gradient = tf.grad(
t => tf.batchToSpaceND(t.clone(), blockShape, crops).clone())(t, dy);
expect(gradient.shape).toEqual([4, 2, 2, 1]);
expectArraysClose(
await gradient.data(),
[1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16]);
});
});

describeWithFlags('spaceToBatchND', ALL_ENVS, () => {
it('tensor4d, input shape=[1, 2, 2, 1], blockShape=[2, 2]', async () => {
const t = tf.tensor4d([[[[1], [2]], [[3], [4]]]], [1, 2, 2, 1]);
Expand Down
Loading

0 comments on commit df06f39

Please sign in to comment.