Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Multi-threaded #743

Merged
merged 1 commit into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions api/src/main/java/ai/djl/training/DefaultTrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Predicate;

/** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */
Expand All @@ -33,6 +35,7 @@ public class DefaultTrainingConfig implements TrainingConfig {
private Optimizer optimizer;
private Device[] devices;
private Loss loss;
private ExecutorService executorService;
private List<Evaluator> evaluators;
private List<TrainingListener> listeners;

Expand Down Expand Up @@ -112,6 +115,26 @@ public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
return this;
}

/**
* Sets the {@link ExecutorService} with the global {@link ForkJoinPool#commonPool()}.
*
* @return this {@link DefaultTrainingConfig}
*/
public DefaultTrainingConfig optExecutorService() {
return optExecutorService(ForkJoinPool.commonPool());
}

/**
* Sets the {@link ExecutorService} to train with multiple threads.
*
* @param executorService the executor service
* @return this {@link DefaultTrainingConfig}
*/
public DefaultTrainingConfig optExecutorService(ExecutorService executorService) {
this.executorService = executorService;
return this;
}

/**
* Adds an {@link Evaluator} that needs to be computed during training.
*
Expand Down Expand Up @@ -161,6 +184,11 @@ public Loss getLossFunction() {
return loss;
}

@Override
public ExecutorService getExecutorService() {
return executorService;
}

/** {@inheritDoc} */
@Override
public List<Evaluator> getEvaluators() {
Expand Down
79 changes: 61 additions & 18 deletions api/src/main/java/ai/djl/training/EasyTrain.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import ai.djl.translate.TranslateException;
import ai.djl.util.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;

/** Helper for easy training of a whole model, a trainining batch, or a validation batch. */
public final class EasyTrain {
Expand Down Expand Up @@ -86,24 +90,45 @@ public static void trainBatch(Trainer trainer, Batch batch) {
BatchData batchData =
new BatchData(batch, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());
try (GradientCollector collector = trainer.newGradientCollector()) {
for (Batch split : splits) {
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.forward(data, labels);
long time = System.nanoTime();
NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
collector.backward(lossValue);
trainer.addMetric("backward", time);
time = System.nanoTime();
batchData.getLabels().put(labels.get(0).getDevice(), labels);
batchData.getPredictions().put(preds.get(0).getDevice(), preds);
trainer.addMetric("training-metrics", time);

if (splits.length > 1 && trainer.getExecutorService().isPresent()) {
// multi-threaded
ExecutorService executor = trainer.getExecutorService().get();
List<CompletableFuture<Boolean>> futures = new ArrayList<>(splits.length);
for (Batch split : splits) {
futures.add(
CompletableFuture.supplyAsync(
() -> trainSplit(trainer, collector, batchData, split),
executor));
}
CompletableFuture.allOf(futures.stream().toArray(CompletableFuture[]::new));
} else {
// sequence
for (Batch split : splits) {
trainSplit(trainer, collector, batchData, split);
}
}
}

trainer.notifyListeners(listener -> listener.onTrainingBatch(trainer, batchData));
}

private static boolean trainSplit(
Trainer trainer, GradientCollector collector, BatchData batchData, Batch split) {
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.forward(data, labels);
long time = System.nanoTime();
NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
collector.backward(lossValue);
trainer.addMetric("backward", time);
time = System.nanoTime();
batchData.getLabels().put(labels.get(0).getDevice(), labels);
batchData.getPredictions().put(preds.get(0).getDevice(), preds);
trainer.addMetric("training-metrics", time);
return true;
}

/**
* Validates the given batch of data.
*
Expand All @@ -122,17 +147,35 @@ public static void validateBatch(Trainer trainer, Batch batch) {
BatchData batchData =
new BatchData(batch, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());

for (Batch split : splits) {
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.evaluate(data);
batchData.getLabels().put(labels.get(0).getDevice(), labels);
batchData.getPredictions().put(preds.get(0).getDevice(), preds);
if (splits.length > 1 && trainer.getExecutorService().isPresent()) {
// multi-threaded
ExecutorService executor = trainer.getExecutorService().get();
List<CompletableFuture<Boolean>> futures = new ArrayList<>(splits.length);
for (Batch split : splits) {
futures.add(
CompletableFuture.supplyAsync(
() -> validateSplit(trainer, batchData, split), executor));
}
CompletableFuture.allOf(futures.stream().toArray(CompletableFuture[]::new));
} else {
// sequence
for (Batch split : splits) {
validateSplit(trainer, batchData, split);
}
}

trainer.notifyListeners(listener -> listener.onValidationBatch(trainer, batchData));
}

private static boolean validateSplit(Trainer trainer, BatchData batchData, Batch split) {
NDList data = split.getData();
NDList labels = split.getLabels();
NDList preds = trainer.evaluate(data);
batchData.getLabels().put(labels.get(0).getDevice(), labels);
batchData.getPredictions().put(preds.get(0).getDevice(), preds);
return true;
}

/**
* Evaluates the test dataset.
*
Expand Down
89 changes: 0 additions & 89 deletions api/src/main/java/ai/djl/training/ParallelTrain.java

This file was deleted.

15 changes: 14 additions & 1 deletion api/src/main/java/ai/djl/training/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -72,6 +74,7 @@ public class Trainer implements AutoCloseable {
private ParameterStore parameterStore;
private List<Evaluator> evaluators;
private Loss loss;
private ExecutorService executorService;

private boolean gradientsChecked;

Expand All @@ -91,6 +94,7 @@ public Trainer(Model model, TrainingConfig trainingConfig) {
Objects.requireNonNull(loss, "You must specify a loss for the trainer");
evaluators = new ArrayList<>(trainingConfig.getEvaluators());
evaluators.add(loss); // track loss as an evaluator by default
executorService = trainingConfig.getExecutorService();

ParameterServer parameterServer =
manager.getEngine().newParameterServer(trainingConfig.getOptimizer());
Expand Down Expand Up @@ -129,7 +133,7 @@ public void initialize(Shape... shapes) {
* @throws TranslateException if there is an error while processing input
*/
public Iterable<Batch> iterateDataset(Dataset dataset) throws IOException, TranslateException {
return dataset.getData(getManager());
return dataset.getData(getManager(), executorService);
}

/**
Expand Down Expand Up @@ -238,6 +242,15 @@ public Model getModel() {
return model;
}

/**
* Returns the {@link ExecutorService}.
*
* @return the {@link ExecutorService}
*/
public Optional<ExecutorService> getExecutorService() {
return Optional.ofNullable(executorService);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use Optional for all optional arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably discuss that as a larger improvement, but I would be in support of it

}

/**
* Gets all {@link Evaluator}s.
*
Expand Down
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/training/TrainingConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.PairList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;

/**
Expand Down Expand Up @@ -48,6 +49,9 @@
* minimize the loss function. There are a variety of optimizers, most of which are variants
* of stochastic gradient descent. When you are just starting, you can use the default
* optimizer. Later on, customizing the optimizer can result in faster training.
* <li>{@link ExecutorService} - The executorService is used for parallelization when training
* batches on multiple GPUs or loading data from the dataset. If none is provided, all
* operations with be sequential.
* <li>{@link TrainingListener} - The training listeners add additional functionality to the
* training process through a listener interface. This can include showing training progress,
* stopping early if the training fails, or recording performance metrics. We offer several
Expand Down Expand Up @@ -87,6 +91,13 @@ public interface TrainingConfig {
*/
Loss getLossFunction();

/**
* Gets the {@link ExecutorService} for parallelization.
*
* @return an {@link ExecutorService}
*/
ExecutorService getExecutorService();

/**
* Returns the list of {@link Evaluator}s that should be computed during training.
*
Expand Down
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/training/dataset/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.concurrent.ExecutorService;

/**
* An interface to represent a set of sample data/label pairs to train a model.
Expand All @@ -34,6 +35,20 @@ public interface Dataset {
*/
Iterable<Batch> getData(NDManager manager) throws IOException, TranslateException;

/**
* Fetches an iterator that can iterate through the {@link Dataset} with multiple threads.
*
* @param manager the dataset to iterate through
* @param executorService the executorService to use for multi-threading
* @return an {@link Iterable} of {@link Batch} that contains batches of data from the dataset
* @throws IOException for various exceptions depending on the dataset
* @throws TranslateException if there is an error while processing input
*/
default Iterable<Batch> getData(NDManager manager, ExecutorService executorService)
throws IOException, TranslateException {
return getData(manager);
}

/**
* Prepares the dataset for use.
*
Expand Down
Loading