Skip to content

Commit

Permalink
[timeseries] add some basic block and deepAR model (deepjavalibrary#2027
Browse files Browse the repository at this point in the history
)

* feature: add TimeSeriesDataset and training transform
* feature: some basic block and deepar model
* feature: add train example
* feature: add m5-demo and air passengers demo

Co-authored-by: Carkham <[email protected]>
Co-authored-by: Frank Liu <[email protected]>
Co-authored-by: KexinFeng <[email protected]>
  • Loading branch information
4 people authored Oct 28, 2022
1 parent b7fce8c commit fe4a9f3
Show file tree
Hide file tree
Showing 32 changed files with 3,264 additions and 105 deletions.
10 changes: 10 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,14 @@ public interface Audio {
/** Any audio application, including those in {@link Audio}. */
Application ANY = new Application("audio");
}

/** The common set of applications for timeseries extension. */
public interface TimeSeries {

/**
* An application that take a past target vector with corresponding feature and predicts a
* probability distribution based on it.
*/
Application FORECASTING = new Application("timeseries/forecasting");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Application.CV;
import ai.djl.Application.NLP;
import ai.djl.Application.TimeSeries;
import ai.djl.mxnet.engine.MxEngine;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelZoo;
Expand Down Expand Up @@ -53,6 +54,7 @@ public class MxModelZoo extends ModelZoo {
addModel(REPOSITORY.model(CV.ACTION_RECOGNITION, GROUP_ID, "action_recognition", "0.0.1"));
addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1"));
addModel(REPOSITORY.model(NLP.WORD_EMBEDDING, GROUP_ID, "glove", "0.0.2"));
addModel(REPOSITORY.model(TimeSeries.FORECASTING, GROUP_ID, "deepar", "0.0.1"));
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "timeseries/forecasting",
"groupId": "ai.djl.mxnet",
"artifactId": "resnest",
"name": "deepar",
"description": "DeepAR model for timeseries forecasting",
"website": "http://www.djl.ai/engines/mxnet/model-zoo",
"licenses": {
"apache": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "airpassengers",
"properties": {
"dataset": "airpassengers"
},
"arguments": {
"prediction_length": 12,
"freq": "M",
"use_feat_dynamic_real": false,
"use_feat_static_cat": false,
"use_feat_static_real": false,
"translatorFactory": "ai.djl.timeseries.translator.DeepARTranslatorFactory"
},
"files": {
"model": {
"uri": "0.0.1/airpassengers.zip",
"sha1Hash": "1c99cdaefb79c3e63bc7ff1965b0fb2ba45e96c3",
"name": "",
"size": 106895
}
}
},
{
"version": "0.0.1",
"snapshot": false,
"name": "m5forecast",
"properties": {
"dataset": "m5forecast"
},
"arguments": {
"prediction_length": 4,
"freq": "W",
"use_feat_dynamic_real": false,
"use_feat_static_cat": false,
"use_feat_static_real": false,
"translatorFactory": "ai.djl.timeseries.translator.DeepARTranslatorFactory"
},
"files": {
"model": {
"uri": "0.0.1/m5forecast.zip",
"sha1Hash": "e251628df3a246911479de0ed36762515a5df241",
"name": "",
"size": 96363
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.examples.inference.timeseries;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.SampleForecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;

import com.google.gson.GsonBuilder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.Reader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Date;

public final class AirPassengersDeepAR {

private static final Logger logger = LoggerFactory.getLogger(AirPassengersDeepAR.class);

private AirPassengersDeepAR() {}

public static void main(String[] args) throws IOException, TranslateException, ModelException {
float[] results = predict();
logger.info("{}", results);
}

public static float[] predict() throws IOException, TranslateException, ModelException {
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/airpassengers")
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", 12)
.optArgument("freq", "M")
.optArgument("use_feat_dynamic_real", false)
.optArgument("use_feat_static_cat", false)
.optArgument("use_feat_static_real", false)
.optProgress(new ProgressBar())
.build();

String url = "https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json";

try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager(null, "MXNet")) {
TimeSeriesData data = getTimeSeriesData(manager, new URL(url));

// save data for plotting
NDArray target = data.get(FieldName.TARGET);
target.setName("target");
saveNDArray(target);

Forecast forecast = predictor.predict(data);

// save data for plotting. Please see the corresponding python script from
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
NDArray samples = ((SampleForecast) forecast).getSortedSamples();
samples.setName("samples");
saveNDArray(samples);
return forecast.mean().toFloatArray();
}
}

private static TimeSeriesData getTimeSeriesData(NDManager manager, URL url) throws IOException {
try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
AirPassengers passengers =
new GsonBuilder()
.setDateFormat("yyyy-MM")
.create()
.fromJson(reader, AirPassengers.class);

LocalDateTime start =
passengers.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
NDArray target = manager.create(passengers.target);
TimeSeriesData data = new TimeSeriesData(10);
data.setStartTime(start);
data.setField(FieldName.TARGET, target);
return data;
}
}

private static void saveNDArray(NDArray array) throws IOException {
Path path = Paths.get("build").resolve(array.getName() + ".npz");
try (OutputStream os = Files.newOutputStream(path)) {
new NDList(new NDList(array)).encode(os, true);
}
}

private static final class AirPassengers {

Date start;
float[] target;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* and limitations under the License.
*/

package ai.djl.examples.inference;
package ai.djl.examples.inference.timeseries;

import ai.djl.ModelException;
import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
Expand All @@ -27,8 +27,8 @@
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.translator.DeepARTranslator;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;

Expand All @@ -55,17 +55,16 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class DeepARTimeSeries {
public final class M5ForecastingDeepAR {

private static final Logger logger = LoggerFactory.getLogger(DeepARTimeSeries.class);
private static final Logger logger = LoggerFactory.getLogger(M5ForecastingDeepAR.class);

private DeepARTimeSeries() {}
private M5ForecastingDeepAR() {}

public static void main(String[] args) throws IOException, TranslateException, ModelException {
logger.info("model: DeepAR");
Map<String, Float> metrics = predict();
for (Map.Entry<String, Float> entry : metrics.entrySet()) {
logger.info(String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue()));
logger.info("{}", String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue()));
}
}

Expand All @@ -74,25 +73,21 @@ public static Map<String, Float> predict()
// M5 Forecasting - Accuracy dataset requires manual download
String pathToData = "/Desktop/m5example/m5-forecasting-accuracy";
Path m5ForecastFile = Paths.get(System.getProperty("user.home") + pathToData);
NDManager manager = NDManager.newBaseManager();
NDManager manager = NDManager.newBaseManager(null, "MXNet");
M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build();

String modelUrl = "https://resources.djl.ai/test-models/mxnet/timeseries/deepar.zip";
Map<String, Object> arguments = new ConcurrentHashMap<>();
int predictionLength = 28;
arguments.put("prediction_length", predictionLength);
arguments.put("freq", "D");
arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);

DeepARTranslator.Builder builder = DeepARTranslator.builder(arguments);
DeepARTranslator translator = builder.build();
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls(modelUrl)
.optTranslator(translator)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/m5forecast")
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", predictionLength)
.optArgument("freq", "W")
.optArgument("use_feat_dynamic_real", "false")
.optArgument("use_feat_static_cat", "false")
.optArgument("use_feat_static_real", "false")
.optProgress(new ProgressBar())
.build();

Expand All @@ -119,6 +114,8 @@ public static Map<String, Float> predict()
evaluator.aggregateMetrics(evaluator.getMetricsPerTs(gt, pastTarget, forecast));
progress.increment(1);
}

manager.close();
return evaluator.computeTotalMetrics();
}
}
Expand All @@ -143,14 +140,13 @@ private static final class M5Dataset implements Iterable<NDList>, Iterator<NDLis
try {
prepare(builder);
} catch (Exception e) {
throw new AssertionError(
"Failed to read m5-forecast-accuracy/sales_train_evaluation.csv file.", e);
throw new AssertionError("Failed to read files.", e);
}
size = csvRecords.size();
}

private void prepare(Builder builder) throws IOException {
URL csvUrl = builder.root.resolve("sales_train_evaluation.csv").toUri().toURL();
URL csvUrl = builder.root.resolve("weekly_sales_train_evaluation.csv").toUri().toURL();
try (Reader reader =
new InputStreamReader(
new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) {
Expand Down Expand Up @@ -213,8 +209,8 @@ public static final class Builder {
.setTrim(true)
.build();
target = new ArrayList<>();
for (int i = 1; i <= 1941; i++) {
target.add(new Feature("d_" + i, true));
for (int i = 1; i <= 277; i++) {
target.add(new Feature("w_" + i, true));
}
}

Expand All @@ -234,7 +230,8 @@ public M5Dataset build() {
}
}

private static final class M5Evaluator {
/** An evaluator that calculates performance metrics. */
public static final class M5Evaluator {
private float[] quantiles;
Map<String, Float> totalMetrics;
Map<String, Integer> totalNum;
Expand All @@ -252,15 +249,14 @@ public Map<String, Float> getMetricsPerTs(
new ConcurrentHashMap<>((8 + quantiles.length * 2) * 3 / 2);
NDArray meanFcst = forecast.mean();
NDArray medianFcst = forecast.median();
NDArray target = NDArrays.concat(new NDList(pastTarget, gtTarget), -1);

NDArray successiveDiff = target.get("1:").sub(target.get(":-1"));
successiveDiff = successiveDiff.square();
successiveDiff = successiveDiff.get(":{}", -forecast.getPredictionLength());
NDArray denom = successiveDiff.mean();
NDArray meanSquare = gtTarget.sub(meanFcst).square().mean();
NDArray scaleDenom = gtTarget.get("1:").sub(gtTarget.get(":-1")).square().mean();

NDArray rmsse = meanSquare.div(scaleDenom).sqrt();
rmsse = NDArrays.where(scaleDenom.eq(0), rmsse.onesLike(), rmsse);

NDArray num = gtTarget.sub(meanFcst).square().mean();
retMetrics.put("RMSSE", num.getFloat() / denom.getFloat());
retMetrics.put("RMSSE", rmsse.getFloat());

retMetrics.put("MSE", gtTarget.sub(meanFcst).square().mean().getFloat());
retMetrics.put("abs_error", gtTarget.sub(medianFcst).abs().sum().getFloat());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains examples of time series forecasting. */
package ai.djl.examples.inference.timeseries;
Loading

0 comments on commit fe4a9f3

Please sign in to comment.