Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Change-Id: I7f44c815d0d9293c5493057c5255d247fbb98e18
  • Loading branch information
Qing Lan committed Mar 5, 2021
1 parent 21a5467 commit 1948ef9
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ public final class DlrEngine extends Engine {
public static final String ENGINE_NAME = "DLR";

private Engine alternativeEngine;
private boolean disableAlternative;

private DlrEngine() {}
private DlrEngine() {
disableAlternative =
Boolean.parseBoolean(System.getProperty("djl_dlr_disable_alternative", "false"));
}

static Engine newInstance() {
try {
Expand All @@ -47,6 +51,7 @@ static Engine newInstance() {
}

private Engine getAlternativeEngine() {
if (disableAlternative) return null;
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ public final class OrtEngine extends Engine {

private Engine alternativeEngine;
private OrtEnvironment env;
private boolean disableAlternative;

private OrtEngine() {
// init OrtRuntime
this.env = OrtEnvironment.getEnvironment();
disableAlternative =
Boolean.parseBoolean(System.getProperty("djl_onnx_disable_alternative", "false"));
}

static Engine newInstance() {
Expand All @@ -57,6 +60,7 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
if (disableAlternative) return null;
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private OrtNDManager(NDManager parent, Device device, OrtEnvironment env) {
this.env = env;
}

public static OrtNDManager getSystemManager() {
static OrtNDManager getSystemManager() {
return SYSTEM_MANAGER;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException {
public void testStringTensor()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
System.setProperty("djl_onnx_disable_alternative", "true");
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
Expand All @@ -82,12 +83,13 @@ public void testStringTensor()
.build();
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> predictor = model.newPredictor()) {
OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager();
OrtNDManager manager = (OrtNDManager) model.getNDManager();
NDArray stringNd =
manager.create(
new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"},
new Shape(1, 2));
predictor.predict(new NDList(stringNd));
}
System.clearProperty("djl_onnx_disable_alternative");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ public final class PpEngine extends Engine {

private Engine alternativeEngine;
private String version;
private boolean disableAlternative;

private PpEngine() {
version = JniUtils.getVersion();
disableAlternative =
Boolean.parseBoolean(System.getProperty("djl_paddle_disable_alternative", "false"));
}

static Engine newInstance() {
Expand All @@ -57,6 +60,7 @@ public int getRank() {
}

Engine getAlternativeEngine() {
if (disableAlternative) return null;
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ public final class TfLiteEngine extends Engine {
public static final String ENGINE_NAME = "TFLite";

private Engine alternativeEngine;
private boolean disableAlternative;

private TfLiteEngine() {
LibUtils.loadLibrary();
disableAlternative =
Boolean.parseBoolean(System.getProperty("djl_tflite_disable_alternative", "false"));
}

static Engine newInstance() {
Expand All @@ -54,6 +57,7 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
if (disableAlternative) return null;
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down

0 comments on commit 1948ef9

Please sign in to comment.