Skip to content

Commit

Permalink
feature: change dataset source and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Carkham authored and KexinFeng committed Nov 4, 2022
1 parent 22d6ae5 commit b8572c2
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

package ai.djl.examples.inference.timeseries;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.inference.Predictor;
Expand All @@ -22,6 +24,9 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
Expand All @@ -46,7 +51,6 @@
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -70,17 +74,22 @@ public static void main(String[] args) throws IOException, TranslateException, M

public static Map<String, Float> predict()
throws IOException, TranslateException, ModelException {
// M5 Forecasting - Accuracy dataset requires manual download
String pathToData = "/Desktop/m5example/m5-forecasting-accuracy";
Path m5ForecastFile = Paths.get(System.getProperty("user.home") + pathToData);
NDManager manager = NDManager.newBaseManager(null, "MXNet");
M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build();

// To use local dataset, users can load data as follows
// Repository repository = Repository.newInstance("local_dataset",
// Paths.get("rootPath/m5-forecasting-accuracy"));
// Then set `Builder.optRepository(repository)`
M5Dataset dataset = M5Dataset.builder().setManager(manager).build();

// The modelUrl can be replaced by local model path. E.g.,
// String modelUrl = "rootPath/m5forecast.zip";
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/m5forecast")
.optModelUrls(modelUrl)
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", predictionLength)
Expand All @@ -104,13 +113,11 @@ public static Map<String, Float> predict()
input.setStartTime(LocalDateTime.parse("2011-01-29T00:00"));
input.setField(FieldName.TARGET, pastTarget);
Forecast forecast = predictor.predict(input);
// Here we focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same
// as
// We focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same as
// https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/evaluation
// The error is not small compared to the data values (sale amount). This is because
// The model is trained on a sparse data with many zeros. This will be improved by
// aggregating/coarse graining the data which will appear in the next PR.
// TODO: coarse graining the data.
// aggregating/coarse graining the data. See https://github.com/Carkham/m5_blog
evaluator.aggregateMetrics(evaluator.getMetricsPerTs(gt, pastTarget, forecast));
progress.increment(1);
}
Expand Down Expand Up @@ -146,7 +153,14 @@ private static final class M5Dataset implements Iterable<NDList>, Iterator<NDLis
}

private void prepare(Builder builder) throws IOException {
URL csvUrl = builder.root.resolve("weekly_sales_train_evaluation.csv").toUri().toURL();
MRL mrl = builder.getMrl();
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, null);

Path root = mrl.getRepository().getResourceDirectory(artifact);
Path csvFile = root.resolve("weekly_sales_train_evaluation.csv");

URL csvUrl = csvFile.toUri().toURL();
try (Reader reader =
new InputStreamReader(
new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) {
Expand Down Expand Up @@ -197,9 +211,17 @@ public static final class Builder {
NDManager manager;
List<Feature> target;
CSVFormat csvFormat;
Path root;

Repository repository;
String groupId;
String artifactId;
String version;

Builder() {
repository = BasicDatasets.REPOSITORY;
groupId = BasicDatasets.GROUP_ID;
artifactId = "m5forecast-unittest";
version = "1.0";
csvFormat =
CSVFormat.DEFAULT
.builder()
Expand All @@ -214,8 +236,8 @@ public static final class Builder {
}
}

public Builder setRoot(Path root) {
this.root = root;
public Builder optRepository(Repository repository) {
this.repository = repository;
return this;
}

Expand All @@ -227,6 +249,10 @@ public Builder setManager(NDManager manager) {
public M5Dataset build() {
return new M5Dataset(this);
}

MRL getMrl() {
return repository.dataset(Application.Tabular.ANY, groupId, artifactId, version);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR;
Expand All @@ -26,7 +27,6 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.repository.Repository;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
Expand Down Expand Up @@ -83,13 +83,6 @@ public static void main(String[] args) throws IOException, TranslateException, M
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
// use data path to create a custom repository
Repository repository =
Repository.newInstance(
"test",
Paths.get(
System.getProperty("user.home")
+ "/Desktop/m5-forecasting-accuracy"));

Arguments arguments = new Arguments().parseArgs(args);
try (Model model = Model.newInstance("deepar")) {
Expand All @@ -106,8 +99,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
int contextLength = trainingNetwork.getContextLength();

M5Forecast trainSet =
getDataset(
trainingTransformation, repository, contextLength, Dataset.Usage.TRAIN);
getDataset(trainingTransformation, contextLength, Dataset.Usage.TRAIN);

try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Expand Down Expand Up @@ -144,13 +136,6 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans

public static Map<String, Float> predict(String outputDir)
throws IOException, TranslateException, ModelException {
Repository repository =
Repository.newInstance(
"test",
Paths.get(
System.getProperty("user.home")
+ "/Desktop/m5-forecasting-accuracy"));

try (Model model = Model.newInstance("deepar")) {
DeepARNetwork predictionNetwork = getDeepARModel(new NegativeBinomialOutput(), false);
model.setBlock(predictionNetwork);
Expand All @@ -159,7 +144,6 @@ public static Map<String, Float> predict(String outputDir)
M5Forecast testSet =
getDataset(
new ArrayList<>(),
repository,
predictionNetwork.getContextLength(),
Dataset.Usage.TEST);

Expand Down Expand Up @@ -262,17 +246,16 @@ private static DeepARNetwork getDeepARModel(
}

private static M5Forecast getDataset(
List<TimeSeriesTransform> transformation,
Repository repository,
int contextLength,
Dataset.Usage usage)
List<TimeSeriesTransform> transformation, int contextLength, Dataset.Usage usage)
throws IOException {
// In order to create a TimeSeriesDataset, you must specify the transformation of the data
// preprocessing
M5Forecast.Builder builder =
M5Forecast.builder()
.optUsage(usage)
.optRepository(repository)
.optRepository(BasicDatasets.REPOSITORY)
.optGroupId(BasicDatasets.GROUP_ID)
.optArtifactId("m5forecast-unittest")
.setTransformation(transformation)
.setContextLength(contextLength)
.setSampling(32, usage == Dataset.Usage.TRAIN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,10 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize)
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};

// If users want to use local repository, then the dataset can be loaded as follows
// Repository repository = Repository.newInstance("banana", Paths.get(LOCAL_FOLDER/{train OR
// test}));
// FruitsFreshAndRotten dataset =
// FruitsFreshAndRotten.builder()
// .optRepository(repository)
// .build()
// To use local dataset, users can load it as follows
// Repository repository = Repository.newInstance("banana",
// Paths.get("local_data_root/banana/train"));
// Then add the setting `Builder.optRepository(repository)` to the builder below
FruitsFreshAndRotten dataset =
FruitsFreshAndRotten.builder()
.optUsage(usage)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.examples.inference.timeseries.AirPassengersDeepAR;
import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.Map;

public class TimeSeriesTest {

private static final Logger logger = LoggerFactory.getLogger(TimeSeriesTest.class);

@Test
public void testM5Forecasting() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet");

Map<String, Float> result = M5ForecastingDeepAR.predict();

String[] metricNames =
new String[] {
"RMSSE",
"MSE",
"abs_error",
"abs_target_sum",
"abs_target_mean",
"MAPE",
"sMAPE",
"ND"
};
for (String metricName : metricNames) {
Assert.assertTrue(result.containsKey(metricName));
}
}

@Test
public void testAirPassenger() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet");

float[] result = AirPassengersDeepAR.predict();
logger.info("{}", result);

Assert.assertEquals(result.length, 12);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;

public class TrainTimeSeriesTest {

@Test
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.engine("MXNet");

String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
float loss = result.getTrainLoss();
Assert.assertTrue(loss < 10f, "Loss: " + loss);
}
}

0 comments on commit b8572c2

Please sign in to comment.