Skip to content

Commit

Permalink
[WebGPU] Ensure tf.any and tf.all return bool tensors (#7928)
Browse files Browse the repository at this point in the history
Fix a bug where tf.all and tf.any return int tensors instead of bool tensors.
  • Loading branch information
mattsoulanille authored Aug 28, 2023
1 parent d42502e commit 93489b2
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
9 changes: 7 additions & 2 deletions tfjs-backend-webgpu/src/kernel_utils/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, sumOutType, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';
import {backend_util, DataType, sumOutType, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {reshape} from '../kernels/Reshape';
Expand All @@ -26,6 +26,11 @@ import {maxImplCPU} from './shared';
import {prodImplCPU} from './shared';

type ReduceTypes = 'all'|'any'|'max'|'mean'|'min'|'prod'|'sum';
const RETURN_TYPES: {[key in ReduceTypes]?: DataType} = {
'mean': 'float32',
'all': 'bool',
'any': 'bool',
};

export function reduce(
x: TensorInfo, axis: number|number[], keepDims: boolean,
Expand Down Expand Up @@ -79,7 +84,7 @@ export function reduce(
const batchSize = xSize / inSize;

const reduceInfo = {windowSize: inSize, inSize, batchSize, outSize: 1};
const dtype = reduceType === 'mean' ? 'float32' : sumOutType(x.dtype);
const dtype = RETURN_TYPES[reduceType] || sumOutType(x.dtype);
const uniformData = [
{type: 'int32', data: [inSize]},
];
Expand Down
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/all_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,10 @@ describeWithFlags('all', ALL_ENVS, () => {
.toThrowError(
/Argument 'x' passed to 'all' must be bool tensor, but got string/);
});

it('returns a boolean tensor', () => {
const a = tf.tensor1d([1, 0], 'bool');
const r = tf.all(a);
expect(r.dtype).toEqual('bool');
});
});
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/any_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,10 @@ describeWithFlags('any', ALL_ENVS, () => {
expect(() => tf.any(['a']))
.toThrowError(/Argument 'x' passed to 'any' must be bool tensor/);
});

it('returns a boolean tensor', () => {
const a = tf.tensor1d([1, 0], 'bool');
const r = tf.any(a);
expect(r.dtype).toEqual('bool');
});
});
6 changes: 3 additions & 3 deletions tfjs/karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ module.exports = function(config) {
browser: 'chrome',
browser_version: 'latest',
os: 'OS X',
os_version: 'Sierra'
os_version: 'High Sierra'
},
bs_firefox_mac: {
base: 'BrowserStack',
browser: 'firefox',
browser_version: 'latest',
browser_version: '90',
os: 'OS X',
os_version: 'Sierra'
os_version: 'High Sierra'
},
chrome_with_swift_shader: {
base: 'Chrome',
Expand Down

0 comments on commit 93489b2

Please sign in to comment.