Skip to content

Commit

Permalink
[tvm4j] add GraphRuntime (#1472)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and tqchen committed Jul 24, 2018
1 parent b963cf0 commit 5643846
Show file tree
Hide file tree
Showing 17 changed files with 601 additions and 42 deletions.
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
@Override public void release() {
Expand Down Expand Up @@ -269,6 +268,7 @@ private static void pushArgToStack(Object arg) {
case BYTES:
Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes());
break;
case HANDLE:
case ARRAY_HANDLE:
case MODULE_HANDLE:
case FUNC_HANDLE:
Expand Down
10 changes: 8 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ private static Function getApi(String name) {
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
@Override public void release() {
Expand Down Expand Up @@ -122,6 +121,13 @@ public void importModule(Module module) {
Base.checkCall(Base._LIB.tvmModImport(handle, module.handle));
}

/**
* @return type key of the module.
*/
public String typeKey() {
return getApi("_GetTypeKey").pushArg(this).invoke().asString();
}

/**
* Load module from file.
* @param path The path to the module file.
Expand Down
14 changes: 12 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
*/
public class NDArray extends NDArrayBase {
private final TVMType dtype;
private final TVMContext context;

NDArray(long handle, boolean isView, TVMType dtype) {
NDArray(long handle, boolean isView, TVMType dtype, TVMContext ctx) {
super(handle, isView);
this.dtype = dtype;
this.context = ctx;
}

@Override protected void finalize() throws Throwable {
Expand Down Expand Up @@ -361,6 +363,14 @@ private byte[][] groupInternalBytes() {
return units;
}

/**
* Get the context of current array.
* @return the context.
*/
public TVMContext ctx() {
return context;
}

/**
* Create an empty array given shape, type and device.
* @param shape The shape of the array.
Expand All @@ -373,7 +383,7 @@ public static NDArray empty(long[] shape, TVMType dtype, TVMContext ctx) {
Base.checkCall(Base._LIB.tvmArrayAlloc(
shape, dtype.typeCode, dtype.bits, dtype.lanes,
ctx.deviceType, ctx.deviceId, refHandle));
return new NDArray(refHandle.value, false, dtype);
return new NDArray(refHandle.value, false, dtype, ctx);
}

/**
Expand Down
3 changes: 1 addition & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/NDArrayBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public NDArrayBase copyTo(NDArrayBase target) {
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
public void release() {
Expand Down
16 changes: 8 additions & 8 deletions jvm/core/src/main/java/ml/dmlc/tvm/TVMType.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ public TVMType(String typeStr, int lanes) {
this.lanes = lanes;
int bitsTemp = 0;
if (typeStr.startsWith("int")) {
typeCode = 0;
typeCode = INT;
bitsTemp = Integer.parseInt(typeStr.substring(3));
} else if (typeStr.startsWith("uint")) {
typeCode = 1;
typeCode = UINT;
bitsTemp = Integer.parseInt(typeStr.substring(4));
} else if (typeStr.startsWith("float")) {
typeCode = 2;
typeCode = FLOAT;
bitsTemp = Integer.parseInt(typeStr.substring(5));
} else if (typeStr.startsWith("handle")) {
typeCode = 4;
typeCode = HANDLE;
bitsTemp = 64;
} else {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
Expand Down Expand Up @@ -78,16 +78,16 @@ public TVMType(String typeStr) {
@Override public String toString() {
String typeCodeStr;
switch (typeCode) {
case 0:
case INT:
typeCodeStr = "int";
break;
case 1:
case UINT:
typeCodeStr = "uint";
break;
case 2:
case FLOAT:
typeCodeStr = "float";
break;
case 4:
case HANDLE:
typeCodeStr = "handle";
break;
default:
Expand Down
34 changes: 34 additions & 0 deletions jvm/core/src/main/java/ml/dmlc/tvm/TVMValueHandle.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ml.dmlc.tvm;

/**
* Java class related to TVM handles (TypeCode.HANDLE)
*/
public class TVMValueHandle extends TVMValue {
public final long value;

public TVMValueHandle(long value) {
super(TypeCode.HANDLE);
this.value = value;
}

@Override public long asHandle() {
return value;
}
}
170 changes: 170 additions & 0 deletions jvm/core/src/main/java/ml/dmlc/tvm/contrib/GraphModule.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package ml.dmlc.tvm.contrib;

import ml.dmlc.tvm.Function;
import ml.dmlc.tvm.Module;
import ml.dmlc.tvm.NDArray;
import ml.dmlc.tvm.TVMContext;

/**
* Wrapper runtime module.
* This is a thin wrapper of the underlying TVM module.
* you can also directly call set_input, run, and get_output
* of underlying module functions.
*/
public class GraphModule {
private Module module;
private TVMContext ctx;

private Function fsetInput;
private Function frun;
private Function fgetOutput;
private Function fgetInput;
private Function fdebugGetOutput;
private Function floadParams;

GraphModule(Module module, TVMContext ctx) {
this.module = module;
this.ctx = ctx;
fsetInput = module.getFunction("set_input");
frun = module.getFunction("run");
fgetInput = module.getFunction("get_input");
fgetOutput = module.getFunction("get_output");
try {
fdebugGetOutput = module.getFunction("debug_get_output");
} catch (IllegalArgumentException ignored) {
// ignore
}
floadParams = module.getFunction("load_params");
}

/**
* Release the GraphModule.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy.
* </p>
*/
public void release() {
fsetInput.release();
frun.release();
fgetInput.release();
fgetOutput.release();
if (fdebugGetOutput != null) {
fdebugGetOutput.release();
}
floadParams.release();
module.release();
}

/**
* Set inputs to the module.
* @param key The input key.
* @param value The input value
* @return self.
*/
public GraphModule setInput(String key, NDArray value) {
NDArray input = value;
if (!value.ctx().equals(ctx)) {
input = NDArray.empty(value.shape(), ctx);
value.copyTo(input);
}
fsetInput.pushArg(key).pushArg(input).invoke();
return this;
}

/**
* Set inputs to the module
* @param key The input key.
* @param value The input value.
* @return self.
*/
public GraphModule setInput(int key, NDArray value) {
NDArray input = value;
if (!value.ctx().equals(ctx)) {
input = NDArray.empty(value.shape(), ctx);
value.copyTo(input);
}
fsetInput.pushArg(key).pushArg(input).invoke();
return this;
}

/**
* Run forward execution of the graph.
* @return self.
*/
public GraphModule run() {
frun.invoke();
return this;
}

/**
* Get index-th input to out.
* @param index The input index.
* @param out The output array container.
* @return out.
*/
public NDArray getInput(int index, NDArray out) {
fgetInput.pushArg(index).pushArg(out).invoke();
return out;
}

/**
* Get index-th output to out.
* @param index The output index.
* @param out The output array container.
* @return out.
*/
public NDArray getOutput(int index, NDArray out) {
fgetOutput.pushArg(index).pushArg(out).invoke();
return out;
}

/**
* Run graph up to node and get the output to out.
* @param node The node name.
* @param out The output array container.
* @return out.
*/
public NDArray debugGetOutput(String node, NDArray out) {
if (fdebugGetOutput != null) {
fdebugGetOutput.pushArg(node).pushArg(out).invoke();
} else {
throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
}
return out;
}

/**
* Run graph up to node and get the output to out.
* @param node The node index.
* @param out The output array container.
* @return out.
*/
public NDArray debugGetOutput(int node, NDArray out) {
if (fdebugGetOutput != null) {
fdebugGetOutput.pushArg(node).pushArg(out).invoke();
} else {
throw new RuntimeException("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0");
}
return out;
}

/**
* Load parameters from serialized byte array of parameter dict.
* @param params The serialized parameter.
* @return self.
*/
public GraphModule loadParams(byte[] params) {
floadParams.pushArg(params).invoke();
return this;
}

/**
* Get internal module function.
* @param key The key to the module.
* @return The function.
* @throws IllegalArgumentException if function does not exist.
*/
public Function getFunction(String key) {
return module.getFunction(key);
}
}
Loading

0 comments on commit 5643846

Please sign in to comment.