diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst
index b482d30515d4..8ac4e1ff7d3a 100644
--- a/docs/api/python/contrib.rst
+++ b/docs/api/python/contrib.rst
@@ -48,9 +48,9 @@ tvm.contrib.dlpack
.. automodule:: tvm.contrib.dlpack
:members:
-tvm.contrib.emscripten
-~~~~~~~~~~~~~~~~~~~~~~
-.. automodule:: tvm.contrib.emscripten
+tvm.contrib.emcc
+~~~~~~~~~~~~~~~~
+.. automodule:: tvm.contrib.emcc
:members:
tvm.contrib.miopen
diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py
index 0d1a4e214791..de8f7b565c09 100644
--- a/python/tvm/_ffi/libinfo.py
+++ b/python/tvm/_ffi/libinfo.py
@@ -88,6 +88,9 @@ def find_lib_path(name=None, search_path=None, optional=False):
dll_path.append(install_lib_dir)
+ if os.path.isdir(source_dir):
+ dll_path.append(os.path.join(source_dir, "web", "dist", "wasm"))
+
dll_path = [os.path.realpath(x) for x in dll_path]
if search_path is not None:
if isinstance(search_path, list):
@@ -154,6 +157,7 @@ def find_include_path(name=None, search_path=None, optional=False):
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")
install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..")
+
third_party_dir = os.path.join(source_dir, "3rdparty")
header_path = []
diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py
similarity index 65%
rename from python/tvm/contrib/emscripten.py
rename to python/tvm/contrib/emcc.py
index 7f31273451f7..6df205a030bc 100644
--- a/python/tvm/contrib/emscripten.py
+++ b/python/tvm/contrib/emcc.py
@@ -16,18 +16,16 @@
# under the License.
"""Util to invoke emscripten compilers in the system."""
# pylint: disable=invalid-name
-from __future__ import absolute_import as _abs
-
import subprocess
-from .._ffi.base import py_str
-from .._ffi.libinfo import find_lib_path
+from tvm._ffi.base import py_str
+from tvm._ffi.libinfo import find_lib_path
+
-def create_js(output,
- objects,
- options=None,
- side_module=False,
- cc="emcc"):
- """Create emscripten javascript library.
+def create_tvmjs_wasm(output,
+ objects,
+ options=None,
+ cc="emcc"):
+ """Create wasm that is supposed to run with the tvmjs.
Parameters
----------
@@ -44,25 +42,27 @@ def create_js(output,
The compile string.
"""
cmd = [cc]
- cmd += ["-Oz"]
- if not side_module:
- cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"]
- cmd += ["-s", "NO_EXIT_RUNTIME=1"]
- extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction']
- cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]"
- cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg]
- else:
- cmd += ["-s", "SIDE_MODULE=1"]
- cmd += ["-o", output]
+ cmd += ["-O3"]
+
+ cmd += ["-std=c++14"]
+ cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"]
+ cmd += ["-s", "STANDALONE_WASM=1"]
+ cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"]
+
+
objects = [objects] if isinstance(objects, str) else objects
+
with_runtime = False
for obj in objects:
- if obj.find("libtvm_web_runtime.bc") != -1:
+ if obj.find("wasm_runtime.bc") != -1:
with_runtime = True
- if not with_runtime and not side_module:
- objects += [find_lib_path("libtvm_web_runtime.bc")[0]]
+ if not with_runtime:
+ objects += [find_lib_path("wasm_runtime.bc")[0]]
+ objects += [find_lib_path("tvmjs_support.bc")[0]]
+
+ cmd += ["-o", output]
cmd += objects
if options:
@@ -79,4 +79,4 @@ def create_js(output,
msg += py_str(out)
raise RuntimeError(msg)
-create_js.object_format = "bc"
+create_tvmjs_wasm.object_format = "bc"
diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py
index 4cf341335ea7..59da8fa5a3bf 100644
--- a/python/tvm/exec/rpc_proxy.py
+++ b/python/tvm/exec/rpc_proxy.py
@@ -29,12 +29,11 @@
def find_example_resource():
"""Find resource examples."""
curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- base_path = os.path.join(curr_path, "../../../")
- index_page = os.path.join(base_path, "web/example_rpc.html")
+ base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
+ index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html")
js_files = [
- os.path.join(base_path, "web/tvm_runtime.js"),
- os.path.join(base_path, "build/libtvm_web_runtime.js"),
- os.path.join(base_path, "build/libtvm_web_runtime.js.mem")
+ os.path.join(base_path, "web/dist/tvmjs.bundle.js"),
+ os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js")
]
for fname in [index_page] + js_files:
if not os.path.exists(fname):
@@ -69,7 +68,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--host', type=str, default="0.0.0.0",
+ parser.add_argument('--host', type=str, default="localhost",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the RPC')
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 04d6c94e10e8..da3a456dafb6 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -36,6 +36,7 @@
"scala",
"java",
"go",
+ "ts",
"sh",
"py",
"pyi",
@@ -81,6 +82,7 @@
# List of file names allowed
ALLOW_FILE_NAME = {
".gitignore",
+ ".eslintignore",
".gitattributes",
"README",
"Makefile",
@@ -107,8 +109,7 @@
"rust/runtime/tests/test_wasm32/.cargo/config",
"apps/sgx/.cargo/config",
# html for demo purposes
- "tests/webgl/test_static_webgl_library.html",
- "web/example_rpc.html",
+ "web/apps/browser/rpc_server.html",
# images are normally not allowed
# discuss with committers before add more images
"apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png",
diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes
index 5421d22a08aa..0714850287f3 100644
--- a/tests/lint/rat-excludes
+++ b/tests/lint/rat-excludes
@@ -28,6 +28,8 @@ core.cpp
build
_static
_build
+node_modules
+dist
.*~
\#..*\#
\.#.*
@@ -40,6 +42,7 @@ RelayVisitor.py
# Specific files
package-list
MANIFEST
+.eslintignore
.gitignore
.gitattributes
.gitmodules
diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh
index 819961dc6ebd..41006f41f754 100755
--- a/tests/scripts/task_python_docs.sh
+++ b/tests/scripts/task_python_docs.sh
@@ -43,9 +43,6 @@ cd ..
make doc
rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5
-# JS doc
-jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md
-
# Java doc
make javadoc
@@ -54,7 +51,6 @@ rm -rf _docs
mv docs/_build/html _docs
rm -f _docs/.buildinfo
mv docs/doxygen/html _docs/doxygen
-mv out _docs/jsdoc
mv jvm/core/target/site/apidocs _docs/javadoc
echo "Start creating the docs tarball.."
diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js
deleted file mode 100644
index d239f7346e74..000000000000
--- a/tests/web/test_packed_func.js
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-// Load Emscripten Module, need to change path to root/build
-const path = require("path");
-process.chdir(path.join(__dirname, "../../build"));
-var Module = require("../../build/libtvm_web_runtime.js");
-// Bootstrap TVMruntime with emscripten module.
-const tvm_runtime = require("../../web/tvm_runtime.js");
-const tvm = tvm_runtime.create(Module);
-
-function testGetGlobal() {
- var targs = [10, 10.0, "hello"]
- tvm.registerFunc("my_packed_func", function () {
- tvm.assert(Array.from(arguments).toString() == targs, "assert fail");
- return 10
- });
- var f = tvm.getGlobalFunc("my_packed_func")
- tvm.assert(tvm.isPackedFunc(f));
- y = f.apply(null, targs);
- tvm.assert(y == 10);
- f.release();
-}
-
-
-function testReturnFunc() {
- function addy(y) {
- function add(x) {
- return x + y;
- }
- return add;
- }
- var myf = tvm.convertFunc(addy);
- var f = myf(10);
- tvm.assert(tvm.isPackedFunc(f));
- tvm.assert(f(11) == 21);
- myf.release();
- f.release();
-}
-
-function testByteArray() {
- var a = new Uint8Array(3);
- a[0] = 1;
- a[1] = 2;
- function myfunc(ss){
- tvm.assert(ss instanceof Uint8Array);
- tvm.assert(ss.toString() == a);
- }
- f = tvm.convertFunc(myfunc);
- f(a);
- f.release();
-}
-
-testGetGlobal();
-testReturnFunc();
-testByteArray();
diff --git a/tests/webgl/README.md b/tests/webgl/README.md
deleted file mode 100644
index 5303cc059740..000000000000
--- a/tests/webgl/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-## Test cases for the WebGL backend
-
-Any test case with name `test_local_...` tests the C++ OpenGL backend on the
-local OS, which can be executed automatically.
-
-Any test case with name `test_remote_...` tests the WebGL backend within the
-browser, which must be run manually. See instruction within the test.
diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py
deleted file mode 100644
index 6bd22bf0057b..000000000000
--- a/tests/webgl/test_local_gemm.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-import numpy as np
-
-def test_local_gemm():
- if not tvm.runtime.enabled("opengl"):
- return
- if not tvm.runtime.enabled("llvm"):
- return
-
- nn = 1024
- n = te.var('n')
- n = tvm.runtime.convert(nn)
- m = n
- l = n
- A = te.placeholder((n, l), name='A', dtype='int32')
- B = te.placeholder((m, l), name='B', dtype='int32')
- k = te.reduce_axis((0, l), name='k')
- C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k),
- name='CC')
-
- s = te.create_schedule(C.op)
- s[C].opengl()
- print(tvm.lower(s, [A, B, C], simple_mode=True))
-
- f = tvm.build(s, [A, B, C], "opengl", name="gemm")
- print("------opengl code------")
- print(f.imported_modules[0].get_source(fmt="gl"))
-
- ctx = tvm.opengl()
- n, m, l = nn, nn, nn
- a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype)
- b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(b_np, ctx)
- c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
- f(a, b, c)
-
- tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T))
-
-if __name__ == "__main__":
- test_local_gemm()
diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py
deleted file mode 100644
index cca68020c0c2..000000000000
--- a/tests/webgl/test_local_save_load.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import numpy as np
-import tvm
-from tvm import te
-from tvm import rpc
-from tvm.contrib import util, emscripten
-
-def test_local_save_load():
- if not tvm.runtime.enabled("opengl"):
- return
- if not tvm.runtime.enabled("llvm"):
- return
-
- n = te.var("n")
- A = te.placeholder((n,), name='A', dtype='int32')
- B = te.placeholder((n,), name='B', dtype='int32')
- C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
- s = te.create_schedule(C.op)
- s[C].opengl()
-
- f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd")
-
- ctx = tvm.opengl(0)
- n = 10
- a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx)
- b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx)
- c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx)
- f(a, b, c)
-
- temp = util.tempdir()
- path_so = temp.relpath("myadd.so")
- f.export_library(path_so)
- f1 = tvm.runtime.load_module(path_so)
- f1(a, b, c)
- tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
-
-if __name__ == "__main__":
- test_local_save_load()
diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py
deleted file mode 100644
index 0d9b7776096a..000000000000
--- a/tests/webgl/test_local_topi_conv2d_nchw.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Example code to do convolution.
-Copied from topi/tests/python/test_topi_conv2d_nchw.py.
-Should be removed once we fix OpenGL testing on Jenkins."""
-import os
-import numpy as np
-import tvm
-from tvm import te
-import topi
-from tvm.contrib.pickle_memoize import memoize
-from topi.util import get_const_tuple
-
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
- in_height = in_width = in_size
-
- A = te.placeholder((batch, in_channel, in_height, in_width), name='A')
- W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W')
- B = topi.nn.conv2d_nchw(A, W, stride, padding)
- C = topi.nn.relu(B)
-
- a_shape = get_const_tuple(A.shape)
- w_shape = get_const_tuple(W.shape)
- dtype = A.dtype
-
- @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw")
- def get_ref_data():
- a_np = np.random.uniform(size=a_shape).astype(dtype)
- w_np = np.random.uniform(size=w_shape).astype(dtype)
- b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
- c_np = np.maximum(b_np, 0)
- return a_np, w_np, b_np, c_np
-
- a_np, w_np, b_np, c_np = get_ref_data()
-
- def check_device(device):
- ctx = tvm.context(device, 0)
- if not ctx.exist:
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s1 = topi.generic.schedule_conv2d_nchw([B])
- s2 = topi.generic.schedule_conv2d_nchw([C])
- a = tvm.nd.array(a_np, ctx)
- w = tvm.nd.array(w_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
- c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
- with tvm.target.build_config(auto_unroll_max_step=1400,
- unroll_explicit=(device != "cuda")):
- func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
- func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
- func1(a, w, b)
- func2(a, w, c)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
- tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
-
- for device in ['opengl']:
- check_device(device)
-
-
-def test_conv2d_nchw():
- # ResNet18 worklaods
- verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
- verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
- verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
- verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
- verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
- verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
- verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
- verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
- verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
- verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
- verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
- verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
- # Vgg16 workloads
- verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
- # Super resolution workloads
- verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
- verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
- verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
- verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
-
-if __name__ == "__main__":
- test_conv2d_nchw()
diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py
deleted file mode 100644
index 60dfe1ff690f..000000000000
--- a/tests/webgl/test_local_topi_dense.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Test code for dense operator
-Copied from topi/tests/python/test_topi_dense.py.
-Should be removed once we fix OpenGL testing on Jenkins.
-"""
-import numpy as np
-import tvm
-from tvm import te
-import topi
-from topi.util import get_const_tuple
-from tvm.contrib.pickle_memoize import memoize
-
-
-def verify_dense(batch, in_dim, out_dim, use_bias=True):
- A = te.placeholder((batch, in_dim), name='A')
- B = te.placeholder((out_dim, in_dim), name='B')
- C = te.placeholder((out_dim,), name='C')
- D = topi.nn.dense(A, B, C if use_bias else None)
- D = topi.nn.relu(D)
- dtype = A.dtype
-
- # use memoize to pickle the test data for next time use
- @memoize("topi.tests.test_topi_dense")
- def get_ref_data():
- a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
- b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
- c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
- if use_bias:
- d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
- else:
- d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
- return (a_np, b_np, c_np, d_np)
- # get the test data
- a_np, b_np, c_np, d_np = get_ref_data()
-
- def check_device(device):
- if not tvm.runtime.enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s = topi.generic.schedule_dense(D)
- ctx = tvm.context(device, 0)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(b_np, ctx)
- c = tvm.nd.array(c_np, ctx)
- d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
- f = tvm.build(s, [A, B, C, D], device, name="dense")
- f(a, b, c, d)
- tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
-
- for device in ['opengl']:
- check_device(device)
-
-def test_dense():
- verify_dense(1, 1024, 1000, use_bias=True)
- verify_dense(1, 1024, 1000, use_bias=False)
-
-
-if __name__ == "__main__":
- test_dense()
diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py
deleted file mode 100644
index 3adae7bba51c..000000000000
--- a/tests/webgl/test_local_topi_pooling.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Test code for pooling
-Copied from topi/tests/python/test_topi_pooling.py.
-Should be removed once we fix OpenGL testing on Jenkins.
-"""
-import numpy as np
-import tvm
-from tvm import te
-import topi
-import math
-from topi.util import get_const_tuple
-
-def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
- iw = ih
- kw = kh
- sw = sh
- ph, pw = padding
- A = te.placeholder((n, ic, ih, iw), name='A')
- B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
- pool_type=pool_type, ceil_mode=ceil_mode)
- B = topi.nn.relu(B)
- dtype = A.dtype
-
- bshape = get_const_tuple(B.shape)
- ashape = get_const_tuple(A.shape)
- if ceil_mode:
- assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
- assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
- else:
- assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
- assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
-
-
- a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
- pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
- no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
- pad_np[np.ix_(*no_zero)] = a_np
- _, oc, oh, ow = get_const_tuple(B.shape)
- b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
-
- if pool_type == 'avg':
- for i in range(oh):
- for j in range(ow):
- b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
- elif pool_type =='max':
- for i in range(oh):
- for j in range(ow):
- b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
- b_np = np.maximum(b_np, 0.0)
-
- def check_device(device):
- if not tvm.runtime.enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s = topi.generic.schedule_pool(B)
- ctx = tvm.context(device, 0)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
- print(tvm.lower(s, [A, B], simple_mode=True))
-
- f = tvm.build(s, [A, B], device)
- f(a, b)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
- for device in ['opengl']:
- check_device(device)
-
-def test_pool():
- verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False)
- verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False)
- verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
- verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
- verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
-
-
-
-def verify_global_pool(n, c, h, w, pool_type):
- A = te.placeholder((n, c, h, w), name='A')
- B = topi.nn.global_pool(A, pool_type=pool_type)
- B = topi.nn.relu(B)
-
- a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
- if pool_type == 'avg':
- b_np = np.mean(a_np, axis=(2,3), keepdims=True)
- elif pool_type =='max':
- b_np = np.max(a_np, axis=(2,3), keepdims=True)
- b_np = np.maximum(b_np, 0.0)
-
- def check_device(device):
- if not tvm.runtime.enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s = topi.generic.schedule_global_pool(B)
- ctx = tvm.context(device, 0)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
- f = tvm.build(s, [A, B], device)
- f(a, b)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
- for device in ['opengl']:
- check_device(device)
-
-def test_global_pool():
- verify_global_pool(1, 1024, 7, 7, 'avg')
- verify_global_pool(4, 1024, 7, 7, 'avg')
- verify_global_pool(1, 1024, 7, 7, 'max')
- verify_global_pool(4, 1024, 7, 7, 'max')
-
-
-if __name__ == "__main__":
- test_pool()
- test_global_pool()
diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py
deleted file mode 100644
index c0ddbf21419a..000000000000
--- a/tests/webgl/test_local_topi_softmax.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Test code for softmax
-Copied from topi/tests/python/test_topi_softmax.py.
-Should be removed once we fix OpenGL testing on Jenkins.
-"""
-
-import os
-import numpy as np
-import tvm
-from tvm import te
-import topi
-import logging
-from topi.util import get_const_tuple
-
-def verify_softmax(m, n):
- A = te.placeholder((m, n), name='A')
- B = topi.nn.softmax(A)
- # confirm lower works
- s = te.create_schedule([B.op])
- tvm.lower(s, [A, B], simple_mode=True)
-
- a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
- b_np = topi.testing.softmax_python(a_np)
-
- def check_device(device):
- if not tvm.runtime.enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s = topi.generic.schedule_softmax(B)
- ctx = tvm.context(device, 0)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
- foo = tvm.build(s, [A, B], device, name="softmax")
- foo(a, b)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
- for device in ["opengl"]:
- check_device(device)
-
-def test_softmax():
- verify_softmax(32, 10)
- verify_softmax(3, 4)
-
-
-def verify_log_softmax(m, n):
- A = te.placeholder((m, n), name='A')
- B = topi.nn.log_softmax(A)
- # confirm lower works
- s = te.create_schedule([B.op])
- tvm.lower(s, [A, B], simple_mode=True)
- a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
- b_np = topi.testing.log_softmax_python(a_np)
-
- def check_device(device):
- if not tvm.runtime.enabled(device):
- print("Skip because %s is not enabled" % device)
- return
- print("Running on target: %s" % device)
- with tvm.target.create(device):
- s = topi.generic.schedule_softmax(B)
- ctx = tvm.context(device, 0)
- a = tvm.nd.array(a_np, ctx)
- b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
- foo = tvm.build(s, [A, B], device, name="log_softmax")
- foo(a, b)
- tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
- for device in ["opengl"]:
- check_device(device)
-
-
-def test_log_softmax():
- verify_log_softmax(32, 10)
- verify_log_softmax(3, 4)
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.DEBUG)
- test_softmax()
- test_log_softmax()
diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py
deleted file mode 100644
index 34bbb3fa0f00..000000000000
--- a/tests/webgl/test_remote_save_load.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-The following instruction is based on web/README.md.
-
-Setup an RPC server:
-$ python -m tvm.exec.rpc_proxy --example-rpc=1
-
-Go to http://localhost:9190 in browser.
-
-Click "Connect To Proxy".
-
-Run this test script:
-$ python tests/webgl/test_remote_save_load.py
-"""
-
-import numpy as np
-import tvm
-from tvm import te
-from tvm import rpc
-from tvm.contrib import util, emscripten
-
-proxy_host = "localhost"
-proxy_port = 9090
-
-def try_remote_save_load():
- if not tvm.runtime.enabled("rpc"):
- return
- if not tvm.runtime.enabled("opengl"):
- return
- if not tvm.runtime.enabled("llvm"):
- return
-
- # Build the module.
- n = te.var("n")
- A = te.placeholder((n,), name='A')
- B = te.placeholder((n,), name='B')
- C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
- s = te.create_schedule(C.op)
- s[C].opengl()
- target_host = "llvm -target=asmjs-unknown-emscripten -system-lib"
- f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd")
-
- remote = rpc.connect(proxy_host, proxy_port, key="js")
-
- temp = util.tempdir()
- ctx = remote.opengl(0)
- path_obj = temp.relpath("myadd.bc")
- path_dso = temp.relpath("myadd.js")
- path_gl = temp.relpath("myadd.gl")
- path_json = temp.relpath("myadd.tvm_meta.json")
-
- f.save(path_obj)
- emscripten.create_js(path_dso, path_obj, side_module=True)
- f.imported_modules[0].save(path_gl)
-
- remote.upload(path_dso, "myadd.dso")
- remote.upload(path_gl)
- remote.upload(path_json)
-
- remote.download("myadd.dso")
- remote.download("myadd.gl")
- remote.download("myadd.tvm_meta.json")
-
- print('Loading myadd.dso')
- fhost = remote.load_module("myadd.dso")
-
- print('Loading myadd.gl')
- fdev = remote.load_module("myadd.gl")
-
- print('import_module')
- fhost.import_module(fdev)
-
- print('running...')
- a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx)
- b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx)
- c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx)
- fhost(a, b, c)
- tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
-
-if __name__ == "__main__":
- try_remote_save_load()
diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html
deleted file mode 100644
index f9268c65edf3..000000000000
--- a/tests/webgl/test_static_webgl_library.html
+++ /dev/null
@@ -1,72 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- TVM RPC Test Page
-
-
-
-
TVM Test Page
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py
deleted file mode 100644
index 929da4ca294c..000000000000
--- a/tests/webgl/test_static_webgl_library.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Create a static WebGL library and run it in the browser."""
-
-from __future__ import absolute_import, print_function
-
-import os, shutil, SimpleHTTPServer, SocketServer
-import tvm
-from tvm import te
-from tvm.contrib import emscripten, util
-import numpy as np
-
-def try_static_webgl_library():
- curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-
- # Change to lib/ which contains "libtvm_runtime.bc".
- os.chdir(os.path.join(curr_path, "../../lib"))
-
- # Create OpenGL module.
- n = te.var("n")
- A = te.placeholder((n,), name='A', dtype="float")
- B = te.compute((n,), lambda *i: A[i], name="B")
-
- s = te.create_schedule(B.op)
- s[B].opengl()
-
- target_host = "llvm -target=asmjs-unknown-emscripten -system-lib"
- f = tvm.build(s, [A, B], name="identity", target="opengl",
- target_host=target_host)
-
- # Create a JS library that contains both the module and the tvm runtime.
- path_dso = "identity_static.js"
- f.export_library(path_dso, emscripten.create_js, options=[
- "-s", "USE_GLFW=3",
- "-s", "USE_WEBGL2=1",
- "-lglfw",
- ])
-
- # Create "tvm_runtime.js" and "identity_static.html" in lib/
- shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"),
- "tvm_runtime.js")
- shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"),
- "identity_static.html")
-
- port = 8080
- handler = SimpleHTTPServer.SimpleHTTPRequestHandler
- httpd = SocketServer.TCPServer(("", port), handler)
- print("Please open http://localhost:" + str(port) + "/identity_static.html")
- httpd.serve_forever()
-
-if __name__ == "__main__":
- try_static_webgl_library()
diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py
index 45f0599bae4d..a7e55320d6dd 100644
--- a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py
+++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py
@@ -137,7 +137,8 @@ def test_conv2d_nhwc_winograd_direct():
def test_conv2d_nhwc_winograd_tensorcore():
"""Test the conv2d with winograd for nhwc layout"""
- print("test_winograd_tensorcore...")
+ if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+ return
verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore")
verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore")
verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore")
@@ -145,8 +146,7 @@ def test_conv2d_nhwc_winograd_tensorcore():
verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore")
verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore")
+
if __name__ == "__main__":
test_conv2d_nhwc_winograd_direct()
-
- if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
- test_conv2d_nhwc_winograd_tensorcore()
+ test_conv2d_nhwc_winograd_tensorcore()
diff --git a/web/.eslintignore b/web/.eslintignore
new file mode 100644
index 000000000000..1521c8b7652b
--- /dev/null
+++ b/web/.eslintignore
@@ -0,0 +1 @@
+dist
diff --git a/web/.gitignore b/web/.gitignore
new file mode 100644
index 000000000000..a3135cf24b9d
--- /dev/null
+++ b/web/.gitignore
@@ -0,0 +1,6 @@
+.vscode
+*~
+out
+node_modules
+package-lock.json
+build
diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json
deleted file mode 100644
index 33783b3bbb21..000000000000
--- a/web/.jsdoc_conf.json
+++ /dev/null
@@ -1,7 +0,0 @@
-{
- "templates": {
- "default": {
- "includeDate": false
- }
- }
-}
diff --git a/web/Makefile b/web/Makefile
new file mode 100644
index 000000000000..be7fa193c04c
--- /dev/null
+++ b/web/Makefile
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+TVM_ROOT=$(shell cd ..; pwd)
+
+INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\
+ -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include
+
+.PHONY: clean all
+
+all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js
+
+EMCC = emcc
+
+EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \
+ -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0
+
+EMCC_LDFLAGS = --pre-js emcc/preload.js
+
+dist/wasm/%.bc: emcc/%.cc
+ @mkdir -p $(@D)
+ $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d
+ $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $<
+
+
+dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc
+ @mkdir -p $(@D)
+ $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS)
+
+
+dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py
+ python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@
+
+clean:
+ @rm -rf dist/wasm
+
+-include dist/wasm/*.d
diff --git a/web/README.md b/web/README.md
index 5dfd6917934b..66a64a3d3d37 100644
--- a/web/README.md
+++ b/web/README.md
@@ -15,163 +15,70 @@
-# TVM WebAssembly and Javascript Backend
+# TVM WebAssembly Runtime
-This folder contains TVM WebAssembly and Javascript backend through Emscripten.
+This folder contains TVM WebAssembly Runtime.
## Installation
-While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other
-system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against
-the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten```
-as a backend target.
+
+The LLVM main branch support webassembly as a target, we can directly
+build TVM with LLVM mainline to generate wasm modules.
+Note that, however, we still need emscripten to compile the runtime and provide system library support.
+
+Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install).
### Setup Emscripten
-Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html)
-to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document.
-```bash
-./emsdk update
-./emsdk install latest
-./emsdk activate latest
-```
+We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy
+to the browser environment.
-Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library.
-Which can be installed via following command.
+Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment.
-```bash
-./emsdk install clang-incoming-64bit
-./emsdk activate clang-incoming-64bit
-```
+### Build TVM Wasm Runtime
-### Setup Environment Variable
+After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder.
-In normal setting, we can setup the necessary environment variable with the following command.
```bash
-source /path-to-emsdk-portable/emsdk_env.sh
+make
```
-However, this will put emscripten's clang and llvm path ahead of the current system path.
-What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones.
-You can get the detailed path by type ```./emsdk activate```
-```bash
-export PATH=${PATH}:/emsdk-related-path-here
+This command will create the follow files:
+- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into.
+- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes.
+- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime.
-```
-### Build TVM with Fastcomp LLVM
+### Build TVM Wasm JS Frontend
-To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk```
-to point to fastcomp's llvm-config and build TVM normally.
+Type the following command in the web folder.
```bash
-LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config
+npm run bundle
```
-### Build TVM Web Runtime
+This command will create the tvmjs library that we can use to interface with the wasm runtime.
-The above command gives us the TVM compiling environment. Now we need to build runtime,
-to do so, make sure we set the environment correctly as in previous section and type
-```bash
-make web
-```
+## Use TVM to Generate Wasm Library and Run it
-This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```.
-
-## Use TVM to Generate Javascript Library
-
-The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```.
-
-The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate
-the compilation process.
-
-```python
-import tvm
-from tvm import te
-from tvm.contrib import emscripten
-import os
-def prepare_test_libs(base_path):
- target = "llvm -target=asmjs-unknown-emscripten -system-lib"
- if not tvm.runtime.enabled(target):
- raise RuntimeError("Target %s is not enbaled" % target)
- n = te.var("n")
- A = te.placeholder((n,), name='A')
- B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
- s = te.create_schedule(B.op)
- fadd1 = tvm.build(s, [A, B], target, name="add_one")
- obj_path = os.path.join(base_path, "test_add_one.bc")
- fadd1.save(obj_path)
- emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path)
-
-if __name__ == "__main__":
- curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- prepare_test_libs(os.path.join(curr_path, "../../build"))
-```
+Check code snippet in
-In this workflow, we use TVM to generate a ```.bc``` file and statically link
-that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that).
-The result js library is a library that contains both TVM runtime and the compiled function.
-
-
-## Run the Generated Library
-
-The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate
-how to run the compiled library.
-
-```js
-// Load Emscripten Module, need to change path to root/build
-const path = require("path");
-process.chdir(path.join(__dirname, "../../build"));
-var Module = require("../../build/test_module.js");
-// Bootstrap TVMruntime with emscripten module.
-const tvm_runtime = require("../../web/tvm_runtime.js");
-const tvm = tvm_runtime.create(Module);
-
-// Load system library, the compiled function is registered in sysLib.
-var sysLib = tvm.systemLib();
-
-function randomArray(length, max) {
- return Array.apply(null, Array(length)).map(function() {
- return Math.random() * max;
- });
-}
-
-function testAddOne() {
- // grab pre-loaded function
- var faddOne = sysLib.getFunction("add_one");
- var assert = require('assert');
- tvm.assert(tvm.isPackedFunc(faddOne));
- var n = 124;
- var A = tvm.empty(n).copyFrom(randomArray(n, 1));
- var B = tvm.empty(n);
- // call the function.
- faddOne(A, B);
- AA = A.asArray(); // retrieve values in js array
- BB = B.asArray(); // retrieve values in js array
- // verify
- for (var i = 0; i < BB.length; ++i) {
- assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5);
- }
- faddOne.release();
-}
-
-testAddOne();
-sysLib.release();
-console.log("Finish verifying test_module_load");
-```
+- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py)
+ shows how to create a wasm library that links with tvm runtime.
+ - Note that all wasm libraries have to created using the `--system-lib` option
+ - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc`
+- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate
+ how to run the generated library through tvmjs API.
-Current example supports static linking, which is the preferred way to get more efficiency
-in javascript backend.
-## Proxy based RPC
+## Run Wasm Remotely through WebSocket RPC.
-We can now use javascript end to start an RPC server and connect to it from python side,
+We can now use js side to start an RPC server and connect to it from python side,
making the testing flow easier.
-The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install)
-- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
-- Open broswer, goto the server webpage click Connect to proxy.
- - Alternatively run "node web/example_rpc_node.js"
-- run "python tests/web/websock_rpc_test.py" to run the rpc client.
-
-The general idea is to use Emscripten's dynamic linking to dynamically load modules.
+The following is an example to reproduce this.
+- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy.
+- Start the WebSocket RPC
+ - Browswer version: open https://localhost:8888, click connect to proxy
+ - NodeJS version: `npm run rpc`
+- run `python tests/node/websock_rpc_test.py` to run the rpc client.
diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html
new file mode 100644
index 000000000000..22907f1561d1
--- /dev/null
+++ b/web/apps/browser/rpc_server.html
@@ -0,0 +1,79 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ TVM RPC Test Page
+
+
+
+
+
+
TVM WebSocket RPC Server
+ To use this page
+
+
Run "make" and "npm run bundle" to create the libraries.
+
+ run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
+
+
Click Connect to proxy.
+
run "python tests/python/websock_rpc_test.py" to run the rpc client.
+
+
+
Options
+ Proxy URL
+ RPC Server Key
+
+
+
+
+
+
diff --git a/web/apps/node/example.js b/web/apps/node/example.js
new file mode 100644
index 000000000000..f81a9c903e5d
--- /dev/null
+++ b/web/apps/node/example.js
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/**
+ * Example code to start the runtime.
+ */
+const path = require("path");
+const fs = require("fs");
+const tvmjs = require("../../dist");
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
+// Here we pass the javascript module generated by emscripten as the
+// LibraryProvider to provide WASI related libraries.
+// the async version of the API.
+tvmjs.instantiate(wasmSource, new EmccWASI())
+.then((tvm) => {
+ // List all the global functions from the runtime.
+ console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames());
+});
+
diff --git a/web/apps/node/wasi_example.js b/web/apps/node/wasi_example.js
new file mode 100644
index 000000000000..95ec2e0b1d07
--- /dev/null
+++ b/web/apps/node/wasi_example.js
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/**
+ * Example code to start the runtime.
+ */
+const { WASI } = require('wasi');
+const path = require("path");
+const fs = require("fs");
+const tvmjs = require("../../dist");
+
+const wasmPath = tvmjs.wasmPath();
+const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
+
+const wasi = new WASI({ args: process.argv, env: process.env });
+// Here we pass the javascript module generated by emscripten as the
+// LibraryProvider to provide WASI related libraries.
+const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi);
+
+// List all the global functions from the runtime.
+console.log("Runtime using WASI\n", tvm.listGlobalFuncNames());
diff --git a/web/example_rpc_node.js b/web/apps/node/wasi_rpc_server.js
similarity index 60%
rename from web/example_rpc_node.js
rename to web/apps/node/wasi_rpc_server.js
index 45f917a3234b..eb4c6ed52be9 100644
--- a/web/example_rpc_node.js
+++ b/web/apps/node/wasi_rpc_server.js
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -17,17 +17,20 @@
* under the License.
*/
-// Javascript RPC server example
-// Start and connect to websocket proxy.
+/**
+ * Example code to start the RPC server on nodejs using WASI
+ */
+const { WASI } = require("wasi");
+const tvmjs = require("../../dist");
+
+// Get import returns a fresh library in each call.
+const getImports = () => {
+ return new WASI({
+ args: process.argv,
+ env: process.env
+ });
+};
-// Load Emscripten Module, need to change path to root/lib
-const path = require("path");
-process.chdir(path.join(__dirname, "../lib"));
-var Module = require("../lib/libtvm_web_runtime.js");
-// Bootstrap TVMruntime with emscripten module.
-const tvm_runtime = require("../web/tvm_runtime.js");
-const tvm = tvm_runtime.create(Module);
+const proxyUrl = "ws://localhost:8888/ws";
-var websock_proxy = "ws://localhost:9190/ws";
-var num_sess = 100;
-tvm.startRPCServer(websock_proxy, "js", num_sess)
+new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log);
diff --git a/tests/webgl/test_local_multi_stage.py b/web/emcc/decorate_as_wasi.py
similarity index 50%
rename from tests/webgl/test_local_multi_stage.py
rename to web/emcc/decorate_as_wasi.py
index 54a554b74ed9..741e33bb22ea 100644
--- a/tests/webgl/test_local_multi_stage.py
+++ b/web/emcc/decorate_as_wasi.py
@@ -14,34 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import tvm
-from tvm import te
-import numpy as np
+"""Decorate emcc generated js to a WASI compatible API."""
-def test_local_multi_stage():
- if not tvm.runtime.enabled("opengl"):
- return
- if not tvm.runtime.enabled("llvm"):
- return
+import sys
- n = te.var("n")
- A = te.placeholder((n,), name='A', dtype="int32")
- B = te.compute((n,), lambda i: A[i] + 1, name="B")
- C = te.compute((n,), lambda i: B[i] * 2, name="C")
+template_head = """
+function EmccWASI() {
+"""
- s = te.create_schedule(C.op)
- s[B].opengl()
- s[C].opengl()
+template_tail = """
+ this.Module = Module;
+ this.start = Module.wasmLibraryProvider.start;
+ this.imports = Module.wasmLibraryProvider.imports;
+ this.wasiImport = this.imports["wasi_snapshot_preview1"];
+}
- f = tvm.build(s, [A, C], "opengl", name="multi_stage")
-
- ctx = tvm.opengl(0)
- n = 10
- a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
- c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
- f(a, c)
-
- tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
+if (typeof module !== "undefined" && module.exports) {
+ module.exports = EmccWASI;
+}
+"""
if __name__ == "__main__":
- test_local_multi_stage()
+ if len(sys.argv) != 3:
+ print("Usage ")
+ result = template_head + open(sys.argv[1]).read() + template_tail
+ with open(sys.argv[2], "w") as fo:
+ fo.write(result)
diff --git a/web/emcc/preload.js b/web/emcc/preload.js
new file mode 100644
index 000000000000..882280f9cac0
--- /dev/null
+++ b/web/emcc/preload.js
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/* eslint-disable no-unused-vars */
+/**
+ * JS config used by --pre-js in emcc.
+ * Wrap module as a LibraryProvider.
+ */
+
+var __wasmLib = {};
+
+function __wasmLibInstantiateWasm(imports, successCallback) {
+ __wasmLib.imports = imports;
+ __wasmLib.successCallback = successCallback;
+}
+
+function __wasmLibStart(wasmInstance) {
+ __wasmLib.successCallback(wasmInstance);
+}
+
+__wasmLib.start = __wasmLibStart;
+
+var Module = {
+ "instantiateWasm": __wasmLibInstantiateWasm,
+ "wasmLibraryProvider": __wasmLib
+};
diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc
new file mode 100644
index 000000000000..97099e75f16f
--- /dev/null
+++ b/web/emcc/tvmjs_support.cc
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file tvmjs_support.cc
+ * \brief Support functions to be linked with wasm_runtime to provide
+ * PackedFunc callbacks in tvmjs.
+ * We do not need to link this file in standalone wasm.
+ */
+
+// configurations for the dmlc log.
+#define DMLC_LOG_CUSTOMIZE 0
+#define DMLC_LOG_STACK_TRACE 0
+#define DMLC_LOG_DEBUG 0
+#define DMLC_LOG_NODATE 1
+#define DMLC_LOG_FATAL_THROW 0
+
+
+#include
+#include
+#include
+#include
+#include
+
+extern "C" {
+// --- Additional C API for the Wasm runtime ---
+/*!
+ * \brief Allocate space aligned to 64 bit.
+ * \param size The size of the space.
+ * \return The allocated space.
+ */
+TVM_DLL void* TVMWasmAllocSpace(int size);
+
+/*!
+ * \brief Free the space allocated by TVMWasmAllocSpace.
+ * \param data The data pointer.
+ */
+TVM_DLL void TVMWasmFreeSpace(void* data);
+
+/*!
+ * \brief Create PackedFunc from a resource handle.
+ * \param resource_handle The handle to the resource.
+ * \param out The output PackedFunc.
+ * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
+3A * \return 0 if success.
+ */
+TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle,
+ TVMFunctionHandle *out);
+
+// --- APIs to be implemented by the frontend. ---
+/*!
+ * \brief Wasm frontend packed function caller.
+ *
+ * \param args The arguments
+ * \param type_codes The type codes of the arguments
+ * \param num_args Number of arguments.
+ * \param ret The return value handle.
+ * \param resource_handle The handle additional resouce handle from fron-end.
+ * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
+ */
+extern int TVMWasmPackedCFunc(TVMValue* args,
+ int* type_codes,
+ int num_args,
+ TVMRetValueHandle ret,
+ void* resource_handle);
+
+/*!
+ * \brief Wasm frontend resource finalizer.
+ * \param resource_handle The pointer to the external resource.
+ */
+extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
+} // extern "C"
+
+
+void* TVMWasmAllocSpace(int size) {
+ int num_count = (size + 7) / 8;
+ return new int64_t[num_count];
+}
+
+void TVMWasmFreeSpace(void* arr) {
+ delete[] static_cast(arr);
+}
+
+int TVMWasmFuncCreateFromCFunc(void* resource_handle,
+ TVMFunctionHandle *out) {
+ return TVMFuncCreateFromCFunc(
+ TVMWasmPackedCFunc, resource_handle,
+ TVMWasmPackedCFuncFinalizer, out);
+}
+
+
+namespace tvm {
+namespace runtime {
+
+// chrono in the WASI does not provide very accurate time support
+// and also have problems in the i64 support in browser.
+// We redirect the timer to a JS side time using performance.now
+PackedFunc WrapWasmTimeEvaluator(PackedFunc pf,
+ TVMContext ctx,
+ int number,
+ int repeat,
+ int min_repeat_ms) {
+ auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](
+ TVMArgs args, TVMRetValue *rv) {
+
+ TVMRetValue temp;
+ auto finvoke = [&](int n) {
+ // start timing
+ for (int i = 0; i < n; ++i) {
+ pf.CallPacked(args, &temp);
+ }
+ DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
+ };
+
+ auto* get_timer = runtime::Registry::Get("wasm.GetTimer");
+ CHECK(get_timer != nullptr) << "Cannot find wasm.GetTimer in the global function";
+ TypedPackedFunc timer_ms = (*get_timer)(
+ TypedPackedFunc(finvoke));
+
+ std::ostringstream os;
+ finvoke(1);
+
+ int setup_number = number;
+
+ for (int i = 0; i < repeat; ++i) {
+ double duration_ms = 0.0;
+
+ do {
+ if (duration_ms > 0.0) {
+ setup_number = static_cast(
+ std::max((min_repeat_ms / (duration_ms / number) + 1),
+ number * 1.618)); // 1.618 is chosen by random
+ }
+ duration_ms = timer_ms(setup_number);
+ } while (duration_ms < min_repeat_ms);
+
+ double speed = duration_ms / setup_number / 1000;
+ os.write(reinterpret_cast(&speed), sizeof(speed));
+ }
+
+ std::string blob = os.str();
+ TVMByteArray arr;
+ arr.size = blob.length();
+ arr.data = blob.data();
+ // return the time.
+ *rv = arr;
+ };
+ return PackedFunc(ftimer);
+}
+
+TVM_REGISTER_GLOBAL("wasm.RPCTimeEvaluator")
+.set_body_typed([](Optional opt_mod,
+ std::string name,
+ int device_type,
+ int device_id,
+ int number,
+ int repeat,
+ int min_repeat_ms) {
+ TVMContext ctx;
+ ctx.device_type = static_cast(device_type);
+ ctx.device_id = device_id;
+
+ if (opt_mod.defined()) {
+ Module m = opt_mod.value();
+ std::string tkey = m->type_key();
+ return WrapWasmTimeEvaluator(
+ m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms);
+ } else {
+ auto* pf = runtime::Registry::Get(name);
+ CHECK(pf != nullptr) << "Cannot find " << name << " in the global function";
+ return WrapWasmTimeEvaluator(
+ *pf, ctx, number, repeat, min_repeat_ms);
+ }
+});
+
+} // namespace runtime
+} // namespace tvm
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc
new file mode 100644
index 000000000000..6ff652cf408a
--- /dev/null
+++ b/web/emcc/wasm_runtime.cc
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file wasm_runtime.cc
+ * \brief TVM wasm runtime library pack.
+ */
+
+// configurations for the dmlc log.
+#define DMLC_LOG_CUSTOMIZE 0
+#define DMLC_LOG_STACK_TRACE 0
+#define DMLC_LOG_DEBUG 0
+#define DMLC_LOG_NODATE 1
+#define DMLC_LOG_FATAL_THROW 0
+
+#include
+#include
+
+#include "src/runtime/c_runtime_api.cc"
+#include "src/runtime/cpu_device_api.cc"
+#include "src/runtime/workspace_pool.cc"
+#include "src/runtime/library_module.cc"
+#include "src/runtime/system_library.cc"
+
+#include "src/runtime/module.cc"
+#include "src/runtime/ndarray.cc"
+#include "src/runtime/object.cc"
+#include "src/runtime/registry.cc"
+#include "src/runtime/file_util.cc"
+#include "src/runtime/graph/graph_runtime.cc"
+#include "src/runtime/rpc/rpc_session.cc"
+#include "src/runtime/rpc/rpc_endpoint.cc"
+#include "src/runtime/rpc/rpc_event_impl.cc"
+#include "src/runtime/rpc/rpc_channel.cc"
+#include "src/runtime/rpc/rpc_local_session.cc"
+#include "src/runtime/rpc/rpc_module.cc"
+
+
+// --- Implementations of backend and wasm runtime API. ---
+
+int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
+ void* cdata,
+ int num_task) {
+ TVMParallelGroupEnv env;
+ env.num_task = 1;
+ flambda(0, &env, cdata);
+ return 0;
+}
+
+int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
+ return 0;
+}
+
+// --- Environment PackedFuncs for testing ---
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("testing.echo")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ *ret = args[0];
+});
+
+TVM_REGISTER_GLOBAL("testing.add_one")
+.set_body_typed([](int x) {
+ return x + 1;
+});
+
+TVM_REGISTER_GLOBAL("testing.wrap_callback")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ PackedFunc pf = args[0];
+ *ret = runtime::TypedPackedFunc([pf](){
+ pf();
+ });
+ });
+} // namespace runtime
+} // namespace tvm
diff --git a/web/example_rpc.html b/web/example_rpc.html
deleted file mode 100644
index ae2b1dd9c44b..000000000000
--- a/web/example_rpc.html
+++ /dev/null
@@ -1,61 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- TVM RPC Test Page
-
-
-
-
-
TVM Test Page
- To use this page, the easiest way is to do
-
-
run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy.
-
Click Connect to proxy.
-
run "python tests/web/websock_rpc_test.py" to run the rpc client.
-
-
Options
- Proxy URL
- RPC Server Key
-
-
-
-
-
-
-
diff --git a/web/package.json b/web/package.json
new file mode 100644
index 000000000000..76aa111e2acf
--- /dev/null
+++ b/web/package.json
@@ -0,0 +1,29 @@
+{
+ "name": "tvmjs",
+ "displayName": "TVM Wasm JS runtime",
+ "license": "Apache-2.0",
+ "version": "0.7.0",
+ "scripts": {
+ "build": "tsc -b",
+ "watch": "tsc -b -w",
+ "lint": "eslint -c .eslintrc.json .",
+ "bundle": "npm run build && rollup -c rollup.config.js",
+ "example": "npm run bundle && node apps/node/example.js",
+ "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js",
+ "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js"
+ },
+ "devDependencies": {
+ "typescript": "^3.8.3",
+ "@types/node": "^12.12.37",
+ "eslint": "^6.8.0",
+ "@typescript-eslint/eslint-plugin": "^2.29.0",
+ "@typescript-eslint/parser": "^2.29.0",
+ "typedoc": "^0.17.6",
+ "rollup": "^2.7.6",
+ "ws": "^7.2.5",
+ "@rollup/plugin-commonjs": "^11.1.0",
+ "@rollup/plugin-node-resolve": "^7.1.3",
+ "rollup-plugin-typescript2": "^0.27.0"
+ },
+ "dependencies": {}
+}
diff --git a/web/.eslintrc.js b/web/rollup.config.js
similarity index 69%
rename from web/.eslintrc.js
rename to web/rollup.config.js
index 2e82ba50e3c4..0046e4434076 100644
--- a/web/.eslintrc.js
+++ b/web/rollup.config.js
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -17,29 +17,18 @@
* under the License.
*/
-module.exports = {
- "env": {
- "browser": true,
- "node": true,
- "es6": true
+import commonjs from '@rollup/plugin-commonjs';
+import resolve from '@rollup/plugin-node-resolve';
+
+export default {
+ input: 'dist/index.js',
+ output: {
+ file: 'dist/tvmjs.bundle.js',
+ format: 'umd',
+ name: 'tvmjs',
+ exports: 'named',
+ globals: {'ws': 'ws'}
},
- "extends": "eslint:recommended",
- "rules": {
- "indent": [
- "error",
- 2
- ],
- "linebreak-style": [
- "error",
- "unix"
- ],
- "quotes": [
- "error",
- "double"
- ],
- "semi": [
- "error",
- "always"
- ]
- }
+ plugins: [commonjs(), resolve()],
+ external: ['ws']
};
diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts
new file mode 100644
index 000000000000..f533b4e491a6
--- /dev/null
+++ b/web/src/ctypes.ts
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Types for C API.
+ */
+
+/** A pointer to points to the raw address space. */
+export type Pointer = number;
+
+/** A pointer offset, need to add a base address to get a valid ptr. */
+export type PtrOffset = number;
+
+// -- TVM runtime C API --
+/**
+ * const char *TVMGetLastError();
+ */
+export type FTVMGetLastError = () => Pointer;
+
+/**
+ * int TVMModGetFunction(TVMModuleHandle mod,
+ * const char* func_name,
+ * int query_imports,
+ * TVMFunctionHandle *out);
+ */
+export type FTVMModGetFunction = (
+ mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number;
+/**
+ * int TVMModImport(TVMModuleHandle mod,
+ * TVMModuleHandle dep);
+ */
+export type FTVMModImport = (mod: Pointer, dep: Pointer) => number;
+/**
+ * int TVMModFree(TVMModuleHandle mod);
+ */
+export type FTVMModFree = (mod: Pointer) => number;
+
+/**
+ * int TVMFuncFree(TVMFunctionHandle func);
+ */
+export type FTVMFuncFree = (func: Pointer) => number;
+
+/**
+ * int TVMFuncCall(TVMFunctionHandle func,
+ * TVMValue* arg_values,
+ * int* type_codes,
+ * int num_args,
+ * TVMValue* ret_val,
+ * int* ret_type_code);
+ */
+export type FTVMFuncCall = (
+ func: Pointer, argValues: Pointer, typeCode: Pointer,
+ nargs: number, retValue: Pointer, retCode: Pointer) => number;
+
+/**
+ * int TVMCFuncSetReturn(TVMRetValueHandle ret,
+ * TVMValue* value,
+ * int* type_code,
+ * int num_ret);
+ */
+export type FTVMCFuncSetReturn = (
+ ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number;
+
+/**
+ * int TVMCbArgToReturn(TVMValue* value, int* code);
+ */
+export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number;
+
+/**
+ * int TVMFuncListGlobalNames(int* outSize, const char*** outArray);
+ */
+export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number;
+
+/**
+ * int TVMFuncRegisterGlobal(
+ * const char* name, TVMFunctionHandle f, int override);
+ */
+export type FTVMFuncRegisterGlobal = (
+ name: Pointer, f: Pointer, override: number) => number;
+
+/**
+ *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
+ */
+export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number;
+
+/**
+ * int TVMArrayAlloc(const tvm_index_t* shape,
+ * int ndim,
+ * int dtype_code,
+ * int dtype_bits,
+ * int dtype_lanes,
+ * int device_type,
+ * int device_id,
+ * TVMArrayHandle* out);
+ */
+export type FTVMArrayAlloc = (
+ shape: Pointer, ndim: number,
+ dtypeCode: number, dtypeBits: number,
+ dtypeLanes: number, deviceType: number, deviceId: number,
+ out: Pointer) => number;
+
+/**
+ * int TVMArrayFree(TVMArrayHandle handle);
+ */
+export type FTVMArrayFree = (handle: Pointer) => number;
+
+/**
+ * int TVMArrayCopyFromBytes(TVMArrayHandle handle,
+ * void* data,
+ * size_t nbytes);
+ */
+export type FTVMArrayCopyFromBytes = (
+ handle: Pointer, data: Pointer, nbytes: number) => number;
+
+/**
+ * int TVMArrayCopyToBytes(TVMArrayHandle handle,
+ * void* data,
+ * size_t nbytes);
+ */
+export type FTVMArrayCopyToBytes = (
+ handle: Pointer, data: Pointer, nbytes: number) => number;
+
+/**
+ * int TVMArrayCopyFromTo(TVMArrayHandle from,
+ * TVMArrayHandle to,
+ * TVMStreamHandle stream);
+ */
+export type FTVMArrayCopyFromTo = (
+ from: Pointer, to: Pointer, stream: Pointer) => number;
+
+/**
+ * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
+ */
+export type FTVMSynchronize = (
+ deviceType: number, deviceId: number, stream: Pointer) => number;
+
+/**
+ * typedef int (*TVMBackendPackedCFunc)(TVMValue* args,
+ * int* type_codes,
+ * int num_args,
+ * TVMValue* out_ret_value,
+ * int* out_ret_tcode);
+ */
+export type FTVMBackendPackedCFunc = (
+ argValues: Pointer, argCodes: Pointer, nargs: number,
+ outValue: Pointer, outCode: Pointer) => number;
+
+// -- TVM Wasm Auxiliary C API --
+
+/** void* TVMWasmAllocSpace(int size); */
+export type FTVMWasmAllocSpace = (size: number) => Pointer;
+
+/** void TVMWasmFreeSpace(void* data); */
+export type FTVMWasmFreeSpace = (ptr: Pointer) => void;
+
+/**
+ * int TVMWasmPackedCFunc(TVMValue* args,
+ * int* type_codes,
+ * int num_args,
+ * TVMRetValueHandle ret,
+ * void* resource_handle);
+ */
+export type FTVMWasmPackedCFunc = (
+ args: Pointer, typeCodes: Pointer, nargs: number,
+ ret: Pointer, resourceHandle: Pointer) => number;
+
+/**
+ * int TVMWasmFuncCreateFromCFunc(void* resource_handle,
+ * TVMFunctionHandle *out);
+ */
+export type FTVMWasmFuncCreateFromCFunc = (
+ resource: Pointer, out: Pointer) => number;
+
+/**
+ * void TVMWasmPackedCFuncFinalizer(void* resource_handle);
+ */
+export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void;
+
+/**
+ * Size of common data types.
+ */
+export const enum SizeOf {
+ U8 = 1,
+ U16 = 2,
+ I32 = 4,
+ I64 = 8,
+ F32 = 4,
+ F64 = 8,
+ TVMValue = 8,
+ DLDataType = I32,
+ DLContext = I32 + I32,
+}
+
+/**
+ * Type code in TVM FFI.
+ */
+export const enum TypeCode {
+ Int = 0,
+ UInt = 1,
+ Float = 2,
+ TVMOpaqueHandle = 3,
+ Null = 4,
+ TVMDataType = 5,
+ TVMContext = 6,
+ TVMDLTensorHandle = 7,
+ TVMObjectHandle = 8,
+ TVMModuleHandle = 9,
+ TVMPackedFuncHandle = 10,
+ TVMStr = 11,
+ TVMBytes = 12,
+ TVMNDArrayHandle = 13,
+ TVMObjectRValueRefArg = 14
+}
\ No newline at end of file
diff --git a/web/src/environment.ts b/web/src/environment.ts
new file mode 100644
index 000000000000..df0fe68c81e0
--- /dev/null
+++ b/web/src/environment.ts
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/**
+ * Runtime environment that provide js libaries calls.
+ */
+import { Pointer } from "./ctypes";
+import { LibraryProvider } from "./types";
+import { assert } from "./support";
+import * as ctypes from "./ctypes";
+
+/**
+ * Detect library provider from the importObject.
+ *
+ * @param importObject The import object.
+ */
+function detectLibraryProvider(
+ importObject: Record
+): LibraryProvider | undefined {
+ if (
+ importObject["wasmLibraryProvider"] &&
+ importObject["wasmLibraryProvider"]["start"] &&
+ importObject["wasmLibraryProvider"]["imports"] !== undefined
+ ) {
+ const item = importObject as { wasmLibraryProvider: LibraryProvider };
+ // create provider so that we capture imports in the provider.
+ return {
+ imports: item.wasmLibraryProvider.imports,
+ start: (inst: WebAssembly.Instance): void => {
+ item.wasmLibraryProvider.start(inst);
+ },
+ };
+ } else if (importObject["imports"] && importObject["start"] !== undefined) {
+ return importObject as LibraryProvider;
+ } else if (importObject["wasiImport"] && importObject["start"] !== undefined) {
+ // WASI
+ return {
+ imports: {
+ "wasi_snapshot_preview1": importObject["wasiImport"],
+ },
+ start: (inst: WebAssembly.Instance): void => {
+ importObject["start"](inst);
+ }
+ };
+ } else {
+ return undefined;
+ }
+}
+
+/**
+ * Environment to impelement most of the JS library functions.
+ */
+export class Environment implements LibraryProvider {
+ logger: (msg: string) => void;
+ imports: Record;
+ /**
+ * Maintains a table of FTVMWasmPackedCFunc that the C part
+ * can call via TVMWasmPackedCFunc.
+ *
+ * We maintain a separate table so that we can have un-limited amount
+ * of functions that do not maps to the address space.
+ */
+ packedCFuncTable: Array = [
+ undefined,
+ ];
+ /**
+ * Free table index that can be recycled.
+ */
+ packedCFuncTableFreeId: Array = [];
+
+ private libProvider?: LibraryProvider;
+
+ constructor(
+ importObject: Record = {},
+ logger: (msg: string) => void = console.log
+ ) {
+ this.logger = logger;
+ this.libProvider = detectLibraryProvider(importObject);
+ // get imports from the provider
+ if (this.libProvider !== undefined) {
+ this.imports = this.libProvider.imports;
+ } else {
+ this.imports = importObject;
+ }
+ // update with more functions
+ this.imports.env = this.environment(this.imports.env);
+ }
+
+ /** Mark the start of the instance. */
+ start(inst: WebAssembly.Instance): void {
+ if (this.libProvider !== undefined) {
+ this.libProvider.start(inst);
+ }
+ }
+
+ private environment(initEnv: Record): Record {
+ // default env can be be overriden by libraries.
+ const defaultEnv = {
+ "__cxa_thread_atexit": (): void => {},
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ "emscripten_notify_memory_growth": (index: number): void => {}
+ };
+ const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = (
+ args: Pointer,
+ typeCodes: Pointer,
+ nargs: number,
+ ret: Pointer,
+ resourceHandle: Pointer
+ ): number => {
+ const cfunc = this.packedCFuncTable[resourceHandle];
+ assert(cfunc !== undefined);
+ return cfunc(args, typeCodes, nargs, ret, resourceHandle);
+ };
+
+ const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = (
+ resourceHandle: Pointer
+ ): void => {
+ this.packedCFuncTable[resourceHandle] = undefined;
+ this.packedCFuncTableFreeId.push(resourceHandle);
+ };
+
+ const newEnv = {
+ TVMWasmPackedCFunc: wasmPackedCFunc,
+ TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer,
+ "__console_log": (msg: string): void => {
+ this.logger(msg);
+ }
+ };
+ return Object.assign(defaultEnv, initEnv, newEnv);
+ }
+}
\ No newline at end of file
diff --git a/web/src/index.ts b/web/src/index.ts
new file mode 100644
index 000000000000..5d7d7ccc39cc
--- /dev/null
+++ b/web/src/index.ts
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+export {
+ Scalar, DLContext, DLDataType,
+ PackedFunc, Module, NDArray, Instance,
+ instantiate
+} from "./runtime";
+export { Disposable, LibraryProvider } from "./types";
+export { RPCServer } from "./rpc_server";
+export { wasmPath } from "./support";
\ No newline at end of file
diff --git a/web/src/memory.ts b/web/src/memory.ts
new file mode 100644
index 000000000000..ac737b7c297d
--- /dev/null
+++ b/web/src/memory.ts
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/**
+ * Classes to manipulate Wasm memories.
+ */
+import { Pointer, PtrOffset, SizeOf } from "./ctypes";
+import { Disposable } from "./types";
+import { assert, StringToUint8Array } from "./support";
+
+import * as ctypes from "./ctypes";
+
+/**
+ * Wasm Memory wrapper to perform JS side raw memory access.
+ */
+export class Memory {
+ memory: WebAssembly.Memory;
+ wasm32 = true;
+ private buffer: ArrayBuffer | SharedArrayBuffer;
+ private viewU8: Uint8Array;
+ private viewU16: Uint16Array;
+ private viewI32: Int32Array;
+ private viewU32: Uint32Array;
+ private viewF32: Float32Array;
+ private viewF64: Float64Array;
+
+ constructor(memory: WebAssembly.Memory) {
+ this.memory = memory;
+ this.buffer = this.memory.buffer;
+ this.viewU8 = new Uint8Array(this.buffer);
+ this.viewU16 = new Uint16Array(this.buffer);
+ this.viewI32 = new Int32Array(this.buffer);
+ this.viewU32 = new Uint32Array(this.buffer);
+ this.viewF32 = new Float32Array(this.buffer);
+ this.viewF64 = new Float64Array(this.buffer);
+ }
+
+ loadU8(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewU8[ptr >> 0];
+ }
+
+ loadU16(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewU16[ptr >> 1];
+ }
+
+ loadU32(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewU32[ptr >> 2];
+ }
+
+ loadI32(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewI32[ptr >> 2];
+ }
+
+ loadI64(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const base = ptr >> 2;
+ // assumes little endian, for now truncate high.
+ return this.viewI32[base];
+ }
+
+ loadF32(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewF32[ptr >> 2];
+ }
+
+ loadF64(ptr: Pointer): number {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ return this.viewF64[ptr >> 3];
+ }
+
+ loadPointer(ptr: Pointer): Pointer {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ if (this.wasm32) {
+ return this.loadU32(ptr);
+ } else {
+ return this.loadI64(ptr);
+ }
+ }
+ loadUSize(ptr: Pointer): Pointer {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ if (this.wasm32) {
+ return this.loadU32(ptr);
+ } else {
+ return this.loadI64(ptr);
+ }
+ }
+ sizeofPtr(): number {
+ return this.wasm32 ? SizeOf.I32 : SizeOf.I64;
+ }
+ /**
+ * Load raw bytes from ptr.
+ * @param ptr The head address
+ * @param numBytes The number
+ */
+ loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ const result = new Uint8Array(numBytes);
+ result.set(this.viewU8.slice(ptr, ptr + numBytes));
+ return result;
+ }
+ /**
+ * Load TVMByteArray from ptr.
+ *
+ * @param ptr The address of the header.
+ */
+ loadTVMBytes(ptr: Pointer): Uint8Array {
+ const data = this.loadPointer(ptr);
+ const length = this.loadUSize(ptr + this.sizeofPtr());
+ return this.loadRawBytes(data, length);
+ }
+ /**
+ * Load null-terminated C-string from ptr.
+ * @param ptr The head address
+ */
+ loadCString(ptr: Pointer): string {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ // NOTE: the views are still valid for read.
+ const ret = [];
+ let ch = 1;
+ while (ch != 0) {
+ ch = this.viewU8[ptr];
+ if (ch != 0) {
+ ret.push(String.fromCharCode(ch));
+ }
+ ++ptr;
+ }
+ return ret.join("");
+ }
+ /**
+ * Store raw bytes to the ptr.
+ * @param ptr The head address.
+ * @param bytes The bytes content.
+ */
+ storeRawBytes(ptr: Pointer, bytes: Uint8Array): void {
+ if (this.buffer != this.memory.buffer) {
+ this.updateViews();
+ }
+ this.viewU8.set(bytes, ptr);
+ }
+
+ /**
+ * Update memory view after the memory growth.
+ */
+ private updateViews(): void {
+ this.buffer = this.memory.buffer;
+ this.viewU8 = new Uint8Array(this.buffer);
+ this.viewU16 = new Uint16Array(this.buffer);
+ this.viewI32 = new Int32Array(this.buffer);
+ this.viewU32 = new Uint32Array(this.buffer);
+ this.viewF32 = new Float32Array(this.buffer);
+ this.viewF64 = new Float64Array(this.buffer);
+ }
+}
+
+/**
+ * Auxiliary call stack for the FFI calls.
+ *
+ * Lifecyle of a call stack.
+ * - Calls into allocXX to allocate space, mixed with storeXXX to store data.
+ * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change),
+ * can still call into storeXX
+ * - Calls into commitToWasmMemory once.
+ * - reset.
+ */
+export class CachedCallStack implements Disposable {
+ /** List of temporay arguments that can be disposed during reset. */
+ tempArgs: Array = [];
+
+ private memory: Memory;
+ private cAllocSpace: ctypes.FTVMWasmAllocSpace;
+ private cFreeSpace: ctypes.FTVMWasmFreeSpace;
+
+ private buffer: ArrayBuffer;
+ private viewU8: Uint8Array;
+ private viewI32: Int32Array;
+ private viewU32: Uint32Array;
+ private viewF64: Float64Array;
+
+ private stackTop: PtrOffset = 0;
+ private basePtr: Pointer = 0;
+
+ private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = [];
+
+ constructor(
+ memory: Memory,
+ allocSpace: ctypes.FTVMWasmAllocSpace,
+ freeSpace: ctypes.FTVMWasmFreeSpace
+ ) {
+ const initCallStackSize = 128;
+ this.memory = memory;
+ this.cAllocSpace = allocSpace;
+ this.cFreeSpace = freeSpace;
+ this.buffer = new ArrayBuffer(initCallStackSize);
+ this.basePtr = this.cAllocSpace(initCallStackSize);
+ this.viewU8 = new Uint8Array(this.buffer);
+ this.viewI32 = new Int32Array(this.buffer);
+ this.viewU32 = new Uint32Array(this.buffer);
+ this.viewF64 = new Float64Array(this.buffer);
+ this.updateViews();
+ }
+
+ dispose(): void {
+ if (this.basePtr != 0) {
+ this.cFreeSpace(this.basePtr);
+ this.basePtr = 0;
+ }
+ }
+ /**
+ * Rest the call stack so that it can be reused again.
+ */
+ reset(): void {
+ this.stackTop = 0;
+ assert(this.addressToSetTargetValue.length == 0);
+ while (this.tempArgs.length != 0) {
+ (this.tempArgs.pop() as Disposable).dispose();
+ }
+ }
+
+ /**
+ * Commit all the cached data to WasmMemory.
+ * This function can only be called once.
+ * No further store function should be called.
+ *
+ * @param nbytes Number of bytes to be stored.
+ */
+ commitToWasmMemory(nbytes: number = this.stackTop): void {
+ // commit all pointer values.
+ while (this.addressToSetTargetValue.length != 0) {
+ const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [
+ number,
+ number
+ ];
+ this.storePtr(targetOffset, this.ptrFromOffset(valueOffset));
+ }
+ this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes));
+ }
+
+ /**
+ * Allocate space by number of bytes
+ * @param nbytes Number of bytes.
+ * @note This function always allocate space that aligns to 64bit.
+ */
+ allocRawBytes(nbytes: number): PtrOffset {
+ // always aligns to 64bit
+ nbytes = ((nbytes + 7) >> 3) << 3;
+
+ if (this.stackTop + nbytes > this.buffer.byteLength) {
+ const newSize = Math.max(
+ this.buffer.byteLength * 2,
+ this.stackTop + nbytes
+ );
+ const oldU8 = this.viewU8;
+ this.buffer = new ArrayBuffer(newSize);
+ this.updateViews();
+ this.viewU8.set(oldU8);
+ if (this.basePtr != 0) {
+ this.cFreeSpace(this.basePtr);
+ }
+ this.basePtr = this.cAllocSpace(newSize);
+ }
+ const retOffset = this.stackTop;
+ this.stackTop += nbytes;
+ return retOffset;
+ }
+
+ /**
+ * Allocate space for pointers.
+ * @param count Number of pointers.
+ * @returns The allocated pointer array.
+ */
+ allocPtrArray(count: number): PtrOffset {
+ return this.allocRawBytes(this.memory.sizeofPtr() * count);
+ }
+
+ /**
+ * Get the real pointer from offset values.
+ * Note that the returned value becomes obsolete if alloc is called on the stack.
+ * @param offset The allocated offset.
+ */
+ ptrFromOffset(offset: PtrOffset): Pointer {
+ return this.basePtr + offset;
+ }
+
+ // Store APIs
+ storePtr(offset: PtrOffset, value: Pointer): void {
+ if (this.memory.wasm32) {
+ this.storeU32(offset, value);
+ } else {
+ this.storeI64(offset, value);
+ }
+ }
+
+ storeUSize(offset: PtrOffset, value: Pointer): void {
+ if (this.memory.wasm32) {
+ this.storeU32(offset, value);
+ } else {
+ this.storeI64(offset, value);
+ }
+ }
+
+ storeI32(offset: PtrOffset, value: number): void {
+ this.viewI32[offset >> 2] = value;
+ }
+
+ storeU32(offset: PtrOffset, value: number): void {
+ this.viewU32[offset >> 2] = value;
+ }
+
+ storeI64(offset: PtrOffset, value: number): void {
+ // For now, just store as 32bit
+ // NOTE: wasm always uses little endian.
+ const low = value & 0xffffffff;
+ const base = offset >> 2;
+ this.viewI32[base] = low;
+ this.viewI32[base + 1] = 0;
+ }
+
+ storeF64(offset: PtrOffset, value: number): void {
+ this.viewF64[offset >> 3] = value;
+ }
+
+ storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void {
+ this.viewU8.set(bytes, offset);
+ }
+
+ /**
+ * Allocate then set C-String pointer to the offset.
+ * This function will call into allocBytes to allocate necessary data.
+ * The address won't be set immediately(because the possible change of basePtr)
+ * and will be filled when we commit the data.
+ *
+ * @param offset The offset to set ot data pointer.
+ * @param data The string content.
+ */
+ allocThenSetArgString(offset: PtrOffset, data: string): void {
+ const strOffset = this.allocRawBytes(data.length + 1);
+ this.storeRawBytes(strOffset, StringToUint8Array(data));
+ this.addressToSetTargetValue.push([offset, strOffset]);
+ }
+ /**
+ * Allocate then set the argument location with a TVMByteArray.
+ * Allocate new temporary space for bytes.
+ *
+ * @param offset The offset to set ot data pointer.
+ * @param data The string content.
+ */
+ allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void {
+ // Note: size of size_t equals sizeof ptr.
+ const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2);
+ const dataOffset = this.allocRawBytes(data.length);
+ this.storeRawBytes(dataOffset, data);
+ this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length);
+
+ this.addressToSetTargetValue.push([offset, headerOffset]);
+ this.addressToSetTargetValue.push([headerOffset, dataOffset]);
+ }
+
+ /**
+ * Update internal cache views.
+ */
+ private updateViews(): void {
+ this.viewU8 = new Uint8Array(this.buffer);
+ this.viewI32 = new Int32Array(this.buffer);
+ this.viewU32 = new Uint32Array(this.buffer);
+ this.viewF64 = new Float64Array(this.buffer);
+ }
+}
diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts
new file mode 100644
index 000000000000..054a1b6019cc
--- /dev/null
+++ b/web/src/rpc_server.ts
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import { SizeOf, TypeCode } from "./ctypes";
+import { assert, StringToUint8Array, Uint8ArrayToString } from "./support";
+import * as runtime from "./runtime";
+import { Class } from "estree";
+
+enum RPCServerState {
+ InitHeader,
+ InitHeaderKey,
+ InitServer,
+ WaitForCallback,
+ ReceivePacketHeader,
+ ReceivePacketBody,
+}
+
+/** RPC magic header */
+const RPC_MAGIC = 0xff271;
+
+/**
+ * An utility class to read from binary bytes.
+ */
+class ByteStreamReader {
+ offset = 0;
+ bytes: Uint8Array;
+
+ constructor(bytes: Uint8Array) {
+ this.bytes = bytes;
+ }
+
+ readU32(): number {
+ const i = this.offset;
+ const b = this.bytes;
+ const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24);
+ this.offset += 4;
+ return val;
+ }
+
+ readU64(): number {
+ const val = this.readU32();
+ this.offset += 4;
+ return val;
+ }
+
+ readByteArray(): Uint8Array {
+ const len = this.readU64();
+ assert(this.offset + len <= this.bytes.byteLength);
+ const ret = new Uint8Array(len);
+ ret.set(this.bytes.slice(this.offset, this.offset + len));
+ this.offset += len;
+ return ret;
+ }
+}
+
+/**
+ * A websocket based RPC
+ */
+export class RPCServer {
+ url: string;
+ key: string;
+ socket: WebSocket;
+ state: RPCServerState = RPCServerState.InitHeader;
+ logger: (msg: string) => void;
+ getImports: () => Record;
+ private name: string;
+ private inst?: runtime.Instance = undefined;
+ private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void;
+ private currPacketHeader?: Uint8Array;
+ private currPacketLength = 0;
+ private remoteKeyLength = 0;
+ private pendingBytes = 0;
+ private buffredBytes = 0;
+ private messageQueue: Array = [];
+
+ constructor(
+ url: string,
+ key: string,
+ getImports: () => Record,
+ logger: (msg: string) => void = console.log
+ ) {
+ this.url = url;
+ this.key = key;
+ this.name = "WebSocketRPCServer[" + this.key + "]: ";
+ this.getImports = getImports;
+ this.logger = logger;
+
+ this.checkLittleEndian();
+
+ if (typeof WebSocket == "undefined") {
+ // eslint-disable-next-line @typescript-eslint/no-var-requires
+ const WebSocket = require("ws");
+ this.socket = new WebSocket(url);
+ } else {
+ this.socket = new (WebSocket as any)(url);
+ }
+
+ //this.socket = this.getSocket(url);
+ this.socket.binaryType = "arraybuffer";
+
+ this.socket.addEventListener("open", (event: Event) => {
+ return this.onOpen(event);
+ });
+ this.socket.addEventListener("message", (event: MessageEvent) => {
+ return this.onMessage(event);
+ });
+ this.socket.addEventListener("close", (event: CloseEvent) => {
+ return this.onClose(event);
+ });
+ }
+
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ private onClose(_event: CloseEvent): void {
+ if (this.inst !== undefined) {
+ this.inst.dispose();
+ }
+ if (this.state == RPCServerState.ReceivePacketHeader) {
+ this.log("Closing the server in clean state");
+ } else {
+ this.log("Closing the server, final state=" + this.state);
+ }
+ }
+
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ private onOpen(_event: Event): void {
+ // Send the headers
+ let bkey = StringToUint8Array("server:" + this.key);
+ bkey = bkey.slice(0, bkey.length - 1);
+ const intbuf = new Int32Array(1);
+ intbuf[0] = RPC_MAGIC;
+ this.socket.send(intbuf);
+ intbuf[0] = bkey.length;
+ this.socket.send(intbuf);
+ this.socket.send(bkey);
+ this.log("connected...");
+ // request bytes: magic + keylen
+ this.requestBytes(SizeOf.I32 + SizeOf.I32);
+ this.state = RPCServerState.InitHeader;
+ }
+
+ /** Handler for raw message. */
+ private onMessage(event: MessageEvent): void {
+ const buffer = event.data;
+ this.buffredBytes += buffer.byteLength;
+ this.messageQueue.push(new Uint8Array(buffer));
+ this.processEvents();
+ }
+ /** Process ready events. */
+ private processEvents(): void {
+ while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) {
+ this.onDataReady();
+ }
+ }
+ /** State machine to handle each request */
+ private onDataReady(): void {
+ switch (this.state) {
+ case RPCServerState.InitHeader: {
+ this.handleInitHeader();
+ break;
+ }
+ case RPCServerState.InitHeaderKey: {
+ this.handleInitHeaderKey();
+ break;
+ }
+ case RPCServerState.ReceivePacketHeader: {
+ this.currPacketHeader = this.readFromBuffer(SizeOf.I64);
+ const reader = new ByteStreamReader(this.currPacketHeader);
+ this.currPacketLength = reader.readU64();
+ assert(this.pendingBytes == 0);
+ this.requestBytes(this.currPacketLength);
+ this.state = RPCServerState.ReceivePacketBody;
+ break;
+ }
+ case RPCServerState.ReceivePacketBody: {
+ const body = this.readFromBuffer(this.currPacketLength);
+ assert(this.pendingBytes == 0);
+ assert(this.currPacketHeader !== undefined);
+ this.onPacketReady(this.currPacketHeader, body);
+ break;
+ }
+ case RPCServerState.WaitForCallback: {
+ assert(this.pendingBytes == 0);
+ break;
+ }
+ default: {
+ throw new Error("Cannot handle state " + this.state);
+ }
+ }
+ }
+
+ private onPacketReady(header: Uint8Array, body: Uint8Array): void {
+ if (this.inst === undefined) {
+ // initialize server.
+ const reader = new ByteStreamReader(body);
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const code = reader.readU32();
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const ver = Uint8ArrayToString(reader.readByteArray());
+ const nargs = reader.readU32();
+ const tcodes = [];
+ const args = [];
+ for (let i = 0; i < nargs; ++i) {
+ tcodes.push(reader.readU32());
+ }
+
+ for (let i = 0; i < nargs; ++i) {
+ const tcode = tcodes[i];
+ if (tcode == TypeCode.TVMStr) {
+ const str = Uint8ArrayToString(reader.readByteArray());
+ args.push(str);
+ } else if (tcode == TypeCode.TVMBytes) {
+ args.push(reader.readByteArray());
+ } else {
+ throw new Error("cannot support type code " + tcode);
+ }
+ }
+ this.onInitServer(args, header, body);
+ } else {
+ assert(this.serverRecvData !== undefined);
+ this.serverRecvData(header, body);
+ this.requestBytes(SizeOf.I64);
+ this.state = RPCServerState.ReceivePacketHeader;
+ }
+ }
+
+ /** Event handler during server initialization. */
+ private onInitServer(
+ args: Array,
+ header: Uint8Array,
+ body: Uint8Array
+ ): void {
+ // start the server
+ assert(args[0] == "rpc.WasmSession");
+ assert(args[1] instanceof Uint8Array);
+ assert(this.pendingBytes == 0);
+
+ runtime.instantiate(args[1].buffer, this.getImports())
+ .then((inst: runtime.Instance) => {
+ this.inst = inst;
+ const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer");
+
+ const messageHandler = fcreate(
+ (cbytes: Uint8Array): runtime.Scalar => {
+ assert(this.inst !== undefined);
+ if (this.socket.readyState == 1) {
+ this.socket.send(cbytes);
+ return this.inst.scalar(cbytes.length, "int32");
+ } else {
+ return this.inst.scalar(0, "int32");
+ }
+ },
+ this.name,
+ this.key
+ );
+
+ fcreate.dispose();
+ const writeFlag = this.inst.scalar(3, "int32");
+
+ this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => {
+ if (messageHandler(header, writeFlag) == 0) {
+ this.socket.close();
+ }
+ if (messageHandler(body, writeFlag) == 0) {
+ this.socket.close();
+ }
+ };
+
+ // Forward the same init sequence to the wasm RPC.
+ // The RPC will look for "rpc.wasmSession"
+ // and we will redirect it to the correct local session.
+ // register the callback to redirect the session to local.
+ const flocal = this.inst.getGlobalFunc("rpc.LocalSession");
+ const localSession = flocal();
+ flocal.dispose();
+ assert(localSession instanceof runtime.Module);
+
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ this.inst.registerFunc(
+ "rpc.WasmSession",
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ (_args: unknown): runtime.Module => {
+ return localSession;
+ }
+ );
+ messageHandler(header, writeFlag);
+ messageHandler(body, writeFlag);
+ localSession.dispose();
+
+ this.log("Finish initializing the Wasm Server..");
+ this.requestBytes(SizeOf.I64);
+ this.state = RPCServerState.ReceivePacketHeader;
+ // call process events in case there are bufferred data.
+ this.processEvents();
+ });
+ this.state = RPCServerState.WaitForCallback;
+ }
+
+ private log(msg: string): void {
+ this.logger(this.name + msg);
+ }
+
+ private handleInitHeader(): void {
+ const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2));
+ const magic = reader.readU32();
+ if (magic == RPC_MAGIC + 1) {
+ throw new Error("key: " + this.key + " has already been used in proxy");
+ } else if (magic == RPC_MAGIC + 2) {
+ throw new Error("RPCProxy do not have matching client key " + this.key);
+ }
+ assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy");
+ this.remoteKeyLength = reader.readU32();
+ assert(this.pendingBytes == 0);
+ this.requestBytes(this.remoteKeyLength);
+ this.state = RPCServerState.InitHeaderKey;
+ }
+
+ private handleInitHeaderKey(): void {
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const remoteKey = Uint8ArrayToString(
+ this.readFromBuffer(this.remoteKeyLength)
+ );
+ assert(this.pendingBytes == 0);
+ this.requestBytes(SizeOf.I64);
+ this.state = RPCServerState.ReceivePacketHeader;
+ }
+
+ private checkLittleEndian(): void {
+ const a = new ArrayBuffer(4);
+ const b = new Uint8Array(a);
+ const c = new Uint32Array(a);
+ b[0] = 0x11;
+ b[1] = 0x22;
+ b[2] = 0x33;
+ b[3] = 0x44;
+ assert(c[0] === 0x44332211, "RPCServer little endian to work");
+ }
+
+ private requestBytes(nbytes: number): void {
+ this.pendingBytes += nbytes;
+ }
+
+ private readFromBuffer(nbytes: number): Uint8Array {
+ const ret = new Uint8Array(nbytes);
+ let ptr = 0;
+ while (ptr < nbytes) {
+ assert(this.messageQueue.length != 0);
+ const nleft = nbytes - ptr;
+ if (this.messageQueue[0].byteLength <= nleft) {
+ const buffer = this.messageQueue.shift() as Uint8Array;
+ ret.set(buffer, ptr);
+ ptr += buffer.byteLength;
+ } else {
+ const buffer = this.messageQueue[0];
+ ret.set(buffer.slice(0, nleft), ptr);
+ this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength);
+ ptr += nleft;
+ }
+ }
+ this.buffredBytes -= nbytes;
+ this.pendingBytes -= nbytes;
+ return ret;
+ }
+}
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
new file mode 100644
index 000000000000..cd9b967596af
--- /dev/null
+++ b/web/src/runtime.ts
@@ -0,0 +1,1113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * TVM JS Wasm Runtime library.
+ */
+import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes";
+import { Disposable } from "./types";
+import { Memory, CachedCallStack } from "./memory";
+import { assert, StringToUint8Array } from "./support";
+import { Environment } from "./environment";
+
+import * as ctypes from "./ctypes";
+
+/**
+ * Type for PackedFunc inthe TVMRuntime.
+ */
+export type PackedFunc = ((...args: any) => any) &
+ Disposable & { _tvmPackedCell: PackedFuncCell };
+
+/**
+ * @internal
+ * FFI Library wrapper, maintains most runtime states.
+ */
+class FFILibrary implements Disposable {
+ wasm32: boolean;
+ memory: Memory;
+ exports: Record;
+ private wasmInstance: WebAssembly.Instance;
+
+ private recycledCallStacks: Array = [];
+
+ constructor(
+ wasmInstance: WebAssembly.Instance,
+ imports: Record
+ ) {
+ this.wasmInstance = wasmInstance;
+ this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports));
+ assert(
+ this.wasmInstance.exports !== undefined,
+ "Expect the library module contains exports"
+ );
+ this.exports = this.wasmInstance.exports as Record;
+ this.wasm32 = this.memory.wasm32;
+ this.validateInstance();
+ }
+
+ dispose(): void {
+ while (this.recycledCallStacks.length != 0) {
+ (this.recycledCallStacks.pop() as Disposable).dispose();
+ }
+ }
+
+ sizeofPtr(): number {
+ return this.memory.sizeofPtr();
+ }
+
+ checkCall(code: number): void {
+ if (code != 0) {
+ const msgPtr = (this.exports
+ .TVMGetLastError as ctypes.FTVMGetLastError)();
+ throw new Error("TVMError: " + this.memory.loadCString(msgPtr));
+ }
+ }
+
+ getOrAllocCallStack(): CachedCallStack {
+ if (this.recycledCallStacks.length != 0) {
+ return this.recycledCallStacks.pop() as CachedCallStack;
+ }
+ return new CachedCallStack(
+ this.memory,
+ this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace,
+ this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace
+ );
+ }
+
+ recycleCallStack(callstack: CachedCallStack): void {
+ callstack.reset();
+ this.recycledCallStacks.push(callstack);
+ }
+
+ private validateInstance(): void {
+ this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]);
+ }
+
+ private checkExports(funcNames: Array): void {
+ const missList = [];
+ for (const name of funcNames) {
+ const f = this.exports[name];
+ if (!(f instanceof Function)) {
+ missList.push(name);
+ }
+ }
+ if (missList.length != 0) {
+ throw new Error("Cannot find " + missList + " in exports");
+ }
+ }
+
+ private detectWasmMemory(
+ instance: WebAssembly.Instance,
+ imports: Record
+ ): WebAssembly.Memory {
+ if (instance.exports.memory instanceof WebAssembly.Memory) {
+ return instance.exports.memory;
+ }
+ if (imports.env && imports.env.memory instanceof WebAssembly.Memory) {
+ return imports.env.memory;
+ }
+
+ throw new Error(
+ "Cannt detect wasm memory from imports " +
+ imports +
+ " or exports" +
+ instance.exports
+ );
+ }
+}
+
+/**
+ * A typed scalar constant used to represent a typed number
+ * argument to PackedFunc calls.
+ */
+export class Scalar {
+ /** The value. */
+ value: number;
+ /** The data type of the scalar. */
+ dtype: string;
+
+ constructor(value: number, dtype: string) {
+ this.value = value;
+ this.dtype = dtype;
+ }
+}
+
+/**
+ * Cell holds the PackedFunc object.
+ */
+class PackedFuncCell implements Disposable {
+ handle: Pointer;
+ private lib: FFILibrary;
+
+ constructor(handle: Pointer, lib: FFILibrary) {
+ this.handle = handle;
+ this.lib = lib;
+ }
+
+ dispose(): void {
+ if (this.handle != 0) {
+ this.lib.checkCall(
+ (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle)
+ );
+ this.handle = 0;
+ }
+ }
+}
+
+const DeviceEnumToStr: Record = {
+ 1: "cpu",
+ 2: "gpu",
+ 4: "opencl",
+ 7: "vulkan",
+ 8: "metal",
+};
+
+const DeviceStrToEnum: Record = {
+ cpu: 1,
+ gpu: 2,
+ cuda: 2,
+ cl: 4,
+ opencl: 4,
+ vulkan: 7,
+ metal: 8,
+};
+
+/**
+ * Represent a runtime context where a NDArray can reside.
+ */
+export class DLContext {
+ /** The device type code of the context. */
+ deviceType: number;
+ /** The device index. */
+ deviceId: number;
+
+ private lib: FFILibrary;
+
+ constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) {
+ const tp = typeof deviceType;
+ if (tp == "string") {
+ this.deviceType = DeviceStrToEnum[deviceType];
+ } else if (tp == "number") {
+ this.deviceType = deviceType as number;
+ } else {
+ throw new Error("Cannot take type " + tp + " as deviceType");
+ }
+ this.deviceId = deviceId;
+ this.lib = lib;
+ }
+
+ /**
+ * Synchronize the context
+ */
+ sync(): void {
+ this.lib.checkCall(
+ (this.lib.exports.TVMSynchronize as ctypes.FTVMSynchronize)(
+ this.deviceType,
+ this.deviceId,
+ 0
+ )
+ );
+ }
+
+ toString(): string {
+ return (
+ DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")"
+ );
+ }
+}
+
+const DLDataTypeCodeToStr: Record = {
+ 0: "int",
+ 1: "uint",
+ 2: "float",
+ 4: "handle",
+};
+
+/**
+ * Runtime data type of NDArray.
+ */
+export class DLDataType {
+ /** The type code */
+ code: number;
+ /** Number of bits in the data type. */
+ bits: number;
+ /** Number of vector lanes. */
+ lanes: number;
+
+ constructor(code: number, bits: number, lanes: number) {
+ this.code = code;
+ this.bits = bits;
+ this.lanes = lanes;
+ }
+
+ toString(): string {
+ const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString();
+ if (this.lanes != 1) {
+ return ret + "x" + this.lanes.toString();
+ } else {
+ return ret;
+ }
+ }
+
+ numStorageBytes(): number {
+ return (this.bits * this.lanes + 7) >> 3;
+ }
+}
+
+/**
+ * n-dimnesional array.
+ */
+export class NDArray implements Disposable {
+ /** Internal array handle. */
+ handle: Pointer;
+ /** Number of dimensions. */
+ ndim: number;
+ /** Data type of the array. */
+ dtype: string;
+ /** Shape of the array. */
+ shape: Array;
+ /** Context of the array. */
+ context: DLContext;
+
+ private byteOffset: number;
+ private dltensor: Pointer;
+ private lib: FFILibrary;
+ private dlDataType: DLDataType;
+
+ constructor(handle: Pointer, lib: FFILibrary) {
+ this.handle = handle;
+ this.lib = lib;
+
+ this.dltensor = this.getDLTensorFromArrayHandle(this.handle);
+ // constant offsets.
+ const arrayOffsetData = 0;
+ const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr();
+ const arrayOffsetDevType = arrayOffsetContext;
+ const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32;
+ const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext;
+ const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32;
+ const arrayOffsetDtypeCode = arrayOffsetDtype;
+ const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8;
+ const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8;
+ const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType;
+ const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr();
+ const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr();
+ // ndim
+ this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim);
+ // shape
+ const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape);
+ this.shape = [];
+ for (let i = 0; i < this.ndim; ++i) {
+ this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64));
+ }
+ // dtype
+ const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode);
+ const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits);
+ const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes);
+ this.dlDataType = new DLDataType(code, bits, lanes);
+ this.dtype = this.dlDataType.toString();
+
+ // ctx
+ const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType);
+ const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId);
+ this.context = new DLContext(deviceType, deviceId, lib);
+
+ // byte_offset
+ this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset);
+ }
+
+ dispose(): void {
+ if (this.handle != 0) {
+ this.lib.checkCall(
+ (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle)
+ );
+ this.handle = 0;
+ }
+ }
+ /**
+ * Copy data from another NDArray or javascript array.
+ * The number of elements must match.
+ *
+ * @param data The source data array.
+ * @returns this
+ */
+ copyFrom(data: NDArray | Array): this {
+ if (data instanceof NDArray) {
+ this.lib.checkCall(
+ (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)(
+ data.handle,
+ this.handle,
+ 0
+ )
+ );
+ return this;
+ } else {
+ const size = this.shape.reduce((a, b) => {
+ return a * b;
+ }, 1);
+ if (data.length != size) {
+ throw new Error(
+ "data size and shape mismatch data.length" +
+ data.length +
+ " vs " +
+ size
+ );
+ }
+ let buffer: ArrayBuffer;
+ if (this.dtype == "float32") {
+ buffer = Float32Array.from(data).buffer;
+ } else if (this.dtype == "float64") {
+ buffer = Float64Array.from(data).buffer;
+ } else if (this.dtype == "int32") {
+ buffer = Int32Array.from(data).buffer;
+ } else if (this.dtype == "int8") {
+ buffer = Int8Array.from(data).buffer;
+ } else if (this.dtype == "uint8") {
+ buffer = Uint8Array.from(data).buffer;
+ } else {
+ throw new Error("Unsupported data type " + this.dtype);
+ }
+ return this.copyFromRawBytes(new Uint8Array(buffer));
+ }
+ }
+ /**
+ * Copy data from raw bytes.
+ * @param data Uint8Array of bytes.
+ * @returns this
+ */
+ copyFromRawBytes(data: Uint8Array): this {
+ const size = this.shape.reduce((a, b) => {
+ return a * b;
+ }, 1);
+ const nbytes = this.dlDataType.numStorageBytes() * size;
+ if (nbytes != data.length) {
+ throw new Error("Expect the data's length equals nbytes=" + nbytes);
+ }
+
+ const stack = this.lib.getOrAllocCallStack();
+
+ const tempOffset = stack.allocRawBytes(nbytes);
+ const tempPtr = stack.ptrFromOffset(tempOffset);
+ this.lib.memory.storeRawBytes(tempPtr, data);
+ this.lib.checkCall(
+ (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)(
+ this.handle,
+ tempPtr,
+ nbytes
+ )
+ );
+
+ this.lib.recycleCallStack(stack);
+ return this;
+ }
+ /**
+ * Return a copied Uint8Array of the raw bytes in the NDArray.
+ * @returns The result array.
+ */
+ toRawBytes(): Uint8Array {
+ const size = this.shape.reduce((a, b) => {
+ return a * b;
+ }, 1);
+ const nbytes = this.dlDataType.numStorageBytes() * size;
+ const stack = this.lib.getOrAllocCallStack();
+
+ const tempOffset = stack.allocRawBytes(nbytes);
+ const tempPtr = stack.ptrFromOffset(tempOffset);
+ this.lib.checkCall(
+ (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)(
+ this.handle,
+ tempPtr,
+ nbytes
+ )
+ );
+ const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes);
+
+ this.lib.recycleCallStack(stack);
+ return ret;
+ }
+
+ /**
+ * Return a TypedArray copy of the NDArray, the specific type depends on
+ * the dtype of the NDArray.
+ * @returns The result array.
+ */
+ toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array {
+ const stype = this.dtype;
+ if (stype == "float32") {
+ return new Float32Array(this.toRawBytes().buffer);
+ } else if (stype == "float64") {
+ return new Float64Array(this.toRawBytes().buffer);
+ } else if (stype == "int32") {
+ return new Int32Array(this.toRawBytes().buffer);
+ } else if (stype == "int8") {
+ return new Int8Array(this.toRawBytes().buffer);
+ } else if (stype == "uint8") {
+ return new Uint8Array(this.toRawBytes().buffer);
+ } else {
+ throw new Error("Unsupported data type " + this.dtype);
+ }
+ }
+
+ private getDLTensorFromArrayHandle(handle: Pointer): Pointer {
+ // Note: this depends on the NDArray C ABI.
+ // keep this function in case of ABI change.
+ return handle;
+ }
+}
+
+/**
+ * Runtime Module.
+ */
+export class Module implements Disposable {
+ handle: Pointer;
+ private lib: FFILibrary;
+ private makePackedFunc: (ptr: Pointer) => PackedFunc;
+
+ constructor(
+ handle: Pointer,
+ lib: FFILibrary,
+ makePackedFunc: (ptr: Pointer) => PackedFunc
+ ) {
+ this.handle = handle;
+ this.lib = lib;
+ this.makePackedFunc = makePackedFunc;
+ }
+
+ dispose(): void {
+ if (this.handle != 0) {
+ this.lib.checkCall(
+ (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle)
+ );
+ this.handle = 0;
+ }
+ }
+
+ /**
+ * Get a function in the module.
+ * @param name The name of the function.
+ * @returns The result function.
+ */
+ getFunction(name: string): PackedFunc {
+ const stack = this.lib.getOrAllocCallStack();
+ const nameOffset = stack.allocRawBytes(name.length + 1);
+ stack.storeRawBytes(nameOffset, StringToUint8Array(name));
+
+ const outOffset = stack.allocPtrArray(1);
+ const outPtr = stack.ptrFromOffset(outOffset);
+
+ stack.commitToWasmMemory(outOffset);
+
+ this.lib.checkCall(
+ (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)(
+ this.handle,
+ stack.ptrFromOffset(nameOffset),
+ 1,
+ outPtr
+ )
+ );
+ const handle = this.lib.memory.loadPointer(outPtr);
+ this.lib.recycleCallStack(stack);
+ if (handle == 0) {
+ throw Error("Cannot find function " + name);
+ }
+ const ret = this.makePackedFunc(handle);
+ return ret;
+ }
+
+ /**
+ * Import another module into the current runtime module.
+ * @param mod The module to be imported.
+ */
+ importModule(mod: Module): void {
+ this.lib.checkCall(
+ (this.lib.exports.TVMModImport as ctypes.FTVMModImport)(
+ this.handle,
+ mod.handle
+ )
+ );
+ }
+}
+
+/**
+ * TVM runtime instance.
+ */
+export class Instance implements Disposable {
+ memory: Memory;
+ exports: Record;
+ private lib: FFILibrary;
+ private env: Environment;
+
+ /**
+ * Internal function(registered by the runtime)
+ */
+ private wasmCreateLibraryModule?: PackedFunc &
+ ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc);
+
+ /**
+ * Constructor
+ *
+ * importObject can also be a {@link LibraryProvider} object,
+ * a WASI object, or an object containing wasmLibraryProvider field.
+ *
+ * @param wasmModule The input module or instance.
+ * @param importObject The imports to initialize the wasmInstance if it is not provided.
+ * @param wasmInstance Additional wasm instance argument for deferred construction.
+ * @param env Directly specified environment module.
+ *
+ * @see Please use the async version {@link instantiate} when targeting browsers.
+ */
+ constructor(
+ wasmModule: WebAssembly.Module,
+ importObject: Record = {},
+ wasmInstance?: WebAssembly.Instance,
+ env?: Environment
+ ) {
+ if (wasmInstance instanceof WebAssembly.Instance) {
+ assert(
+ env instanceof Environment,
+ "env must be provided when passing in instance"
+ );
+ } else {
+ assert(env === undefined);
+ env = new Environment(importObject);
+ wasmInstance = new WebAssembly.Instance(wasmModule, env.imports);
+ }
+
+ env.start(wasmInstance);
+ this.env = env;
+ this.lib = new FFILibrary(wasmInstance, env.imports);
+ this.memory = this.lib.memory;
+ this.exports = this.lib.exports;
+ this.registerEnvGlobalPackedFuncs();
+ }
+
+ dispose(): void {
+ this.lib.dispose();
+ }
+ /**
+ * Get system-wide library module in the wasm.
+ * System lib is a global module that contains self register functions in startup.
+ * @returns The system library module.
+ */
+ systemLib(): Module {
+ const getSysLib = this.getGlobalFunc("runtime.SystemLib");
+ const mod = getSysLib() as Module;
+ getSysLib.dispose();
+ return mod;
+ }
+ /**
+ * List all the global function names registered in the runtime.
+ * @returns The name list.
+ */
+ listGlobalFuncNames(): Array {
+ const stack = this.lib.getOrAllocCallStack();
+
+ const outSizeOffset = stack.allocPtrArray(2);
+
+ const outSizePtr = stack.ptrFromOffset(outSizeOffset);
+ const outArrayPtr = stack.ptrFromOffset(
+ outSizeOffset + this.lib.sizeofPtr()
+ );
+
+ this.lib.checkCall(
+ (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)(
+ outSizePtr,
+ outArrayPtr
+ )
+ );
+
+ const size = this.memory.loadI32(outSizePtr);
+ const array = this.memory.loadPointer(outArrayPtr);
+ const names: Array = [];
+
+ for (let i = 0; i < size; ++i) {
+ names.push(
+ this.memory.loadCString(
+ this.memory.loadPointer(array + this.lib.sizeofPtr() * i)
+ )
+ );
+ }
+
+ this.lib.recycleCallStack(stack);
+ return names;
+ }
+
+ /**
+ * Register function to be global function in tvm runtime.
+ * @param name The name of the function.
+ * @param f function to be registered.
+ * @param override Whether overwrite function in existing registry.
+ */
+ registerFunc(
+ name: string,
+ func: PackedFunc | Function,
+ override = false
+ ): void {
+ const packedFunc = this.toPackedFunc(func);
+ const ioverride = override ? 1 : 0;
+
+ const stack = this.lib.getOrAllocCallStack();
+ const nameOffset = stack.allocRawBytes(name.length + 1);
+ stack.storeRawBytes(nameOffset, StringToUint8Array(name));
+ stack.commitToWasmMemory();
+
+ this.lib.checkCall(
+ (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)(
+ stack.ptrFromOffset(nameOffset),
+ packedFunc._tvmPackedCell.handle,
+ ioverride
+ )
+ );
+ }
+
+ /**
+ * Get global PackedFunc from the runtime.
+ * @param name The name of the function.
+ * @returns The result function.
+ */
+ getGlobalFunc(name: string): PackedFunc {
+ const stack = this.lib.getOrAllocCallStack();
+ const nameOffset = stack.allocRawBytes(name.length + 1);
+ stack.storeRawBytes(nameOffset, StringToUint8Array(name));
+ const outOffset = stack.allocPtrArray(1);
+ const outPtr = stack.ptrFromOffset(outOffset);
+
+ stack.commitToWasmMemory(outOffset);
+
+ this.lib.checkCall(
+ (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)(
+ stack.ptrFromOffset(nameOffset),
+ outPtr
+ )
+ );
+ const handle = this.memory.loadPointer(outPtr);
+ this.lib.recycleCallStack(stack);
+ if (handle == 0) {
+ throw Error("Cannot find global function " + name);
+ }
+ const ret = this.makePackedFunc(handle);
+ return ret;
+ }
+
+ /**
+ * Check if func is PackedFunc.
+ *
+ * @param func The input.
+ * @returns The check result.
+ */
+ isPackedFunc(func: unknown): boolean {
+ // eslint-disable-next-line no-prototype-builtins
+ return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell");
+ }
+
+ /**
+ * Convert func to PackedFunc
+ *
+ * @param func Input function.
+ * @returns The converted function.
+ */
+ toPackedFunc(func: Function): PackedFunc {
+ if (this.isPackedFunc(func)) return func as PackedFunc;
+ return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func));
+ }
+
+ /**
+ * Convert dtype to {@link DLDataType}
+ *
+ * @param dtype The input dtype string or DLDataType.
+ * @returns The converted result.
+ */
+ toDLDataType(dtype: string | DLDataType): DLDataType {
+ if (dtype instanceof DLDataType) return dtype;
+ if (typeof dtype == "string") {
+ let pattern = dtype;
+ let code,
+ bits = 32,
+ lanes = 1;
+ if (pattern.substring(0, 5) == "float") {
+ pattern = pattern.substring(5, pattern.length);
+ code = TypeCode.Float;
+ } else if (pattern.substring(0, 3) == "int") {
+ pattern = pattern.substring(3, pattern.length);
+ code = TypeCode.Int;
+ } else if (pattern.substring(0, 4) == "uint") {
+ pattern = pattern.substring(4, pattern.length);
+ code = TypeCode.UInt;
+ } else if (pattern.substring(0, 6) == "handle") {
+ pattern = pattern.substring(5, pattern.length);
+ code = TypeCode.TVMOpaqueHandle;
+ bits = 64;
+ } else {
+ throw new Error("Unknown dtype " + dtype);
+ }
+
+ const arr = pattern.split("x");
+ if (arr.length >= 1) {
+ const parsed = parseInt(arr[0]);
+ if (parsed + "" == arr[0]) {
+ bits = parsed;
+ }
+ }
+ if (arr.length >= 2) {
+ lanes = parseInt(arr[1]);
+ }
+ return new DLDataType(code, bits, lanes);
+ } else {
+ throw new Error("Unknown dtype " + dtype);
+ }
+ }
+
+ /**
+ * Create a new {@link Scalar} that can be passed to a PackedFunc.
+ * @param value The number value.
+ * @param dtype The dtype string.
+ * @returns The created scalar.
+ */
+ scalar(value: number, dtype: string): Scalar {
+ return new Scalar(value, dtype);
+ }
+
+ /**
+ * Create a new {@link DLContext}
+ * @param deviceType The device type.
+ * @param deviceId The device index.
+ * @returns The created context.
+ */
+ context(deviceType: number | string, deviceId: number): DLContext {
+ return new DLContext(deviceType, deviceId, this.lib);
+ }
+
+ /**
+ * Create an empty {@link NDArray} with given shape and dtype.
+ *
+ * @param shape The shape of the array.
+ * @param dtype The data type of the array.
+ * @param ctx The context of the ndarray.
+ * @returns The created ndarray.
+ */
+ empty(
+ shape: Array | number,
+ dtype: string | DLDataType = "float32",
+ ctx: DLContext = this.context("cpu", 0)
+ ): NDArray {
+ dtype = this.toDLDataType(dtype);
+ shape = typeof shape == "number" ? [shape] : shape;
+
+ const stack = this.lib.getOrAllocCallStack();
+ const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64);
+ for (let i = 0; i < shape.length; ++i) {
+ stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]);
+ }
+
+ const outOffset = stack.allocPtrArray(1);
+ const outPtr = stack.ptrFromOffset(outOffset);
+ stack.commitToWasmMemory(outOffset);
+
+ this.lib.checkCall(
+ (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)(
+ stack.ptrFromOffset(shapeOffset),
+ shape.length,
+ dtype.code,
+ dtype.bits,
+ dtype.lanes,
+ ctx.deviceType,
+ ctx.deviceId,
+ outPtr
+ )
+ );
+ const ret = new NDArray(this.memory.loadPointer(outPtr), this.lib);
+ this.lib.recycleCallStack(stack);
+ return ret;
+ }
+
+ /** Register global packed functions needed by the backend to the env. */
+ private registerEnvGlobalPackedFuncs(): void {
+ // Register the timer function to enable the time_evaluator.
+ let perf: Performance;
+ if (typeof performance == "undefined") {
+ // eslint-disable-next-line @typescript-eslint/no-var-requires
+ const performanceNode = require('perf_hooks');
+ perf = performanceNode.performance as Performance;
+ } else {
+ perf = performance as Performance;
+ }
+
+ const getTimer = (func: PackedFunc) => {
+ return (n: number): number => {
+ const nscalar = this.scalar(n, "int32");
+ const tstart: number = perf.now();
+ func(nscalar);
+ const tend: number = perf.now();
+ return tend - tstart;
+ }
+ };
+ this.registerFunc("wasm.GetTimer", getTimer);
+ const rpcWrapTimeEvaluator = this.getGlobalFunc("wasm.RPCTimeEvaluator");
+ this.registerFunc("runtime.RPCTimeEvaluator", rpcWrapTimeEvaluator, true);
+ rpcWrapTimeEvaluator.dispose();
+ }
+
+ private createPackedFuncFromCFunc(
+ func: ctypes.FTVMWasmPackedCFunc
+ ): PackedFunc {
+ let findex = this.env.packedCFuncTable.length;
+ if (this.env.packedCFuncTableFreeId.length != 0) {
+ findex = this.env.packedCFuncTableFreeId.pop() as number;
+ } else {
+ this.env.packedCFuncTable.push(undefined);
+ }
+ this.env.packedCFuncTable[findex] = func;
+
+ const stack = this.lib.getOrAllocCallStack();
+ const outOffset = stack.allocPtrArray(1);
+ const outPtr = stack.ptrFromOffset(outOffset);
+ this.lib.checkCall(
+ (this.exports
+ .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)(
+ findex,
+ outPtr
+ )
+ );
+ const ret = this.makePackedFunc(this.memory.loadPointer(outPtr));
+ this.lib.recycleCallStack(stack);
+ return ret;
+ }
+
+ /**
+ * Set packed function arguments into the location indicated by argsValue and argsCode.
+ * Allocate new temporary space from the stack if necessary.
+ *
+ * @parma stack The call stack
+ * @param args The input arguments.
+ * @param argsValue The offset of argsValue.
+ * @param argsCode The offset of argsCode.
+ */
+ setPackedArguments(
+ stack: CachedCallStack,
+ args: Array,
+ argsValue: PtrOffset,
+ argsCode: PtrOffset
+ ): void {
+ for (let i = 0; i < args.length; ++i) {
+ let val = args[i];
+ const tp = typeof val;
+ const valueOffset = argsValue + i * SizeOf.TVMValue;
+ const codeOffset = argsCode + i * SizeOf.I32;
+ if (val instanceof NDArray) {
+ stack.storePtr(valueOffset, val.handle);
+ stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle);
+ } else if (val instanceof Scalar) {
+ if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) {
+ stack.storeI64(valueOffset, val.value);
+ stack.storeI32(codeOffset, TypeCode.Int);
+ } else if (val.dtype.startsWith("float")) {
+ stack.storeF64(valueOffset, val.value);
+ stack.storeI32(codeOffset, TypeCode.Float);
+ } else {
+ assert(val.dtype == "handle", "Expect handle");
+ stack.storePtr(valueOffset, val.value);
+ stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle);
+ }
+ } else if (tp == "number") {
+ stack.storeF64(valueOffset, val);
+ stack.storeI32(codeOffset, TypeCode.Float);
+ // eslint-disable-next-line no-prototype-builtins
+ } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) {
+ stack.storePtr(valueOffset, val._tvmPackedCell.handle);
+ stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle);
+ } else if (val === null || val == undefined) {
+ stack.storePtr(valueOffset, 0);
+ stack.storeI32(codeOffset, TypeCode.Null);
+ } else if (tp == "string") {
+ stack.allocThenSetArgString(valueOffset, val);
+ stack.storeI32(codeOffset, TypeCode.TVMStr);
+ } else if (val instanceof Uint8Array) {
+ stack.allocThenSetArgBytes(valueOffset, val);
+ stack.storeI32(codeOffset, TypeCode.TVMBytes);
+ } else if (val instanceof Function) {
+ val = this.toPackedFunc(val);
+ stack.tempArgs.push(val);
+ stack.storePtr(valueOffset, val._tvmPackedCell.handle);
+ stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle);
+ } else if (val instanceof Module) {
+ stack.storePtr(valueOffset, val.handle);
+ stack.storeI32(codeOffset, TypeCode.TVMModuleHandle);
+ } else {
+ throw new Error("Unsupported argument type " + tp);
+ }
+ }
+ }
+
+ private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc {
+ const lib = this.lib;
+ return (
+ argValues: Pointer,
+ argCodes: Pointer,
+ nargs: number,
+ ret: Pointer,
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ _handle: Pointer
+ ): number => {
+ const jsArgs = [];
+ for (let i = 0; i < nargs; ++i) {
+ const valuePtr = argValues + i * SizeOf.TVMValue;
+ const codePtr = argCodes + i * SizeOf.I32;
+ let tcode = lib.memory.loadI32(codePtr);
+
+ if (
+ tcode == TypeCode.TVMObjectHandle ||
+ tcode == TypeCode.TVMObjectRValueRefArg ||
+ tcode == TypeCode.TVMPackedFuncHandle ||
+ tcode == TypeCode.TVMModuleHandle
+ ) {
+ lib.checkCall(
+ (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)(
+ valuePtr,
+ codePtr
+ )
+ );
+ }
+ tcode = lib.memory.loadI32(codePtr);
+ jsArgs.push(this.retValueToJS(valuePtr, tcode));
+ }
+
+ const rv = func(...jsArgs);
+
+ if (rv !== undefined && rv !== null) {
+ const stack = lib.getOrAllocCallStack();
+ const valueOffset = stack.allocRawBytes(SizeOf.TVMValue);
+ const codeOffset = stack.allocRawBytes(SizeOf.I32);
+ this.setPackedArguments(stack, [rv], valueOffset, codeOffset);
+ const valuePtr = stack.ptrFromOffset(valueOffset);
+ const codePtr = stack.ptrFromOffset(codeOffset);
+ stack.commitToWasmMemory();
+ lib.checkCall(
+ (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)(
+ ret,
+ valuePtr,
+ codePtr,
+ 1
+ )
+ );
+ lib.recycleCallStack(stack);
+ }
+ return 0;
+ };
+ }
+
+ private makePackedFunc(handle: Pointer): PackedFunc {
+ const cell = new PackedFuncCell(handle, this.lib);
+
+ const packedFunc = (...args: any): any => {
+ const stack = this.lib.getOrAllocCallStack();
+
+ const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length);
+ const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length);
+
+ this.setPackedArguments(stack, args, valueOffset, tcodeOffset);
+
+ const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue);
+ const rcodeOffset = stack.allocRawBytes(SizeOf.I32);
+ const rvaluePtr = stack.ptrFromOffset(rvalueOffset);
+ const rcodePtr = stack.ptrFromOffset(rcodeOffset);
+
+ // commit to wasm memory, till rvalueOffset (the return value don't need to be committed)
+ stack.commitToWasmMemory(rvalueOffset);
+
+ this.lib.checkCall(
+ (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)(
+ handle,
+ stack.ptrFromOffset(valueOffset),
+ stack.ptrFromOffset(tcodeOffset),
+ args.length,
+ rvaluePtr,
+ rcodePtr
+ )
+ );
+
+ const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr));
+ this.lib.recycleCallStack(stack);
+ return ret;
+ };
+ // Attach attributes to the function type.
+ // This is because javascript do not allow us to overload call.
+ const ret: any = packedFunc;
+ ret.dispose = (): void => {
+ cell.dispose();
+ };
+ ret._tvmPackedCell = cell;
+ return ret as PackedFunc;
+ }
+
+ private retValueToJS(rvaluePtr: Pointer, tcode: number): any {
+ switch (tcode) {
+ case TypeCode.Int:
+ case TypeCode.UInt:
+ return this.memory.loadI64(rvaluePtr);
+ case TypeCode.Float:
+ return this.memory.loadF64(rvaluePtr);
+ case TypeCode.TVMNDArrayHandle: {
+ return new NDArray(this.memory.loadPointer(rvaluePtr), this.lib);
+ }
+ case TypeCode.TVMPackedFuncHandle: {
+ return this.makePackedFunc(this.memory.loadPointer(rvaluePtr));
+ }
+ case TypeCode.TVMModuleHandle: {
+ return new Module(
+ this.memory.loadPointer(rvaluePtr),
+ this.lib,
+ (ptr: Pointer) => {
+ return this.makePackedFunc(ptr);
+ }
+ );
+ }
+ case TypeCode.Null:
+ return undefined;
+ case TypeCode.TVMStr: {
+ return this.memory.loadCString(this.memory.loadPointer(rvaluePtr));
+ }
+ case TypeCode.TVMBytes: {
+ return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr));
+ }
+ default:
+ throw new Error("Unsupported return type code=" + tcode);
+ }
+ }
+}
+
+/**
+ * Asynchrously instantiate a new {@link Instance}.
+ *
+ * importObject can also be a {@link LibraryProvider} object,
+ * a WASI object, or an object containing wasmLibraryProvider field.
+ * We can take benefit of syslib implementations from the Emscripten
+ * by passing its generated js Module as the imports.
+ */
+export function instantiate(
+ bufferSource: ArrayBuffer,
+ importObject: Record = {}
+): Promise {
+ const env = new Environment(importObject);
+
+ return WebAssembly.instantiate(bufferSource, env.imports).then(
+ (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => {
+ return new Instance(result.module, {}, result.instance, env);
+ }
+ );
+}
diff --git a/web/src/support.ts b/web/src/support.ts
new file mode 100644
index 000000000000..7a2667a2299f
--- /dev/null
+++ b/web/src/support.ts
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Convert string to Uint8array.
+ * @param str The string.
+ * @returns The corresponding Uint8Array.
+ */
+export function StringToUint8Array(str: string): Uint8Array {
+ const arr = new Uint8Array(str.length + 1);
+ for (let i = 0; i < str.length; ++i) {
+ arr[i] = str.charCodeAt(i);
+ }
+ arr[str.length] = 0;
+ return arr;
+}
+
+/**
+ * Convert Uint8array to string.
+ * @param array The array.
+ * @returns The corresponding string.
+ */
+export function Uint8ArrayToString(arr: Uint8Array): string {
+ const ret = [];
+ for (const ch of arr) {
+ ret.push(String.fromCharCode(ch));
+ }
+ return ret.join("");
+}
+
+/**
+ * Internal assert helper
+ * @param condition condition The condition to fail.
+ * @param msg msg The message.
+ */
+export function assert(condition: boolean, msg?: string): asserts condition {
+ if (!condition) {
+ throw new Error("AssertError:" + (msg || ""));
+ }
+}
+
+/**
+ * Get the path to the wasm library in nodejs.
+ * @return The wasm path.
+ */
+export function wasmPath(): string {
+ return __dirname + "/wasm";
+}
\ No newline at end of file
diff --git a/web/src/types.ts b/web/src/types.ts
new file mode 100644
index 000000000000..621375a23f5f
--- /dev/null
+++ b/web/src/types.ts
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/** Common type definitions. */
+
+/**
+ * Library interface provider that can provide
+ * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime.
+ *
+ * It can be viewed as a generalization of imports used in WebAssembly instance creation.
+ *
+ * The {@link LibraryProvider.start} callback will be called
+ * to allow the library provider to initialize related resources during startup time.
+ *
+ * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }.
+ */
+export interface LibraryProvider {
+ /** The imports that can be passed to WebAssembly instance creation. */
+ imports: Record;
+ /**
+ * Callback function to notify the provider the created instance.
+ * @param inst The created instance.
+ */
+ start: (inst: WebAssembly.Instance) => void;
+}
+
+/**
+ * Disposable classes that contains resources (WasmMemory, GPU buffer)
+ * which needs to be explicitly disposed.
+ */
+export interface Disposable {
+ /**
+ * Dispose the internal resource
+ * This function can be called multiple times,
+ * only the first call will take effect.
+ */
+ dispose: () => void;
+}
diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js
similarity index 64%
rename from tests/web/test_module_load.js
rename to web/tests/node/test_module_load.js
index f4c809536bb5..45e84fd404a9 100644
--- a/tests/web/test_module_load.js
+++ b/web/tests/node/test_module_load.js
@@ -19,14 +19,18 @@
// Load Emscripten Module, need to change path to root/lib
const path = require("path");
-process.chdir(path.join(__dirname, "../../build"));
-var Module = require("../../build/test_module.js");
-// Bootstrap TVMruntime with emscripten module.
-const tvm_runtime = require("../../web/tvm_runtime.js");
-const tvm = tvm_runtime.create(Module);
+const fs = require("fs");
+const assert = require("assert");
+const tvmjs = require("../../dist");
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm"));
+
+const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI());
// Load system library
-var sysLib = tvm.systemLib();
+const sysLib = tvm.systemLib();
function randomArray(length, max) {
return Array.apply(null, Array(length)).map(function() {
@@ -36,23 +40,22 @@ function randomArray(length, max) {
function testAddOne() {
// grab pre-loaded function
- var faddOne = sysLib.getFunction("add_one");
- var assert = require('assert');
- tvm.assert(tvm.isPackedFunc(faddOne));
- var n = 124;
- var A = tvm.empty(n).copyFrom(randomArray(n, 1));
- var B = tvm.empty(n);
+ const faddOne = sysLib.getFunction("add_one");
+ assert(tvm.isPackedFunc(faddOne));
+ const n = 124;
+ const A = tvm.empty(n).copyFrom(randomArray(n, 1));
+ const B = tvm.empty(n);
// call the function.
faddOne(A, B);
- AA = A.asArray(); // retrieve values in js array
- BB = B.asArray(); // retrieve values in js array
+ const AA = A.toArray(); // retrieve values in js array
+ const BB = B.toArray(); // retrieve values in js array
// verify
for (var i = 0; i < BB.length; ++i) {
assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5);
}
- faddOne.release();
+ faddOne.dispose();
}
testAddOne();
-sysLib.release();
+sysLib.dispose();
console.log("Finish verifying test_module_load");
diff --git a/tests/web/test_basic.js b/web/tests/node/test_ndarray.js
similarity index 55%
rename from tests/web/test_basic.js
rename to web/tests/node/test_ndarray.js
index 6852319dbc12..ba43621ecb05 100644
--- a/tests/web/test_basic.js
+++ b/web/tests/node/test_ndarray.js
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -16,31 +16,34 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-// Load Emscripten Module, need to change path to root/build
const path = require("path");
-process.chdir(path.join(__dirname, "../../build"));
-var Module = require("../../build/libtvm_web_runtime.js");
-// Bootstrap TVMruntime with emscripten module.
-const tvm_runtime = require("../../web/tvm_runtime.js");
-const tvm = tvm_runtime.create(Module);
+const fs = require("fs");
+const assert = require("assert");
+const tvmjs = require("../../dist/tvmjs.bundle")
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
+
+let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI());
// Basic fields.
-tvm.assert(tvm.float32 == "float32");
-tvm.assert(tvm.listGlobalFuncNames() !== "undefined");
-var sysLib = tvm.systemLib();
-tvm.assert(typeof sysLib.getFunction !== "undefined");
-sysLib.release();
+assert(tvm.listGlobalFuncNames() !== undefined);
// Test ndarray
-function testArrayCopy(dtype, arr) {
- var data = [1, 2, 3, 4, 5, 6];
- var a = tvm.empty([2, 3], dtype);
- a.copyFrom(data);
- var ret = a.asArray();
- tvm.assert(ret instanceof arr);
- tvm.assert(ret.toString() == arr.from(data));
- a.release();
+function testArrayCopy(dtype, arrayType) {
+ let data = [1, 2, 3, 4, 5, 6];
+ let a = tvm.empty([2, 3], dtype).copyFrom(data);
+
+ assert(a.context.toString() == "cpu(0)");
+ assert(a.shape[0] == 2 && a.shape[1] == 3);
+
+ let ret = a.toArray();
+ assert(ret instanceof arrayType);
+ assert(ret.toString() == arrayType.from(data).toString());
+ // test multiple dispose.
+ a.dispose();
+ a.dispose();
}
testArrayCopy("float32", Float32Array);
@@ -48,8 +51,3 @@ testArrayCopy("int", Int32Array);
testArrayCopy("int8", Int8Array);
testArrayCopy("uint8", Uint8Array);
testArrayCopy("float64", Float64Array);
-
-// Function registration
-tvm.registerFunc("xyz", function(x, y) {
- return x + y;
-});
diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js
new file mode 100644
index 000000000000..c961f9576e3f
--- /dev/null
+++ b/web/tests/node/test_packed_func.js
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+const path = require("path");
+const fs = require("fs");
+const assert = require('assert');
+const tvmjs = require("../../dist")
+
+const wasmPath = tvmjs.wasmPath();
+const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js"));
+const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
+
+let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI());
+
+function testGetGlobal() {
+ let flist = tvm.listGlobalFuncNames();
+ let faddOne = tvm.getGlobalFunc("testing.add_one");
+ let fecho = tvm.getGlobalFunc("testing.echo");
+
+ assert(faddOne(tvm.scalar(1, "int")) == 2);
+ // check function argument with different types.
+ assert(fecho(1123) == 1123);
+ assert(fecho("xyz") == "xyz");
+
+ let bytes = new Uint8Array([1, 2, 3]);
+ let rbytes = fecho(bytes);
+ assert(rbytes.length == bytes.length);
+
+ for (let i = 0; i < bytes.length; ++i) {
+ assert(rbytes[i] == bytes[i]);
+ }
+
+ assert(fecho(undefined) == undefined);
+
+ let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]);
+ let arr2 = fecho(arr);
+ assert(arr.handle == arr2.handle);
+ assert(arr2.toArray().toString() == arr.toArray().toString());
+
+ let mod = tvm.systemLib();
+ let ret = fecho(mod);
+ assert(ret.handle == mod.handle);
+ assert(flist.length != 0);
+
+ mod.dispose();
+ ret.dispose();
+ arr.dispose();
+ arr2.dispose();
+ fecho.dispose();
+ faddOne.dispose();
+}
+
+function testReturnFunc() {
+ function addy(y) {
+ function add(x, z) {
+ return x + y + z;
+ }
+ return add;
+ }
+
+ let fecho = tvm.getGlobalFunc("testing.echo");
+ let myf = tvm.toPackedFunc(addy);
+ assert(tvm.isPackedFunc(myf));
+ let myf2 = tvm.toPackedFunc(myf);
+ assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle);
+ let f = myf(10);
+
+ assert(tvm.isPackedFunc(f));
+ assert(f(11, 0) == 21);
+ assert(f("x", 1) == "x101");
+ assert(f("x", "yz") == "x10yz");
+
+ fecho.dispose();
+ myf.dispose();
+ myf2.dispose();
+ // test multiple dispose.
+ f.dispose();
+ f.dispose();
+}
+
+function testRegisterGlobal() {
+ tvm.registerFunc("xyz", function (x, y) {
+ return x + y;
+ });
+
+ let f = tvm.getGlobalFunc("xyz");
+ assert(f(1, 2) == 3);
+ f.dispose();
+
+ let syslib = tvm.systemLib();
+ syslib.dispose();
+}
+
+function testTimer() {
+ const fecho = tvm.getGlobalFunc("testing.echo");
+ const fgetTimer = tvm.getGlobalFunc("wasm.GetTimer");
+
+ let finvoke = (n) => {
+ let x = "xyz";
+ for (let i = 0; i < n; ++i) {
+ x = fecho(x);
+ }
+ };
+ const number = 10000;
+ const invokeTimer = fgetTimer(finvoke);
+ console.log("Time cost:", number / invokeTimer(number) * 1000, " ops/sec");
+ fecho.dispose();
+ invokeTimer.dispose();
+ fgetTimer.dispose();
+}
+
+testGetGlobal();
+testRegisterGlobal();
+testReturnFunc();
+testTimer();
diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py
similarity index 69%
rename from tests/web/prepare_test_libs.py
rename to web/tests/python/prepare_test_libs.py
index a0e2c13eab82..ec4eb5be1536 100644
--- a/tests/web/prepare_test_libs.py
+++ b/web/tests/python/prepare_test_libs.py
@@ -14,27 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# Prepare test library for js.
+# Prepare test library for standalone wasm runtime test.
+
import tvm
from tvm import te
-from tvm.contrib import emscripten
+from tvm.contrib import emcc
import os
+
def prepare_test_libs(base_path):
- target = "llvm -target=asmjs-unknown-emscripten -system-lib"
+ target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib"
if not tvm.runtime.enabled(target):
raise RuntimeError("Target %s is not enbaled" % target)
n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = te.create_schedule(B.op)
- fadd1 = tvm.build(s, [A, B], target, name="add_one")
- obj_path = os.path.join(base_path, "test_add_one.bc")
- fadd1.save(obj_path)
- emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path,
- options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s",
- "USE_WEBGL2=1", "-lglfw"])
+ fadd = tvm.build(s, [A, B], target, name="add_one")
+
+ wasm_path = os.path.join(base_path, "test_addone.wasm")
+ fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+
if __name__ == "__main__":
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- prepare_test_libs(os.path.join(curr_path, "../../build"))
+ prepare_test_libs(os.path.join(curr_path, "../../dist/wasm"))
diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py
similarity index 55%
rename from tests/web/websock_rpc_test.py
rename to web/tests/python/websock_rpc_test.py
index 8be8ce04cb75..7fa0c6bdfb57 100644
--- a/tests/web/websock_rpc_test.py
+++ b/web/tests/python/websock_rpc_test.py
@@ -22,45 +22,61 @@
import tvm
from tvm import te
-import os
from tvm import rpc
-from tvm.contrib import util, emscripten
+from tvm.contrib import util, emcc
import numpy as np
proxy_host = "localhost"
proxy_port = 9090
-def test_rpc_array():
+def test_rpc():
if not tvm.runtime.enabled("rpc"):
return
- # graph
- n = tvm.runtime.convert(1024)
+ # generate the wasm library
+ target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib"
+ if not tvm.runtime.enabled(target):
+ raise RuntimeError("Target %s is not enbaled" % target)
+ n = te.var("n")
A = te.placeholder((n,), name='A')
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = te.create_schedule(B.op)
- remote = rpc.connect(proxy_host, proxy_port, key="js")
- target = "llvm -target=asmjs-unknown-emscripten -system-lib"
- def check_remote():
- if not tvm.runtime.enabled(target):
- print("Skip because %s is not enabled" % target)
- return
- temp = util.tempdir()
+
+ fadd = tvm.build(s, [A, B], target, name="addone")
+ temp = util.tempdir()
+
+ wasm_path = temp.relpath("addone.wasm")
+ fadd.export_library(wasm_path, emcc.create_tvmjs_wasm)
+
+ wasm_binary = open(wasm_path, "rb").read()
+
+ remote = rpc.connect(proxy_host, proxy_port, key="wasm",
+ session_constructor_args=["rpc.WasmSession", wasm_binary])
+
+ def check(remote):
+ # basic function checks.
+ fecho = remote.get_function("testing.echo")
+ assert(fecho(1, 2, 3) == 1)
+ assert(fecho(100, 2, 3) == 100)
+ assert(fecho("xyz") == "xyz")
+ assert(bytes(fecho(bytearray(b"123"))) == b"123")
+
+ # run the generated library.
+ f1 = remote.system_lib()
ctx = remote.cpu(0)
- f = tvm.build(s, [A, B], target, name="myadd")
- path_obj = temp.relpath("dev_lib.bc")
- path_dso = temp.relpath("dev_lib.js")
- f.save(path_obj)
- emscripten.create_js(path_dso, path_obj, side_module=True)
- # Upload to suffix as dso so it can be loaded remotely
- remote.upload(path_dso, "dev_lib.dso")
- data = remote.download("dev_lib.dso")
- f1 = remote.load_module("dev_lib.dso")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
- time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
+ # invoke the function
+ addone = f1.get_function("addone")
+ addone(a, b)
+
+ # time evaluator
+ time_f = f1.time_evaluator("addone", ctx, number=10)
+ time_f(a, b)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
- check_remote()
-test_rpc_array()
+ check(remote)
+
+
+test_rpc()
diff --git a/web/tsconfig.json b/web/tsconfig.json
new file mode 100644
index 000000000000..3c20b3d20692
--- /dev/null
+++ b/web/tsconfig.json
@@ -0,0 +1,13 @@
+{
+ "compilerOptions": {
+ "module": "commonjs",
+ "target": "es6",
+ "outDir": "dist",
+ "rootDir": "src",
+ "declaration": true,
+ "sourceMap": true,
+ "strict": true,
+ },
+ "include": ["src"],
+ "exclude": ["node_modules"]
+}
diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js
deleted file mode 100644
index 86ef59cb73b1..000000000000
--- a/web/tvm_runtime.js
+++ /dev/null
@@ -1,1274 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/**
- * TVM Javascript web runtime library.
- *
- * @projectname tvm
- * @version 0.7.dev1
- */
-/* eslint no-unused-vars: "off" */
-/* eslint no-unexpected-multiline: "off" */
-/* eslint indent: "off" */
-/* eslint no-console: "off" */
-/**
- * TVM Runtime namespace.
- * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}.
- *
- * @namespace tvm_runtime
- */
-var tvm_runtime = tvm_runtime || {};
-
-/**
- * TVM root namespace.
- * The classes inside this namespace need to be constructed by factory functions.
- * Use {@link tvm_runtime}.create to get started.
- *
- * @namespace tvm
- */
-(function() {
- /**
- * TVMRuntime object for interacting with TVM runtime.
- * This object can be constructed using {@link tvm_runtime}.create
- *
- * @class
- * @memberof tvm
- */
- function TVMRuntime() {
- "use strict";
- var runtime_ref = this;
- // Utility function to throw error
- function throwError(message) {
- if (typeof runtime_ref.logger !== "undefined") {
- runtime_ref.logger(message);
- }
- if (typeof Error !== "undefined") {
- throw new Error(message);
- }
- throw message;
- }
- var Module = this.Module;
- var Runtime = this.Runtime;
- if (typeof Module === "undefined") {
- throwError("Emscripten Module is not available");
- }
- // constants
- var SIZEOF_POINTER = 4;
- var SIZEOF_SIZE_T = 4;
- var SIZEOF_FLOAT = 4;
- var SIZEOF_INT = 4;
- var SIZEOF_INT8 = 1;
- var SIZEOF_INT64 = 8;
- var SIZEOF_DOUBLE = 8;
- var SIZEOF_TYPE = 4;
- var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT;
- var SIZEOF_TVMVALUE = SIZEOF_DOUBLE;
- var ARRAY_OFFSET_DATA = 0;
- var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER;
- var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX;
- var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT;
- var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX;
- var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT;
- var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE;
- var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8;
- var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8;
- var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE;
- var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER;
- var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER;
- // Type codes
- var kInt = 0;
- var kUInt = 1;
- var kFloat = 2;
- var kTVMOpaqueHandle = 3;
- var kNull = 4;
- var kTVMDataType = 5;
- var kTVMContext = 6;
- var kTVMDLTensorHandle = 7;
- var kTVMObjectHandle = 8;
- var kTVMModuleHandle = 9;
- var kTVMPackedFuncHandle = 10;
- var kTVMStr = 11;
- var kTVMBytes = 12;
- var kTVMObjectRValueRefArg = 14;
- //-----------------------------------------
- // TVM CWrap library
- // ----------------------------------------
- var TVMGetLastError = Module.cwrap(
- "TVMGetLastError",
- "string", // const char*
- []);
-
- var TVMAPISetLastError = Module.cwrap
- ("TVMAPISetLastError",
- null,
- ["string" // const char*
- ]);
-
- var TVMModImport = Module.cwrap
- ("TVMModImport",
- "number",
- ["number", // TVMModuleHandle mod
- "number" // TVMModuleHandle dep
- ]);
-
- var TVMModGetFunction = Module.cwrap
- ("TVMModGetFunction",
- "number",
- ["number", // TVMModuleHandle mod
- "string", // const char* func_name
- "number", // int query_imports
- "number" // TVMFunctionHandle *out
- ]);
-
- var TVMModFree = Module.cwrap
- ("TVMModFree",
- "number",
- ["number" // TVMModeHandle mod
- ]);
-
- var TVMFuncFree = Module.cwrap
- ("TVMFuncFree",
- "number",
- ["number" // TVMFunctionHandle func
- ]);
-
- var TVMFuncCall = Module.cwrap
- ("TVMFuncCall",
- "number",
- ["number", // TVMFunctionHandle func
- "number", // TVMValue* arg_values
- "number", // int* arg_tcodes
- "number", // int num_args
- "number", // int ret_val
- "number" // int ret_type_code
- ]);
-
- var TVMCFuncSetReturn = Module.cwrap
- ("TVMCFuncSetReturn",
- "number",
- ["number", // TVMRetValueHandle ret
- "number", // TVMValue* value
- "number", // int* type_code
- "number" // int num_ret
- ]);
-
- var TVMCbArgToReturn = Module.cwrap
- ("TVMCbArgToReturn",
- "number",
- ["number", // TVMValue* value
- "number" // int* code
- ]);
-
- var TVMFuncCreateFromCFunc = Module.cwrap
- ("TVMFuncCreateFromCFunc",
- "number",
- ["number", // TVMPackedCFunc func,
- "number", // void* resource_handle
- "number", // TVMPackedCFuncFinalizer fin
- "number" // TVMFunctionHandle *out
- ]);
-
- var TVMFuncRegisterGlobal = Module.cwrap
- ("TVMFuncRegisterGlobal",
- "number",
- ["string", // name
- "number", // TVMFunctionHandle f
- "number" // int override
- ]);
-
- var TVMFuncGetGlobal = Module.cwrap
- ("TVMFuncGetGlobal",
- "number",
- ["string", // const char* name
- "number" // TVMFunctionHandle* out
- ]);
-
- var TVMFuncListGlobalNames = Module.cwrap
- ("TVMFuncListGlobalNames",
- "number",
- ["number", // int* out_size
- "number" // const char*** out_array
- ]);
-
-
- var TVMArrayAlloc = Module.cwrap
- ("TVMArrayAlloc",
- "number",
- ["number", // const tvm_index_t* shape
- "number", // int ndim
- "number", // int dtype_code
- "number", // int dtype_bits
- "number", // int dtype_lanes
- "number", // int device_type
- "number", // int device_id
- "number" // int TVMArrayHandle* out
- ]);
-
- var TVMArrayFree = Module.cwrap
- ("TVMArrayFree",
- "number",
- ["number" // TVMArrayHandle handle
- ]);
-
- var TVMArrayCopyFromTo = Module.cwrap
- ("TVMArrayCopyFromTo",
- "number",
- ["number", // TVMArrayHandle from
- "number" // TVMArrayHandle to
- ]);
-
- var TVMArrayCopyFromBytes = Module.cwrap
- ("TVMArrayCopyFromBytes",
- "number",
- ["number", // TVMArrayHandle handle
- "number", // int data
- "number" // size_t nbytes
- ]);
-
- var TVMArrayCopyToBytes = Module.cwrap
- ("TVMArrayCopyToBytes",
- "number",
- ["number", // TVMArrayHandle handle
- "number", // int data
- "number" // size_t nbytes
- ]);
-
- var TVMModLoadFromFile = Module.cwrap
- ("TVMModLoadFromFile",
- "number",
- ["string", // const char* file_name
- "string", // const char* format
- "number" // TVMModuleHandle* out
- ])
-
- //-----------------------------------------
- // Static utility functions
- // ----------------------------------------
- this.assert = function(condition, message) {
- if (!condition) {
- message = message || "assert failed";
- throwError(message);
- }
- };
- /**
- * Logging function.
- * Override this to change logger behavior.
- *
- * @param {string} message
- */
- this.logger = function(message) {
- console.log(message);
- };
-
- function logging(message) {
- runtime_ref.logger(message);
- }
- // Override print error to logging
- Module.printErr = logging;
- var CHECK = this.assert;
-
- function TVM_CALL(ret) {
- if (ret != 0) {
- throwError(TVMGetLastError());
- }
- }
-
- function CInt64ArrayToJS(ptr, size) {
- var ret = [];
- for (var i = 0; i < size; ++i) {
- ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64"));
- }
- return ret;
- }
-
- function CStringToJS(ptr) {
- var ret = [];
- var ch = 1;
- while (ch != 0) {
- ch = Module.getValue(ptr, "i8");
- if (ch != 0) {
- ret.push(String.fromCharCode(ch));
- }
- ++ptr;
- }
- return ret.join("");
- }
-
- function CBytesToJS(ptr) {
- var data = Module.getValue(ptr, "*");
- var size = Module.getValue(ptr + SIZEOF_POINTER, "i32");
- var ret = new Uint8Array(new ArrayBuffer(size));
- ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size));
- return ret;
- }
-
- function StringToUint8Array(str) {
- var arr = new Uint8Array(str.length + 1);
- for(var i = 0; i < str.length; ++i) {
- arr[i] = str.charCodeAt(i);
- }
- arr[str.length] = 0;
- return arr;
- }
- //-----------------------------------------
- // Class declarations
- // ----------------------------------------
- function CBuffer(nbytes) {
- this.data = Module._malloc(nbytes);
- }
-
- function RefTVMValue() {
- this.data = Module._malloc(SIZEOF_TVMVALUE);
- }
-
- function TVMArgs(nargs) {
- this.nargs = nargs;
- this.value = Module._malloc(SIZEOF_TVMVALUE * nargs);
- this.tcode = Module._malloc(SIZEOF_INT * nargs);
- this.temp = [];
- }
-
- function TVMType(code, bits, lanes) {
- this.code = code;
- this.bits = bits;
- this.lanes = lanes;
- }
- /**
- * TVM device context.
- * @class
- * @memberof tvm
- */
- function TVMContext(device_type, device_id) {
- this.device_type = device_type;
- this.device_id = device_id;
- }
- /**
- * TVM n-dimensional array.
- *
- * Use {@link tvm.TVMRuntime}.empty to create an instance.
- * @class
- * @memberof tvm
- */
- function NDArray(handle) {
- this.handle = handle;
- this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32");
- // shape
- var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*");
- this.shape = CInt64ArrayToJS(cshape, this.ndim);
- // dtype
- var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8");
- var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8");
- var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16");
- var dtype = new TVMType(code, bits, lanes);
- this.dtype = dtype;
- this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8);
- // ctx
- var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32");
- var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32");
- this.context = new TVMContext(device_type, device_id);
- // byte_offset
- this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64");
- }
-
- function TVMFunction(handle) {
- this.handle = handle;
- }
- /**
- * Module container of TVM generated functions.
- *
- * @class
- * @memberof tvm
- */
- function TVMModule(handle) {
- this.handle = handle;
- }
- /**
- * A typed scalar constant.
- * This can be used to pass number as integer types to tvm function.
- * Use {@link tvm.TVMRuntime}.constant to create an instance.
- * @class
- * @memberof tvm
- */
- function TVMConstant(value, dtype) {
- this.value = value;
- this.dtype = dtype;
- }
- //-----------------------------------------
- // Private Functions
- // ----------------------------------------
- function getTVMType(dtype) {
- if (dtype instanceof TVMType) return dtype;
- if (typeof dtype == "string") {
- var pattern = dtype;
- var code, bits = 32, lanes = 1;
- if (pattern.substring(0, 5) == "float") {
- pattern = pattern.substring(5, pattern.length);
- code = kFloat;
- } else if (pattern.substring(0, 3) == "int") {
- pattern = pattern.substring(3, pattern.length);
- code = kInt;
- } else if (pattern.substring(0, 4) == "uint") {
- pattern = pattern.substring(4, pattern.length);
- code = kUInt;
- } else if (pattern.substring(0, 6) == "handle") {
- pattern = pattern.substring(5, pattern.length);
- code = kTVMOpaqueHandle;
- bits = 64;
- } else {
- throw throwError("Unknown dtype " + dtype);
- }
- var arr = pattern.split("x");
- if (arr.length >= 1) {
- var parsed = parseInt(arr[0]);
- if (parsed == arr[0]) {
- bits = parsed;
- }
- }
- if (arr.length >= 2) {
- lanes = parseInt(arr[1]);
- }
- return new TVMType(code, bits, lanes);
- } else {
- throw throwError("Unknown dtype " + dtype);
- }
- }
-
- function TVMRetValueToJS(vptr, tcode) {
- switch (tcode) {
- case kInt:
- case kUInt: return Module.getValue(vptr, "i64");
- case kFloat: return Module.getValue(vptr, "double");
- case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*"));
- case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*"));
- case kNull: return null;
- case kTVMStr: return CStringToJS(Module.getValue(vptr, "*"));
- case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*"));
- default: throwError("Unsupported return type code=" + tcode);
- }
- }
-
- function makeTVMFunction(handle) {
- var func = new TVMFunction(handle);
- var ret = function () {
- // alloc
- var args = new TVMArgs(arguments.length);
- var rvalue = new RefTVMValue();
- var rtcode = new RefTVMValue();
- args.setArguments(arguments);
- TVM_CALL(TVMFuncCall(handle, args.value, args.tcode,
- args.nargs, rvalue.data, rtcode.data));
- var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt());
- // release
- args.release();
- rvalue.release();
- rtcode.release();
- return rv;
- };
- var release = function() {
- func.release();
- };
- ret._tvm_function = func;
- ret.release = release;
- return ret;
- }
- //-----------------------------------------
- // Javascript PackedCallback System
- // ----------------------------------------
- var funcTable = [0];
- var freeFuncId = [];
-
- function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) {
- var args = [];
- for (var i = 0; i < nargs; ++i) {
- var vptr = arg_value + i * SIZEOF_TVMVALUE;
- var tcodeptr = arg_tcode + i * SIZEOF_INT;
- var tcode = Module.getValue(tcodeptr, "i32");
- if (tcode == kTVMObjectHandle ||
- tcode == kTVMObjectRValueRefArg ||
- tcode == kTVMPackedFuncHandle ||
- tcode == kTVMModuleHandle) {
- TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr));
- }
- tcode = Module.getValue(tcodeptr, "i32");
- args.push(TVMRetValueToJS(vptr, tcode));
- }
- var rv = funcTable[handle].apply(null, args);
- if (typeof rv !== "undefined") {
- // alloc
- var rarg = new TVMArgs(1);
- rarg.setArguments([rv]);
- TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1));
- // release
- rarg.release();
- }
- return 0;
- }
- function freeCallback(handle) {
- funcTable[handle] = 0;
- freeFuncId.push(handle);
- }
- var fptrInvokeCallback = null;
- var fptrFreeCallback = null;
- if (typeof Runtime !== "undefined" &&
- typeof Runtime.addFunction !== "undefined") {
- fptrInvokeCallback = Runtime.addFunction(invokeCallback);
- fptrFreeCallback = Runtime.addFunction(freeCallback);
- }
- /**
- * Check if a function is TVM PackedFunc
- * @param {Function} f function to be checked.
- * @return {boolean} Whether f is PackedFunc
- */
- this.isPackedFunc = function(f) {
- return (typeof f == "function") && f.hasOwnProperty("_tvm_function");
- };
- var isPackedFunc = this.isPackedFunc;
- /**
- * Convert a javascript function to TVM function.
- * @param {Function} f javascript function.
- * @return {Function} The created TVMFunction.
- */
- this.convertFunc = function(f) {
- if (isPackedFunc(f)) return f;
- CHECK(fptrInvokeCallback !== null,
- "Emscripten Runtime addFunction is not available");
- var fid;
- if (freeFuncId.length != 0) {
- fid = freeFuncId.pop();
- } else {
- fid = funcTable.length;
- funcTable.push(0);
- }
- funcTable[fid] = f;
- // alloc
- var out = new RefTVMValue();
- TVM_CALL(TVMFuncCreateFromCFunc(
- fptrInvokeCallback, fid, fptrFreeCallback, out.data));
- var out_handle = out.asHandle();
- // release
- out.release();
- return makeTVMFunction(out_handle);
- };
- var convertFunc = this.convertFunc;
- //-----------------------------------------
- // Private Class declarations
- // ----------------------------------------
- CBuffer.prototype = {
- /**
- * Finalizer: resources from the object.
- */
- release : function() {
- if (this.data != 0) {
- Module._free(this.data);
- this.data = 0;
- }
- },
- };
- // RefTVMValue
- RefTVMValue.prototype = {
- /**
- * Finalizer: resources from the object.
- */
- release : function() {
- if (this.data != 0) {
- Module._free(this.data);
- this.data = 0;
- }
- },
- asInt : function() {
- return Module.getValue(this.data, "i32");
- },
- asInt64 : function() {
- return Module.getValue(this.data, "i64");
- },
- asDouble : function() {
- return Module.getValue(this.data, "double");
- },
- asHandle : function() {
- return Module.getValue(this.data, "*");
- }
- };
- // TVMArgs
- TVMArgs.prototype = {
- release : function() {
- if (this.value != 0) {
- Module._free(this.value);
- Module._free(this.tcode);
- this.value = 0;
- for (var i = 0; i< this.temp.length; ++i) {
- if (this.temp[i].release instanceof Function) {
- this.temp[i].release();
- }
- }
- }
- },
- setInt : function(index, value) {
- Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32");
- Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64");
- },
- setDouble : function(index, value) {
- Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32");
- Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double");
- },
- setHandle : function(index, value, tcode) {
- Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32");
- Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*");
- },
- setString : function(index, value) {
- var sdata = new CBuffer(value.length + 1);
- Module.HEAPU8.set(StringToUint8Array(value), sdata.data);
- this.temp.push(sdata);
- Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32");
- Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*");
- },
- setBytes : function(index, value) {
- CHECK(value instanceof Uint8Array);
- var sdata = new CBuffer(value.length);
- var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T);
- Module.HEAPU8.set(new Uint8Array(value), sdata.data);
- Module.setValue(sheader.data, sdata.data, "*");
- Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32");
- this.temp.push(sdata);
- this.temp.push(sheader);
- Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32");
- Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*");
- },
- setArguments : function(args) {
- for (var i = 0; i < args.length; ++i) {
- var v = args[i];
- var tp = typeof v;
- if (v instanceof NDArray) {
- this.setHandle(i, v.handle, kTVMDLTensorHandle);
- } else if (v instanceof TVMConstant) {
- var code = getTVMType(v.dtype).code;
- if (code == kInt || code == kUInt) {
- this.setInt(i, v.value);
- } else if (code == kFloat) {
- this.setDouble(i, v.value);
- } else {
- CHECK(code == kTVMOpaqueHandle);
- this.setHandle(i, v.value, kTVMOpaqueHandle);
- }
- } else if (tp == "number") {
- this.setDouble(i, v);
- } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) {
- this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle);
- } else if (v === null) {
- this.setHandle(i, 0, kNull);
- } else if (tp == "string") {
- this.setString(i, v);
- } else if (v instanceof Uint8Array) {
- this.setBytes(i, v);
- } else if (v instanceof Function) {
- v = convertFunc(v);
- this.temp.push(v);
- this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle);
- } else if (v instanceof TVMModule) {
- this.setHandle(i, v.handle, kTVMModuleHandle);
- } else {
- throwError("Unsupported argument type " + tp);
- }
- }
- }
- };
- // TVMType
- var TYPE_CODE2STR = {
- 0 : "int",
- 1 : "uint",
- 2 : "float",
- 4 : "handle"
- };
-
- TVMType.prototype = {
- toString : function() {
- var ret = TYPE_CODE2STR[this.code] + this.bits.toString();
- if (this.lanes != 1) {
- return ret + "x" + this.lanes.toString();
- } else {
- return ret;
- }
- }
- };
- // TVMFunction
- TVMFunction.prototype = {
- release : function() {
- if (this.handle != 0) {
- TVM_CALL(TVMFuncFree(this.handle));
- this.handle = 0;
- }
- }
- };
- // TVMContext
- var CTX_MASK2STR = {
- 1 : "cpu",
- 2 : "gpu",
- 4 : "opencl",
- 7 : "vulkan",
- 8 : "metal",
- 9 : "vpi",
- 11 : "opengl",
- };
- var CTX_STR2MASK = {
- "cpu": 1,
- "gpu": 2,
- "cuda": 2,
- "cl": 4,
- "opencl": 4,
- "vulkan": 7,
- "metal": 8,
- "vpi": 9,
- "opengl": 11,
- };
- TVMContext.prototype = {
- toString : function() {
- return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")";
- }
- };
- //-----------------------------------------
- // Public Functions
- // ----------------------------------------
- /**
- * Construct a TVMContext given device type and id.
- *
- * @param {number} device_type, string or int, The device type.
- * @param {number} device_id, the device id.
- * @return {tvm.TVMContext} The created TVMContext
- */
- this.context = function(device_type, device_id) {
- if (typeof device_type == "string") {
- device_type = CTX_STR2MASK[device_type];
- }
- return new TVMContext(device_type, device_id);
- };
- var context = this.context;
- /**
- * Create empty ndarray with given shape.
- *
- * @param {Array.} shape The shape of the array.
- * @param {string} dtype The data type of the array, optional, default="float32"
- * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0).
- * @return {tvm.NDArray} The created ndarray.
- */
- this.empty = function(shape, dtype, ctx) {
- dtype = (typeof dtype !== "undefined") ? dtype: "float32";
- ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0);
- shape = (typeof shape == "number") ? [shape] : shape;
- // alloc
- var cshape = Module._malloc(SIZEOF_INT64 * shape.length);
- var out = new RefTVMValue();
- for (var i = 0; i < shape.length; ++i) {
- Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64");
- }
- dtype = getTVMType(dtype);
- TVM_CALL(TVMArrayAlloc(cshape, shape.length,
- dtype.code, dtype.bits, dtype.lanes,
- ctx.device_type, ctx.device_id,
- out.data));
- var out_handle = out.asHandle();
- // release
- Module._free(cshape);
- out.release();
- return new NDArray(out_handle);
- };
- /**
- * List all global function names in the TVM runtime.
- * @return {Array.} List of global function names.
- */
- this.listGlobalFuncNames = function() {
- // alloc
- var out_size = new RefTVMValue();
- var out_array = new RefTVMValue();
- TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data));
- var length = out_size.asInt();
- var base = out_array.asHandle();
- var names = [];
- for (var i = 0 ; i < length; ++i) {
- names.push(
- CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*")));
- }
- // release
- out_size.release();
- out_array.release();
- return names;
- };
- var listGlobalFuncNames = this.listGlobalFuncNames;
- /**
- * Get a global function from TVM runtime.
- *
- * @param {string} The name of the function.
- * @return {Function} The corresponding function, null if function do not exist
- */
- this.getGlobalFunc = function (name) {
- // alloc
- var out = new RefTVMValue();
- TVM_CALL(TVMFuncGetGlobal(name, out.data));
- var out_handle = out.asHandle();
- // release
- out.release();
- if (out_handle != 0) {
- return makeTVMFunction(out_handle);
- } else {
- return null;
- }
- };
- var getGlobalFunc = this.getGlobalFunc;
- /**
- * Register function to be global function in tvm runtime.
- * @param {string} name The name of the function.
- * @param {Function} f function to be registered.
- * @param {boolean} override Whether overwrite function in existing registry.
- */
- this.registerFunc = function(name, f, override) {
- f = convertFunc(f);
- override = (typeof override !== "undefined") ? override: false;
- var ioverride = override ? 1 : 0;
- TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride));
- };
- /**
- * Create a typed scalar constant.
- * This can be used to pass number as integer types to tvm function.
- *
- * @param {number} value The value of the data.
- * @param {string} dtype The data type.
- * @param {tvm.TVMConstant} The created typed scalar.
- */
- this.constant = function(value, dtype) {
- return new TVMConstant(value, dtype);
- };
- //-----------------------------------------
- // Wrap of TVM Functions.
- // ----------------------------------------
- var systemFunc = {};
- /**
- * Get system-wide library module singleton.5A
- * System lib is a global module that contains self register functions in startup.
- * @return {tvm.TVMModule} The system module singleton.
- */
- this.systemLib = function() {
- if (typeof systemFunc.fGetSystemLib === "undefined") {
- systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib");
- }
- return systemFunc.fGetSystemLib();
- };
-
- this.startRPCServer = function(url, key, counter) {
- if (typeof key === "undefined") {
- key = "";
- }
- if (typeof counter === "undefined") {
- counter = 1;
- }
- // Node js, import websocket
- var bkey = StringToUint8Array("server:" + key);
- bkey = bkey.slice(0, bkey.length - 1);
- var server_name = "WebSocketRPCServer[" + key + "]";
- var RPC_MAGIC = 0xff271;
- function checkEndian() {
- var a = new ArrayBuffer(4);
- var b = new Uint8Array(a);
- var c = new Uint32Array(a);
- b[0] = 0x11;
- b[1] = 0x22;
- b[2] = 0x33;
- b[3] = 0x44;
- CHECK(c[0] === 0x44332211, "Need little endian to work");
- }
- checkEndian();
- // start rpc
- function RPCServer(counter) {
- var socket;
- if (typeof module !== "undefined" && module.exports) {
- // WebSocket for nodejs
- const WebSocket = require("ws");
- socket = new WebSocket(url);
- } else {
- socket = new WebSocket(url);
- }
- var self = this;
- socket.binaryType = "arraybuffer";
- this.init = true;
- this.counter = counter;
-
- if (typeof systemFunc.fcreateServer === "undefined") {
- systemFunc.fcreateServer =
- getGlobalFunc("rpc.CreateEventDrivenServer");
- }
- if (systemFunc.fcreateServer == null) {
- throwError("RPCServer is not included in runtime");
- }
-
- var message_handler = systemFunc.fcreateServer(
- function(cbytes) {
- if (socket.readyState == 1) {
- socket.send(cbytes);
- return new TVMConstant(cbytes.length, "int32");
- } else {
- return new TVMConstant(0, "int32");
- }
- } , server_name, "%toinit");
-
- function on_open(event) {
- var intbuf = new Int32Array(1);
- intbuf[0] = RPC_MAGIC;
- socket.send(intbuf);
- intbuf[0] = bkey.length;
- socket.send(intbuf);
- socket.send(bkey);
- logging(server_name + " connected...");
- }
-
- function on_message(event) {
- if (self.init) {
- var msg = new Uint8Array(event.data);
- CHECK(msg.length >= 4, "Need message header to be bigger than 4");
- var magic = new Int32Array(event.data)[0];
-
- if (magic == RPC_MAGIC + 1) {
- throwError("key: " + key + " has already been used in proxy");
- } else if (magic == RPC_MAGIC + 2) {
- logging(server_name + ": RPCProxy do not have matching client key " + key);
- } else {
- CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy");
- self.init = false;
- }
- logging(server_name + "init end...");
- if (msg.length > 4) {
- if (message_handler(
- new Uint8Array(event.data, 4, msg.length -4),
- new TVMConstant(3, "int32")) == 0) {
- socket.close();
- }
- }
- } else {
- if (message_handler(new Uint8Array(event.data),
- new TVMConstant(3, "int32")) == 0) {
- socket.close();
- }
- }
- }
- function on_close(event) {
- message_handler.release();
- logging(server_name + ": closed finish...");
- if (!self.init && self.counter != 0) {
- logging(server_name + ":reconnect to serve another request, session left=" + counter);
- // start a new server.
- new RPCServer(counter - 1);
- }
- }
- socket.addEventListener("open", on_open);
- socket.addEventListener("message", on_message);
- socket.addEventListener("close", on_close);
- }
- return new RPCServer(counter);
- };
-
- /**
- * Load a TVM module from a library file.
- * The file must be present in the Emscripten virtual file system.
- * For example, you can pass "--preload-file file" or "--preload-file dir/"
- * to "emcc" when compiling the TVM library, in order to populate files into
- * the file system.
- * For more detail, see:
- * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files
- * @param {string} file_name Path of the file to be loaded. The path refers
- * to the Emscripten virtual file system.
- * @param {string} format The format of the file.
- * @return {tvm.TVMModule} The loaded module.
- */
- this.loadModuleFromFile = function (file_name, format) {
- // alloc
- var out = new RefTVMValue();
- TVM_CALL(TVMModLoadFromFile(file_name, format, out.data));
- var out_handle = out.asHandle();
- // release
- out.release();
- if (out_handle != 0) {
- return new TVMModule(out_handle);
- } else {
- return null;
- }
- };
- var loadModuleFromFile = this.loadModuleFromFile;
-
- /**
- * Wrapper runtime module.
- * Wraps around set_input, load_params, run, and get_output.
- *
- * @class
- * @memberof tvm
- */
- function GraphModule(tvm_graph_module, ctx) {
- CHECK(tvm_graph_module instanceof TVMModule,
- "tvm_graph_module must be TVMModule");
- CHECK(ctx instanceof TVMContext, "ctx must be TVMContext");
-
- this.tvm_graph_module = tvm_graph_module;
- this.ctx = ctx;
- this._set_input = tvm_graph_module.getFunction("set_input");
- this._load_params = tvm_graph_module.getFunction("load_params");
- this._run = tvm_graph_module.getFunction("run");
- this._get_output = tvm_graph_module.getFunction("get_output");
- };
-
- GraphModule.prototype = {
- /**
- * Set input to graph module.
- *
- * @param {string} key The name of the input.
- * @param {NDArray} value The input value.
- */
- "set_input" : function(key, value) {
- CHECK(typeof key == "string", "key must be string");
- CHECK(value instanceof NDArray, "value must be NDArray");
- this._set_input(key, value);
- },
-
- /**
- * Load parameters from serialized byte array of parameter dict.
- *
- * @param {Uint8Array} params The serialized parameter dict.
- */
- "load_params" : function(params) {
- CHECK(params instanceof Uint8Array, "params must be Uint8Array");
- this._load_params(params);
- },
-
- /**
- * Load parameters from serialized base64 string of parameter dict.
- *
- * @param {string} base64_params The serialized parameter dict.
- */
- "load_base64_params" : function(base64_params) {
- CHECK(typeof base64_params == "string", "base64_params must be string");
- var decoded_string = atob(base64_params);
- var decoded_u8 = new Uint8Array(decoded_string.length);
- for (var i = 0; i < decoded_string.length; i++) {
- decoded_u8[i] = decoded_string[i].charCodeAt(0);
- }
- this.load_params(decoded_u8);
- },
-
- /**
- * Run forward execution of the graph.
- */
- "run" : function() {
- this._run();
- },
-
- /**
- * Get index-th output to out.
- *
- * @param {NDArray} out The output array container.
- * @return {NDArray} The output array container.
- */
- "get_output" : function(index, out) {
- CHECK(typeof index == "number", "index must be number");
- CHECK(out instanceof NDArray, "out must be NDArray");
- this._get_output(new TVMConstant(index, "int32"), out);
- return out;
- }
- };
-
- /**
- * Create a runtime executor module given a graph and a module.
- * @param {string} graph_json_str The Json string of the graph.
- * @param {TVMModule} libmod The TVM module.
- * @param {TVMContext} ctx The context to deploy the module.
- * @return {GraphModule} Runtime graph module for executing the graph.
- */
- this.createGraphRuntime = function(graph_json_str, libmod, ctx) {
- CHECK(typeof graph_json_str == "string", "graph_json_str must be string");
- CHECK(libmod instanceof TVMModule, "libmod must be TVMModule");
- CHECK(ctx instanceof TVMContext, "ctx must be TVMContext");
-
- var fcreate = getGlobalFunc("tvm.graph_runtime.create");
- CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create");
-
- var tvm_graph_module = fcreate(graph_json_str, libmod,
- new TVMConstant(ctx.device_type, "int32"),
- new TVMConstant(ctx.device_id, "int32"));
-
- return new GraphModule(tvm_graph_module, ctx);
- };
-
- //-----------------------------------------
- // Class defintions
- // ----------------------------------------
- // NDArray.
- NDArray.prototype = {
- /**
- * Finalizer: resources from the object.
- */
- release : function() {
- if (this.handle != 0) {
- TVM_CALL(TVMArrayFree(this.handle));
- this.handle = 0;
- }
- },
- /**
- * Copy data from another NDArray or javascript array.
- * The number of elements must match.
- *
- * @param {Array} data The source data array.
- */
- copyFrom : function(data) {
- if (data instanceof NDArray) {
- TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle));
- } else {
- var size = this.shape.reduce(function(a, b) { return a * b; }, 1);
- if (data.length != size) {
- throwError("data size and shape mismatch data.length" + data.length + " vs " + size);
- }
- if (this.dtype == "float32") {
- data = Float32Array.from(data);
- } else if (this.dtype == "float64") {
- data = Float64Array.from(data);
- } else if (this.dtype == "int32") {
- data = Int32Array.from(data);
- } else if (this.dtype == "int8") {
- data = Int8Array.from(data);
- } else if (this.dtype == "uint8") {
- data = Uint8Array.from(data);
- } else {
- throwError("Unsupported data type " + this.dtype);
- }
- return this.copyFromRawBytes(new Uint8Array(data.buffer));
- }
- },
- /**
- * Copy data from raw bytes.
- * @param {Uint8Array} data Uint8Array of bytes.
- */
- copyFromRawBytes : function(data) {
- var size = this.shape.reduce(function(a, b) { return a * b; }, 1);
- var dtype = getTVMType(this.dtype);
- var nbytes = this.BYTES_PER_ELEMENT * size;
- CHECK(data instanceof Uint8Array);
- CHECK(data.length == nbytes,
- "Data length and bytes do not match " + data.length +
- " vs " + nbytes);
- var temp = Module._malloc(nbytes);
- Module.HEAPU8.set(data, temp);
- TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes));
- Module._free(temp);
- return this;
- },
- /**
- * Return a copied Uint8Array of the raw bytes in the NDArray.
- * @return {Uint8Array} The created array.
- */
- asRawBytes : function() {
- var size = this.shape.reduce(function(a, b) { return a * b; }, 1);
- var nbytes = this.BYTES_PER_ELEMENT * size;
- var temp = Module._malloc(nbytes);
- TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes));
- var ret = new Uint8Array(new ArrayBuffer(nbytes));
- ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes));
- Module._free(temp);
- return ret;
- },
- /**
- * Return Array data content as javascript typed array.
- * @return {TypedArray} The created array.
- */
- asArray : function() {
- if (this.dtype == "float32") {
- return new Float32Array(this.asRawBytes().buffer);
- } else if (this.dtype == "float64") {
- return new Float64Array(this.asRawBytes().buffer);
- } else if (this.dtype == "int32") {
- return new Int32Array(this.asRawBytes().buffer);
- } else if (this.dtype == "int8") {
- return new Int8Array(this.asRawBytes().buffer);
- } else if (this.dtype == "uint8") {
- return new Uint8Array(this.asRawBytes().buffer);
- } else {
- throwError("Unsupported data type " + this.dtype);
- }
- }
- };
-
- TVMModule.prototype = {
- /**
- * Finalizer: resources from the object.
- */
- release : function() {
- if (this.handle != 0) {
- TVM_CALL(TVMModFree(this.handle));
- this.handle = 0;
- }
- },
- /**
- * Get function from the module.
- * @param {string} name The name of the function.
- * @return {Function} The correspondin function.
- */
- getFunction : function(name) {
- // alloc
- var out = new RefTVMValue();
- TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data));
- var out_handle = out.asHandle();
- // release
- out.release();
- if (out_handle == 0) {
- throwError("Module has no function " + name);
- }
- return makeTVMFunction(out_handle);
- },
- /**
- * Add module to the import list of current one.
- * @param {tvm.TVMModule} mod The other module to be imported.
- */
- import_module : function(mod) {
- CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule");
- TVM_CALL(TVMModImport(this.handle, mod.handle));
- }
- };
- //-----------------------------------------
- // Static variables.
- // ----------------------------------------
- /** Float32 type */
- this.float32 = "float32";
- /** Int32 type */
- this.int32 = "int32";
- }
- /**
- * Create a TVM runtime given emscripten module.
- * @property {string} create
- * @memberof tvm_runtime
- * @param Module The emscripten module.
- * @return {tvm.TVMRuntime} The created TVM runtime.
- */
- this.create = function(Module) {
- var tvm = {};
- tvm.Module = Module;
- if (typeof Module.addFunction !== "undefined") {
- tvm.Runtime = Module;
- } else {
- tvm.Runtime = Module.Runtime;
- }
- TVMRuntime.apply(tvm);
- return tvm;
- };
-}).apply(tvm_runtime);
-
-// export things in node
-if (typeof module !== "undefined" && module.exports) {
- module.exports = tvm_runtime;
-}
diff --git a/web/web_runtime.cc b/web/web_runtime.cc
deleted file mode 100644
index 701ded76288e..000000000000
--- a/web/web_runtime.cc
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file web_runtime.cc
- */
-#include
-#include
-
-#include "../src/runtime/c_runtime_api.cc"
-#include "../src/runtime/cpu_device_api.cc"
-#include "../src/runtime/workspace_pool.cc"
-#include "../src/runtime/library_module.cc"
-#include "../src/runtime/system_library.cc"
-#include "../src/runtime/module.cc"
-#include "../src/runtime/ndarray.cc"
-#include "../src/runtime/object.cc"
-#include "../src/runtime/registry.cc"
-#include "../src/runtime/file_util.cc"
-#include "../src/runtime/dso_library.cc"
-#include "../src/runtime/rpc/rpc_session.cc"
-#include "../src/runtime/rpc/rpc_event_impl.cc"
-#include "../src/runtime/rpc/rpc_server_env.cc"
-#include "../src/runtime/graph/graph_runtime.cc"
-#include "../src/runtime/opengl/opengl_device_api.cc"
-#include "../src/runtime/opengl/opengl_module.cc"
-
-namespace tvm {
-namespace contrib {
-
-struct RPCEnv {
- public:
- RPCEnv() {
- base_ = "/rpc";
- mkdir(&base_[0], 0777);
- }
- // Get Path.
- std::string GetPath(const std::string& file_name) {
- return base_ + "/" + file_name;
- }
-
- private:
- std::string base_;
-};
-
-TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
-.set_body_typed([](std::string path) {
- static RPCEnv env;
- return env.GetPath(path);
- });
-
-TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
-.set_body_typed([](std::string path) {
- std::string file_name = "/rpc/" + path;
- LOG(INFO) << "Load module from " << file_name << " ...";
- return Module::LoadFromFile(file_name, "");
- });
-} // namespace contrib
-} // namespace tvm
-
-// dummy parallel runtime
-int TVMBackendParallelLaunch(
- FTVMParallelLambda flambda,
- void* cdata,
- int num_task) {
- TVMAPISetLastError("Parallel is not supported in Web runtime");
- return -1;
-}
-
-int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {
- return 0;
-}