From 5643846eeadd38b3bd0c1d110ca6db593a441a27 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 24 Jul 2018 14:45:18 -0700 Subject: [PATCH] [tvm4j] add GraphRuntime (#1472) --- .../src/main/java/ml/dmlc/tvm/Function.java | 4 +- .../src/main/java/ml/dmlc/tvm/Module.java | 10 +- .../src/main/java/ml/dmlc/tvm/NDArray.java | 14 +- .../main/java/ml/dmlc/tvm/NDArrayBase.java | 3 +- .../src/main/java/ml/dmlc/tvm/TVMType.java | 16 +- .../main/java/ml/dmlc/tvm/TVMValueHandle.java | 34 ++++ .../java/ml/dmlc/tvm/contrib/GraphModule.java | 170 ++++++++++++++++++ .../ml/dmlc/tvm/contrib/GraphRuntime.java | 121 +++++++++++++ .../src/main/java/ml/dmlc/tvm/rpc/RPC.java | 5 + .../main/java/ml/dmlc/tvm/rpc/RPCSession.java | 4 +- .../ml/dmlc/tvm/rpc/TVMRemoteContext.java | 30 ++++ .../src/test/java/ml/dmlc/tvm/TestUtils.java | 26 +++ .../ml/dmlc/tvm/contrib/GraphRuntimeTest.java | 114 ++++++++++++ .../test/java/ml/dmlc/tvm/rpc/RPCTest.java | 34 ++-- .../src/test/scripts/test_graph_runtime.py | 47 +++++ jvm/native/src/main/native/jni_helper_func.h | 10 ++ tests/scripts/task_java_unittest.sh | 1 + 17 files changed, 601 insertions(+), 42 deletions(-) create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java create mode 100644 jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java create mode 100644 jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java create mode 100644 jvm/core/src/test/scripts/test_graph_runtime.py diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java index 63602f3a14d0..5b2008a757ed 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java @@ -109,8 +109,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a /** * Release the Function. *

- * We highly recommend you to do this manually since the GC strategy is lazy - * and `finalize()` is not guaranteed to be called when GC happens. + * We highly recommend you to do this manually since the GC strategy is lazy. *

*/ @Override public void release() { @@ -269,6 +268,7 @@ private static void pushArgToStack(Object arg) { case BYTES: Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes()); break; + case HANDLE: case ARRAY_HANDLE: case MODULE_HANDLE: case FUNC_HANDLE: diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Module.java b/jvm/core/src/main/java/ml/dmlc/tvm/Module.java index 6aa417e889f5..7c55add36639 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Module.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Module.java @@ -72,8 +72,7 @@ private static Function getApi(String name) { /** * Release the Module. *

- * We highly recommend you to do this manually since the GC strategy is lazy - * and `finalize()` is not guaranteed to be called when GC happens. + * We highly recommend you to do this manually since the GC strategy is lazy. *

*/ @Override public void release() { @@ -122,6 +121,13 @@ public void importModule(Module module) { Base.checkCall(Base._LIB.tvmModImport(handle, module.handle)); } + /** + * @return type key of the module. + */ + public String typeKey() { + return getApi("_GetTypeKey").pushArg(this).invoke().asString(); + } + /** * Load module from file. * @param path The path to the module file. diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java b/jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java index 1aea1a35c96d..431924c4c9b0 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java @@ -27,10 +27,12 @@ */ public class NDArray extends NDArrayBase { private final TVMType dtype; + private final TVMContext context; - NDArray(long handle, boolean isView, TVMType dtype) { + NDArray(long handle, boolean isView, TVMType dtype, TVMContext ctx) { super(handle, isView); this.dtype = dtype; + this.context = ctx; } @Override protected void finalize() throws Throwable { @@ -361,6 +363,14 @@ private byte[][] groupInternalBytes() { return units; } + /** + * Get the context of current array. + * @return the context. + */ + public TVMContext ctx() { + return context; + } + /** * Create an empty array given shape, type and device. * @param shape The shape of the array. @@ -373,7 +383,7 @@ public static NDArray empty(long[] shape, TVMType dtype, TVMContext ctx) { Base.checkCall(Base._LIB.tvmArrayAlloc( shape, dtype.typeCode, dtype.bits, dtype.lanes, ctx.deviceType, ctx.deviceId, refHandle)); - return new NDArray(refHandle.value, false, dtype); + return new NDArray(refHandle.value, false, dtype, ctx); } /** diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java b/jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java index d15caa79384f..11c77207fd1c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java @@ -57,8 +57,7 @@ public NDArrayBase copyTo(NDArrayBase target) { /** * Release the NDArray memory. *

- * We highly recommend you to do this manually since the GC strategy is lazy - * and `finalize()` is not guaranteed to be called when GC happens. + * We highly recommend you to do this manually since the GC strategy is lazy. *

*/ public void release() { diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java b/jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java index 86d6efbb908b..e6b5e2ea1b9c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java @@ -37,16 +37,16 @@ public TVMType(String typeStr, int lanes) { this.lanes = lanes; int bitsTemp = 0; if (typeStr.startsWith("int")) { - typeCode = 0; + typeCode = INT; bitsTemp = Integer.parseInt(typeStr.substring(3)); } else if (typeStr.startsWith("uint")) { - typeCode = 1; + typeCode = UINT; bitsTemp = Integer.parseInt(typeStr.substring(4)); } else if (typeStr.startsWith("float")) { - typeCode = 2; + typeCode = FLOAT; bitsTemp = Integer.parseInt(typeStr.substring(5)); } else if (typeStr.startsWith("handle")) { - typeCode = 4; + typeCode = HANDLE; bitsTemp = 64; } else { throw new IllegalArgumentException("Do not know how to handle type " + typeStr); @@ -78,16 +78,16 @@ public TVMType(String typeStr) { @Override public String toString() { String typeCodeStr; switch (typeCode) { - case 0: + case INT: typeCodeStr = "int"; break; - case 1: + case UINT: typeCodeStr = "uint"; break; - case 2: + case FLOAT: typeCodeStr = "float"; break; - case 4: + case HANDLE: typeCodeStr = "handle"; break; default: diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java b/jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java new file mode 100644 index 000000000000..b4316b7e72f3 --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm; + +/** + * Java class related to TVM handles (TypeCode.HANDLE) + */ +public class TVMValueHandle extends TVMValue { + public final long value; + + public TVMValueHandle(long value) { + super(TypeCode.HANDLE); + this.value = value; + } + + @Override public long asHandle() { + return value; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java b/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java new file mode 100644 index 000000000000..208006886cac --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java @@ -0,0 +1,170 @@ +package ml.dmlc.tvm.contrib; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.NDArray; +import ml.dmlc.tvm.TVMContext; + +/** + * Wrapper runtime module. + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions. + */ +public class GraphModule { + private Module module; + private TVMContext ctx; + + private Function fsetInput; + private Function frun; + private Function fgetOutput; + private Function fgetInput; + private Function fdebugGetOutput; + private Function floadParams; + + GraphModule(Module module, TVMContext ctx) { + this.module = module; + this.ctx = ctx; + fsetInput = module.getFunction("set_input"); + frun = module.getFunction("run"); + fgetInput = module.getFunction("get_input"); + fgetOutput = module.getFunction("get_output"); + try { + fdebugGetOutput = module.getFunction("debug_get_output"); + } catch (IllegalArgumentException ignored) { + // ignore + } + floadParams = module.getFunction("load_params"); + } + + /** + * Release the GraphModule. + *

+ * We highly recommend you to do this manually since the GC strategy is lazy. + *

+ */ + public void release() { + fsetInput.release(); + frun.release(); + fgetInput.release(); + fgetOutput.release(); + if (fdebugGetOutput != null) { + fdebugGetOutput.release(); + } + floadParams.release(); + module.release(); + } + + /** + * Set inputs to the module. + * @param key The input key. + * @param value The input value + * @return self. + */ + public GraphModule setInput(String key, NDArray value) { + NDArray input = value; + if (!value.ctx().equals(ctx)) { + input = NDArray.empty(value.shape(), ctx); + value.copyTo(input); + } + fsetInput.pushArg(key).pushArg(input).invoke(); + return this; + } + + /** + * Set inputs to the module + * @param key The input key. + * @param value The input value. + * @return self. + */ + public GraphModule setInput(int key, NDArray value) { + NDArray input = value; + if (!value.ctx().equals(ctx)) { + input = NDArray.empty(value.shape(), ctx); + value.copyTo(input); + } + fsetInput.pushArg(key).pushArg(input).invoke(); + return this; + } + + /** + * Run forward execution of the graph. + * @return self. + */ + public GraphModule run() { + frun.invoke(); + return this; + } + + /** + * Get index-th input to out. + * @param index The input index. + * @param out The output array container. + * @return out. + */ + public NDArray getInput(int index, NDArray out) { + fgetInput.pushArg(index).pushArg(out).invoke(); + return out; + } + + /** + * Get index-th output to out. + * @param index The output index. + * @param out The output array container. + * @return out. + */ + public NDArray getOutput(int index, NDArray out) { + fgetOutput.pushArg(index).pushArg(out).invoke(); + return out; + } + + /** + * Run graph up to node and get the output to out. + * @param node The node name. + * @param out The output array container. + * @return out. + */ + public NDArray debugGetOutput(String node, NDArray out) { + if (fdebugGetOutput != null) { + fdebugGetOutput.pushArg(node).pushArg(out).invoke(); + } else { + throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"); + } + return out; + } + + /** + * Run graph up to node and get the output to out. + * @param node The node index. + * @param out The output array container. + * @return out. + */ + public NDArray debugGetOutput(int node, NDArray out) { + if (fdebugGetOutput != null) { + fdebugGetOutput.pushArg(node).pushArg(out).invoke(); + } else { + throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0"); + } + return out; + } + + /** + * Load parameters from serialized byte array of parameter dict. + * @param params The serialized parameter. + * @return self. + */ + public GraphModule loadParams(byte[] params) { + floadParams.pushArg(params).invoke(); + return this; + } + + /** + * Get internal module function. + * @param key The key to the module. + * @return The function. + * @throws IllegalArgumentException if function does not exist. + */ + public Function getFunction(String key) { + return module.getFunction(key); + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java new file mode 100644 index 000000000000..edcde0cc65ec --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphRuntime.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.contrib; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMContext; +import ml.dmlc.tvm.TVMValue; +import ml.dmlc.tvm.rpc.RPC; +import ml.dmlc.tvm.rpc.RPCSession; +import ml.dmlc.tvm.rpc.TVMRemoteContext; + +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +public class GraphRuntime { + /** + * Create a runtime executor module given a graph and module. + * @param graphJson The graph deployed in json format output by nnvm graph. + * @param libmod The module of the corresponding function. + * @param ctx The local or remote context to deploy the module. + * @return Runtime graph module that can be used to execute the graph. + */ + public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) { + Module graphModule = null; + if (ctx.deviceType >= RPC.RPC_SESS_MASK) { + if (!(ctx instanceof TVMRemoteContext)) { + throw new IllegalArgumentException( + "Looks like you are using remote context with no RPCSession bind." + + "Use session.context instead."); + } + RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession; + // check arguments + if (!"rpc".equals(libmod.typeKey())) { + throw new IllegalArgumentException("libmod.typeKey != rpc"); + } + final int sessIndex = (int) ((Function) reflectionStaticCall( + RPC.class, "getApi", "_SessTableIndex")) + .pushArg(libmod).invoke().asLong(); + if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { + throw new IllegalArgumentException(String.format( + "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d", + sessIndex, reflectionGetField(rpcSession, "tblIndex"))); + } + + Function rpcModuleHandle = (Function) reflectionStaticCall( + RPC.class, "getApi","_ModuleHandle"); + if (rpcModuleHandle == null) { + throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." + + "Did you compile tvm_runtime with the correct version?"); + } + + Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." + + "Did you compile tvm_runtime with correct version?"); + } + + TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke(); + graphModule = fcreate.call(graphJson, hmod, + ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule(); + } else { + Function fcreate = Function.getFunction("tvm.graph_runtime.create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." + + "Did you compile tvm_runtime with correct version?"); + } + graphModule = fcreate.pushArg(graphJson) + .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) + .invoke().asModule(); + } + + return new GraphModule(graphModule, ctx); + } + + private static Object reflectionGetField(Object obj, String fieldName) { + try { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(obj); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static Object reflectionStaticCall(Class clazz, String methodName, Object ... args) { + Class[] types = new Class[args.length]; + for (int i = 0; i < args.length; ++i) { + types[i] = args[i].getClass(); + } + try { + Method method = clazz.getDeclaredMethod(methodName, types); + method.setAccessible(true); + return method.invoke(null, args); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java index 757fc0df3265..ee763fc41e18 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java @@ -44,6 +44,11 @@ protected Map initialValue() { } }; + /** + * Get internal function starts with namespace tvm.rpc. + * @param name function name. + * @return the function, null if not exists. + */ static Function getApi(String name) { Function func = apiFuncs.get().get(name); if (func == null) { diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java index 59da849000e9..0eec9224a40c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -60,7 +60,7 @@ public Function getFunction(String name) { public TVMContext context(String devType, int devId) { TVMContext ctx = new TVMContext(devType, devId); int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; - return new TVMContext(ctx.deviceType + encode, devId); + return new TVMRemoteContext(ctx.deviceType + encode, devId, this); } /** @@ -80,7 +80,7 @@ public TVMContext context(String devType) { */ public TVMContext context(int devType, int devId) { int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; - return new TVMContext(devType + encode, devId); + return new TVMRemoteContext(devType + encode, devId, this); } /** diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java new file mode 100644 index 000000000000..8b4449aee44d --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/TVMRemoteContext.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.TVMContext; + +// always related to RPCSession. Cannot construct by users. +public class TVMRemoteContext extends TVMContext { + public final RPCSession rpcSession; + + TVMRemoteContext(int deviceType, int deviceId, RPCSession rpcSession) { + super(deviceType, deviceId); + this.rpcSession = rpcSession; + } +} diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java b/jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java new file mode 100644 index 000000000000..23e22779adae --- /dev/null +++ b/jvm/core/src/test/java/ml/dmlc/tvm/TestUtils.java @@ -0,0 +1,26 @@ +package ml.dmlc.tvm; + +import ml.dmlc.tvm.rpc.Server; + +import java.io.IOException; + +public class TestUtils { + public static class RefInt { + public int value; + } + + public static Server startServer(RefInt portRef) { + Server server = null; + int port = 9981; + for (int i = 0; i < 10; ++i) { + try { + server = new Server(port + i); + server.start(); + portRef.value = port + i; + return server; + } catch (IOException e) { + } + } + throw new RuntimeException("Cannot find an available port."); + } +} diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java new file mode 100644 index 000000000000..d719eb6f61e7 --- /dev/null +++ b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.contrib; + +import ml.dmlc.tvm.*; +import ml.dmlc.tvm.rpc.Client; +import ml.dmlc.tvm.rpc.RPCSession; +import ml.dmlc.tvm.rpc.Server; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.Scanner; + +import static org.junit.Assert.assertArrayEquals; + +public class GraphRuntimeTest { + private final Logger logger = LoggerFactory.getLogger(GraphRuntime.class); + private static String loadingDir; + + @BeforeClass + public static void beforeClass() { + loadingDir = System.getProperty("test.tempdir"); + } + + @Test + public void test_add_one_local() throws IOException { + Module libmod = Module.load(loadingDir + File.separator + "graph_addone_lib.so"); + String graphJson = new Scanner(new File( + loadingDir + File.separator + "graph_addone.json")) + .useDelimiter("\\Z").next(); + + TVMContext ctx = TVMContext.cpu(); + GraphModule graph = GraphRuntime.create(graphJson, libmod, ctx); + + long[] shape = new long[]{4}; + NDArray arr = NDArray.empty(shape, ctx); + arr.copyFrom(new float[]{1f, 2f, 3f, 4f}); + + NDArray out = NDArray.empty(shape, ctx); + + graph.setInput("x", arr).run(); + graph.getOutput(0, out); + + assertArrayEquals(new float[]{2f, 3f, 4f, 5f}, out.asFloatArray(), 1e-3f); + + arr.release(); + out.release(); + graph.release(); + } + + @Test + public void test_add_one_remote() throws IOException { + if (!Module.enabled("rpc")) { + logger.warn("RPC is not enabled. Skip."); + return; + } + + String libPath = loadingDir + File.separator + "graph_addone_lib.so"; + String graphJson = new Scanner(new File( + loadingDir + File.separator + "graph_addone.json")) + .useDelimiter("\\Z").next(); + + TestUtils.RefInt port = new TestUtils.RefInt(); + Server server = null; + try { + server = TestUtils.startServer(port); + RPCSession remote = Client.connect("localhost", port.value); + TVMContext ctx = remote.cpu(); + + remote.upload(new File(libPath)); + Module mlib = remote.loadModule("graph_addone_lib.so"); + + GraphModule graph = GraphRuntime.create(graphJson, mlib, ctx); + + long[] shape = new long[]{4}; + NDArray arr = NDArray.empty(shape, ctx); + arr.copyFrom(new float[]{1f, 2f, 3f, 4f}); + + NDArray out = NDArray.empty(shape, ctx); + + graph.setInput("x", arr).run(); + graph.getOutput(0, out); + + assertArrayEquals(new float[]{2f, 3f, 4f, 5f}, out.asFloatArray(), 1e-3f); + + arr.release(); + out.release(); + graph.release(); + } finally { + if (server != null) { + server.terminate(); + } + } + } +} diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java index 982be0b8117d..63cf5575b37d 100644 --- a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java +++ b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java @@ -20,36 +20,21 @@ import ml.dmlc.tvm.Function; import ml.dmlc.tvm.Module; import ml.dmlc.tvm.TVMValue; +import ml.dmlc.tvm.TestUtils; import org.junit.Ignore; import org.junit.Test; - -import java.io.IOException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static org.junit.Assert.assertEquals; public class RPCTest { - static class RefInt { - public int value; - } - - private static Server startServer(RefInt portRef) { - Server server = null; - int port = 9981; - for (int i = 0; i < 10; ++i) { - try { - server = new Server(port + i); - server.start(); - portRef.value = port + i; - return server; - } catch (IOException e) { - } - } - throw new RuntimeException("Cannot find an available port."); - } + private final Logger logger = LoggerFactory.getLogger(RPCTest.class); @Test public void test_addone() { if (!Module.enabled("rpc")) { + logger.warn("RPC is not enabled. Skip."); return; } Function.register("test.rpc.addone", new Function.Callback() { @@ -58,10 +43,10 @@ public void test_addone() { } }); - RefInt port = new RefInt(); + TestUtils.RefInt port = new TestUtils.RefInt(); Server server = null; try { - server = startServer(port); + server = TestUtils.startServer(port); RPCSession client = Client.connect("localhost", port.value); Function func = client.getFunction("test.rpc.addone"); assertEquals(11L, func.call(10).asLong()); @@ -75,6 +60,7 @@ public void test_addone() { @Test public void test_strcat() { if (!Module.enabled("rpc")) { + logger.warn("RPC is not enabled. Skip."); return; } Function.register("test.rpc.strcat", new Function.Callback() { @@ -83,10 +69,10 @@ public void test_strcat() { } }); - RefInt port = new RefInt(); + TestUtils.RefInt port = new TestUtils.RefInt(); Server server = null; try { - server = startServer(port); + server = TestUtils.startServer(port); RPCSession client = Client.connect("localhost", port.value); Function func = client.getFunction("test.rpc.strcat"); assertEquals("abc:11", func.call("abc", 11L).asString()); diff --git a/jvm/core/src/test/scripts/test_graph_runtime.py b/jvm/core/src/test/scripts/test_graph_runtime.py new file mode 100644 index 000000000000..a60736c2468d --- /dev/null +++ b/jvm/core/src/test/scripts/test_graph_runtime.py @@ -0,0 +1,47 @@ +import os + +import tvm +import json +from tvm.contrib import graph_runtime + +def dump_graph_lib(target_dir): + dim = 4 + A = tvm.placeholder((dim,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + sched = tvm.create_schedule(B.op) + + node0 = {"op": "null", "name": "x", "inputs": []} + node1 = {"op": "tvm_op", "name": "add", + "inputs": [[0, 0, 0]], + "attrs": {"func_name": "myadd", + "flatten_data": "1", + "num_inputs" : "1", + "num_outputs" : "1"}} + nodes = [node0, node1] + arg_nodes = [0] + node_row_ptr = [0, 1, 2] + outputs = [[1, 0, 0]] + shape = (4,) + attrs = { + "shape" : ["list_shape", [shape, shape]], + "dltype" : ["list_str", ["float32", "float32"]], + "storage_id" : ["list_int", [0, 1]], + } + graph = {"nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": outputs, + "attrs": attrs} + + graph = json.dumps(graph) + mlib = tvm.build(sched, [A, B], "llvm", name="myadd") + + mlib.export_library(os.path.join(target_dir, "graph_addone_lib.so")) + with open(os.path.join(target_dir, "graph_addone.json"), "w") as fo: + fo.write(graph) + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + dump_graph_lib(sys.argv[1]) diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index dc04f4191d1a..d4435bdaaba8 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -72,6 +72,14 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) { return ret; } +jobject newTVMValueHandle(JNIEnv *env, jlong value) { + jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueHandle"); + jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); + jobject object = env->NewObject(cls, constructor, value); + env->DeleteLocalRef(cls); + return object; +} + jobject newTVMValueLong(JNIEnv *env, jlong value) { jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); @@ -166,6 +174,8 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { return newTVMValueLong(env, static_cast(value.v_int64)); case kDLFloat: return newTVMValueDouble(env, static_cast(value.v_float64)); + case kHandle: + return newTVMValueHandle(env, reinterpret_cast(value.v_handle)); case kModuleHandle: return newModule(env, reinterpret_cast(value.v_handle)); case kFuncHandle: diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index 8ae79b5c52b2..df85e496b226 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -8,6 +8,7 @@ TEMP_DIR=$(mktemp -d) python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1 +python $SCRIPT_DIR/test_graph_runtime.py $TEMP_DIR || exit -1 # start rpc proxy server PORT=$(( ( RANDOM % 1000 ) + 9000 ))