Skip to content

Commit

Permalink
Fix transfer learning jupyter notebook.
Browse files Browse the repository at this point in the history
Change-Id: I0263fc6eab7646fe63f7f10a59119543070b75f0
  • Loading branch information
frankfliu authored and zachgk committed Feb 24, 2020
1 parent 69f5acf commit 03ab63c
Showing 1 changed file with 1 addition and 45 deletions.
46 changes: 1 addition & 45 deletions jupyter/TrainingUtils.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import ai.djl.Model;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;

Expand Down Expand Up @@ -31,7 +29,7 @@ public static void fit(
}
}
// reset training and validation evaluators at end of epoch
trainer.resetEvaluators();
trainer.endEpoch();
// save model at end of each epoch
if (outputDir != null) {
Model model = trainer.getModel();
Expand All @@ -40,46 +38,4 @@ public static void fit(
}
}
}

public static TrainingListener getTrainingListener(
ProgressBar trainingProgressBar, ProgressBar validateProgressBar) {
return new SimpleTrainingListener(trainingProgressBar, validateProgressBar);
}

private static final class SimpleTrainingListener implements TrainingListener {

private ProgressBar trainingProgressBar;
private ProgressBar validateProgressBar;
private int trainingProgress;
private int validateProgress;

public SimpleTrainingListener(
ProgressBar trainingProgressBar, ProgressBar validateProgressBar) {
this.trainingProgressBar = trainingProgressBar;
this.validateProgressBar = validateProgressBar;
}

/** {@inheritDoc} */
@Override
public void onTrainingBatch() {
if (trainingProgressBar != null) {
trainingProgressBar.update(trainingProgress++);
}
}

/** {@inheritDoc} */
@Override
public void onValidationBatch() {
if (validateProgressBar != null) {
validateProgressBar.update(validateProgress++);
}
}

/** {@inheritDoc} */
@Override
public void onEpoch() {
trainingProgress = 0;
validateProgress = 0;
}
}
}

0 comments on commit 03ab63c

Please sign in to comment.