diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index b482d30515d4..8ac4e1ff7d3a 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -48,9 +48,9 @@ tvm.contrib.dlpack .. automodule:: tvm.contrib.dlpack :members: -tvm.contrib.emscripten -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: tvm.contrib.emscripten +tvm.contrib.emcc +~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.emcc :members: tvm.contrib.miopen diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 0d1a4e214791..de8f7b565c09 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -88,6 +88,9 @@ def find_lib_path(name=None, search_path=None, optional=False): dll_path.append(install_lib_dir) + if os.path.isdir(source_dir): + dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: if isinstance(search_path, list): @@ -154,6 +157,7 @@ def find_include_path(name=None, search_path=None, optional=False): ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source_dir = os.path.join(ffi_dir, "..", "..", "..") install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py similarity index 65% rename from python/tvm/contrib/emscripten.py rename to python/tvm/contrib/emcc.py index 7f31273451f7..6df205a030bc 100644 --- a/python/tvm/contrib/emscripten.py +++ b/python/tvm/contrib/emcc.py @@ -16,18 +16,16 @@ # under the License. """Util to invoke emscripten compilers in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import subprocess -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path +from tvm._ffi.base import py_str +from tvm._ffi.libinfo import find_lib_path + -def create_js(output, - objects, - options=None, - side_module=False, - cc="emcc"): - """Create emscripten javascript library. +def create_tvmjs_wasm(output, + objects, + options=None, + cc="emcc"): + """Create wasm that is supposed to run with the tvmjs. Parameters ---------- @@ -44,25 +42,27 @@ def create_js(output, The compile string. """ cmd = [cc] - cmd += ["-Oz"] - if not side_module: - cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"] - cmd += ["-s", "NO_EXIT_RUNTIME=1"] - extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction'] - cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]" - cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg] - else: - cmd += ["-s", "SIDE_MODULE=1"] - cmd += ["-o", output] + cmd += ["-O3"] + + cmd += ["-std=c++14"] + cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] + cmd += ["-s", "STANDALONE_WASM=1"] + cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"] + + objects = [objects] if isinstance(objects, str) else objects + with_runtime = False for obj in objects: - if obj.find("libtvm_web_runtime.bc") != -1: + if obj.find("wasm_runtime.bc") != -1: with_runtime = True - if not with_runtime and not side_module: - objects += [find_lib_path("libtvm_web_runtime.bc")[0]] + if not with_runtime: + objects += [find_lib_path("wasm_runtime.bc")[0]] + objects += [find_lib_path("tvmjs_support.bc")[0]] + + cmd += ["-o", output] cmd += objects if options: @@ -79,4 +79,4 @@ def create_js(output, msg += py_str(out) raise RuntimeError(msg) -create_js.object_format = "bc" +create_tvmjs_wasm.object_format = "bc" diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 4cf341335ea7..59da8fa5a3bf 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -29,12 +29,11 @@ def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - base_path = os.path.join(curr_path, "../../../") - index_page = os.path.join(base_path, "web/example_rpc.html") + base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) + index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") js_files = [ - os.path.join(base_path, "web/tvm_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js.mem") + os.path.join(base_path, "web/dist/tvmjs.bundle.js"), + os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js") ] for fname in [index_page] + js_files: if not os.path.exists(fname): @@ -69,7 +68,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", + parser.add_argument('--host', type=str, default="localhost", help='the hostname of the server') parser.add_argument('--port', type=int, default=9090, help='The port of the RPC') diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 04d6c94e10e8..da3a456dafb6 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -36,6 +36,7 @@ "scala", "java", "go", + "ts", "sh", "py", "pyi", @@ -81,6 +82,7 @@ # List of file names allowed ALLOW_FILE_NAME = { ".gitignore", + ".eslintignore", ".gitattributes", "README", "Makefile", @@ -107,8 +109,7 @@ "rust/runtime/tests/test_wasm32/.cargo/config", "apps/sgx/.cargo/config", # html for demo purposes - "tests/webgl/test_static_webgl_library.html", - "web/example_rpc.html", + "web/apps/browser/rpc_server.html", # images are normally not allowed # discuss with committers before add more images "apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png", diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 5421d22a08aa..0714850287f3 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -28,6 +28,8 @@ core.cpp build _static _build +node_modules +dist .*~ \#..*\# \.#.* @@ -40,6 +42,7 @@ RelayVisitor.py # Specific files package-list MANIFEST +.eslintignore .gitignore .gitattributes .gitmodules diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 819961dc6ebd..41006f41f754 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -43,9 +43,6 @@ cd .. make doc rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5 -# JS doc -jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md - # Java doc make javadoc @@ -54,7 +51,6 @@ rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo mv docs/doxygen/html _docs/doxygen -mv out _docs/jsdoc mv jvm/core/target/site/apidocs _docs/javadoc echo "Start creating the docs tarball.." diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js deleted file mode 100644 index d239f7346e74..000000000000 --- a/tests/web/test_packed_func.js +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -function testGetGlobal() { - var targs = [10, 10.0, "hello"] - tvm.registerFunc("my_packed_func", function () { - tvm.assert(Array.from(arguments).toString() == targs, "assert fail"); - return 10 - }); - var f = tvm.getGlobalFunc("my_packed_func") - tvm.assert(tvm.isPackedFunc(f)); - y = f.apply(null, targs); - tvm.assert(y == 10); - f.release(); -} - - -function testReturnFunc() { - function addy(y) { - function add(x) { - return x + y; - } - return add; - } - var myf = tvm.convertFunc(addy); - var f = myf(10); - tvm.assert(tvm.isPackedFunc(f)); - tvm.assert(f(11) == 21); - myf.release(); - f.release(); -} - -function testByteArray() { - var a = new Uint8Array(3); - a[0] = 1; - a[1] = 2; - function myfunc(ss){ - tvm.assert(ss instanceof Uint8Array); - tvm.assert(ss.toString() == a); - } - f = tvm.convertFunc(myfunc); - f(a); - f.release(); -} - -testGetGlobal(); -testReturnFunc(); -testByteArray(); diff --git a/tests/webgl/README.md b/tests/webgl/README.md deleted file mode 100644 index 5303cc059740..000000000000 --- a/tests/webgl/README.md +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - -## Test cases for the WebGL backend - -Any test case with name `test_local_...` tests the C++ OpenGL backend on the -local OS, which can be executed automatically. - -Any test case with name `test_remote_...` tests the WebGL backend within the -browser, which must be run manually. See instruction within the test. diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py deleted file mode 100644 index 6bd22bf0057b..000000000000 --- a/tests/webgl/test_local_gemm.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_gemm(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - nn = 1024 - n = te.var('n') - n = tvm.runtime.convert(nn) - m = n - l = n - A = te.placeholder((n, l), name='A', dtype='int32') - B = te.placeholder((m, l), name='B', dtype='int32') - k = te.reduce_axis((0, l), name='k') - C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), - name='CC') - - s = te.create_schedule(C.op) - s[C].opengl() - print(tvm.lower(s, [A, B, C], simple_mode=True)) - - f = tvm.build(s, [A, B, C], "opengl", name="gemm") - print("------opengl code------") - print(f.imported_modules[0].get_source(fmt="gl")) - - ctx = tvm.opengl() - n, m, l = nn, nn, nn - a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype) - b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - - tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T)) - -if __name__ == "__main__": - test_local_gemm() diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py deleted file mode 100644 index cca68020c0c2..000000000000 --- a/tests/webgl/test_local_save_load.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -def test_local_save_load(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype='int32') - B = te.placeholder((n,), name='B', dtype='int32') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - - f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx) - f(a, b, c) - - temp = util.tempdir() - path_so = temp.relpath("myadd.so") - f.export_library(path_so) - f1 = tvm.runtime.load_module(path_so) - f1(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - test_local_save_load() diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py deleted file mode 100644 index 0d9b7776096a..000000000000 --- a/tests/webgl/test_local_topi_conv2d_nchw.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example code to do convolution. -Copied from topi/tests/python/test_topi_conv2d_nchw.py. -Should be removed once we fix OpenGL testing on Jenkins.""" -import os -import numpy as np -import tvm -from tvm import te -import topi -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name='A') - W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d_nchw(A, W, stride, padding) - C = topi.nn.relu(B) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - - -def test_conv2d_nchw(): - # ResNet18 worklaods - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d_nchw() diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py deleted file mode 100644 index 60dfe1ff690f..000000000000 --- a/tests/webgl/test_local_topi_dense.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for dense operator -Copied from topi/tests/python/test_topi_dense.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -from topi.util import get_const_tuple -from tvm.contrib.pickle_memoize import memoize - - -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - C = te.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_dense(D) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(c_np, ctx) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - - -if __name__ == "__main__": - test_dense() diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py deleted file mode 100644 index 3adae7bba51c..000000000000 --- a/tests/webgl/test_local_topi_pooling.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for pooling -Copied from topi/tests/python/test_topi_pooling.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -import math -from topi.util import get_const_tuple - -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): - iw = ih - kw = kh - sw = sh - ph, pw = padding - A = te.placeholder((n, ic, ih, iw), name='A') - B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, - pool_type=pool_type, ceil_mode=ceil_mode) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - - - a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - elif pool_type =='max': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B], simple_mode=True)) - - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - - - -def verify_global_pool(n, c, h, w, pool_type): - A = te.placeholder((n, c, h, w), name='A') - B = topi.nn.global_pool(A, pool_type=pool_type) - B = topi.nn.relu(B) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - if pool_type == 'avg': - b_np = np.mean(a_np, axis=(2,3), keepdims=True) - elif pool_type =='max': - b_np = np.max(a_np, axis=(2,3), keepdims=True) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_global_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - - -if __name__ == "__main__": - test_pool() - test_global_pool() diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py deleted file mode 100644 index c0ddbf21419a..000000000000 --- a/tests/webgl/test_local_topi_softmax.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for softmax -Copied from topi/tests/python/test_topi_softmax.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" - -import os -import numpy as np -import tvm -from tvm import te -import topi -import logging -from topi.util import get_const_tuple - -def verify_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - - -def verify_log_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.log_softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="log_softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - - -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py deleted file mode 100644 index 34bbb3fa0f00..000000000000 --- a/tests/webgl/test_remote_save_load.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The following instruction is based on web/README.md. - -Setup an RPC server: -$ python -m tvm.exec.rpc_proxy --example-rpc=1 - -Go to http://localhost:9190 in browser. - -Click "Connect To Proxy". - -Run this test script: -$ python tests/webgl/test_remote_save_load.py -""" - -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -proxy_host = "localhost" -proxy_port = 9090 - -def try_remote_save_load(): - if not tvm.runtime.enabled("rpc"): - return - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - # Build the module. - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd") - - remote = rpc.connect(proxy_host, proxy_port, key="js") - - temp = util.tempdir() - ctx = remote.opengl(0) - path_obj = temp.relpath("myadd.bc") - path_dso = temp.relpath("myadd.js") - path_gl = temp.relpath("myadd.gl") - path_json = temp.relpath("myadd.tvm_meta.json") - - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - f.imported_modules[0].save(path_gl) - - remote.upload(path_dso, "myadd.dso") - remote.upload(path_gl) - remote.upload(path_json) - - remote.download("myadd.dso") - remote.download("myadd.gl") - remote.download("myadd.tvm_meta.json") - - print('Loading myadd.dso') - fhost = remote.load_module("myadd.dso") - - print('Loading myadd.gl') - fdev = remote.load_module("myadd.gl") - - print('import_module') - fhost.import_module(fdev) - - print('running...') - a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx) - c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx) - fhost(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - try_remote_save_load() diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html deleted file mode 100644 index f9268c65edf3..000000000000 --- a/tests/webgl/test_static_webgl_library.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - -

TVM Test Page

-
- - - - - - - - \ No newline at end of file diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py deleted file mode 100644 index 929da4ca294c..000000000000 --- a/tests/webgl/test_static_webgl_library.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Create a static WebGL library and run it in the browser.""" - -from __future__ import absolute_import, print_function - -import os, shutil, SimpleHTTPServer, SocketServer -import tvm -from tvm import te -from tvm.contrib import emscripten, util -import numpy as np - -def try_static_webgl_library(): - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - - # Change to lib/ which contains "libtvm_runtime.bc". - os.chdir(os.path.join(curr_path, "../../lib")) - - # Create OpenGL module. - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="float") - B = te.compute((n,), lambda *i: A[i], name="B") - - s = te.create_schedule(B.op) - s[B].opengl() - - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B], name="identity", target="opengl", - target_host=target_host) - - # Create a JS library that contains both the module and the tvm runtime. - path_dso = "identity_static.js" - f.export_library(path_dso, emscripten.create_js, options=[ - "-s", "USE_GLFW=3", - "-s", "USE_WEBGL2=1", - "-lglfw", - ]) - - # Create "tvm_runtime.js" and "identity_static.html" in lib/ - shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"), - "tvm_runtime.js") - shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"), - "identity_static.html") - - port = 8080 - handler = SimpleHTTPServer.SimpleHTTPRequestHandler - httpd = SocketServer.TCPServer(("", port), handler) - print("Please open http://localhost:" + str(port) + "/identity_static.html") - httpd.serve_forever() - -if __name__ == "__main__": - try_static_webgl_library() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py index 45f0599bae4d..a7e55320d6dd 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py @@ -137,7 +137,8 @@ def test_conv2d_nhwc_winograd_direct(): def test_conv2d_nhwc_winograd_tensorcore(): """Test the conv2d with winograd for nhwc layout""" - print("test_winograd_tensorcore...") + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + return verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") @@ -145,8 +146,7 @@ def test_conv2d_nhwc_winograd_tensorcore(): verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore") verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore") + if __name__ == "__main__": test_conv2d_nhwc_winograd_direct() - - if nvcc.have_tensorcore(tvm.gpu(0).compute_version): - test_conv2d_nhwc_winograd_tensorcore() + test_conv2d_nhwc_winograd_tensorcore() diff --git a/web/.eslintignore b/web/.eslintignore new file mode 100644 index 000000000000..1521c8b7652b --- /dev/null +++ b/web/.eslintignore @@ -0,0 +1 @@ +dist diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 000000000000..a3135cf24b9d --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,6 @@ +.vscode +*~ +out +node_modules +package-lock.json +build diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json deleted file mode 100644 index 33783b3bbb21..000000000000 --- a/web/.jsdoc_conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "templates": { - "default": { - "includeDate": false - } - } -} diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 000000000000..be7fa193c04c --- /dev/null +++ b/web/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + +.PHONY: clean all + +all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \ + -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +EMCC_LDFLAGS = --pre-js emcc/preload.js + +dist/wasm/%.bc: emcc/%.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d + $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< + + +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) + + +dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py + python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@ + +clean: + @rm -rf dist/wasm + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md index 5dfd6917934b..66a64a3d3d37 100644 --- a/web/README.md +++ b/web/README.md @@ -15,163 +15,70 @@ -# TVM WebAssembly and Javascript Backend +# TVM WebAssembly Runtime -This folder contains TVM WebAssembly and Javascript backend through Emscripten. +This folder contains TVM WebAssembly Runtime. ## Installation -While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other -system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against -the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten``` -as a backend target. + +The LLVM main branch support webassembly as a target, we can directly +build TVM with LLVM mainline to generate wasm modules. +Note that, however, we still need emscripten to compile the runtime and provide system library support. + +Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install). ### Setup Emscripten -Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html) -to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document. -```bash -./emsdk update -./emsdk install latest -./emsdk activate latest -``` +We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy +to the browser environment. -Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library. -Which can be installed via following command. +Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment. -```bash -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -``` +### Build TVM Wasm Runtime -### Setup Environment Variable +After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder. -In normal setting, we can setup the necessary environment variable with the following command. ```bash -source /path-to-emsdk-portable/emsdk_env.sh +make ``` -However, this will put emscripten's clang and llvm path ahead of the current system path. -What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones. -You can get the detailed path by type ```./emsdk activate``` -```bash -export PATH=${PATH}:/emsdk-related-path-here +This command will create the follow files: +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. +- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime. -``` -### Build TVM with Fastcomp LLVM +### Build TVM Wasm JS Frontend -To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk``` -to point to fastcomp's llvm-config and build TVM normally. +Type the following command in the web folder. ```bash -LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config +npm run bundle ``` -### Build TVM Web Runtime +This command will create the tvmjs library that we can use to interface with the wasm runtime. -The above command gives us the TVM compiling environment. Now we need to build runtime, -to do so, make sure we set the environment correctly as in previous section and type -```bash -make web -``` +## Use TVM to Generate Wasm Library and Run it -This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```. - -## Use TVM to Generate Javascript Library - -The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```. - -The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate -the compilation process. - -```python -import tvm -from tvm import te -from tvm.contrib import emscripten -import os -def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - if not tvm.runtime.enabled(target): - raise RuntimeError("Target %s is not enbaled" % target) - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') - s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path) - -if __name__ == "__main__": - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) -``` +Check code snippet in -In this workflow, we use TVM to generate a ```.bc``` file and statically link -that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that). -The result js library is a library that contains both TVM runtime and the compiled function. - - -## Run the Generated Library - -The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate -how to run the compiled library. - -```js -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Load system library, the compiled function is registered in sysLib. -var sysLib = tvm.systemLib(); - -function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { - return Math.random() * max; - }); -} - -function testAddOne() { - // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); - // call the function. - faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array - // verify - for (var i = 0; i < BB.length; ++i) { - assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); - } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); -``` +- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py) + shows how to create a wasm library that links with tvm runtime. + - Note that all wasm libraries have to created using the `--system-lib` option + - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc` +- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate + how to run the generated library through tvmjs API. -Current example supports static linking, which is the preferred way to get more efficiency -in javascript backend. -## Proxy based RPC +## Run Wasm Remotely through WebSocket RPC. -We can now use javascript end to start an RPC server and connect to it from python side, +We can now use js side to start an RPC server and connect to it from python side, making the testing flow easier. -The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install) -- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -- Open broswer, goto the server webpage click Connect to proxy. - - Alternatively run "node web/example_rpc_node.js" -- run "python tests/web/websock_rpc_test.py" to run the rpc client. - -The general idea is to use Emscripten's dynamic linking to dynamically load modules. +The following is an example to reproduce this. +- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy. +- Start the WebSocket RPC + - Browswer version: open https://localhost:8888, click connect to proxy + - NodeJS version: `npm run rpc` +- run `python tests/node/websock_rpc_test.py` to run the rpc client. diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html new file mode 100644 index 000000000000..22907f1561d1 --- /dev/null +++ b/web/apps/browser/rpc_server.html @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + TVM RPC Test Page + + + + + +

TVM WebSocket RPC Server

+ To use this page + + +

Options

+ Proxy URL
+ RPC Server Key
+ + +
+ + + diff --git a/web/apps/node/example.js b/web/apps/node/example.js new file mode 100644 index 000000000000..f81a9c903e5d --- /dev/null +++ b/web/apps/node/example.js @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +// the async version of the API. +tvmjs.instantiate(wasmSource, new EmccWASI()) +.then((tvm) => { + // List all the global functions from the runtime. + console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); +}); + diff --git a/web/apps/node/wasi_example.js b/web/apps/node/wasi_example.js new file mode 100644 index 000000000000..95ec2e0b1d07 --- /dev/null +++ b/web/apps/node/wasi_example.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const { WASI } = require('wasi'); +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +const wasi = new WASI({ args: process.argv, env: process.env }); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi); + +// List all the global functions from the runtime. +console.log("Runtime using WASI\n", tvm.listGlobalFuncNames()); diff --git a/web/example_rpc_node.js b/web/apps/node/wasi_rpc_server.js similarity index 60% rename from web/example_rpc_node.js rename to web/apps/node/wasi_rpc_server.js index 45f917a3234b..eb4c6ed52be9 100644 --- a/web/example_rpc_node.js +++ b/web/apps/node/wasi_rpc_server.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,17 +17,20 @@ * under the License. */ -// Javascript RPC server example -// Start and connect to websocket proxy. +/** + * Example code to start the RPC server on nodejs using WASI + */ +const { WASI } = require("wasi"); +const tvmjs = require("../../dist"); + +// Get import returns a fresh library in each call. +const getImports = () => { + return new WASI({ + args: process.argv, + env: process.env + }); +}; -// Load Emscripten Module, need to change path to root/lib -const path = require("path"); -process.chdir(path.join(__dirname, "../lib")); -var Module = require("../lib/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const proxyUrl = "ws://localhost:8888/ws"; -var websock_proxy = "ws://localhost:9190/ws"; -var num_sess = 100; -tvm.startRPCServer(websock_proxy, "js", num_sess) +new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log); diff --git a/tests/webgl/test_local_multi_stage.py b/web/emcc/decorate_as_wasi.py similarity index 50% rename from tests/webgl/test_local_multi_stage.py rename to web/emcc/decorate_as_wasi.py index 54a554b74ed9..741e33bb22ea 100644 --- a/tests/webgl/test_local_multi_stage.py +++ b/web/emcc/decorate_as_wasi.py @@ -14,34 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te -import numpy as np +"""Decorate emcc generated js to a WASI compatible API.""" -def test_local_multi_stage(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return +import sys - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute((n,), lambda i: A[i] + 1, name="B") - C = te.compute((n,), lambda i: B[i] * 2, name="C") +template_head = """ +function EmccWASI() { +""" - s = te.create_schedule(C.op) - s[B].opengl() - s[C].opengl() +template_tail = """ + this.Module = Module; + this.start = Module.wasmLibraryProvider.start; + this.imports = Module.wasmLibraryProvider.imports; + this.wasiImport = this.imports["wasi_snapshot_preview1"]; +} - f = tvm.build(s, [A, C], "opengl", name="multi_stage") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) - f(a, c) - - tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2) +if (typeof module !== "undefined" && module.exports) { + module.exports = EmccWASI; +} +""" if __name__ == "__main__": - test_local_multi_stage() + if len(sys.argv) != 3: + print("Usage ") + result = template_head + open(sys.argv[1]).read() + template_tail + with open(sys.argv[2], "w") as fo: + fo.write(result) diff --git a/web/emcc/preload.js b/web/emcc/preload.js new file mode 100644 index 000000000000..882280f9cac0 --- /dev/null +++ b/web/emcc/preload.js @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-unused-vars */ +/** + * JS config used by --pre-js in emcc. + * Wrap module as a LibraryProvider. + */ + +var __wasmLib = {}; + +function __wasmLibInstantiateWasm(imports, successCallback) { + __wasmLib.imports = imports; + __wasmLib.successCallback = successCallback; +} + +function __wasmLibStart(wasmInstance) { + __wasmLib.successCallback(wasmInstance); +} + +__wasmLib.start = __wasmLibStart; + +var Module = { + "instantiateWasm": __wasmLibInstantiateWasm, + "wasmLibraryProvider": __wasmLib +}; diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc new file mode 100644 index 000000000000..97099e75f16f --- /dev/null +++ b/web/emcc/tvmjs_support.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvmjs_support.cc + * \brief Support functions to be linked with wasm_runtime to provide + * PackedFunc callbacks in tvmjs. + * We do not need to link this file in standalone wasm. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + + +#include +#include +#include +#include +#include + +extern "C" { +// --- Additional C API for the Wasm runtime --- +/*! + * \brief Allocate space aligned to 64 bit. + * \param size The size of the space. + * \return The allocated space. + */ +TVM_DLL void* TVMWasmAllocSpace(int size); + +/*! + * \brief Free the space allocated by TVMWasmAllocSpace. + * \param data The data pointer. + */ +TVM_DLL void TVMWasmFreeSpace(void* data); + +/*! + * \brief Create PackedFunc from a resource handle. + * \param resource_handle The handle to the resource. + * \param out The output PackedFunc. + * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer +3A * \return 0 if success. + */ +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out); + +// --- APIs to be implemented by the frontend. --- +/*! + * \brief Wasm frontend packed function caller. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param ret The return value handle. + * \param resource_handle The handle additional resouce handle from fron-end. + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. + */ +extern int TVMWasmPackedCFunc(TVMValue* args, + int* type_codes, + int num_args, + TVMRetValueHandle ret, + void* resource_handle); + +/*! + * \brief Wasm frontend resource finalizer. + * \param resource_handle The pointer to the external resource. + */ +extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +} // extern "C" + + +void* TVMWasmAllocSpace(int size) { + int num_count = (size + 7) / 8; + return new int64_t[num_count]; +} + +void TVMWasmFreeSpace(void* arr) { + delete[] static_cast(arr); +} + +int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out) { + return TVMFuncCreateFromCFunc( + TVMWasmPackedCFunc, resource_handle, + TVMWasmPackedCFuncFinalizer, out); +} + + +namespace tvm { +namespace runtime { + +// chrono in the WASI does not provide very accurate time support +// and also have problems in the i64 support in browser. +// We redirect the timer to a JS side time using performance.now +PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( + TVMArgs args, TVMRetValue *rv) { + + TVMRetValue temp; + auto finvoke = [&](int n) { + // start timing + for (int i = 0; i < n; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + }; + + auto* get_timer = runtime::Registry::Get("wasm.GetTimer"); + CHECK(get_timer != nullptr) << "Cannot find wasm.GetTimer in the global function"; + TypedPackedFunc timer_ms = (*get_timer)( + TypedPackedFunc(finvoke)); + + std::ostringstream os; + finvoke(1); + + int setup_number = number; + + for (int i = 0; i < repeat; ++i) { + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + setup_number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + duration_ms = timer_ms(setup_number); + } while (duration_ms < min_repeat_ms); + + double speed = duration_ms / setup_number / 1000; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + +TVM_REGISTER_GLOBAL("wasm.RPCTimeEvaluator") +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc new file mode 100644 index 000000000000..6ff652cf408a --- /dev/null +++ b/web/emcc/wasm_runtime.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file wasm_runtime.cc + * \brief TVM wasm runtime library pack. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include + +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/workspace_pool.cc" +#include "src/runtime/library_module.cc" +#include "src/runtime/system_library.cc" + +#include "src/runtime/module.cc" +#include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/rpc/rpc_endpoint.cc" +#include "src/runtime/rpc/rpc_event_impl.cc" +#include "src/runtime/rpc/rpc_channel.cc" +#include "src/runtime/rpc/rpc_local_session.cc" +#include "src/runtime/rpc/rpc_module.cc" + + +// --- Implementations of backend and wasm runtime API. --- + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, + void* cdata, + int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { + return 0; +} + +// --- Environment PackedFuncs for testing --- +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("testing.echo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = args[0]; +}); + +TVM_REGISTER_GLOBAL("testing.add_one") +.set_body_typed([](int x) { + return x + 1; +}); + +TVM_REGISTER_GLOBAL("testing.wrap_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf](){ + pf(); + }); + }); +} // namespace runtime +} // namespace tvm diff --git a/web/example_rpc.html b/web/example_rpc.html deleted file mode 100644 index ae2b1dd9c44b..000000000000 --- a/web/example_rpc.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - - -

TVM Test Page

- To use this page, the easiest way is to do -
    -
  • run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -
  • Click Connect to proxy. -
  • run "python tests/web/websock_rpc_test.py" to run the rpc client. -
-

Options

- Proxy URL
- RPC Server Key
- - -
- - - - diff --git a/web/package.json b/web/package.json new file mode 100644 index 000000000000..76aa111e2acf --- /dev/null +++ b/web/package.json @@ -0,0 +1,29 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.7.0", + "scripts": { + "build": "tsc -b", + "watch": "tsc -b -w", + "lint": "eslint -c .eslintrc.json .", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "typescript": "^3.8.3", + "@types/node": "^12.12.37", + "eslint": "^6.8.0", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "typedoc": "^0.17.6", + "rollup": "^2.7.6", + "ws": "^7.2.5", + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "rollup-plugin-typescript2": "^0.27.0" + }, + "dependencies": {} +} diff --git a/web/.eslintrc.js b/web/rollup.config.js similarity index 69% rename from web/.eslintrc.js rename to web/rollup.config.js index 2e82ba50e3c4..0046e4434076 100644 --- a/web/.eslintrc.js +++ b/web/rollup.config.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,29 +17,18 @@ * under the License. */ -module.exports = { - "env": { - "browser": true, - "node": true, - "es6": true +import commonjs from '@rollup/plugin-commonjs'; +import resolve from '@rollup/plugin-node-resolve'; + +export default { + input: 'dist/index.js', + output: { + file: 'dist/tvmjs.bundle.js', + format: 'umd', + name: 'tvmjs', + exports: 'named', + globals: {'ws': 'ws'} }, - "extends": "eslint:recommended", - "rules": { - "indent": [ - "error", - 2 - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ] - } + plugins: [commonjs(), resolve()], + external: ['ws'] }; diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts new file mode 100644 index 000000000000..f533b4e491a6 --- /dev/null +++ b/web/src/ctypes.ts @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLContext = I32 + I32, +} + +/** + * Type code in TVM FFI. + */ +export const enum TypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + TVMContext = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} \ No newline at end of file diff --git a/web/src/environment.ts b/web/src/environment.ts new file mode 100644 index 000000000000..df0fe68c81e0 --- /dev/null +++ b/web/src/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts new file mode 100644 index 000000000000..5d7d7ccc39cc --- /dev/null +++ b/web/src/index.ts @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { + Scalar, DLContext, DLDataType, + PackedFunc, Module, NDArray, Instance, + instantiate +} from "./runtime"; +export { Disposable, LibraryProvider } from "./types"; +export { RPCServer } from "./rpc_server"; +export { wasmPath } from "./support"; \ No newline at end of file diff --git a/web/src/memory.ts b/web/src/memory.ts new file mode 100644 index 000000000000..ac737b7c297d --- /dev/null +++ b/web/src/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts new file mode 100644 index 000000000000..054a1b6019cc --- /dev/null +++ b/web/src/rpc_server.ts @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SizeOf, TypeCode } from "./ctypes"; +import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import * as runtime from "./runtime"; +import { Class } from "estree"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private name: string; + private inst?: runtime.Instance = undefined; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + + this.checkLittleEndian(); + + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + this.socket = new WebSocket(url); + } else { + this.socket = new (WebSocket as any)(url); + } + + //this.socket = this.getSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == TypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == TypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(args[1] instanceof Uint8Array); + assert(this.pendingBytes == 0); + + runtime.instantiate(args[1].buffer, this.getImports()) + .then((inst: runtime.Instance) => { + this.inst = inst; + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + this.socket.send(cbytes); + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + + fcreate.dispose(); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("rpc.LocalSession"); + const localSession = flocal(); + flocal.dispose(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + localSession.dispose(); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + }); + this.state = RPCServerState.WaitForCallback; + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts new file mode 100644 index 000000000000..cd9b967596af --- /dev/null +++ b/web/src/runtime.ts @@ -0,0 +1,1113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Disposable } from "./types"; +import { Memory, CachedCallStack } from "./memory"; +import { assert, StringToUint8Array } from "./support"; +import { Environment } from "./environment"; + +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + private wasmInstance: WebAssembly.Instance; + + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "gpu", + 4: "opencl", + 7: "vulkan", + 8: "metal", +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + gpu: 2, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLContext { + /** The device type code of the context. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the context + */ + sync(): void { + this.lib.checkCall( + (this.lib.exports.TVMSynchronize as ctypes.FTVMSynchronize)( + this.deviceType, + this.deviceId, + 0 + ) + ); + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 4: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Context of the array. */ + context: DLContext; + + private byteOffset: number; + private dltensor: Pointer; + private lib: FFILibrary; + private dlDataType: DLDataType; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // ctx + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.context = new DLContext(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.handle, + this.handle, + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.handle, + stack.ptrFromOffset(nameOffset), + 1, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.handle, + mod.handle + ) + ); + } +} + +/** + * TVM runtime instance. + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + private lib: FFILibrary; + private env: Environment; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.registerEnvGlobalPackedFuncs(); + } + + dispose(): void { + this.lib.dispose(); + } + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + const getSysLib = this.getGlobalFunc("runtime.SystemLib"); + const mod = getSysLib() as Module; + getSysLib.dispose(); + return mod; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + const packedFunc = this.toPackedFunc(func); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.handle, + ioverride + ) + ); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = TypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = TypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.TVMOpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLContext} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created context. + */ + context(deviceType: number | string, deviceId: number): DLContext { + return new DLContext(deviceType, deviceId, this.lib); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param ctx The context of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + ctx: DLContext = this.context("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + ctx.deviceType, + ctx.deviceId, + outPtr + ) + ); + const ret = new NDArray(this.memory.loadPointer(outPtr), this.lib); + this.lib.recycleCallStack(stack); + return ret; + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + let perf: Performance; + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require('perf_hooks'); + perf = performanceNode.performance as Performance; + } else { + perf = performance as Performance; + } + + const getTimer = (func: PackedFunc) => { + return (n: number): number => { + const nscalar = this.scalar(n, "int32"); + const tstart: number = perf.now(); + func(nscalar); + const tend: number = perf.now(); + return tend - tstart; + } + }; + this.registerFunc("wasm.GetTimer", getTimer); + const rpcWrapTimeEvaluator = this.getGlobalFunc("wasm.RPCTimeEvaluator"); + this.registerFunc("runtime.RPCTimeEvaluator", rpcWrapTimeEvaluator, true); + rpcWrapTimeEvaluator.dispose(); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + } + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, TypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFunc(val); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == TypeCode.TVMObjectHandle || + tcode == TypeCode.TVMObjectRValueRefArg || + tcode == TypeCode.TVMPackedFuncHandle || + tcode == TypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode)); + } + + const rv = func(...jsArgs); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + handle, + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr)); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + private retValueToJS(rvaluePtr: Pointer, tcode: number): any { + switch (tcode) { + case TypeCode.Int: + case TypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case TypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case TypeCode.TVMNDArrayHandle: { + return new NDArray(this.memory.loadPointer(rvaluePtr), this.lib); + } + case TypeCode.TVMPackedFuncHandle: { + return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMModuleHandle: { + return new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.makePackedFunc(ptr); + } + ); + } + case TypeCode.Null: + return undefined; + case TypeCode.TVMStr: { + return this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {} +): Promise { + const env = new Environment(importObject); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/web/src/support.ts b/web/src/support.ts new file mode 100644 index 000000000000..7a2667a2299f --- /dev/null +++ b/web/src/support.ts @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of arr) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return __dirname + "/wasm"; +} \ No newline at end of file diff --git a/web/src/types.ts b/web/src/types.ts new file mode 100644 index 000000000000..621375a23f5f --- /dev/null +++ b/web/src/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js similarity index 64% rename from tests/web/test_module_load.js rename to web/tests/node/test_module_load.js index f4c809536bb5..45e84fd404a9 100644 --- a/tests/web/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -19,14 +19,18 @@ // Load Emscripten Module, need to change path to root/lib const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm")); + +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Load system library -var sysLib = tvm.systemLib(); +const sysLib = tvm.systemLib(); function randomArray(length, max) { return Array.apply(null, Array(length)).map(function() { @@ -36,23 +40,22 @@ function randomArray(length, max) { function testAddOne() { // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); + const faddOne = sysLib.getFunction("add_one"); + assert(tvm.isPackedFunc(faddOne)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n); // call the function. faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array // verify for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } - faddOne.release(); + faddOne.dispose(); } testAddOne(); -sysLib.release(); +sysLib.dispose(); console.log("Finish verifying test_module_load"); diff --git a/tests/web/test_basic.js b/web/tests/node/test_ndarray.js similarity index 55% rename from tests/web/test_basic.js rename to web/tests/node/test_ndarray.js index 6852319dbc12..ba43621ecb05 100644 --- a/tests/web/test_basic.js +++ b/web/tests/node/test_ndarray.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -16,31 +16,34 @@ * specific language governing permissions and limitations * under the License. */ - -// Load Emscripten Module, need to change path to root/build const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Basic fields. -tvm.assert(tvm.float32 == "float32"); -tvm.assert(tvm.listGlobalFuncNames() !== "undefined"); -var sysLib = tvm.systemLib(); -tvm.assert(typeof sysLib.getFunction !== "undefined"); -sysLib.release(); +assert(tvm.listGlobalFuncNames() !== undefined); // Test ndarray -function testArrayCopy(dtype, arr) { - var data = [1, 2, 3, 4, 5, 6]; - var a = tvm.empty([2, 3], dtype); - a.copyFrom(data); - var ret = a.asArray(); - tvm.assert(ret instanceof arr); - tvm.assert(ret.toString() == arr.from(data)); - a.release(); +function testArrayCopy(dtype, arrayType) { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], dtype).copyFrom(data); + + assert(a.context.toString() == "cpu(0)"); + assert(a.shape[0] == 2 && a.shape[1] == 3); + + let ret = a.toArray(); + assert(ret instanceof arrayType); + assert(ret.toString() == arrayType.from(data).toString()); + // test multiple dispose. + a.dispose(); + a.dispose(); } testArrayCopy("float32", Float32Array); @@ -48,8 +51,3 @@ testArrayCopy("int", Int32Array); testArrayCopy("int8", Int8Array); testArrayCopy("uint8", Uint8Array); testArrayCopy("float64", Float64Array); - -// Function registration -tvm.registerFunc("xyz", function(x, y) { - return x + y; -}); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js new file mode 100644 index 000000000000..c961f9576e3f --- /dev/null +++ b/web/tests/node/test_packed_func.js @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +const path = require("path"); +const fs = require("fs"); +const assert = require('assert'); +const tvmjs = require("../../dist") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +function testGetGlobal() { + let flist = tvm.listGlobalFuncNames(); + let faddOne = tvm.getGlobalFunc("testing.add_one"); + let fecho = tvm.getGlobalFunc("testing.echo"); + + assert(faddOne(tvm.scalar(1, "int")) == 2); + // check function argument with different types. + assert(fecho(1123) == 1123); + assert(fecho("xyz") == "xyz"); + + let bytes = new Uint8Array([1, 2, 3]); + let rbytes = fecho(bytes); + assert(rbytes.length == bytes.length); + + for (let i = 0; i < bytes.length; ++i) { + assert(rbytes[i] == bytes[i]); + } + + assert(fecho(undefined) == undefined); + + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); + let arr2 = fecho(arr); + assert(arr.handle == arr2.handle); + assert(arr2.toArray().toString() == arr.toArray().toString()); + + let mod = tvm.systemLib(); + let ret = fecho(mod); + assert(ret.handle == mod.handle); + assert(flist.length != 0); + + mod.dispose(); + ret.dispose(); + arr.dispose(); + arr2.dispose(); + fecho.dispose(); + faddOne.dispose(); +} + +function testReturnFunc() { + function addy(y) { + function add(x, z) { + return x + y + z; + } + return add; + } + + let fecho = tvm.getGlobalFunc("testing.echo"); + let myf = tvm.toPackedFunc(addy); + assert(tvm.isPackedFunc(myf)); + let myf2 = tvm.toPackedFunc(myf); + assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle); + let f = myf(10); + + assert(tvm.isPackedFunc(f)); + assert(f(11, 0) == 21); + assert(f("x", 1) == "x101"); + assert(f("x", "yz") == "x10yz"); + + fecho.dispose(); + myf.dispose(); + myf2.dispose(); + // test multiple dispose. + f.dispose(); + f.dispose(); +} + +function testRegisterGlobal() { + tvm.registerFunc("xyz", function (x, y) { + return x + y; + }); + + let f = tvm.getGlobalFunc("xyz"); + assert(f(1, 2) == 3); + f.dispose(); + + let syslib = tvm.systemLib(); + syslib.dispose(); +} + +function testTimer() { + const fecho = tvm.getGlobalFunc("testing.echo"); + const fgetTimer = tvm.getGlobalFunc("wasm.GetTimer"); + + let finvoke = (n) => { + let x = "xyz"; + for (let i = 0; i < n; ++i) { + x = fecho(x); + } + }; + const number = 10000; + const invokeTimer = fgetTimer(finvoke); + console.log("Time cost:", number / invokeTimer(number) * 1000, " ops/sec"); + fecho.dispose(); + invokeTimer.dispose(); + fgetTimer.dispose(); +} + +testGetGlobal(); +testRegisterGlobal(); +testReturnFunc(); +testTimer(); diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py similarity index 69% rename from tests/web/prepare_test_libs.py rename to web/tests/python/prepare_test_libs.py index a0e2c13eab82..ec4eb5be1536 100644 --- a/tests/web/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -14,27 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Prepare test library for js. +# Prepare test library for standalone wasm runtime test. + import tvm from tvm import te -from tvm.contrib import emscripten +from tvm.contrib import emcc import os + def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path, - options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s", - "USE_WEBGL2=1", "-lglfw"]) + fadd = tvm.build(s, [A, B], target, name="add_one") + + wasm_path = os.path.join(base_path, "test_addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) + prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py similarity index 55% rename from tests/web/websock_rpc_test.py rename to web/tests/python/websock_rpc_test.py index 8be8ce04cb75..7fa0c6bdfb57 100644 --- a/tests/web/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -22,45 +22,61 @@ import tvm from tvm import te -import os from tvm import rpc -from tvm.contrib import util, emscripten +from tvm.contrib import util, emcc import numpy as np proxy_host = "localhost" proxy_port = 9090 -def test_rpc_array(): +def test_rpc(): if not tvm.runtime.enabled("rpc"): return - # graph - n = tvm.runtime.convert(1024) + # generate the wasm library + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - remote = rpc.connect(proxy_host, proxy_port, key="js") - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - def check_remote(): - if not tvm.runtime.enabled(target): - print("Skip because %s is not enabled" % target) - return - temp = util.tempdir() + + fadd = tvm.build(s, [A, B], target, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + # run the generated library. + f1 = remote.system_lib() ctx = remote.cpu(0) - f = tvm.build(s, [A, B], target, name="myadd") - path_obj = temp.relpath("dev_lib.bc") - path_dso = temp.relpath("dev_lib.js") - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - # Upload to suffix as dso so it can be loaded remotely - remote.upload(path_dso, "dev_lib.dso") - data = remote.download("dev_lib.dso") - f1 = remote.load_module("dev_lib.dso") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + # invoke the function + addone = f1.get_function("addone") + addone(a, b) + + # time evaluator + time_f = f1.time_evaluator("addone", ctx, number=10) + time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() -test_rpc_array() + check(remote) + + +test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 000000000000..3c20b3d20692 --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true, + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js deleted file mode 100644 index 86ef59cb73b1..000000000000 --- a/web/tvm_runtime.js +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/** - * TVM Javascript web runtime library. - * - * @projectname tvm - * @version 0.7.dev1 - */ -/* eslint no-unused-vars: "off" */ -/* eslint no-unexpected-multiline: "off" */ -/* eslint indent: "off" */ -/* eslint no-console: "off" */ -/** - * TVM Runtime namespace. - * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. - * - * @namespace tvm_runtime - */ -var tvm_runtime = tvm_runtime || {}; - -/** - * TVM root namespace. - * The classes inside this namespace need to be constructed by factory functions. - * Use {@link tvm_runtime}.create to get started. - * - * @namespace tvm - */ -(function() { - /** - * TVMRuntime object for interacting with TVM runtime. - * This object can be constructed using {@link tvm_runtime}.create - * - * @class - * @memberof tvm - */ - function TVMRuntime() { - "use strict"; - var runtime_ref = this; - // Utility function to throw error - function throwError(message) { - if (typeof runtime_ref.logger !== "undefined") { - runtime_ref.logger(message); - } - if (typeof Error !== "undefined") { - throw new Error(message); - } - throw message; - } - var Module = this.Module; - var Runtime = this.Runtime; - if (typeof Module === "undefined") { - throwError("Emscripten Module is not available"); - } - // constants - var SIZEOF_POINTER = 4; - var SIZEOF_SIZE_T = 4; - var SIZEOF_FLOAT = 4; - var SIZEOF_INT = 4; - var SIZEOF_INT8 = 1; - var SIZEOF_INT64 = 8; - var SIZEOF_DOUBLE = 8; - var SIZEOF_TYPE = 4; - var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT; - var SIZEOF_TVMVALUE = SIZEOF_DOUBLE; - var ARRAY_OFFSET_DATA = 0; - var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER; - var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX; - var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT; - var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX; - var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT; - var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE; - var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8; - var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8; - var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE; - var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - // Type codes - var kInt = 0; - var kUInt = 1; - var kFloat = 2; - var kTVMOpaqueHandle = 3; - var kNull = 4; - var kTVMDataType = 5; - var kTVMContext = 6; - var kTVMDLTensorHandle = 7; - var kTVMObjectHandle = 8; - var kTVMModuleHandle = 9; - var kTVMPackedFuncHandle = 10; - var kTVMStr = 11; - var kTVMBytes = 12; - var kTVMObjectRValueRefArg = 14; - //----------------------------------------- - // TVM CWrap library - // ---------------------------------------- - var TVMGetLastError = Module.cwrap( - "TVMGetLastError", - "string", // const char* - []); - - var TVMAPISetLastError = Module.cwrap - ("TVMAPISetLastError", - null, - ["string" // const char* - ]); - - var TVMModImport = Module.cwrap - ("TVMModImport", - "number", - ["number", // TVMModuleHandle mod - "number" // TVMModuleHandle dep - ]); - - var TVMModGetFunction = Module.cwrap - ("TVMModGetFunction", - "number", - ["number", // TVMModuleHandle mod - "string", // const char* func_name - "number", // int query_imports - "number" // TVMFunctionHandle *out - ]); - - var TVMModFree = Module.cwrap - ("TVMModFree", - "number", - ["number" // TVMModeHandle mod - ]); - - var TVMFuncFree = Module.cwrap - ("TVMFuncFree", - "number", - ["number" // TVMFunctionHandle func - ]); - - var TVMFuncCall = Module.cwrap - ("TVMFuncCall", - "number", - ["number", // TVMFunctionHandle func - "number", // TVMValue* arg_values - "number", // int* arg_tcodes - "number", // int num_args - "number", // int ret_val - "number" // int ret_type_code - ]); - - var TVMCFuncSetReturn = Module.cwrap - ("TVMCFuncSetReturn", - "number", - ["number", // TVMRetValueHandle ret - "number", // TVMValue* value - "number", // int* type_code - "number" // int num_ret - ]); - - var TVMCbArgToReturn = Module.cwrap - ("TVMCbArgToReturn", - "number", - ["number", // TVMValue* value - "number" // int* code - ]); - - var TVMFuncCreateFromCFunc = Module.cwrap - ("TVMFuncCreateFromCFunc", - "number", - ["number", // TVMPackedCFunc func, - "number", // void* resource_handle - "number", // TVMPackedCFuncFinalizer fin - "number" // TVMFunctionHandle *out - ]); - - var TVMFuncRegisterGlobal = Module.cwrap - ("TVMFuncRegisterGlobal", - "number", - ["string", // name - "number", // TVMFunctionHandle f - "number" // int override - ]); - - var TVMFuncGetGlobal = Module.cwrap - ("TVMFuncGetGlobal", - "number", - ["string", // const char* name - "number" // TVMFunctionHandle* out - ]); - - var TVMFuncListGlobalNames = Module.cwrap - ("TVMFuncListGlobalNames", - "number", - ["number", // int* out_size - "number" // const char*** out_array - ]); - - - var TVMArrayAlloc = Module.cwrap - ("TVMArrayAlloc", - "number", - ["number", // const tvm_index_t* shape - "number", // int ndim - "number", // int dtype_code - "number", // int dtype_bits - "number", // int dtype_lanes - "number", // int device_type - "number", // int device_id - "number" // int TVMArrayHandle* out - ]); - - var TVMArrayFree = Module.cwrap - ("TVMArrayFree", - "number", - ["number" // TVMArrayHandle handle - ]); - - var TVMArrayCopyFromTo = Module.cwrap - ("TVMArrayCopyFromTo", - "number", - ["number", // TVMArrayHandle from - "number" // TVMArrayHandle to - ]); - - var TVMArrayCopyFromBytes = Module.cwrap - ("TVMArrayCopyFromBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMArrayCopyToBytes = Module.cwrap - ("TVMArrayCopyToBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMModLoadFromFile = Module.cwrap - ("TVMModLoadFromFile", - "number", - ["string", // const char* file_name - "string", // const char* format - "number" // TVMModuleHandle* out - ]) - - //----------------------------------------- - // Static utility functions - // ---------------------------------------- - this.assert = function(condition, message) { - if (!condition) { - message = message || "assert failed"; - throwError(message); - } - }; - /** - * Logging function. - * Override this to change logger behavior. - * - * @param {string} message - */ - this.logger = function(message) { - console.log(message); - }; - - function logging(message) { - runtime_ref.logger(message); - } - // Override print error to logging - Module.printErr = logging; - var CHECK = this.assert; - - function TVM_CALL(ret) { - if (ret != 0) { - throwError(TVMGetLastError()); - } - } - - function CInt64ArrayToJS(ptr, size) { - var ret = []; - for (var i = 0; i < size; ++i) { - ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64")); - } - return ret; - } - - function CStringToJS(ptr) { - var ret = []; - var ch = 1; - while (ch != 0) { - ch = Module.getValue(ptr, "i8"); - if (ch != 0) { - ret.push(String.fromCharCode(ch)); - } - ++ptr; - } - return ret.join(""); - } - - function CBytesToJS(ptr) { - var data = Module.getValue(ptr, "*"); - var size = Module.getValue(ptr + SIZEOF_POINTER, "i32"); - var ret = new Uint8Array(new ArrayBuffer(size)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size)); - return ret; - } - - function StringToUint8Array(str) { - var arr = new Uint8Array(str.length + 1); - for(var i = 0; i < str.length; ++i) { - arr[i] = str.charCodeAt(i); - } - arr[str.length] = 0; - return arr; - } - //----------------------------------------- - // Class declarations - // ---------------------------------------- - function CBuffer(nbytes) { - this.data = Module._malloc(nbytes); - } - - function RefTVMValue() { - this.data = Module._malloc(SIZEOF_TVMVALUE); - } - - function TVMArgs(nargs) { - this.nargs = nargs; - this.value = Module._malloc(SIZEOF_TVMVALUE * nargs); - this.tcode = Module._malloc(SIZEOF_INT * nargs); - this.temp = []; - } - - function TVMType(code, bits, lanes) { - this.code = code; - this.bits = bits; - this.lanes = lanes; - } - /** - * TVM device context. - * @class - * @memberof tvm - */ - function TVMContext(device_type, device_id) { - this.device_type = device_type; - this.device_id = device_id; - } - /** - * TVM n-dimensional array. - * - * Use {@link tvm.TVMRuntime}.empty to create an instance. - * @class - * @memberof tvm - */ - function NDArray(handle) { - this.handle = handle; - this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32"); - // shape - var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*"); - this.shape = CInt64ArrayToJS(cshape, this.ndim); - // dtype - var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8"); - var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8"); - var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16"); - var dtype = new TVMType(code, bits, lanes); - this.dtype = dtype; - this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8); - // ctx - var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32"); - var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32"); - this.context = new TVMContext(device_type, device_id); - // byte_offset - this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64"); - } - - function TVMFunction(handle) { - this.handle = handle; - } - /** - * Module container of TVM generated functions. - * - * @class - * @memberof tvm - */ - function TVMModule(handle) { - this.handle = handle; - } - /** - * A typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * Use {@link tvm.TVMRuntime}.constant to create an instance. - * @class - * @memberof tvm - */ - function TVMConstant(value, dtype) { - this.value = value; - this.dtype = dtype; - } - //----------------------------------------- - // Private Functions - // ---------------------------------------- - function getTVMType(dtype) { - if (dtype instanceof TVMType) return dtype; - if (typeof dtype == "string") { - var pattern = dtype; - var code, bits = 32, lanes = 1; - if (pattern.substring(0, 5) == "float") { - pattern = pattern.substring(5, pattern.length); - code = kFloat; - } else if (pattern.substring(0, 3) == "int") { - pattern = pattern.substring(3, pattern.length); - code = kInt; - } else if (pattern.substring(0, 4) == "uint") { - pattern = pattern.substring(4, pattern.length); - code = kUInt; - } else if (pattern.substring(0, 6) == "handle") { - pattern = pattern.substring(5, pattern.length); - code = kTVMOpaqueHandle; - bits = 64; - } else { - throw throwError("Unknown dtype " + dtype); - } - var arr = pattern.split("x"); - if (arr.length >= 1) { - var parsed = parseInt(arr[0]); - if (parsed == arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new TVMType(code, bits, lanes); - } else { - throw throwError("Unknown dtype " + dtype); - } - } - - function TVMRetValueToJS(vptr, tcode) { - switch (tcode) { - case kInt: - case kUInt: return Module.getValue(vptr, "i64"); - case kFloat: return Module.getValue(vptr, "double"); - case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); - case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); - case kNull: return null; - case kTVMStr: return CStringToJS(Module.getValue(vptr, "*")); - case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*")); - default: throwError("Unsupported return type code=" + tcode); - } - } - - function makeTVMFunction(handle) { - var func = new TVMFunction(handle); - var ret = function () { - // alloc - var args = new TVMArgs(arguments.length); - var rvalue = new RefTVMValue(); - var rtcode = new RefTVMValue(); - args.setArguments(arguments); - TVM_CALL(TVMFuncCall(handle, args.value, args.tcode, - args.nargs, rvalue.data, rtcode.data)); - var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt()); - // release - args.release(); - rvalue.release(); - rtcode.release(); - return rv; - }; - var release = function() { - func.release(); - }; - ret._tvm_function = func; - ret.release = release; - return ret; - } - //----------------------------------------- - // Javascript PackedCallback System - // ---------------------------------------- - var funcTable = [0]; - var freeFuncId = []; - - function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) { - var args = []; - for (var i = 0; i < nargs; ++i) { - var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcodeptr = arg_tcode + i * SIZEOF_INT; - var tcode = Module.getValue(tcodeptr, "i32"); - if (tcode == kTVMObjectHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMPackedFuncHandle || - tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); - } - tcode = Module.getValue(tcodeptr, "i32"); - args.push(TVMRetValueToJS(vptr, tcode)); - } - var rv = funcTable[handle].apply(null, args); - if (typeof rv !== "undefined") { - // alloc - var rarg = new TVMArgs(1); - rarg.setArguments([rv]); - TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1)); - // release - rarg.release(); - } - return 0; - } - function freeCallback(handle) { - funcTable[handle] = 0; - freeFuncId.push(handle); - } - var fptrInvokeCallback = null; - var fptrFreeCallback = null; - if (typeof Runtime !== "undefined" && - typeof Runtime.addFunction !== "undefined") { - fptrInvokeCallback = Runtime.addFunction(invokeCallback); - fptrFreeCallback = Runtime.addFunction(freeCallback); - } - /** - * Check if a function is TVM PackedFunc - * @param {Function} f function to be checked. - * @return {boolean} Whether f is PackedFunc - */ - this.isPackedFunc = function(f) { - return (typeof f == "function") && f.hasOwnProperty("_tvm_function"); - }; - var isPackedFunc = this.isPackedFunc; - /** - * Convert a javascript function to TVM function. - * @param {Function} f javascript function. - * @return {Function} The created TVMFunction. - */ - this.convertFunc = function(f) { - if (isPackedFunc(f)) return f; - CHECK(fptrInvokeCallback !== null, - "Emscripten Runtime addFunction is not available"); - var fid; - if (freeFuncId.length != 0) { - fid = freeFuncId.pop(); - } else { - fid = funcTable.length; - funcTable.push(0); - } - funcTable[fid] = f; - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncCreateFromCFunc( - fptrInvokeCallback, fid, fptrFreeCallback, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - return makeTVMFunction(out_handle); - }; - var convertFunc = this.convertFunc; - //----------------------------------------- - // Private Class declarations - // ---------------------------------------- - CBuffer.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - }; - // RefTVMValue - RefTVMValue.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - asInt : function() { - return Module.getValue(this.data, "i32"); - }, - asInt64 : function() { - return Module.getValue(this.data, "i64"); - }, - asDouble : function() { - return Module.getValue(this.data, "double"); - }, - asHandle : function() { - return Module.getValue(this.data, "*"); - } - }; - // TVMArgs - TVMArgs.prototype = { - release : function() { - if (this.value != 0) { - Module._free(this.value); - Module._free(this.tcode); - this.value = 0; - for (var i = 0; i< this.temp.length; ++i) { - if (this.temp[i].release instanceof Function) { - this.temp[i].release(); - } - } - } - }, - setInt : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64"); - }, - setDouble : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double"); - }, - setHandle : function(index, value, tcode) { - Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*"); - }, - setString : function(index, value) { - var sdata = new CBuffer(value.length + 1); - Module.HEAPU8.set(StringToUint8Array(value), sdata.data); - this.temp.push(sdata); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); - }, - setBytes : function(index, value) { - CHECK(value instanceof Uint8Array); - var sdata = new CBuffer(value.length); - var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T); - Module.HEAPU8.set(new Uint8Array(value), sdata.data); - Module.setValue(sheader.data, sdata.data, "*"); - Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); - this.temp.push(sdata); - this.temp.push(sheader); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); - }, - setArguments : function(args) { - for (var i = 0; i < args.length; ++i) { - var v = args[i]; - var tp = typeof v; - if (v instanceof NDArray) { - this.setHandle(i, v.handle, kTVMDLTensorHandle); - } else if (v instanceof TVMConstant) { - var code = getTVMType(v.dtype).code; - if (code == kInt || code == kUInt) { - this.setInt(i, v.value); - } else if (code == kFloat) { - this.setDouble(i, v.value); - } else { - CHECK(code == kTVMOpaqueHandle); - this.setHandle(i, v.value, kTVMOpaqueHandle); - } - } else if (tp == "number") { - this.setDouble(i, v); - } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { - this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v === null) { - this.setHandle(i, 0, kNull); - } else if (tp == "string") { - this.setString(i, v); - } else if (v instanceof Uint8Array) { - this.setBytes(i, v); - } else if (v instanceof Function) { - v = convertFunc(v); - this.temp.push(v); - this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v instanceof TVMModule) { - this.setHandle(i, v.handle, kTVMModuleHandle); - } else { - throwError("Unsupported argument type " + tp); - } - } - } - }; - // TVMType - var TYPE_CODE2STR = { - 0 : "int", - 1 : "uint", - 2 : "float", - 4 : "handle" - }; - - TVMType.prototype = { - toString : function() { - var ret = TYPE_CODE2STR[this.code] + this.bits.toString(); - if (this.lanes != 1) { - return ret + "x" + this.lanes.toString(); - } else { - return ret; - } - } - }; - // TVMFunction - TVMFunction.prototype = { - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMFuncFree(this.handle)); - this.handle = 0; - } - } - }; - // TVMContext - var CTX_MASK2STR = { - 1 : "cpu", - 2 : "gpu", - 4 : "opencl", - 7 : "vulkan", - 8 : "metal", - 9 : "vpi", - 11 : "opengl", - }; - var CTX_STR2MASK = { - "cpu": 1, - "gpu": 2, - "cuda": 2, - "cl": 4, - "opencl": 4, - "vulkan": 7, - "metal": 8, - "vpi": 9, - "opengl": 11, - }; - TVMContext.prototype = { - toString : function() { - return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")"; - } - }; - //----------------------------------------- - // Public Functions - // ---------------------------------------- - /** - * Construct a TVMContext given device type and id. - * - * @param {number} device_type, string or int, The device type. - * @param {number} device_id, the device id. - * @return {tvm.TVMContext} The created TVMContext - */ - this.context = function(device_type, device_id) { - if (typeof device_type == "string") { - device_type = CTX_STR2MASK[device_type]; - } - return new TVMContext(device_type, device_id); - }; - var context = this.context; - /** - * Create empty ndarray with given shape. - * - * @param {Array.} shape The shape of the array. - * @param {string} dtype The data type of the array, optional, default="float32" - * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0). - * @return {tvm.NDArray} The created ndarray. - */ - this.empty = function(shape, dtype, ctx) { - dtype = (typeof dtype !== "undefined") ? dtype: "float32"; - ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0); - shape = (typeof shape == "number") ? [shape] : shape; - // alloc - var cshape = Module._malloc(SIZEOF_INT64 * shape.length); - var out = new RefTVMValue(); - for (var i = 0; i < shape.length; ++i) { - Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64"); - } - dtype = getTVMType(dtype); - TVM_CALL(TVMArrayAlloc(cshape, shape.length, - dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, - out.data)); - var out_handle = out.asHandle(); - // release - Module._free(cshape); - out.release(); - return new NDArray(out_handle); - }; - /** - * List all global function names in the TVM runtime. - * @return {Array.} List of global function names. - */ - this.listGlobalFuncNames = function() { - // alloc - var out_size = new RefTVMValue(); - var out_array = new RefTVMValue(); - TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data)); - var length = out_size.asInt(); - var base = out_array.asHandle(); - var names = []; - for (var i = 0 ; i < length; ++i) { - names.push( - CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*"))); - } - // release - out_size.release(); - out_array.release(); - return names; - }; - var listGlobalFuncNames = this.listGlobalFuncNames; - /** - * Get a global function from TVM runtime. - * - * @param {string} The name of the function. - * @return {Function} The corresponding function, null if function do not exist - */ - this.getGlobalFunc = function (name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncGetGlobal(name, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return makeTVMFunction(out_handle); - } else { - return null; - } - }; - var getGlobalFunc = this.getGlobalFunc; - /** - * Register function to be global function in tvm runtime. - * @param {string} name The name of the function. - * @param {Function} f function to be registered. - * @param {boolean} override Whether overwrite function in existing registry. - */ - this.registerFunc = function(name, f, override) { - f = convertFunc(f); - override = (typeof override !== "undefined") ? override: false; - var ioverride = override ? 1 : 0; - TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride)); - }; - /** - * Create a typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * - * @param {number} value The value of the data. - * @param {string} dtype The data type. - * @param {tvm.TVMConstant} The created typed scalar. - */ - this.constant = function(value, dtype) { - return new TVMConstant(value, dtype); - }; - //----------------------------------------- - // Wrap of TVM Functions. - // ---------------------------------------- - var systemFunc = {}; - /** - * Get system-wide library module singleton.5A - * System lib is a global module that contains self register functions in startup. - * @return {tvm.TVMModule} The system module singleton. - */ - this.systemLib = function() { - if (typeof systemFunc.fGetSystemLib === "undefined") { - systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib"); - } - return systemFunc.fGetSystemLib(); - }; - - this.startRPCServer = function(url, key, counter) { - if (typeof key === "undefined") { - key = ""; - } - if (typeof counter === "undefined") { - counter = 1; - } - // Node js, import websocket - var bkey = StringToUint8Array("server:" + key); - bkey = bkey.slice(0, bkey.length - 1); - var server_name = "WebSocketRPCServer[" + key + "]"; - var RPC_MAGIC = 0xff271; - function checkEndian() { - var a = new ArrayBuffer(4); - var b = new Uint8Array(a); - var c = new Uint32Array(a); - b[0] = 0x11; - b[1] = 0x22; - b[2] = 0x33; - b[3] = 0x44; - CHECK(c[0] === 0x44332211, "Need little endian to work"); - } - checkEndian(); - // start rpc - function RPCServer(counter) { - var socket; - if (typeof module !== "undefined" && module.exports) { - // WebSocket for nodejs - const WebSocket = require("ws"); - socket = new WebSocket(url); - } else { - socket = new WebSocket(url); - } - var self = this; - socket.binaryType = "arraybuffer"; - this.init = true; - this.counter = counter; - - if (typeof systemFunc.fcreateServer === "undefined") { - systemFunc.fcreateServer = - getGlobalFunc("rpc.CreateEventDrivenServer"); - } - if (systemFunc.fcreateServer == null) { - throwError("RPCServer is not included in runtime"); - } - - var message_handler = systemFunc.fcreateServer( - function(cbytes) { - if (socket.readyState == 1) { - socket.send(cbytes); - return new TVMConstant(cbytes.length, "int32"); - } else { - return new TVMConstant(0, "int32"); - } - } , server_name, "%toinit"); - - function on_open(event) { - var intbuf = new Int32Array(1); - intbuf[0] = RPC_MAGIC; - socket.send(intbuf); - intbuf[0] = bkey.length; - socket.send(intbuf); - socket.send(bkey); - logging(server_name + " connected..."); - } - - function on_message(event) { - if (self.init) { - var msg = new Uint8Array(event.data); - CHECK(msg.length >= 4, "Need message header to be bigger than 4"); - var magic = new Int32Array(event.data)[0]; - - if (magic == RPC_MAGIC + 1) { - throwError("key: " + key + " has already been used in proxy"); - } else if (magic == RPC_MAGIC + 2) { - logging(server_name + ": RPCProxy do not have matching client key " + key); - } else { - CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy"); - self.init = false; - } - logging(server_name + "init end..."); - if (msg.length > 4) { - if (message_handler( - new Uint8Array(event.data, 4, msg.length -4), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } else { - if (message_handler(new Uint8Array(event.data), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } - function on_close(event) { - message_handler.release(); - logging(server_name + ": closed finish..."); - if (!self.init && self.counter != 0) { - logging(server_name + ":reconnect to serve another request, session left=" + counter); - // start a new server. - new RPCServer(counter - 1); - } - } - socket.addEventListener("open", on_open); - socket.addEventListener("message", on_message); - socket.addEventListener("close", on_close); - } - return new RPCServer(counter); - }; - - /** - * Load a TVM module from a library file. - * The file must be present in the Emscripten virtual file system. - * For example, you can pass "--preload-file file" or "--preload-file dir/" - * to "emcc" when compiling the TVM library, in order to populate files into - * the file system. - * For more detail, see: - * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files - * @param {string} file_name Path of the file to be loaded. The path refers - * to the Emscripten virtual file system. - * @param {string} format The format of the file. - * @return {tvm.TVMModule} The loaded module. - */ - this.loadModuleFromFile = function (file_name, format) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModLoadFromFile(file_name, format, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return new TVMModule(out_handle); - } else { - return null; - } - }; - var loadModuleFromFile = this.loadModuleFromFile; - - /** - * Wrapper runtime module. - * Wraps around set_input, load_params, run, and get_output. - * - * @class - * @memberof tvm - */ - function GraphModule(tvm_graph_module, ctx) { - CHECK(tvm_graph_module instanceof TVMModule, - "tvm_graph_module must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - this.tvm_graph_module = tvm_graph_module; - this.ctx = ctx; - this._set_input = tvm_graph_module.getFunction("set_input"); - this._load_params = tvm_graph_module.getFunction("load_params"); - this._run = tvm_graph_module.getFunction("run"); - this._get_output = tvm_graph_module.getFunction("get_output"); - }; - - GraphModule.prototype = { - /** - * Set input to graph module. - * - * @param {string} key The name of the input. - * @param {NDArray} value The input value. - */ - "set_input" : function(key, value) { - CHECK(typeof key == "string", "key must be string"); - CHECK(value instanceof NDArray, "value must be NDArray"); - this._set_input(key, value); - }, - - /** - * Load parameters from serialized byte array of parameter dict. - * - * @param {Uint8Array} params The serialized parameter dict. - */ - "load_params" : function(params) { - CHECK(params instanceof Uint8Array, "params must be Uint8Array"); - this._load_params(params); - }, - - /** - * Load parameters from serialized base64 string of parameter dict. - * - * @param {string} base64_params The serialized parameter dict. - */ - "load_base64_params" : function(base64_params) { - CHECK(typeof base64_params == "string", "base64_params must be string"); - var decoded_string = atob(base64_params); - var decoded_u8 = new Uint8Array(decoded_string.length); - for (var i = 0; i < decoded_string.length; i++) { - decoded_u8[i] = decoded_string[i].charCodeAt(0); - } - this.load_params(decoded_u8); - }, - - /** - * Run forward execution of the graph. - */ - "run" : function() { - this._run(); - }, - - /** - * Get index-th output to out. - * - * @param {NDArray} out The output array container. - * @return {NDArray} The output array container. - */ - "get_output" : function(index, out) { - CHECK(typeof index == "number", "index must be number"); - CHECK(out instanceof NDArray, "out must be NDArray"); - this._get_output(new TVMConstant(index, "int32"), out); - return out; - } - }; - - /** - * Create a runtime executor module given a graph and a module. - * @param {string} graph_json_str The Json string of the graph. - * @param {TVMModule} libmod The TVM module. - * @param {TVMContext} ctx The context to deploy the module. - * @return {GraphModule} Runtime graph module for executing the graph. - */ - this.createGraphRuntime = function(graph_json_str, libmod, ctx) { - CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); - CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - var fcreate = getGlobalFunc("tvm.graph_runtime.create"); - CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); - - var tvm_graph_module = fcreate(graph_json_str, libmod, - new TVMConstant(ctx.device_type, "int32"), - new TVMConstant(ctx.device_id, "int32")); - - return new GraphModule(tvm_graph_module, ctx); - }; - - //----------------------------------------- - // Class defintions - // ---------------------------------------- - // NDArray. - NDArray.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMArrayFree(this.handle)); - this.handle = 0; - } - }, - /** - * Copy data from another NDArray or javascript array. - * The number of elements must match. - * - * @param {Array} data The source data array. - */ - copyFrom : function(data) { - if (data instanceof NDArray) { - TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle)); - } else { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - if (data.length != size) { - throwError("data size and shape mismatch data.length" + data.length + " vs " + size); - } - if (this.dtype == "float32") { - data = Float32Array.from(data); - } else if (this.dtype == "float64") { - data = Float64Array.from(data); - } else if (this.dtype == "int32") { - data = Int32Array.from(data); - } else if (this.dtype == "int8") { - data = Int8Array.from(data); - } else if (this.dtype == "uint8") { - data = Uint8Array.from(data); - } else { - throwError("Unsupported data type " + this.dtype); - } - return this.copyFromRawBytes(new Uint8Array(data.buffer)); - } - }, - /** - * Copy data from raw bytes. - * @param {Uint8Array} data Uint8Array of bytes. - */ - copyFromRawBytes : function(data) { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var dtype = getTVMType(this.dtype); - var nbytes = this.BYTES_PER_ELEMENT * size; - CHECK(data instanceof Uint8Array); - CHECK(data.length == nbytes, - "Data length and bytes do not match " + data.length + - " vs " + nbytes); - var temp = Module._malloc(nbytes); - Module.HEAPU8.set(data, temp); - TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes)); - Module._free(temp); - return this; - }, - /** - * Return a copied Uint8Array of the raw bytes in the NDArray. - * @return {Uint8Array} The created array. - */ - asRawBytes : function() { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var nbytes = this.BYTES_PER_ELEMENT * size; - var temp = Module._malloc(nbytes); - TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes)); - var ret = new Uint8Array(new ArrayBuffer(nbytes)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes)); - Module._free(temp); - return ret; - }, - /** - * Return Array data content as javascript typed array. - * @return {TypedArray} The created array. - */ - asArray : function() { - if (this.dtype == "float32") { - return new Float32Array(this.asRawBytes().buffer); - } else if (this.dtype == "float64") { - return new Float64Array(this.asRawBytes().buffer); - } else if (this.dtype == "int32") { - return new Int32Array(this.asRawBytes().buffer); - } else if (this.dtype == "int8") { - return new Int8Array(this.asRawBytes().buffer); - } else if (this.dtype == "uint8") { - return new Uint8Array(this.asRawBytes().buffer); - } else { - throwError("Unsupported data type " + this.dtype); - } - } - }; - - TVMModule.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMModFree(this.handle)); - this.handle = 0; - } - }, - /** - * Get function from the module. - * @param {string} name The name of the function. - * @return {Function} The correspondin function. - */ - getFunction : function(name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle == 0) { - throwError("Module has no function " + name); - } - return makeTVMFunction(out_handle); - }, - /** - * Add module to the import list of current one. - * @param {tvm.TVMModule} mod The other module to be imported. - */ - import_module : function(mod) { - CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule"); - TVM_CALL(TVMModImport(this.handle, mod.handle)); - } - }; - //----------------------------------------- - // Static variables. - // ---------------------------------------- - /** Float32 type */ - this.float32 = "float32"; - /** Int32 type */ - this.int32 = "int32"; - } - /** - * Create a TVM runtime given emscripten module. - * @property {string} create - * @memberof tvm_runtime - * @param Module The emscripten module. - * @return {tvm.TVMRuntime} The created TVM runtime. - */ - this.create = function(Module) { - var tvm = {}; - tvm.Module = Module; - if (typeof Module.addFunction !== "undefined") { - tvm.Runtime = Module; - } else { - tvm.Runtime = Module.Runtime; - } - TVMRuntime.apply(tvm); - return tvm; - }; -}).apply(tvm_runtime); - -// export things in node -if (typeof module !== "undefined" && module.exports) { - module.exports = tvm_runtime; -} diff --git a/web/web_runtime.cc b/web/web_runtime.cc deleted file mode 100644 index 701ded76288e..000000000000 --- a/web/web_runtime.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file web_runtime.cc - */ -#include -#include - -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" -#include "../src/runtime/module.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" -#include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" -#include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/opengl/opengl_device_api.cc" -#include "../src/runtime/opengl/opengl_module.cc" - -namespace tvm { -namespace contrib { - -struct RPCEnv { - public: - RPCEnv() { - base_ = "/rpc"; - mkdir(&base_[0], 0777); - } - // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + "/" + file_name; - } - - private: - std::string base_; -}; - -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { - static RPCEnv env; - return env.GetPath(path); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { - std::string file_name = "/rpc/" + path; - LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); - }); -} // namespace contrib -} // namespace tvm - -// dummy parallel runtime -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMAPISetLastError("Parallel is not supported in Web runtime"); - return -1; -} - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -}