diff --git a/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java b/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java index 997d705270c..2122d06faff 100644 --- a/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java +++ b/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java @@ -21,19 +21,27 @@ /** * {@code ArrayDataset} is an implementation of {@link RandomAccessDataset} that consist entirely of - * large {@link NDArray}s. There can be multiple data and label {@link NDArray}s within the dataset. - * Each sample will be retrieved by indexing each {@link NDArray} along the first dimension. + * large {@link NDArray}s. It is recommended only for datasets small enough to fit in memory that + * come in array formats. Otherwise, consider directly using the {@link RandomAccessDataset} + * instead. + * + *

There can be multiple data and label {@link NDArray}s within the dataset. Each sample will be + * retrieved by indexing each {@link NDArray} along the first dimension. * *

The following is an example of how to use ArrayDataset: * *

  *     ArrayDataset dataset = new ArrayDataset.Builder()
- *                              .setData(data)
- *                              .optLabels(label)
+ *                              .setData(data1, data2)
+ *                              .optLabels(labels1, labels2, labels3)
  *                              .setSampling(20, false)
  *                              .build();
  * 
* + *

Suppose you get a {@link Batch} from {@code trainer.iterateDataset(dataset)} or {@code + * dataset.getData(manager)}. In the data of this batch, it will be an NDList with one NDArray for + * each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays. + * * @see Dataset */ public class ArrayDataset extends RandomAccessDataset { diff --git a/api/src/main/java/ai/djl/training/dataset/Dataset.java b/api/src/main/java/ai/djl/training/dataset/Dataset.java index 90a21da7904..9dacf169437 100644 --- a/api/src/main/java/ai/djl/training/dataset/Dataset.java +++ b/api/src/main/java/ai/djl/training/dataset/Dataset.java @@ -21,7 +21,9 @@ /** * An interface to represent a set of sample data/label pairs to train a model. * - * @see The guide on datasets + * @see The guide to datasets + * @see The guide to + * implementing a custom dataset */ public interface Dataset { diff --git a/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java b/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java index eab783fece4..0d5e20623e6 100644 --- a/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java +++ b/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java @@ -34,6 +34,11 @@ /** * RandomAccessDataset represent the dataset that support random access reads. i.e. it could access * a specific data item given the index. + * + *

Almost all datasets in DJL extend, either directly or indirectly, {@link RandomAccessDataset}. + * + * @see The guide to + * implementing a custom dataset */ public abstract class RandomAccessDataset implements Dataset { diff --git a/basicdataset/README.md b/basicdataset/README.md index c342cd15d43..4cd0afbc85c 100644 --- a/basicdataset/README.md +++ b/basicdataset/README.md @@ -4,17 +4,7 @@ This module contains a number of basic and standard datasets in the Deep Java Library's (DJL). These datasets are used to train deep learning models. -## List of datasets - -This module contains the following datasets: - -- [MNIST](http://yann.lecun.com/exdb/mnist/) - A handwritten digits dataset -- [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) - A dataset consisting of 60,000 32x32 color images in 10 classes -- [Coco](http://cocodataset.org) - A large-scale object detection, segmentation, and captioning dataset that contains 1.5 million object instances - - You have to manually add `com.twelvemonkeys.imageio:imageio-jpeg:3.5` dependency to your project -- [ImageNet](http://www.image-net.org/) - An image database organized according to the WordNet hierarchy - >**Note**: You have to manually download the ImageNet dataset due to licensing requirements. -- [Pikachu](http://d2l.ai/chapter_computer-vision/object-detection-dataset.html) - 1000 Pikachu images of different angles and sizes created using an open source 3D Pikachu model +You can find the datasets provided by this module on our [docs](http://docs.djl.ai/docs/dataset.html). ## Documentation diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java index 3cdb4c0223f..b36f2843db9 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java @@ -34,6 +34,13 @@ /** * Coco image detection dataset from http://cocodataset.org/#home. * + *

Coco is a large-scale object detection, segmentation, and captioning dataset although only + * object detection is implemented at thsi time. It contains 1.5 million object instances and is one + * of the standard benchmark object detection datasets. + * + *

To use this dataset, you have to manually add {@code + * com.twelvemonkeys.imageio:imageio-jpeg:3.5} as a dependency in your project. + * *

Each image might have different {@link ai.djl.ndarray.types.Shape}s. */ public class CocoDetection extends ObjectDetectionDataset { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java index 0a7130dfac7..f1ad1f3870a 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java @@ -39,7 +39,14 @@ import java.util.Map; import java.util.Optional; -/** Pikachu image detection dataset that contains multiple Pikachus in each image. */ +/** + * Pikachu image detection dataset that contains multiple Pikachus in each image. + * + *

It was based on a section from the [Dive into Deep Learning + * book](http://d2l.ai/chapter_computer-vision/object-detection-dataset.html). It contains 1000 + * Pikachu images of different angles and sizes created using an open source 3D Pikachu model. Each + * image contains only a single pikachu. + */ public class PikachuDetection extends ObjectDetectionDataset { private static final String VERSION = "1.0"; diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java index fe8ccad10ca..b370ad5d6fd 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java @@ -34,7 +34,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** A dataset for loading image files stored in a folder structure. */ +/** + * A dataset for loading image files stored in a folder structure. + * + *

Usually, you want to use {@link ImageFolder} instead. + */ public abstract class AbstractImageFolder extends ImageClassificationDataset { private static final Logger logger = LoggerFactory.getLogger(AbstractImageFolder.class); diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java index 21d7a5bc0f6..5e94027281e 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java @@ -34,6 +34,9 @@ /** * CIFAR10 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html. * + *

It consists of 60,000 32x32 color images with 10 classes. It can train in a few hours with a + * GPU. + * *

Each sample is an image (in 3-D {@link NDArray}) with shape (32, 32, 3). */ public final class Cifar10 extends ArrayDataset { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index abbad625b92..90f1c0e44f8 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -33,9 +33,11 @@ /** * FashMnist is a dataset from Zalando article images - * https://github.com/zalandoresearch/fashion-mnist. + * (https://github.com/zalandoresearch/fashion-mnist). * - *

Each sample is an image (in 3-D NDArray) with shape (28, 28, 1). + *

Each sample is a grayscale image (in 3-D NDArray) with shape (28, 28, 1). + * + *

It was created to be a drop in replacement for {@link Mnist}, but have a less simplistic task. */ public final class FashionMnist extends ArrayDataset { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java index fa2251477d9..b9da9ee0211 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java @@ -24,6 +24,8 @@ /** * A dataset for loading image files stored in a folder structure. * + *

Below is an example directory layout for the image folder: + * *

  *  The image folder should be structured as follows:
  *       root/shoes/Aerobic Shoes1.png
@@ -32,11 +34,35 @@
  *       root/boots/Black Boots.png
  *       root/boots/White Boots.png
  *       ...
- *       root/pumps/Red Pumps
- *       root/pumps/Pink Pumps
+ *       root/pumps/Red Pumps.png
+ *       root/pumps/Pink Pumps.png
  *       ...
+ *
  *  here shoes, boots, pumps are your labels
  *  
+ * + *

Here, the dataset will take the folder names (shoes, boots, bumps) in sorted order as your + * labels. Nested folder structures are not currently supported. + * + *

Then, you can create your instance of the dataset as follows: + * + *

+ * // set the image folder path
+ * Repository repository = Repository.newInstance("folder", Paths.get("/path/to/imagefolder/root");
+ * ImageFolder dataset =
+ *     new ImageFolder.Builder()
+ *         .setRepository(repository)
+ *         .addTransform(new Resize(100, 100)) // Use image transforms as necessary for your data
+ *         .addTransform(new ToTensor()) // Usually required as the last transform to convert images to tensors
+ *         .setSampling(batchSize, true)
+ *         .build();
+ *
+ * // call prepare before using
+ * dataset.prepare();
+ *
+ * // to get the synset or label names
+ * List>String< synset = dataset.getSynset();
+ * 
*/ public final class ImageFolder extends AbstractImageFolder { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index 257fd6bc44d..1cb87668d83 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -34,7 +34,11 @@ /** * MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist. * - *

Each sample is an image (in 3-D NDArray) with shape (28, 28, 1). + *

Each sample is a grayscale image (in 3-D NDArray) with shape (28, 28, 1). + * + *

It is a common starting dataset because it is small and can train within minutes. However, it + * is an overly easy task that even poor models can still perform very well on. Instead, consider + * {@link FashionMnist} which offers a comparable speed but a more reasonable difficulty task. */ public final class Mnist extends ArrayDataset { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java index 7716d5ab1a9..aa3d0280db0 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java @@ -39,6 +39,8 @@ * questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every * question is a segment of text, or span, from the corresponding reading passage, or the question * might be unanswerable. + * + * @see Dataset website */ @SuppressWarnings("unchecked") public class StanfordQuestionAnsweringDataset extends TextDataset implements RawDataset { diff --git a/docs/dataset.md b/docs/dataset.md index 5bbddeb20e8..cb6d10778c2 100644 --- a/docs/dataset.md +++ b/docs/dataset.md @@ -1,6 +1,6 @@ # Dataset -A dataset (or data set) is a collection of data that are used for machine-learning training job. +A dataset (or data set) is a collection of data that is used for training a machine learning model. Machine learning typically works with three datasets: @@ -11,7 +11,7 @@ Machine learning typically works with three datasets: - Validation dataset The validation set is used to evaluate a given model during the training process. It helps machine learning - engineers to fine-tune the [HyperParameter](https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/training/hyperparameter/param/Hyperparameter.java) + engineers to fine-tune the [HyperParameters](https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/training/hyperparameter/param/Hyperparameter.java) at model development stage. The model doesn't learn from validation dataset; and validation dataset is optional. @@ -19,9 +19,59 @@ Machine learning typically works with three datasets: The Test dataset provides the gold standard used to evaluate the model. It is only used once a model is completely trained. + The test dataset should more accurately evaluate how the model will be performed on new data. See [Jason Brownlee’s article](https://machinelearningmastery.com/difference-test-validation-datasets/) for more detail. ## [Basic Dataset](../basicdataset/README.md) DJL provides a number of built-in basic and standard datasets. These datasets are used to train deep learning models. +This module contains the following datasets: + +### CV + +#### Image Classification + +- [MNIST](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/Mnist.html) - A small and fast handwritten digits dataset +- [Fashion MNIST](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/FashionMnist.html) - A small and fast clothing type detection dataset +- [CIFAR10](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/Cifar10.html) - A dataset consisting of 60,000 32x32 color images in 10 classes +- [ImageNet](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/ImageNet.html) - An image database organized according to the WordNet hierarchy + >**Note**: You have to manually download the ImageNet dataset due to licensing requirements. + +#### Object Detection + +- [Pikachu](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/PikachuDetection.html) - 1000 Pikachu images of different angles and sizes created using an open source 3D Pikachu model +- [Banana Detection](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/BananaDetection.html) - A testing single object detection dataset + +#### Other CV + +- [Captcha](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/CaptchaDataset.html) - A dataset for a grayscale 6-digit CAPTCHA task +- [Coco](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/CocoDetection.html) - A large-scale object detection, segmentation, and captioning dataset that contains 1.5 million object instances + - You have to manually add `com.twelvemonkeys.imageio:imageio-jpeg:3.5` dependency to your project + +### NLP + +#### Text Classification and Sentiment Analysis + +- [AmazonReview](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/AmazonReview.html) - A sentiment analysis dataset of Amazon Reviews with their ratings +- [Stanford Movie Review](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/StanfordMovieReview.html) - A sentiment analysis dataset of movie reviews and sentiments sourced from IMDB +- [GoEmotions](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/GoEmotions.html) - A dataset classifying 50k curated reddit comments into either 27 emotion categories or neutral + +#### Unlabeled Text + +- [Penn Treebank Text](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/PennTreebankText.html) - The text (not POS tags) from the Penn Treebank, a collection of Wall Street Journal stories +- [WikiText2](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/WikiText2.html) - A collection of over 100 million tokens extracted from good and featured articles on wikipedia + +#### Other NLP + +- [Stanford Question Answering Dataset (SQuAD)](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.html) - A reading comprehension dataset with text from wikipedia articles +- [Tatoeba English French Dataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.html) - An english-french translation dataset from the Tatoeba Project + +### Tabular + +- [Airfoil Self-Noise](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/tabular/AirfoilRandomAccess.html) - A 6 feature dataset from NASA tests of airfoils +- [Ames House Pricing](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/tabular/AmesRandomAccess.html) - A 80 feature dataset to predict house prices + +### Time Series + +- [Daily Delhi Climate](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/tabular/DailyDelhiClimate.html) \ No newline at end of file diff --git a/docs/development/add_dataset_to_djl.md b/docs/development/add_dataset_to_djl.md new file mode 100644 index 00000000000..2660806d255 --- /dev/null +++ b/docs/development/add_dataset_to_djl.md @@ -0,0 +1,66 @@ +# Add a new dataset to DJL basic datasets + +This document outlines the procedure to add new datasets to DJL. + +## Step 1: Prepare the folder structure + +1. Navigate to the `test/resources/mlrepo/dataset` folder and create a folder in it to store your dataset based on its category. + For example, `cv/ai/djl/basicdataset/mnist`. +2. Create a version folder within your newly created dataset's folder (e.g `0.0.1`). The version should match your dataset version. + +### Step 2: Create a `metadata.json` file + +You need to create a `metadata.json` file for the repository to load the dataset. You can refer to the format in the `metadata.json` files for existing datasets to create your own. + +**Note:** You need to update the sha1 hash of each file in your `metadata.json` file. Use the following command to get the sha1Hash value: + +```shell +$ shasum -a 1 +``` + +### Step 3: Create a Dataset implementation + +Create a class that implements the dataset and loads it. +For more details on creating datasets, see the [dataset creation guide](how_to_use_dataset.md). +You should also look at examples of official DJL datasets such as [`AmesRandomAccess`](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AmesRandomAccess.java) +or [`Cifar10`](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java). + +Then, add some tests for the dataset. +For testing, you can use a local repository such as: + +```java +Repository repository = Repository.newInstace("testRepository", Paths.get("/test/resources/mlrepo")); +``` + +### Step 4: Update the datasets list + +Add your dataset to the [list of built-in datasets](../dataset.md). + +### Step 5: Upload metadata + +The official DJL ML repository is located on an S3 bucket managed by the AWS DJL team. +You have to add the metadata and any dataset files to the repository. + +For non-AWS team members, go ahead straight to Step 6 and open a pull request. +Within the pull request, you can coordinate with an AWS member to add the necessary files. + +For AWS team members, run the following command to upload your model to the S3 bucket: + +```shell +$ ./gradlew syncS3 +``` + +The `metadata.json` in DJL is mainly a repository of metadata. +Within the metadata is typically contains only links indicating where the actual data would be found. + +However, some datasets can be distributed by DJL depending on whether it makes it easier to use and the dataset permits redistribution. +In that case, coordinate with an AWS team member in your pull request. + +### Step 6: Open a PR to add your ModelLoader and metadata files to the git repository + +**Note**: Avoid checking in binary files to git. Binary files should only be uploaded to the S3 bucket. + +If you are relying on an AWS team member, you should leave your code with the local test repositories. +If you try to use the official repository before it contains your metadata, the tests will not pass the CI. + +Once an AWS team member adds your metadata, they will prompt you to update your PR to the official repository. diff --git a/docs/development/add_model_to_model-zoo.md b/docs/development/add_model_to_model-zoo.md index 4dfa517a36f..201bef46544 100644 --- a/docs/development/add_model_to_model-zoo.md +++ b/docs/development/add_model_to_model-zoo.md @@ -1,6 +1,6 @@ -# Add a new model to the model zoo +# Add a new model to the DJL model zoo -This document outlines the procedure to add new models into the model zoo. +This document outlines the procedure to add new models into the DJL model zoo. ## Step 1: Prepare the model files @@ -31,6 +31,7 @@ For example, `image_classification/ai/djl/resnet`. 3. Copy model files into the version folder. ### Step 3: Create a `metadata.json` file + You need to create a `metadata.json` file for the model zoo to load the model. You can refer to the format in the `metadata.json` files for existing models to create your own. For a model built as a DJL block, you must recreate the block before loading the parameters. As part of your `metadata.json` file, you should use the `arguments` property to specify the arguments required for the model loader to create another `Block` matching the one used to train the model. @@ -51,7 +52,7 @@ Verify that your folder has the following files (see Step 1 for additional files The official DJL ML repository is located on an S3 bucket managed by the AWS DJL team. -For non-team members, coordinate with a team member in your pull request to coordinate adding the necessary files. +For non-team members, coordinate with a team member in your pull request to add the necessary files. For AWS team members, run the following command to upload your model to the S3 bucket: diff --git a/docs/development/example_dataset.md b/docs/development/example_dataset.md new file mode 100644 index 00000000000..78ed77a8ca8 --- /dev/null +++ b/docs/development/example_dataset.md @@ -0,0 +1,143 @@ +## Example CSV Dataset + +If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset. + +Let's take CSVDataset, which can load a csv file, for example. + +### Step 1: Prerequisites +For this example, we'll use [malicious_url_data.csv](https://github.com/incertum/cyber-matrix-ai/blob/master/Malicious-URL-Detection-Deep-Learning/data/url_data_mega_deep_learning.csv). + +The CSV file has the following format. + +| URL | isMalicious | +| ----------- | ----------- | +| sample.url.good.com | 0 | +| sample.url.bad.com | 1 | + +We'll also use the 3rd party [Apache Commons](https://commons.apache.org/) library to read the CSV file. To use the library, include the following dependency: + +``` +api group: 'org.apache.commons', name: 'commons-csv', version: '1.7' +``` + +### Step 2: Implementation +In order to extend the dataset, the following dependencies are required: + +``` +api "ai.djl:api:0.17.0" +api "ai.djl:basicdataset:0.17.0" +``` + +There are four parts we need to implement for CSVDataset. + +1. Constructor and Builder + +First, we need a private field that holds the CSVRecord list from the csv file. +We create a constructor and pass the CSVRecord list from builder to the class field. +For builder, we have all we need in `BaseBuilder` so we only need to include the two minimal methods as shown. +In the *build()* method, we take advantage of CSVParser to get the record of each CSV file and put them in CSVRecord list. + +```java +public class CSVDataset extends RandomAccessDataset { + + private final List csvRecords; + + private CSVDataset(Builder builder) { + super(builder); + csvRecords = builder.csvRecords; + } + ... + public static final class Builder extends BaseBuilder { + List csvRecords; + + @Override + protected Builder self() { + return this; + } + + CSVDataset build() throws IOException { + String csvFilePath = "path/malicious_url_data.csv"; + try (Reader reader = Files.newBufferedReader(Paths.get(csvFilePath)); + CSVParser csvParser = + new CSVParser( + reader, + CSVFormat.DEFAULT + .withHeader("url", "isMalicious") + .withFirstRecordAsHeader() + .withIgnoreHeaderCase() + .withTrim())) { + csvRecords = csvParser.getRecords(); + } + return new CSVDataset(this); + } + } + +} +``` + +2. Prepare + +As mentioned, in this example we are taking advantage of CSVParser to prepare the data for us. To prepare +the data on our own, we use the `prepare()` method. Normally here we would load or create any data +for our dataset and then save it in one of the private fields previously created. This `prepare()` method +is called everytime we call `getData()` so in every case we want to only load the data once, we use a +boolean variable called `prepared` to check if it has previously been loaded or prepared. + +Since we don't have to prepare any data on our own for this example, we only have to override it. + +```java +@Override +public void prepare(Progress progress) {} +``` + +There are great [examples](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java) +in our [basicdataset](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset) +folder that show use cases for `prepare()`. + + + +3. Getter + +The getter returns a Record object which contains encoded inputs and labels. +Here, we use simple encoding to transform the url String to an int array and create a NDArray on top of it. +The reason why we use NDList here is that you might have multiple inputs and labels in different tasks. + +```java +@Override +public Record get(NDManager manager, long index) { + // get a CSVRecord given an index + CSVRecord record = csvRecords.get(Math.toIntExact(index)); + NDArray datum = manager.create(encode(record.get("url"))); + NDArray label = manager.create(Float.parseFloat(record.get("isMalicious"))); + return new Record(new NDList(datum), new NDList(label)); +} +``` + +4. Size + +The number of records available to be read in this Dataset. +Here, we can directly use the size of the List. + +```java +@Override +public long availableSize() { + return csvRecords.size(); +} +``` + +Done! +Now, you can use the CSVDataset with the following code snippet: + +```java +CSVDataset dataset = new CSVDataset.Builder().setSampling(batchSize, false).build(); +for (Batch batch : dataset.getData(model.getNDManager())) { + // use head to get first NDArray + batch.getData().head(); + batch.getLabels().head(); + ... + // don't forget to close the batch in the end + batch.close(); +} +``` + +Full example code could be found in [CSVDataset.java](https://github.com/deepjavalibrary/djl/blob/master/docs/development/CSVDataset.java). diff --git a/docs/development/how_to_use_dataset.md b/docs/development/how_to_use_dataset.md index 206957b04a8..8f6c3abd7a0 100644 --- a/docs/development/how_to_use_dataset.md +++ b/docs/development/how_to_use_dataset.md @@ -1,227 +1,58 @@ # Dataset Creation -The Dataset in DJL represents both the raw data and loading process. -RandomAccessDataset implements the Dataset interface and provides comprehensive data loading functionality. -RandomAccessDataset is also a basic dataset that supports random or sequential access of data using indices. -You can easily customize your own dataset by extending RandomAccessDataset. +The [Dataset](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html) in DJL represents both the raw data and the data loading process. +For this reason, training in DJL usually requires that your data be implemented through using a dataset class. +You can choose to use one of the well-known datasets we have [built in](../dataset.md). +Or, you can create a custom dataset. -We provide several well-known datasets that you can use in our [Basic Datasets module](https://javadoc.io/doc/ai.djl/basicdataset/latest/index.html). +## Dataset Helpers -We also provide several built-in datasets that you can easily wrap around existing NDArrays and images. +There are a number of helpers provided by DJL to make it easy to create custom datasets. +If a helper is available, it can make it easier to implement the dataset then building it from scratch: -- ArrayDataset - a dataset that wraps your existing NDArrays as inputs and labels. -- ImageFolder - a dataset that wraps your images under folders for image classification. +### CV -If none of the provided datasets meet your requirements, you can also easily customize you own dataset by extending -the RandomAccessDataset. +- [ImageDataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/ImageDataset.html) - A abstract dataset to create a dataset where the input is an image such as image classification, object detection, and image segmentation +- [ImageClassificationDataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/ImageClassificationDataset.html) - An abstract dataset for image classification +- [AbstractImageFolder](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/AbstractImageFolder.html) - An abstract dataset for loading images in a folder structure. Usually you want the ImageFolderDataset. +- [ImageFolder](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/classification/ImageFolder.html) - A dataset for loading image folders stored in a folder structure +- [ObjectDetectionDataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/cv/ObjectDetectionDataset.html) - An abstract dataset for object detection -## How to use the ArrayDataset +### NLP -The following code illustrates an implementation of [ArrayDataset](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/ArrayDataset.html). -The ArrayDataset is recommended only if your dataset is small enough to fit in memory. +- [TextDataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/nlp/TextDataset.html) - An abstract dataset for NLP where either the input or labels are text-based. +- [TextData](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/utils/TextData.html) - A utility for managing the text within a dataset -```java -// given you have data1, data2 and label1, label2, label3 -ArrayDataset dataset = new ArrayDataset.Builder() - .setData(data1, data2) - .optLabels(label1, label2, label3) - .setSampling(20, false) - .build(); +### Other -``` +- [CsvDataset](https://javadoc.io/doc/ai.djl/basicdataset/latest/ai/djl/basicdataset/tabular/CsvDataset.html) - An dataset for loading data from a .csv file -When you get the `Batch` from `trainer.iterateDataset(dataset)` or `dataset.getData(manager)`, -you can use ``batch.getData()`` to get a NDList. The size of this NDList is the amount of data NDArrays -entered in `setData()`, in this case the size is 2. You can then use `NDList.get(0)` to get your first -data and `NDList.get(1)` to get your second data. -Similarly, you can use `batch.getLabels()` to get a NDList with size 3. +## Custom Datasets -## How to use the ImageFolder dataset +If none of the provided datasets meet your requirements, you can also easily customize you own dataset in a custom class. +While technically the dataset must only implement [`Dataset`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/Dataset.html), it is best to instead extend [`RandomAccessDataset`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/training/dataset/RandomAccessDataset.html). +It manages data randomization and provides comprehensive data loading functionality. -This section outlines the procedure to use the ImageFolder dataset. +The `RandomAccessDataset` is based on making your data records into a list where each record has an index. +Then, it only needs to know how many records there are and how to load each record giving its index. -The ImageFolder dataset is recommended only if you want to iterate through your images with their single label. For example, when training classic image classification. +As part of implementing the dataset, there are two methods that must be defined: -### Step 1: Prepare Images -Arrange your image folder structure as follows: +- `Record get(NDManager manager, long index)` - Returns the record (both input data and output label) for a particular index +- `long availableSize()` - Returns the number of records in the dataset -``` -dataset_root/shoes/Aerobic Shoes1.png -dataset_root/shoes/Aerobic Shose2.png -... -dataset_root/boots/Black Boots.png -dataset_root/boots/White Boots.png -... -dataset_root/pumps/Red Pumps -dataset_root/pumps/Pink Pumps -... -``` - -The dataset will take the folder name e.g. `boots`, `pumps`, `shoes` in sorted order as your labels -**Note:** Nested folder structures are not currently supported - -### Step 2: Use the Dataset -Add the following code snippet to your project to use the ImageFolder dataset. +In addition, the dataset should also have a nested builder class to contain details on how to load the dataset. +The builder would extend `RandomAccessDataset.BaseBuilder`. +This provides an avenue to modify how RandomAccessDataset loads the data. +You can also add your own options into the builder. +For an example of how this would look like, see [`ImageFolder.Builder`](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java). -```java -// set the image folder path -Repository repository = Repository.newInstance("folder", Paths.get("/imagefolder"); -ImageFolder dataset = - new ImageFolder.Builder() - .setRepository(repository) - .addTransform(new Resize(100, 100)) - .addTransform(new ToTensor()) - .setSampling(batchSize, true) - .build(); -// call prepare before using -dataset.prepare(); - -// to get the synset -List synset = dataset.getSynset(); -``` +You can also view this example of creating a [new CSV dataset](example_dataset.md). -Typically, you would add pre-processing pipelines like Resize(), in order to batchify the dataset, and ToTensor(), which converts the image NDArray to Tensor NDArray. +Many of the abstract dataset helpers above also extend `RandomAccessDataset`. +When using them, most of the same information applies. +You may be asked to implement slightly different methods depending on the particular extended class. +You will also want to extend that classes `BaseBuilder` instead of the one found in `RandomAccessDataset` to get the additional data loading options from the helper. -## How to create a custom dataset - -If the provided Datasets don't meet your requirements, you can also easily extend our dataset to create your own customized dataset. - -Let's take CSVDataset, which can load a csv file, for example. - -### Step 1: Prerequisites -For this example, we'll use [malicious_url_data.csv](https://github.com/incertum/cyber-matrix-ai/blob/master/Malicious-URL-Detection-Deep-Learning/data/url_data_mega_deep_learning.csv). - -The CSV file has the following format. - -| URL | isMalicious | -| ----------- | ----------- | -| sample.url.good.com | 0 | -| sample.url.bad.com | 1 | - -We'll also use the 3rd party [Apache Commons](https://commons.apache.org/) library to read the CSV file. To use the library, include the following dependency: - -``` -api group: 'org.apache.commons', name: 'commons-csv', version: '1.7' -``` - -### Step 2: Implementation -In order to extend the dataset, the following dependencies are required: - -``` -api "ai.djl:api:0.17.0" -api "ai.djl:basicdataset:0.17.0" -``` - -There are four parts we need to implement for CSVDataset. - -1. Constructor and Builder - -First, we need a private field that holds the CSVRecord list from the csv file. -We create a constructor and pass the CSVRecord list from builder to the class field. -For builder, we have all we need in `BaseBuilder` so we only need to include the two minimal methods as shown. -In the *build()* method, we take advantage of CSVParser to get the record of each CSV file and put them in CSVRecord list. - -```java -public class CSVDataset extends RandomAccessDataset { - - private final List csvRecords; - - private CSVDataset(Builder builder) { - super(builder); - csvRecords = builder.csvRecords; - } - ... - public static final class Builder extends BaseBuilder { - List csvRecords; - - @Override - protected Builder self() { - return this; - } - - CSVDataset build() throws IOException { - String csvFilePath = "path/malicious_url_data.csv"; - try (Reader reader = Files.newBufferedReader(Paths.get(csvFilePath)); - CSVParser csvParser = - new CSVParser( - reader, - CSVFormat.DEFAULT - .withHeader("url", "isMalicious") - .withFirstRecordAsHeader() - .withIgnoreHeaderCase() - .withTrim())) { - csvRecords = csvParser.getRecords(); - } - return new CSVDataset(this); - } - } - -} -``` - -2. Prepare - -As mentioned, in this example we are taking advantage of CSVParser to prepare the data for us. To prepare -the data on our own, we use the `prepare()` method. Normally here we would load or create any data -for our dataset and then save it in one of the private fields previously created. This `prepare()` method -is called everytime we call `getData()` so in every case we want to only load the data once, we use a -boolean variable called `prepared` to check if it has previously been loaded or prepared. - -Since we don't have to prepare any data on our own for this example, we only have to override it. - -```java -@Override -public void prepare(Progress progress) {} -``` - -There are great [examples](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java) -in our [basicdataset](https://github.com/deepjavalibrary/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset) -folder that show use cases for `prepare()`. - - - -3. Getter - -The getter returns a Record object which contains encoded inputs and labels. -Here, we use simple encoding to transform the url String to an int array and create a NDArray on top of it. -The reason why we use NDList here is that you might have multiple inputs and labels in different tasks. - -```java -@Override -public Record get(NDManager manager, long index) { - // get a CSVRecord given an index - CSVRecord record = csvRecords.get(Math.toIntExact(index)); - NDArray datum = manager.create(encode(record.get("url"))); - NDArray label = manager.create(Float.parseFloat(record.get("isMalicious"))); - return new Record(new NDList(datum), new NDList(label)); -} -``` - -4. Size - -The number of records available to be read in this Dataset. -Here, we can directly use the size of the List. - -```java -@Override -public long availableSize() { - return csvRecords.size(); -} -``` - -Done! -Now, you can use the CSVDataset with the following code snippet: - -```java -CSVDataset dataset = new CSVDataset.Builder().setSampling(batchSize, false).build(); -for (Batch batch : dataset.getData(model.getNDManager())) { - // use head to get first NDArray - batch.getData().head(); - batch.getLabels().head(); - ... - // don't forget to close the batch in the end - batch.close(); -} -``` - -Full example code could be found in [CSVDataset.java](https://github.com/deepjavalibrary/djl/blob/master/docs/development/CSVDataset.java). +If you create a new dataset for public dataset, consider contributing that dataset back to DJL for others to use. +You can follow [these instructions](add_dataset_to_djl.md) for adding it. \ No newline at end of file diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 041ae3d3e97..026ccb953cd 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -54,6 +54,7 @@ nav: - 'docs/development/troubleshooting.md' - 'docs/development/dependency_management.md' - 'docs/development/add_model_to_model-zoo.md' + - 'docs/development/add_dataset_to_djl.md' - 'docs/roadmap.md' - 'docs/faq.md' - Tutorials: @@ -66,7 +67,7 @@ nav: - 'jupyter/transfer_learning_on_cifar10.ipynb' - Load your own BERT: - BERT with MXNet: 'jupyter/mxnet/load_your_own_mxnet_bert.ipynb' - - BERT with PyTorch: 'jupyter/pytorch/load_your_own_pytorch.ipynb' + - BERT with PyTorch: 'jupyter/pytorch/load_your_own_pytorch_bert.ipynb' - Guides: - 'docs/engine.md' - Models: diff --git a/engines/mxnet/mxnet-engine/README.md b/engines/mxnet/mxnet-engine/README.md index 50d1e4df1ad..5713b8712f7 100644 --- a/engines/mxnet/mxnet-engine/README.md +++ b/engines/mxnet/mxnet-engine/README.md @@ -41,7 +41,10 @@ You can pull the MXNet engine from the central Maven repository by including the By default, DJL will download the Apache MXNet native libraries into [cache folder](../../../docs/development/cache_management.md) the first time you run DJL. It will automatically determine the appropriate jars for your system based on the platform and GPU support. -You can choose a native library based on your platform if you don't have network access at runtime. +If you do not want to rely on the download because you don't have network access at runtime or for other reasons, there are additional options. +The easiest option is to add a DJL native library package to your project dependencies. +The available packages for your platform can be found below. +Finally, you can also specify the path to a valid MXNet build using the `MXNET_LIBRARY_PATH` environment variable. ### macOS @@ -116,15 +119,14 @@ Apache MXNet requires Visual C++ Redistributable Packages. If you encounter an U DJL on Windows, please download and install [Visual C++ 2019 Redistributable Packages](https://support.microsoft.com/en-us/help/2977003/the-latest-supported-visual-c-downloads) and reboot. -For the Windows platform, you can use CPU package. MXNet windows GPU native -library size are large, we no longer provide GPU package, instead you have to -use the Automatic package. +For the windows platform, we support both CPU and GPU. +The CPU can be found using either the automatic runtime detection or through adding the CPU jar to your dependencies. +However, due to the size of the windows GPU native library, we do not offer GPU support through a dependency jar. +You can still access GPU on windows by using the [automatic runtime download](#Installation). #### Windows GPU -- ai.djl.mxnet:mxnet-native-auto:1.8.0 - - This package supports CUDA 11.0 and CUDA 10.2 for Windows. +This package supports CUDA 11.0 and CUDA 10.2 for Windows. ### Windows CPU