From 532f874a21fed1404937c5a0dbf6650b8914eafb Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 3 Nov 2022 19:29:03 -0700 Subject: [PATCH] local model loading --- .../timeseries/M5ForecastingDeepAR.java | 61 +++++++------------ .../transferlearning/TransferFreshFruit.java | 10 ++- 2 files changed, 25 insertions(+), 46 deletions(-) diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java index 6f785a0c8f98..7a627cf22315 100644 --- a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java @@ -51,7 +51,6 @@ import java.nio.FloatBuffer; import java.nio.charset.StandardCharsets; import java.nio.file.Path; -import java.nio.file.Paths; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; @@ -77,47 +76,29 @@ public static Map predict() throws IOException, TranslateException, ModelException { NDManager manager = NDManager.newBaseManager(null, "MXNet"); - // If users want to use local repository, then the dataset can be loaded as follows -// Repository repository = Repository.newInstance("local_dataset", -// Paths.get("root/m5-forecasting-accuracy")); -// M5Dataset dataset = M5Dataset.builder() -// .optRepository(repository) -// .build(); - + // To use local dataset, users can load data as follows + // Repository repository = Repository.newInstance("local_dataset", + // Paths.get("rootPath/m5-forecasting-accuracy")); + // Then set `Builder.optRepository(repository)` M5Dataset dataset = M5Dataset.builder().setManager(manager).build(); - // If users want to use local model, do the following: - Path modelPath = Paths.get("/Users/fenkexin/Downloads/m5forecast.zip"); - int predictionLength = 4; - Criteria criteria = - Criteria.builder() - .setTypes(TimeSeriesData.class, Forecast.class) - .optModelPath(modelPath) - .optEngine("MXNet") - .optTranslatorFactory(new DeferredTranslatorFactory()) - .optArgument("prediction_length", predictionLength) - .optArgument("freq", "D") - .optArgument("use_feat_dynamic_real", "false") - .optArgument("use_feat_static_cat", "false") - .optArgument("use_feat_static_real", "false") - .optProgress(new ProgressBar()) - .build(); - -// String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast"; -// int predictionLength = 4; -// Criteria criteria = -// Criteria.builder() -// .setTypes(TimeSeriesData.class, Forecast.class) -// .optModelUrls(modelUrl) -// .optEngine("MXNet") -// .optTranslatorFactory(new DeferredTranslatorFactory()) -// .optArgument("prediction_length", predictionLength) -// .optArgument("freq", "W") -// .optArgument("use_feat_dynamic_real", "false") -// .optArgument("use_feat_static_cat", "false") -// .optArgument("use_feat_static_real", "false") -// .optProgress(new ProgressBar()) -// .build(); + // The modelUrl can be replaced by local model path. E.g., + // String modelUrl = "rootPath/m5forecast.zip"; + String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast"; + int predictionLength = 4; + Criteria criteria = + Criteria.builder() + .setTypes(TimeSeriesData.class, Forecast.class) + .optModelUrls(modelUrl) + .optEngine("MXNet") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .optArgument("prediction_length", predictionLength) + .optArgument("freq", "W") + .optArgument("use_feat_dynamic_real", "false") + .optArgument("use_feat_static_cat", "false") + .optArgument("use_feat_static_real", "false") + .optProgress(new ProgressBar()) + .build(); try (ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor()) { diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java index 1ce78ae06a7c..e337dea1a526 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java @@ -131,12 +131,10 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize) float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; - // If users want to use local repository, then the dataset can be loaded as follows - // Repository repository = Repository.newInstance("banana", Paths.get("local_data_root/banana/train")); - // FruitsFreshAndRotten dataset = - // FruitsFreshAndRotten.builder() - // .optRepository(repository) - // ... + // To use local dataset, users can load it as follows + // Repository repository = Repository.newInstance("banana", + // Paths.get("local_data_root/banana/train")); + // Then set `Builder.optRepository(repository)` FruitsFreshAndRotten dataset = FruitsFreshAndRotten.builder() .optUsage(usage)