diff --git a/jupyter/TrainingUtils.java b/jupyter/TrainingUtils.java new file mode 100644 index 00000000000..1a5ea729e95 --- /dev/null +++ b/jupyter/TrainingUtils.java @@ -0,0 +1,85 @@ +import ai.djl.Model; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingListener; +import ai.djl.training.dataset.Batch; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.util.ProgressBar; +import java.io.IOException; +import java.nio.file.Paths; + +public class TrainingUtils { + + public static void fit( + Trainer trainer, + int numEpoch, + Dataset trainingDataset, + Dataset validateDataset, + String outputDir, + String modelName) + throws IOException { + for (int epoch = 0; epoch < numEpoch; epoch++) { + for (Batch batch : trainer.iterateDataset(trainingDataset)) { + trainer.trainBatch(batch); + trainer.step(); + batch.close(); + } + + if (validateDataset != null) { + for (Batch batch : trainer.iterateDataset(validateDataset)) { + trainer.validateBatch(batch); + batch.close(); + } + } + // reset training and validation metric at end of epoch + trainer.resetTrainingMetrics(); + // save model at end of each epoch + if (outputDir != null) { + Model model = trainer.getModel(); + model.setProperty("Epoch", String.valueOf(epoch)); + model.save(Paths.get(outputDir), modelName); + } + } + } + + public static TrainingListener getTrainingListener( + ProgressBar trainingProgressBar, ProgressBar validateProgressBar) { + return new SimpleTrainingListener(trainingProgressBar, validateProgressBar); + } + + private static final class SimpleTrainingListener implements TrainingListener { + + private ProgressBar trainingProgressBar; + private ProgressBar validateProgressBar; + private int trainingProgress; + private int validateProgress; + + public SimpleTrainingListener( + ProgressBar trainingProgressBar, ProgressBar validateProgressBar) { + this.trainingProgressBar = trainingProgressBar; + this.validateProgressBar = validateProgressBar; + } + + /** {@inheritDoc} */ + @Override + public void onTrainingBatch() { + if (trainingProgressBar != null) { + trainingProgressBar.update(trainingProgress++); + } + } + + /** {@inheritDoc} */ + @Override + public void onValidationBatch() { + if (validateProgressBar != null) { + validateProgressBar.update(validateProgress++); + } + } + + /** {@inheritDoc} */ + @Override + public void onEpoch() { + trainingProgress = 0; + validateProgress = 0; + } + } +} diff --git a/jupyter/transfer_learning_on_cifar10.ipynb b/jupyter/transfer_learning_on_cifar10.ipynb index 96fcbfed474..028d6c011a9 100644 --- a/jupyter/transfer_learning_on_cifar10.ipynb +++ b/jupyter/transfer_learning_on_cifar10.ipynb @@ -36,14 +36,11 @@ "%maven ai.djl:api:0.2.0\n", "%maven ai.djl:basicdataset:0.2.0\n", "%maven ai.djl:model-zoo:0.2.0\n", - "%maven ai.djl:examples:0.2.0\n", "%maven ai.djl:repository:0.2.0\n", "%maven ai.djl.mxnet:mxnet-engine:0.2.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.2.0\n", - "%maven org.apache.logging.log4j:log4j-slf4j-impl:2.12.1\n", - "%maven org.apache.logging.log4j:log4j-core:2.12.1\n", - "%maven org.apache.logging.log4j:log4j-api:2.12.1 \n", "%maven org.slf4j:slf4j-api:1.7.26\n", + "%maven org.slf4j:slf4j-simple:1.7.26\n", "%maven net.java.dev.jna:jna:5.3.0" ] }, @@ -83,7 +80,6 @@ "source": [ "import ai.djl.*;\n", "import ai.djl.basicdataset.*;\n", - "import ai.djl.examples.training.util.*;\n", "import ai.djl.modality.cv.transform.*;\n", "import ai.djl.mxnet.zoo.*;\n", "import ai.djl.ndarray.*;\n", @@ -238,8 +234,15 @@ "int epoch = 10;\n", "Trainer trainer = model.newTrainer(config);\n", "Shape inputShape = new Shape(1, 3, 32, 32);\n", - "trainer.initialize(inputShape);\n", - "\n", + "trainer.initialize(inputShape);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "float trainingAccuracy = 0f;\n", "for (int i = 0; i < epoch; ++i) {\n", " int index = 0;\n", @@ -284,7 +287,7 @@ "metadata": {}, "source": [ "## Use the `fit` method\n", - "Instead of writing the two `for` loops, you can use the `fit` method in `TrainingUtils`, which will handle everything automatically. Just pass your `trainer`, number of epochs to train, training dataset, validation dataset (if any), model output path, and model name. It will save your model checkpoint at the end of each epoch." + "Instead of writing the two `for` loops, you can use the `fit` method in [TrainingUtils](TrainingUtils.java), which will handle everything automatically. Just pass your `trainer`, number of epochs to train, training dataset, validation dataset (if any), model output path, and model name. It will save your model checkpoint at the end of each epoch." ] }, { @@ -293,6 +296,10 @@ "metadata": {}, "outputs": [], "source": [ + "%load TrainingUtils.java\n", + "\n", + "trainer.setTrainingListener(TrainingUtils.getTrainingListener(progressBar, null));\n", + "\n", "TrainingUtils.fit(trainer, epoch, trainDataset, null, \"build/resnet\", \"resnetv1\");" ] },