Skip to content

Commit

Permalink
[timeseries] Update M5Forecast dataset and its unittest (deepjavalibr…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored and patins1 committed Oct 30, 2022
1 parent 4beb67b commit 7c11b24
Show file tree
Hide file tree
Showing 4 changed files with 3,185 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.repository.Repository;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import ai.djl.util.Utils;

import org.apache.commons.csv.CSVFormat;

Expand All @@ -35,7 +34,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
Expand Down Expand Up @@ -109,9 +107,9 @@ private String getUsagePath(Usage usage) {
// file as 'weekly_***'
switch (usage) {
case TRAIN:
return "fake_weekly_sales_train_validation.csv";
return "weekly_sales_train_validation.csv";
case TEST:
return "fake_weekly_sales_train_evaluation.csv";
return "weekly_sales_train_evaluation.csv";
case VALIDATION:
default:
throw new UnsupportedOperationException("Data not available.");
Expand All @@ -122,11 +120,17 @@ private String getUsagePath(Usage usage) {
public static class Builder extends CsvBuilder<Builder> {

Repository repository;
Usage usage = Usage.TRAIN;
String groupId;
String artifactId;
Usage usage;
M5Features mf;
List<Integer> cardinality;

Builder() {
repository = BasicDatasets.REPOSITORY;
groupId = BasicDatasets.GROUP_ID;
artifactId = ARTIFACT_ID;
usage = Usage.TRAIN;
csvFormat =
CSVFormat.DEFAULT
.builder()
Expand All @@ -139,8 +143,7 @@ public static class Builder extends CsvBuilder<Builder> {
}

MRL getMrl() {
return repository.dataset(
Application.Tabular.ANY, BasicDatasets.GROUP_ID, ARTIFACT_ID, VERSION);
return repository.dataset(Application.Tabular.ANY, groupId, artifactId, VERSION);
}

/** {@inheritDoc} */
Expand All @@ -150,18 +153,46 @@ protected Builder self() {
}

/**
* Set the repository containing the path.
* Sets the optional repository.
*
* @param repository the repository containing the path
* @param repository the repository
* @return this builder
*/
public Builder setRepository(Repository repository) {
public Builder optRepository(Repository repository) {
this.repository = repository;
return this;
}

/**
* Set the optional usage.
* Sets optional groupId.
*
* @param groupId the groupId}
* @return this builder
*/
public Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}

/**
* Sets the optional artifactId.
*
* @param artifactId the artifactId
* @return this builder
*/
public Builder optArtifactId(String artifactId) {
if (artifactId.contains(":")) {
String[] tokens = artifactId.split(":");
groupId = tokens[0];
this.artifactId = tokens[1];
} else {
this.artifactId = artifactId;
}
return this;
}

/**
* Sets the optional usage.
*
* @param usage the usage
* @return this builder
Expand Down Expand Up @@ -228,9 +259,8 @@ public M5Forecast build() {

private void parseFeatures() {
if (mf == null) {
String url =
"https://mlrepo.djl.ai/dataset/tabular/ai/djl/basicdataset/m5forecast/1.0/m5forecast_parser.json";
try (InputStream is = Objects.requireNonNull(Utils.openUrl(url));
try (InputStream is =
M5Forecast.class.getResourceAsStream("m5forecast_parser.json");
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
mf = JsonUtils.GSON.fromJson(reader, M5Features.class);
} catch (IOException e) {
Expand Down
Loading

0 comments on commit 7c11b24

Please sign in to comment.