Skip to content

Commit

Permalink
Add math.concat[1-4]D and math.slice[1-4]D (tensorflow#151)
Browse files Browse the repository at this point in the history
* add slice1d

* fix lstm bug

* generalize concat to support 1-4D and add slice[1-4]D

* update .gitignore

* add slice implementations for math_cpu and tests

* address comments

* Merge master into rnn
  • Loading branch information
dsmilkov authored and mnottheone committed Dec 1, 2018
1 parent c31ef75 commit 34039a3
Show file tree
Hide file tree
Showing 26 changed files with 1,454 additions and 623 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ index.ts
npm-debug.log
.DS_Store
dist/
.idea/
.idea/
6 changes: 2 additions & 4 deletions src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,10 @@ export class Concat3DNode extends Node {
public axis: number) {
super(
graph, 'Concat3D', {x1, x2},
new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis)));
new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis)));
}
validate() {
concat_util.assertConcatShapesMatch(
this.x1.shape, this.x2.shape, 3, this.axis);
concat_util.assertParams(this.x1.shape, this.x2.shape, this.axis);
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/graph/ops/concat3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ export class Concat3D extends Operation {
private x1Tensor: Tensor, private x2Tensor: Tensor, private axis: number,
private yTensor: Tensor) {
super();
concat_util.assertConcatShapesMatch(
x1Tensor.shape, x2Tensor.shape, 3, axis);
concat_util.assertParams(x1Tensor.shape, x2Tensor.shape, axis);
}

feedForward(math: NDArrayMath, inferenceArrays: TensorArrayMap) {
Expand Down
9 changes: 3 additions & 6 deletions src/graph/ops/concat3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand All @@ -75,8 +74,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand All @@ -99,8 +97,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand Down
34 changes: 16 additions & 18 deletions src/math/concat_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,33 @@

import * as util from '../util';

export function assertConcatShapesMatch(
x1Shape: number[], x2Shape: number[], rank: number, axis: number,
errorMessagePrefix = '') {
export function assertParams(aShape: number[], bShape: number[], axis: number) {
const aRank = aShape.length;
const bRank = bShape.length;
util.assert(
x1Shape.length === rank,
errorMessagePrefix + `x1 shape should be of rank ${rank}.`);
util.assert(
x2Shape.length === rank,
errorMessagePrefix + `x2 shape should be of rank ${rank}.`);
aShape.length === bShape.length,
`Error in concat${aRank}D: rank of x1 (${aRank}) and x2 (${bRank}) ` +
`must be the same.`);

util.assert(
axis >= 0 && axis < rank, `axis must be between 0 and ${rank - 1}.`);
axis >= 0 && axis < aRank,
`Error in concat${aRank}D: axis must be ` +
`between 0 and ${aRank - 1}.`);

for (let i = 0; i < rank; i++) {
for (let i = 0; i < aRank; i++) {
util.assert(
(i === axis) || (x1Shape[i] === x2Shape[i]),
errorMessagePrefix +
`Shape (${x1Shape}) does not match (${x2Shape}) along ` +
`the non-concatenated axis ${i}.`);
(i === axis) || (aShape[i] === bShape[i]),
`Error in concat${aRank}D: Shape (${aShape}) does not match ` +
`(${bShape}) along the non-concatenated axis ${i}.`);
}
}

export function computeConcatOutputShape(
x1Shape: number[], x2Shape: number[],
axis: number): [number, number, number] {
export function computeOutShape(
x1Shape: number[], x2Shape: number[], axis: number): number[] {
util.assert(
x1Shape.length === x2Shape.length,
'x1 and x2 should have the same rank.');
const outputShape = x1Shape.slice();
outputShape[axis] += x2Shape[axis];
return outputShape as [number, number, number];
return outputShape;
}
15 changes: 8 additions & 7 deletions src/math/concat_util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,39 @@ import * as concat_util from './concat_util';
describe('concat_util.assertConcatShapesMatch rank=3D', () => {
it('Non-3D tensor x1', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1], [1, 2, 3], 3, 1);
concat_util.assertParams([1], [1, 2, 3], 1);
};

expect(assertFn).toThrow();
});

it('Non-3D tensor x2', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1, 2, 3], [2, 3], 3, 1);
concat_util.assertParams([1, 2, 3], [2, 3], 1);
};

expect(assertFn).toThrow();
});

it('axis out of bound', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1, 2, 3], [1, 2, 3], 3, 4);
concat_util.assertParams([1, 2, 3], [1, 2, 3], 4);
};

expect(assertFn).toThrow();
});

it('non-axis shape mismatch', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([2, 3, 3], [2, 2, 4], 3, 2);
concat_util.assertParams([2, 3, 3], [2, 2, 4], 2);
};

expect(assertFn).toThrow();
});

it('shapes line up', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([2, 3, 3], [2, 3, 4], 3, 2);
concat_util.assertParams([2, 3, 3], [2, 3, 4], 2);
};

expect(assertFn).not.toThrow();
Expand All @@ -61,7 +61,8 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => {

describe('concat_util.computeConcatOutputShape', () => {
it('compute output shape, axis=0', () => {
expect(concat_util.computeConcatOutputShape([2, 2, 3], [1, 2, 3], 0))
.toEqual([3, 2, 3]);
expect(concat_util.computeOutShape([2, 2, 3], [1, 2, 3], 0)).toEqual([
3, 2, 3
]);
});
});
Loading

0 comments on commit 34039a3

Please sign in to comment.