Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Complex ops. #3478

Draft
wants to merge 33 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import './flags_wasm';

import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, BackendTimingInfo, DataStorage, DataType, engine, env, KernelBackend, registerBackend, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasmModule, WasmFactoryConfig} from '../wasm-out/tfjs-backend-wasm';
import wasmFactorySimd from '../wasm-out/tfjs-backend-wasm-simd.js';
Expand All @@ -31,6 +31,7 @@ interface TensorData {
dtype: DataType;
/** Only used for string tensors, storing encoded bytes. */
stringBytes?: Uint8Array[];
complexTensors?: {real: TensorInfo, imag: TensorInfo};
}

export type DataId = object; // object instead of {} to force non-primitive.
Expand Down Expand Up @@ -64,6 +65,10 @@ export class BackendWasm extends KernelBackend {
return {kernelMs};
}

clone(x: TensorInfo): Tensor {
return engine().makeTensorFromDataId(x.dataId, x.shape, x.dtype);
}

move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType): void {
Expand Down Expand Up @@ -97,14 +102,21 @@ export class BackendWasm extends KernelBackend {
}

readSync(dataId: DataId): backend_util.BackendValues {
const {memoryOffset, dtype, shape, stringBytes} =
const {memoryOffset, dtype, shape, stringBytes, complexTensors} =
this.dataIdMap.get(dataId);
if (dtype === 'string') {
return stringBytes;
}
const bytes = this.wasm.HEAPU8.slice(
memoryOffset,
memoryOffset + util.sizeFromShape(shape) * util.bytesPerElement(dtype));
if (dtype === 'complex64') {
const realValues =
this.readSync(complexTensors.real.dataId) as Float32Array;
const imagValues =
this.readSync(complexTensors.imag.dataId) as Float32Array;
return backend_util.mergeRealAndImagArrays(realValues, imagValues);
}
return typedArrayFromBuffer(bytes.buffer, dtype);
}

Expand All @@ -113,6 +125,11 @@ export class BackendWasm extends KernelBackend {
this.wasm._free(data.memoryOffset);
this.wasm.tfjs.disposeData(data.id);
this.dataIdMap.delete(dataId);

// if (data.complexTensors) {
// this.disposeData(data.complexTensors.real.dataId);
// this.disposeData(data.complexTensors.imag.dataId);
// }
}

floatPrecision(): 32 {
Expand Down Expand Up @@ -167,6 +184,8 @@ export class BackendWasm extends KernelBackend {
return new Int32Array(buffer, memoryOffset, size);
case 'bool':
return new Uint8Array(buffer, memoryOffset, size);
case 'complex64':
return new Float32Array(buffer, memoryOffset, size);
default:
throw new Error(`Uknown dtype ${dtype}`);
}
Expand Down Expand Up @@ -270,6 +289,8 @@ function typedArrayFromBuffer(
return new Int32Array(buffer);
case 'bool':
return new Uint8Array(buffer);
case 'complex64':
return new Float32Array(buffer);
default:
throw new Error(`Unknown dtype ${dtype}`);
}
Expand Down
39 changes: 39 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ tfjs_cc_library(
":Div",
":Equal",
":Exp",
":FFT",
":FloorDiv",
":FusedBatchNorm",
":FusedConv2D",
Expand All @@ -205,6 +206,7 @@ tfjs_cc_library(
":GatherNd",
":Greater",
":GreaterEqual",
":IFFT",
":Less",
":LessEqual",
":Max",
Expand All @@ -224,6 +226,7 @@ tfjs_cc_library(
":Relu",
":Relu6",
":ResizeBilinear",
":Reverse",
":ScatterNd",
":SelectV2",
":Sigmoid",
Expand Down Expand Up @@ -426,6 +429,24 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "FFT",
srcs = ["kernels/FFT.cc"],
deps = [
":backend",
":fft_impl",
],
)

tfjs_cc_library(
name = "fft_impl",
srcs = ["fft_impl.cc"],
hdrs = ["fft_impl.h"],
deps = [
":backend",
],
)

tfjs_cc_library(
name = "FloorDiv",
srcs = ["kernels/FloorDiv.cc"],
Expand Down Expand Up @@ -518,6 +539,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "IFFT",
srcs = ["kernels/IFFT.cc"],
deps = [
":backend",
":fft_impl",
],
)

tfjs_cc_library(
name = "Less",
srcs = ["kernels/Less.cc"],
Expand Down Expand Up @@ -746,6 +776,15 @@ tfjs_unit_test(
],
)

tfjs_cc_library(
name = "Reverse",
srcs = ["kernels/Reverse.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "ScatterNd",
srcs = ["kernels/ScatterNd.cc"],
Expand Down
84 changes: 84 additions & 0 deletions tfjs-backend-wasm/src/cc/fft_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cmath>
#include <vector>

#include "src/cc/backend.h"
#include "src/cc/fft_impl.h"

namespace tfjs {
namespace wasm {

void fft(const size_t real_input_id, const size_t imag_input_id,
const size_t outer_dim, const size_t inner_dim,
const size_t is_real_component, const bool is_inverse,
const size_t out_id) {
auto& real_input_info = backend::get_tensor_info(real_input_id);
const float* real_input_buf = real_input_info.f32();
auto& imag_input_info = backend::get_tensor_info(imag_input_id);
const float* imag_input_buf = imag_input_info.f32();

auto& out_info = backend::get_tensor_info_out(out_id);
float* out_buf_ptr = out_info.f32_write();
const size_t input_size = real_input_info.size;

float exponent_multiplier;
if (is_inverse) {
exponent_multiplier = 2.0 * M_PI;
} else {
exponent_multiplier = -2.0 * M_PI;
}

for (size_t row = 0; row < outer_dim; ++row) {
for (size_t col = 0; col < inner_dim; ++col) {
float index_ratio = float(col) / float(inner_dim);
float exponent_multiplier_times_index_ratio =
exponent_multiplier * index_ratio;

float result = 0.0;

for (size_t i = 0; i < inner_dim; ++i) {
float x = exponent_multiplier_times_index_ratio * float(i);
float exp_r = cos(x);
float exp_i = sin(x);
float real = real_input_buf[row * inner_dim + i];
float imag = imag_input_buf[row * inner_dim + i];

float val;

if (is_real_component > 0) {
val = real * exp_r - imag * exp_i;
} else {
val = real * exp_i + imag * exp_r;
}

if (is_inverse) {
val = val / float(inner_dim);
}

result += val;
}

*out_buf_ptr = result;
out_buf_ptr++;
}
}
}
} // namespace wasm
} // namespace tfjs
32 changes: 32 additions & 0 deletions tfjs-backend-wasm/src/cc/fft_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifndef FFT_IMPL_H_
#define FFT_IMPL_H_

#include <cstddef>

#include "src/cc/backend.h"

namespace tfjs {
namespace wasm {

void fft(const size_t real_input_id, const size_t imag_input_id,
const size_t outer_dim, const size_t inner_dim,
const size_t is_real_component, const bool is_inverse,
const size_t out_id);
} // namespace wasm
} // namespace tfjs

#endif // FFT_IMPL_H_
41 changes: 41 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/FFT.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cmath>
#include <vector>

#include "src/cc/backend.h"
#include "src/cc/fft_impl.h"

namespace tfjs {
namespace wasm {
extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

void FFT(const size_t real_input_id, const size_t imag_input_id,
const size_t outer_dim, const size_t inner_dim,
const size_t is_real_component, const size_t out_id) {
const bool is_inverse = false;
tfjs::wasm::fft(real_input_id, imag_input_id, outer_dim, inner_dim,
is_real_component, is_inverse, out_id);
}
} // extern "C"
} // namespace wasm
} // namespace tfjs
41 changes: 41 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/IFFT.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cmath>
#include <vector>

#include "src/cc/backend.h"
#include "src/cc/fft_impl.h"

namespace tfjs {
namespace wasm {
extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

void IFFT(const size_t real_input_id, const size_t imag_input_id,
const size_t outer_dim, const size_t inner_dim,
const size_t is_real_component, const size_t out_id) {
const bool is_inverse = true;
tfjs::wasm::fft(real_input_id, imag_input_id, outer_dim, inner_dim,
is_real_component, is_inverse, out_id);
}
} // extern "C"
} // namespace wasm
} // namespace tfjs
Loading