diff --git a/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java b/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java
index 997d705270c..2122d06faff 100644
--- a/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java
+++ b/api/src/main/java/ai/djl/training/dataset/ArrayDataset.java
@@ -21,19 +21,27 @@
/**
* {@code ArrayDataset} is an implementation of {@link RandomAccessDataset} that consist entirely of
- * large {@link NDArray}s. There can be multiple data and label {@link NDArray}s within the dataset.
- * Each sample will be retrieved by indexing each {@link NDArray} along the first dimension.
+ * large {@link NDArray}s. It is recommended only for datasets small enough to fit in memory that
+ * come in array formats. Otherwise, consider directly using the {@link RandomAccessDataset}
+ * instead.
+ *
+ *
There can be multiple data and label {@link NDArray}s within the dataset. Each sample will be
+ * retrieved by indexing each {@link NDArray} along the first dimension.
*
*
The following is an example of how to use ArrayDataset:
*
*
Suppose you get a {@link Batch} from {@code trainer.iterateDataset(dataset)} or {@code
+ * dataset.getData(manager)}. In the data of this batch, it will be an NDList with one NDArray for
+ * each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.
+ *
* @see Dataset
*/
public class ArrayDataset extends RandomAccessDataset {
diff --git a/api/src/main/java/ai/djl/training/dataset/Dataset.java b/api/src/main/java/ai/djl/training/dataset/Dataset.java
index 90a21da7904..9dacf169437 100644
--- a/api/src/main/java/ai/djl/training/dataset/Dataset.java
+++ b/api/src/main/java/ai/djl/training/dataset/Dataset.java
@@ -21,7 +21,9 @@
/**
* An interface to represent a set of sample data/label pairs to train a model.
*
- * @see The guide on datasets
+ * @see The guide to datasets
+ * @see The guide to
+ * implementing a custom dataset
*/
public interface Dataset {
diff --git a/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java b/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java
index eab783fece4..0d5e20623e6 100644
--- a/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java
+++ b/api/src/main/java/ai/djl/training/dataset/RandomAccessDataset.java
@@ -34,6 +34,11 @@
/**
* RandomAccessDataset represent the dataset that support random access reads. i.e. it could access
* a specific data item given the index.
+ *
+ *
Almost all datasets in DJL extend, either directly or indirectly, {@link RandomAccessDataset}.
+ *
+ * @see The guide to
+ * implementing a custom dataset
*/
public abstract class RandomAccessDataset implements Dataset {
diff --git a/basicdataset/README.md b/basicdataset/README.md
index c342cd15d43..4cd0afbc85c 100644
--- a/basicdataset/README.md
+++ b/basicdataset/README.md
@@ -4,17 +4,7 @@
This module contains a number of basic and standard datasets in the Deep Java Library's (DJL). These datasets are used to train deep learning models.
-## List of datasets
-
-This module contains the following datasets:
-
-- [MNIST](http://yann.lecun.com/exdb/mnist/) - A handwritten digits dataset
-- [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) - A dataset consisting of 60,000 32x32 color images in 10 classes
-- [Coco](http://cocodataset.org) - A large-scale object detection, segmentation, and captioning dataset that contains 1.5 million object instances
- - You have to manually add `com.twelvemonkeys.imageio:imageio-jpeg:3.5` dependency to your project
-- [ImageNet](http://www.image-net.org/) - An image database organized according to the WordNet hierarchy
- >**Note**: You have to manually download the ImageNet dataset due to licensing requirements.
-- [Pikachu](http://d2l.ai/chapter_computer-vision/object-detection-dataset.html) - 1000 Pikachu images of different angles and sizes created using an open source 3D Pikachu model
+You can find the datasets provided by this module on our [docs](http://docs.djl.ai/docs/dataset.html).
## Documentation
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java
index 3cdb4c0223f..b36f2843db9 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java
@@ -34,6 +34,13 @@
/**
* Coco image detection dataset from http://cocodataset.org/#home.
*
+ *
Coco is a large-scale object detection, segmentation, and captioning dataset although only
+ * object detection is implemented at thsi time. It contains 1.5 million object instances and is one
+ * of the standard benchmark object detection datasets.
+ *
+ *
To use this dataset, you have to manually add {@code
+ * com.twelvemonkeys.imageio:imageio-jpeg:3.5} as a dependency in your project.
+ *
*
Each image might have different {@link ai.djl.ndarray.types.Shape}s.
*/
public class CocoDetection extends ObjectDetectionDataset {
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java
index 0a7130dfac7..f1ad1f3870a 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java
@@ -39,7 +39,14 @@
import java.util.Map;
import java.util.Optional;
-/** Pikachu image detection dataset that contains multiple Pikachus in each image. */
+/**
+ * Pikachu image detection dataset that contains multiple Pikachus in each image.
+ *
+ *
It was based on a section from the [Dive into Deep Learning
+ * book](http://d2l.ai/chapter_computer-vision/object-detection-dataset.html). It contains 1000
+ * Pikachu images of different angles and sizes created using an open source 3D Pikachu model. Each
+ * image contains only a single pikachu.
+ */
public class PikachuDetection extends ObjectDetectionDataset {
private static final String VERSION = "1.0";
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java
index fe8ccad10ca..b370ad5d6fd 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java
@@ -34,7 +34,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-/** A dataset for loading image files stored in a folder structure. */
+/**
+ * A dataset for loading image files stored in a folder structure.
+ *
+ *
Usually, you want to use {@link ImageFolder} instead.
+ */
public abstract class AbstractImageFolder extends ImageClassificationDataset {
private static final Logger logger = LoggerFactory.getLogger(AbstractImageFolder.class);
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java
index 21d7a5bc0f6..5e94027281e 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java
@@ -34,6 +34,9 @@
/**
* CIFAR10 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html.
*
+ *
It consists of 60,000 32x32 color images with 10 classes. It can train in a few hours with a
+ * GPU.
+ *
*
Each sample is an image (in 3-D {@link NDArray}) with shape (32, 32, 3).
*/
public final class Cifar10 extends ArrayDataset {
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
index abbad625b92..90f1c0e44f8 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java
@@ -33,9 +33,11 @@
/**
* FashMnist is a dataset from Zalando article images
- * https://github.com/zalandoresearch/fashion-mnist.
+ * (https://github.com/zalandoresearch/fashion-mnist).
*
- *
Each sample is an image (in 3-D NDArray) with shape (28, 28, 1).
+ *
Each sample is a grayscale image (in 3-D NDArray) with shape (28, 28, 1).
+ *
+ *
It was created to be a drop in replacement for {@link Mnist}, but have a less simplistic task.
*/
public final class FashionMnist extends ArrayDataset {
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java
index fa2251477d9..b9da9ee0211 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java
@@ -24,6 +24,8 @@
/**
* A dataset for loading image files stored in a folder structure.
*
+ *
Below is an example directory layout for the image folder:
+ *
*
* The image folder should be structured as follows:
* root/shoes/Aerobic Shoes1.png
@@ -32,11 +34,35 @@
* root/boots/Black Boots.png
* root/boots/White Boots.png
* ...
- * root/pumps/Red Pumps
- * root/pumps/Pink Pumps
+ * root/pumps/Red Pumps.png
+ * root/pumps/Pink Pumps.png
* ...
+ *
* here shoes, boots, pumps are your labels
*
+ *
+ *
Here, the dataset will take the folder names (shoes, boots, bumps) in sorted order as your
+ * labels. Nested folder structures are not currently supported.
+ *
+ *
Then, you can create your instance of the dataset as follows:
+ *
+ *
+ * // set the image folder path
+ * Repository repository = Repository.newInstance("folder", Paths.get("/path/to/imagefolder/root");
+ * ImageFolder dataset =
+ * new ImageFolder.Builder()
+ * .setRepository(repository)
+ * .addTransform(new Resize(100, 100)) // Use image transforms as necessary for your data
+ * .addTransform(new ToTensor()) // Usually required as the last transform to convert images to tensors
+ * .setSampling(batchSize, true)
+ * .build();
+ *
+ * // call prepare before using
+ * dataset.prepare();
+ *
+ * // to get the synset or label names
+ * List>String< synset = dataset.getSynset();
+ *
*/
public final class ImageFolder extends AbstractImageFolder {
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
index 257fd6bc44d..1cb87668d83 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java
@@ -34,7 +34,11 @@
/**
* MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist.
*
- *
Each sample is an image (in 3-D NDArray) with shape (28, 28, 1).
+ *
Each sample is a grayscale image (in 3-D NDArray) with shape (28, 28, 1).
+ *
+ *
It is a common starting dataset because it is small and can train within minutes. However, it
+ * is an overly easy task that even poor models can still perform very well on. Instead, consider
+ * {@link FashionMnist} which offers a comparable speed but a more reasonable difficulty task.
*/
public final class Mnist extends ArrayDataset {
diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java
index 7716d5ab1a9..aa3d0280db0 100644
--- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java
+++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordQuestionAnsweringDataset.java
@@ -39,6 +39,8 @@
* questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every
* question is a segment of text, or span, from the corresponding reading passage, or the question
* might be unanswerable.
+ *
+ * @see Dataset website
*/
@SuppressWarnings("unchecked")
public class StanfordQuestionAnsweringDataset extends TextDataset implements RawDataset