diff --git a/tfjs-backend-wasm/src/backend_wasm.ts b/tfjs-backend-wasm/src/backend_wasm.ts index 0299073da66..4d1ef29fa3c 100644 --- a/tfjs-backend-wasm/src/backend_wasm.ts +++ b/tfjs-backend-wasm/src/backend_wasm.ts @@ -16,7 +16,7 @@ */ import './flags_wasm'; -import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm'; import wasmFactorySimd from '../wasm-out/tfjs-backend-wasm-simd.js'; @@ -31,6 +31,7 @@ interface TensorData { dtype: DataType; /** Only used for string tensors, storing encoded bytes. */ stringBytes?: Uint8Array[]; + complexTensors?: {real: TensorInfo, imag: TensorInfo}; } export type DataId = object; // object instead of {} to force non-primitive. @@ -64,6 +65,10 @@ export class BackendWasm extends KernelBackend { return {kernelMs}; } + clone(x: TensorInfo): Tensor { + return engine().makeTensorFromDataId(x.dataId, x.shape, x.dtype); + } + move( dataId: DataId, values: backend_util.BackendValues, shape: number[], dtype: DataType): void { @@ -97,7 +102,7 @@ export class BackendWasm extends KernelBackend { } readSync(dataId: DataId): backend_util.BackendValues { - const {memoryOffset, dtype, shape, stringBytes} = + const {memoryOffset, dtype, shape, stringBytes, complexTensors} = this.dataIdMap.get(dataId); if (dtype === 'string') { return stringBytes; @@ -105,6 +110,13 @@ export class BackendWasm extends KernelBackend { const bytes = this.wasm.HEAPU8.slice( memoryOffset, memoryOffset + util.sizeFromShape(shape) * util.bytesPerElement(dtype)); + if (dtype === 'complex64') { + const realValues = + this.readSync(complexTensors.real.dataId) as Float32Array; + const imagValues = + this.readSync(complexTensors.imag.dataId) as Float32Array; + return backend_util.mergeRealAndImagArrays(realValues, imagValues); + } return typedArrayFromBuffer(bytes.buffer, dtype); } @@ -113,6 +125,11 @@ export class BackendWasm extends KernelBackend { this.wasm._free(data.memoryOffset); this.wasm.tfjs.disposeData(data.id); this.dataIdMap.delete(dataId); + + // if (data.complexTensors) { + // this.disposeData(data.complexTensors.real.dataId); + // this.disposeData(data.complexTensors.imag.dataId); + // } } floatPrecision(): 32 { @@ -167,6 +184,8 @@ export class BackendWasm extends KernelBackend { return new Int32Array(buffer, memoryOffset, size); case 'bool': return new Uint8Array(buffer, memoryOffset, size); + case 'complex64': + return new Float32Array(buffer, memoryOffset, size); default: throw new Error(`Uknown dtype ${dtype}`); } @@ -270,6 +289,8 @@ function typedArrayFromBuffer( return new Int32Array(buffer); case 'bool': return new Uint8Array(buffer); + case 'complex64': + return new Float32Array(buffer); default: throw new Error(`Unknown dtype ${dtype}`); } diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 0ef8f64fc2c..51d4f97a52d 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -197,6 +197,7 @@ tfjs_cc_library( ":Div", ":Equal", ":Exp", + ":FFT", ":FloorDiv", ":FusedBatchNorm", ":FusedConv2D", @@ -205,6 +206,7 @@ tfjs_cc_library( ":GatherNd", ":Greater", ":GreaterEqual", + ":IFFT", ":Less", ":LessEqual", ":Max", @@ -224,6 +226,7 @@ tfjs_cc_library( ":Relu", ":Relu6", ":ResizeBilinear", + ":Reverse", ":ScatterNd", ":SelectV2", ":Sigmoid", @@ -426,6 +429,24 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "FFT", + srcs = ["kernels/FFT.cc"], + deps = [ + ":backend", + ":fft_impl", + ], +) + +tfjs_cc_library( + name = "fft_impl", + srcs = ["fft_impl.cc"], + hdrs = ["fft_impl.h"], + deps = [ + ":backend", + ], +) + tfjs_cc_library( name = "FloorDiv", srcs = ["kernels/FloorDiv.cc"], @@ -518,6 +539,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "IFFT", + srcs = ["kernels/IFFT.cc"], + deps = [ + ":backend", + ":fft_impl", + ], +) + tfjs_cc_library( name = "Less", srcs = ["kernels/Less.cc"], @@ -746,6 +776,15 @@ tfjs_unit_test( ], ) +tfjs_cc_library( + name = "Reverse", + srcs = ["kernels/Reverse.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "ScatterNd", srcs = ["kernels/ScatterNd.cc"], diff --git a/tfjs-backend-wasm/src/cc/fft_impl.cc b/tfjs-backend-wasm/src/cc/fft_impl.cc new file mode 100644 index 00000000000..b5a269bb938 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/fft_impl.cc @@ -0,0 +1,84 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/fft_impl.h" + +namespace tfjs { +namespace wasm { + +void fft(const size_t real_input_id, const size_t imag_input_id, + const size_t outer_dim, const size_t inner_dim, + const size_t is_real_component, const bool is_inverse, + const size_t out_id) { + auto& real_input_info = backend::get_tensor_info(real_input_id); + const float* real_input_buf = real_input_info.f32(); + auto& imag_input_info = backend::get_tensor_info(imag_input_id); + const float* imag_input_buf = imag_input_info.f32(); + + auto& out_info = backend::get_tensor_info_out(out_id); + float* out_buf_ptr = out_info.f32_write(); + const size_t input_size = real_input_info.size; + + float exponent_multiplier; + if (is_inverse) { + exponent_multiplier = 2.0 * M_PI; + } else { + exponent_multiplier = -2.0 * M_PI; + } + + for (size_t row = 0; row < outer_dim; ++row) { + for (size_t col = 0; col < inner_dim; ++col) { + float index_ratio = float(col) / float(inner_dim); + float exponent_multiplier_times_index_ratio = + exponent_multiplier * index_ratio; + + float result = 0.0; + + for (size_t i = 0; i < inner_dim; ++i) { + float x = exponent_multiplier_times_index_ratio * float(i); + float exp_r = cos(x); + float exp_i = sin(x); + float real = real_input_buf[row * inner_dim + i]; + float imag = imag_input_buf[row * inner_dim + i]; + + float val; + + if (is_real_component > 0) { + val = real * exp_r - imag * exp_i; + } else { + val = real * exp_i + imag * exp_r; + } + + if (is_inverse) { + val = val / float(inner_dim); + } + + result += val; + } + + *out_buf_ptr = result; + out_buf_ptr++; + } + } +} +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/fft_impl.h b/tfjs-backend-wasm/src/cc/fft_impl.h new file mode 100644 index 00000000000..3bc0875e746 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/fft_impl.h @@ -0,0 +1,32 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifndef FFT_IMPL_H_ +#define FFT_IMPL_H_ + +#include + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { + +void fft(const size_t real_input_id, const size_t imag_input_id, + const size_t outer_dim, const size_t inner_dim, + const size_t is_real_component, const bool is_inverse, + const size_t out_id); +} // namespace wasm +} // namespace tfjs + +#endif // FFT_IMPL_H_ diff --git a/tfjs-backend-wasm/src/cc/kernels/FFT.cc b/tfjs-backend-wasm/src/cc/kernels/FFT.cc new file mode 100644 index 00000000000..5aaab0f7004 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/FFT.cc @@ -0,0 +1,41 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/fft_impl.h" + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void FFT(const size_t real_input_id, const size_t imag_input_id, + const size_t outer_dim, const size_t inner_dim, + const size_t is_real_component, const size_t out_id) { + const bool is_inverse = false; + tfjs::wasm::fft(real_input_id, imag_input_id, outer_dim, inner_dim, + is_real_component, is_inverse, out_id); +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/IFFT.cc b/tfjs-backend-wasm/src/cc/kernels/IFFT.cc new file mode 100644 index 00000000000..1bbbd173ffe --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/IFFT.cc @@ -0,0 +1,41 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/fft_impl.h" + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void IFFT(const size_t real_input_id, const size_t imag_input_id, + const size_t outer_dim, const size_t inner_dim, + const size_t is_real_component, const size_t out_id) { + const bool is_inverse = true; + tfjs::wasm::fft(real_input_id, imag_input_id, outer_dim, inner_dim, + is_real_component, is_inverse, out_id); +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Reverse.cc b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc new file mode 100644 index 00000000000..7f1459decc2 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Reverse.cc @@ -0,0 +1,68 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Reverse(const size_t x_id, const size_t* axes_ptr, + const size_t axes_length, const size_t* out_shape_ptr, + const size_t out_shape_length, const size_t out_id) { + auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_length); + auto axes = std::vector(axes_ptr, axes_ptr + axes_length); + + auto& x_info = backend::get_tensor_info(x_id); + const float* x_buf = x_info.f32(); + + auto& out_info = backend::get_tensor_info_out(out_id); + float* out_buf = out_info.f32_write(); + + size_t x_size = x_info.size; + + const std::vector out_strides = + tfjs::util::compute_strides(out_shape); + + for (size_t i = 0; i < x_size; ++i) { + const std::vector out_loc = + tfjs::util::offset_to_loc(i, out_strides); + + std::vector in_loc = out_loc; + for (size_t ax_i = 0; ax_i < axes_length; ++ax_i) { + size_t ax = axes[ax_i]; + in_loc[ax] = out_shape[ax] - 1 - in_loc[ax]; + } + + const size_t x_position = tfjs::util::loc_to_offset(in_loc, out_strides); + out_buf[i] = x_buf[x_position]; + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index d7f2bd6e914..881d997fcfc 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -73,8 +73,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - spyOn(console, 'warn'); - spyOn(console, 'log'); + // spyOn(console, 'warn'); + // spyOn(console, 'log'); }); afterEach(() => { @@ -138,4 +138,29 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); + + // fit('3 fftLength, 5 frameLength, 2 step', async () => { + // const input = tf.tensor1d([1, 1, 1, 1, 1, 1]); + // const frameLength = 5; + // const frameStep = 1; + // const fftLength = 3; + // const output = tf.signal.stft(input, frameLength, frameStep, fftLength); + // expect(output.shape[0]).toEqual(2); + // const data = await output.data(); + // console.log(data); + // // expectArraysClose( + // // await output.data(), + // // [1.5, 0.0, -0.749999, 0.433, 1.5, 0.0, -0.749999, 0.433]); + // }); + + // it('test', async () => { + // const t1Real = tf.tensor1d([1, 2, 3]); + // const t1Imag = tf.tensor1d([0, 0, 0]); + // const t1 = tf.complex(t1Real, t1Imag); + // const data = await tf.spectral.fft(t1).data(); + // console.log(data); + // // expectArraysClose( + // // await tf.spectral.fft(t1).data(), + // // [6, 0, -1.5, 0.866025, -1.5, -0.866025]); + // }); }); diff --git a/tfjs-backend-wasm/src/kernels/Complex.ts b/tfjs-backend-wasm/src/kernels/Complex.ts new file mode 100644 index 00000000000..939f857837f --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Complex.ts @@ -0,0 +1,34 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +import {Complex, ComplexInputs, engine, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function complex(args: {backend: BackendWasm, inputs: ComplexInputs}): + TensorInfo { + const {backend, inputs} = args; + const {real, imag} = inputs; + + const out = backend.makeOutput(real.shape, 'complex64'); + const outData = backend.dataIdMap.get(out.dataId); + outData.complexTensors = { + real: engine().keep(backend.clone(real)), + imag: engine().keep(backend.clone(imag)) + }; + + return out; +} + +registerKernel({kernelName: Complex, backendName: 'wasm', kernelFunc: complex}); diff --git a/tfjs-backend-wasm/src/kernels/FFT.ts b/tfjs-backend-wasm/src/kernels/FFT.ts new file mode 100644 index 00000000000..5c0252a9e8d --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/FFT.ts @@ -0,0 +1,76 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +import {FFT, FFTInputs, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {complex} from './Complex'; +import {reshape} from './Reshape'; + +let wasmFFT: ( + realInputId: number, imagInputId: number, outerDim: number, + innerDim: number, isRealComponent: number, outputId: number) => void; + +function setup(backend: BackendWasm): void { + wasmFFT = backend.wasm.cwrap(FFT, null, [ + 'number', // realInputId + 'number', // imagInputId + 'number', // outerDim + 'number', // innerDim + 'number', // isRealComponent + 'number', // outputId + ]); +} + +function fft(args: {backend: BackendWasm, inputs: FFTInputs}): TensorInfo { + const {backend, inputs} = args; + const {input} = inputs; + + const innerDimensionSize = input.shape[input.shape.length - 1]; + const batch = util.sizeFromShape(input.shape) / innerDimensionSize; + const input2D = reshape({ + inputs: {x: input}, + attrs: {shape: [batch, innerDimensionSize]}, + backend + }); + + const inputData = backend.dataIdMap.get(input2D.dataId); + const realInput = inputData.complexTensors.real; + const imagInput = inputData.complexTensors.imag; + const realInputId = backend.dataIdMap.get(realInput.dataId).id; + const imagInputId = backend.dataIdMap.get(imagInput.dataId).id; + + const real = backend.makeOutput(realInput.shape, realInput.dtype); + const imag = backend.makeOutput(imagInput.shape, imagInput.dtype); + const realId = backend.dataIdMap.get(real.dataId).id; + const imagId = backend.dataIdMap.get(imag.dataId).id; + + const [outerDim, innerDim] = input2D.shape; + + wasmFFT( + realInputId, imagInputId, outerDim, innerDim, 1 /* is real component */, + realId); + wasmFFT( + realInputId, imagInputId, outerDim, innerDim, + 0 /* is not real component */, imagId); + + const out = complex({backend, inputs: {real, imag}}); + backend.disposeData(realInput.dataId); + backend.disposeData(imagInput.dataId); + + return reshape({inputs: {x: out}, attrs: {shape: input2D.shape}, backend}); +} + +registerKernel( + {kernelName: FFT, backendName: 'wasm', setupFunc: setup, kernelFunc: fft}); diff --git a/tfjs-backend-wasm/src/kernels/IFFT.ts b/tfjs-backend-wasm/src/kernels/IFFT.ts new file mode 100644 index 00000000000..962be10b9fd --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/IFFT.ts @@ -0,0 +1,73 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +import {IFFT, IFFTInputs, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {complex} from './Complex'; +import {reshape} from './Reshape'; + +let wasmFFT: ( + realInputId: number, imagInputId: number, outerDim: number, + innerDim: number, isRealComponent: number, outputId: number) => void; + +function setup(backend: BackendWasm): void { + wasmFFT = backend.wasm.cwrap(IFFT, null, [ + 'number', // realInputId + 'number', // imagInputId + 'number', // outerDim + 'number', // innerDim + 'number', // isRealComponent + 'number', // outputId + ]); +} + +function fft(args: {backend: BackendWasm, inputs: IFFTInputs}): TensorInfo { + const {backend, inputs} = args; + const {input} = inputs; + + const innerDimensionSize = input.shape[input.shape.length - 1]; + const batch = util.sizeFromShape(input.shape) / innerDimensionSize; + const input2D = reshape({ + inputs: {x: input}, + attrs: {shape: [batch, innerDimensionSize]}, + backend + }); + + const inputData = backend.dataIdMap.get(input2D.dataId); + const realInput = inputData.complexTensors.real; + const imagInput = inputData.complexTensors.imag; + const realInputId = backend.dataIdMap.get(realInput.dataId).id; + const imagInputId = backend.dataIdMap.get(imagInput.dataId).id; + + const real = backend.makeOutput(realInput.shape, realInput.dtype); + const imag = backend.makeOutput(imagInput.shape, imagInput.dtype); + const realId = backend.dataIdMap.get(real.dataId).id; + const imagId = backend.dataIdMap.get(imag.dataId).id; + + const [outerDim, innerDim] = input2D.shape; + + wasmFFT( + realInputId, imagInputId, outerDim, innerDim, 1 /* is real component */, + realId); + wasmFFT( + realInputId, imagInputId, outerDim, innerDim, + 0 /* is not real component */, imagId); + + const out = complex({backend, inputs: {real, imag}}); + return out; +} + +registerKernel( + {kernelName: IFFT, backendName: 'wasm', setupFunc: setup, kernelFunc: fft}); diff --git a/tfjs-backend-wasm/src/kernels/Imag.ts b/tfjs-backend-wasm/src/kernels/Imag.ts new file mode 100644 index 00000000000..277608b6fee --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Imag.ts @@ -0,0 +1,29 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +import {Imag, ImagInputs, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function imag(args: {backend: BackendWasm, inputs: ImagInputs}): + TensorInfo { + const {backend, inputs} = args; + const {input} = inputs; + + const inputData = backend.dataIdMap.get(input.dataId); + const imagPart = inputData.complexTensors.imag; + return backend.clone(imagPart); +} + +registerKernel({kernelName: Imag, backendName: 'wasm', kernelFunc: imag}); diff --git a/tfjs-backend-wasm/src/kernels/Real.ts b/tfjs-backend-wasm/src/kernels/Real.ts new file mode 100644 index 00000000000..27bad0717ea --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Real.ts @@ -0,0 +1,29 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +import {Real, RealInputs, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +export function real(args: {backend: BackendWasm, inputs: RealInputs}): + TensorInfo { + const {backend, inputs} = args; + const {input} = inputs; + + const inputData = backend.dataIdMap.get(input.dataId); + const realPart = inputData.complexTensors.real; + return backend.clone(realPart); +} + +registerKernel({kernelName: Real, backendName: 'wasm', kernelFunc: real}); diff --git a/tfjs-backend-wasm/src/kernels/Reverse.ts b/tfjs-backend-wasm/src/kernels/Reverse.ts new file mode 100644 index 00000000000..571f4a2bab7 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Reverse.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, Reverse, ReverseAttrs, ReverseInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {reshape} from './Reshape'; + +let wasmReverse: ( + xId: number, axes: Uint8Array, axesLength: number, outShape: Uint8Array, + outShapeLength: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmReverse = backend.wasm.cwrap(Reverse, null, [ + 'number', // x_id + 'array', // axes + 'number', // axes_length + 'array', // out_shape + 'number', // out_shape_length + 'number' // out_id + ]); +} + +export function reverse(args: { + inputs: NamedTensorInfoMap, + backend: BackendWasm, + attrs: NamedAttrMap +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs as {} as ReverseInputs; + const {dims} = attrs as {} as ReverseAttrs; + + const axes = util.parseAxisParam(dims, x.shape); + + if (x.shape.length === 0) { + return backend.clone(x); + } + + const out = backend.makeOutput(x.shape, x.dtype); + const xId = backend.dataIdMap.get(x.dataId).id; + const outId = backend.dataIdMap.get(out.dataId).id; + + const axesBytes = new Uint8Array(new Int32Array(axes).buffer); + const outShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + + wasmReverse( + xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId); + + return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend}); +} + +registerKernel({ + kernelName: Reverse, + backendName: 'wasm', + kernelFunc: reverse, + setupFunc: setup +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index ee032cfc119..04ae06ef920 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -29,6 +29,7 @@ import './Cast'; import './ClipByValue'; import './Concat'; import './Conv2D'; +import './Complex'; import './Conv2DBackpropInput'; import './Cos'; import './CropAndResize'; @@ -36,6 +37,7 @@ import './DepthwiseConv2dNative'; import './Div'; import './Equal'; import './Exp'; +import './FFT'; import './Fill'; import './FloorDiv'; import './FusedBatchNorm'; @@ -45,6 +47,8 @@ import './Gather'; import './GatherNd'; import './Greater'; import './GreaterEqual'; +import './IFFT'; +import './Imag'; import './Less'; import './LessEqual'; import './Log'; @@ -64,10 +68,12 @@ import './OnesLike'; import './PadV2'; import './Pow'; import './Prelu'; +import './Real'; import './Relu'; import './Relu6'; import './Reshape'; import './ResizeBilinear'; +import './Reverse'; import './Rsqrt'; import './SelectV2'; import './ScatterNd'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 0c14b49c728..1e6679c9b82 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -188,6 +188,11 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // Gradient is missing. ] }, + { + include: 'fft', + excludes: ['stft'] // Complex support for concat not implemented yet. + }, + {include: 'fft'}, {include: 'slice '}, {include: 'square '}, { @@ -308,6 +313,7 @@ const TEST_FILTERS: TestFilter[] = [ 'axis=0', // Reduction not supported along inner dimensions. ] }, + {startsWith: 'reverse'}, {startsWith: 'sum '}, { startsWith: 'logicalAnd ', diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b705fdb228f..ad2f09bebb8 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -244,6 +244,9 @@ export type EluGradInputs = Pick; export const Equal = 'Equal'; export type EqualInputs = BinaryInputs; +export const FFT = 'FFT'; +export type FFTInputs = Pick; + export const FloorDiv = 'FloorDiv'; export type FloorDivInputs = BinaryInputs; @@ -273,6 +276,9 @@ export type GreaterEqualInputs = BinaryInputs; export const Identity = 'Identity'; export type IdentityInputs = Pick; +export const IFFT = 'IFFT'; +export type IFFTInputs = Pick; + export const Imag = 'Imag'; export type ImagInputs = Pick; diff --git a/tfjs-core/src/ops/spectral_ops.ts b/tfjs-core/src/ops/spectral_ops.ts index 8cfd88fec69..2a32cda14f7 100644 --- a/tfjs-core/src/ops/spectral_ops.ts +++ b/tfjs-core/src/ops/spectral_ops.ts @@ -16,11 +16,13 @@ */ import {ENGINE} from '../engine'; +import {FFT, FFTInputs, IFFT, IFFTInputs} from '../kernel_names'; import {complex} from '../ops/complex'; import {imag} from '../ops/imag'; import {op} from '../ops/operation'; import {real} from '../ops/real'; import {Tensor, Tensor2D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {assert} from '../util'; import {scalar, zeros} from './tensor_ops'; @@ -49,12 +51,15 @@ function fft_(input: Tensor): Tensor { () => `The dtype for tf.spectral.fft() must be complex64 ` + `but got ${input.dtype}.`); - // Collapse all outer dimensions to a single batch dimension. - const innerDimensionSize = input.shape[input.shape.length - 1]; - const batch = input.size / innerDimensionSize; - const input2D = input.as2D(batch, innerDimensionSize); + const inputs: FFTInputs = {input}; - const ret = ENGINE.runKernelFunc(backend => backend.fft(input2D), {input}); + const ret = ENGINE.runKernelFunc(backend => { + // Collapse all outer dimensions to a single batch dimension. + const innerDimensionSize = input.shape[input.shape.length - 1]; + const batch = input.size / innerDimensionSize; + const input2D = input.as2D(batch, innerDimensionSize); + return backend.fft(input2D); + }, inputs as {} as NamedTensorMap, null /* gradient */, FFT); return ret.reshape(input.shape); } @@ -83,12 +88,15 @@ function ifft_(input: Tensor): Tensor { () => `The dtype for tf.spectral.ifft() must be complex64 ` + `but got ${input.dtype}.`); - // Collapse all outer dimensions to a single batch dimension. - const innerDimensionSize = input.shape[input.shape.length - 1]; - const batch = input.size / innerDimensionSize; - const input2D = input.as2D(batch, innerDimensionSize); + const inputs: IFFTInputs = {input}; - const ret = ENGINE.runKernelFunc(backend => backend.ifft(input2D), {input}); + const ret = ENGINE.runKernelFunc(backend => { + // Collapse all outer dimensions to a single batch dimension. + const innerDimensionSize = input.shape[input.shape.length - 1]; + const batch = input.size / innerDimensionSize; + const input2D = input.as2D(batch, innerDimensionSize); + return backend.ifft(input2D); + }, inputs as {} as NamedTensorMap, null /* gradient */, IFFT); return ret.reshape(input.shape); }