Skip to content

Commit

Permalink
[examples] include all engines in example project build.gradle
Browse files Browse the repository at this point in the history
Change-Id: I06de6be02f20bc1f09b9a2a2d067f5c25aba9c8f
  • Loading branch information
frankfliu committed Nov 4, 2021
1 parent 4a94cff commit 4244007
Show file tree
Hide file tree
Showing 46 changed files with 115 additions and 143 deletions.
6 changes: 3 additions & 3 deletions docs/load_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ You can use [ModelZoo.listModels()](https://javadoc.io/static/ai.djl/api/0.13.0/

#### List available models using DJL command line

Use the following command to list models in examples module for MXNet engine:
Use the following command to list models in the DJL model zoo:

```shell
./gradlew :examples:listmodels
Expand All @@ -200,8 +200,8 @@ Use the following command to list models in examples module for MXNet engine:

```

You can list models from your model folder and only list models for specific Engine with debug log:
You can list models from your model folder with debug log:

```shell
./gradlew :examples:listmodels -Dai.djl.default_engine=PyTorch -Dai.djl.logging.level=debug -Dai.djl.repository.zoo.location=file:///mymodels
./gradlew :examples:listmodels -Dai.djl.logging.level=debug -Dai.djl.repository.zoo.location=file:///mymodels
```
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public String processOutput(TranslatorContext ctx, NDList list) {
NDArray endLogits = output.get(1).reshape(new Shape(1, -1));
int startIdx = (int) startLogits.argMax(1).getLong();
int endIdx = (int) endLogits.argMax(1).getLong();
return tokens.subList(startIdx, endIdx + 1).toString();
return tokenizer.tokenToString(tokens.subList(startIdx, endIdx + 1));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void testForward() throws IOException, ModelNotFoundException, MalformedM
@Test
public void trainWithNewParam()
throws IOException, ModelNotFoundException, MalformedModelException {
if ("MXNet".equals(Engine.getInstance().getEngineName())) {
if ("MXNet".equals(Engine.getDefaultEngineName())) {
// TODO: WARN The gradMeans (but not predictions or loss) changed during the upgrade
// to MXNet 1.8. The issue affect only CPU, but GPU has not changed.
TestRequirements.gpu();
Expand Down
10 changes: 3 additions & 7 deletions examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ dependencies {
implementation project(":basicdataset")
implementation project(":model-zoo")

if (System.getProperty("ai.djl.default_engine") == "PyTorch") {
runtimeOnly project(":engines:pytorch:pytorch-model-zoo")
} else if (System.getProperty("ai.djl.default_engine") == "TensorFlow") {
runtimeOnly project(":engines:tensorflow:tensorflow-model-zoo")
} else {
runtimeOnly project(":engines:mxnet:mxnet-model-zoo")
}
runtimeOnly project(":engines:pytorch:pytorch-model-zoo")
runtimeOnly project(":engines:tensorflow:tensorflow-model-zoo")
runtimeOnly project(":engines:mxnet:mxnet-model-zoo")

testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/biggan.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Use the following commands to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN -Dai.djl.default_engine=PyTorch
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN
```

### Output
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/face_detection.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ Use the following command to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.face.RetinaFaceDetection -Dai.djl.default_engine=PyTorch
./gradlew run -Dmain=ai.djl.examples.inference.face.LightFaceDetection -Dai.djl.default_engine=PyTorch
./gradlew run -Dmain=ai.djl.examples.inference.face.RetinaFaceDetection
./gradlew run -Dmain=ai.djl.examples.inference.face.LightFaceDetection
```

Your output should look like the following:
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/face_recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Use the following command to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureExtraction -Dai.djl.default_engine=PyTorch
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureExtraction
```

Your output should look like the following:
Expand All @@ -38,7 +38,7 @@ Your output should look like the following:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureComparison -Dai.djl.default_engine=PyTorch
./gradlew run -Dmain=ai.djl.examples.inference.face.FeatureComparison
```

Your output should look like the following:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Use the following command to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.ObjectDetectionWithTensorflowSavedModel -Dai.djl.default_engine=TensorFlow
./gradlew run -Dmain=ai.djl.examples.inference.ObjectDetectionWithTensorflowSavedModel
```

Your output should look like the following:
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/sentiment_analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ Follow [setup](../../docs/development/setup.md) to configure your development en

```
cd examples
./gradlew run -Dai.djl.default_engine=PyTorch -Dmain=ai.djl.examples.inference.SentimentAnalysis
./gradlew run -Dmain=ai.djl.examples.inference.SentimentAnalysis
```
2 changes: 1 addition & 1 deletion examples/docs/super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Use the following commands to run the project:

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.sr.SuperResolution -Dai.djl.default_engine=TensorFlow
./gradlew run -Dmain=ai.djl.examples.inference.sr.SuperResolution
```

### Output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public static Classifications predict() throws IOException, ModelException, Tran
.setTypes(Image.class, Classifications.class)
.optFilter("backbone", "inceptionv3")
.optFilter("dataset", "ucf101")
.optEngine("MXNet")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.DefaultVocabulary;
Expand Down Expand Up @@ -57,21 +56,14 @@ public static void main(String[] args) throws IOException, ModelException, Trans
inputs.add("class3\tDJL is good");

Classifications[] results = predict(inputs);
if (results == null) {
logger.info("This example only works for TensorFlow Engine");
} else {
for (int i = 0; i < inputs.size(); i++) {
logger.info("Prediction for: " + inputs.get(i) + "\n" + results[i].toString());
}
for (int i = 0; i < inputs.size(); i++) {
logger.info("Prediction for: " + inputs.get(i) + "\n" + results[i].toString());
}
}

public static Classifications[] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
if (!"TensorFlow".equals(Engine.getInstance().getEngineName())) {
return null;
}
// refer to
// https://medium.com/delvify/bert-rest-inference-from-the-fine-tuned-model-499997b32851 and
// https://github.com/google-research/bert
Expand All @@ -84,6 +76,7 @@ public static Classifications[] predict(List<String> inputs)
.setTypes(String[].class, Classifications[].class)
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator(vocabularyPath, 128))
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
Expand Down Expand Up @@ -68,6 +69,7 @@ public static String predict() throws IOException, TranslateException, ModelExce
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("backbone", "bert")
.optEngine(Engine.getDefaultEngineName())
.optProgress(new ProgressBar())
.build();

Expand Down
8 changes: 1 addition & 7 deletions examples/src/main/java/ai/djl/examples/inference/BigGAN.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.repository.zoo.Criteria;
Expand All @@ -36,11 +35,6 @@ public final class BigGAN {
private BigGAN() {}

public static void main(String[] args) throws ModelException, TranslateException, IOException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
logger.info("This example only works for PyTorch Engine");
return;
}

Image[] generatedImages = BigGAN.generate();
logger.info("Using PyTorch Engine. {} images generated.", generatedImages.length);
saveImages(generatedImages);
Expand All @@ -58,13 +52,13 @@ private static void saveImages(Image[] generatedImages) throws IOException {
}

public static Image[] generate() throws IOException, ModelException, TranslateException {

Criteria<int[], Image[]> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_GENERATION)
.setTypes(int[].class, Image[].class)
.optFilter("size", "256")
.optArgument("truncation", 0.4f)
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Image img = ImageFactory.getInstance().fromFile(imageFile);

String backbone;
if ("TensorFlow".equals(Engine.getInstance().getEngineName())) {
if ("TensorFlow".equals(Engine.getDefaultEngineName())) {
backbone = "mobilenet_v2";
} else {
backbone = "resnet50";
Expand All @@ -64,6 +64,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", backbone)
.optEngine(Engine.getDefaultEngineName())
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
Expand Down Expand Up @@ -69,17 +68,11 @@ public final class ObjectDetectionWithTensorflowSavedModel {
private ObjectDetectionWithTensorflowSavedModel() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
if (!"TensorFlow".equals(Engine.getInstance().getEngineName())) {
logger.info("This example only works for TensorFlow Engine");
return;
}

DetectedObjects detection = ObjectDetectionWithTensorflowSavedModel.predict();
logger.info("{}", detection);
}

public static DetectedObjects predict() throws IOException, ModelException, TranslateException {

Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);

Expand All @@ -94,6 +87,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
// saved_model.pb file is in the subfolder of the model archive file
.optModelName("ssd_mobilenet_v2_320x320_coco17_tpu-8/saved_model")
.optTranslator(new MyTranslator())
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -60,7 +61,7 @@ public static Joints predict() throws IOException, ModelException, TranslateExce

if (person == null) {
logger.warn("No person found in image.");
return null;
return new Joints(Collections.emptyList());
}

return predictJointsInPerson(person);
Expand All @@ -78,6 +79,7 @@ private static Image predictPersonInImage(Image img)
.optFilter("backbone", "resnet50")
.optFilter("flavor", "v1")
.optFilter("dataset", "voc")
.optEngine("MXNet")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
Expand Down Expand Up @@ -46,11 +45,6 @@ public enum Artist {
}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
logger.info("This example only works for PyTorch Engine");
return;
}

Artist artist = Artist.MONET;
String imagePath = "src/test/resources/mountains.png";
Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
Expand All @@ -63,11 +57,6 @@ public static void main(String[] args) throws IOException, ModelException, Trans
public static Image transfer(Image image, Artist artist)
throws IOException, ModelNotFoundException, MalformedModelException,
TranslateException {

if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
return null;
}

String modelName = "style_" + artist.toString().toLowerCase() + ".zip";
String modelUrl =
"https://mlrepo.djl.ai/model/cv/image_generation/ai/djl/pytorch/cyclegan/0.0.1/"
Expand All @@ -80,6 +69,7 @@ public static Image transfer(Image image, Artist artist)
.optModelUrls(modelUrl)
.optProgress(new ProgressBar())
.optTranslatorFactory(new StyleTransferTranslatorFactory())
.optEngine("PyTorch")
.build();

try (ZooModel<Image, Image> model = criteria.loadModel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
Expand Down Expand Up @@ -53,23 +52,14 @@ public static void main(String[] args) throws IOException, ModelException, Trans
inputs.add("I am a sentence for which I would like to get its embedding");

float[][] embeddings = UniversalSentenceEncoder.predict(inputs);
if (embeddings == null) {
logger.info("This example only works for TensorFlow Engine");
} else {
for (int i = 0; i < inputs.size(); i++) {
logger.info(
"Embedding for: " + inputs.get(i) + "\n" + Arrays.toString(embeddings[i]));
}
for (int i = 0; i < inputs.size(); i++) {
logger.info("Embedding for: " + inputs.get(i) + "\n" + Arrays.toString(embeddings[i]));
}
}

public static float[][] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
if (!"TensorFlow".equals(Engine.getInstance().getEngineName())) {
return null;
}

String modelUrl =
"https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/4.tar.gz";

Expand All @@ -79,6 +69,7 @@ public static float[][] predict(List<String> inputs)
.setTypes(String[].class, float[][].class)
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator())
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();
try (ZooModel<String[], float[][]> model = criteria.loadModel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.examples.inference.face;

import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.translate.TranslateException;
Expand All @@ -30,11 +29,6 @@ public final class FeatureComparison {
private FeatureComparison() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
if (!"PyTorch".equals(Engine.getInstance().getEngineName())) {
logger.info("This example only works for PyTorch.");
return;
}

Path imageFile1 = Paths.get("src/test/resources/kana1.jpg");
Image img1 = ImageFactory.getInstance().fromFile(imageFile1);
Path imageFile2 = Paths.get("src/test/resources/kana2.jpg");
Expand Down
Loading

0 comments on commit 4244007

Please sign in to comment.