Skip to content

Commit

Permalink
Avoid depends on example project.
Browse files Browse the repository at this point in the history
Change-Id: I9ef03b2303dae8be60cc26122f3574aa3a9bed5b
  • Loading branch information
frankfliu authored and Rakesh Vasudevan committed Nov 29, 2019
1 parent f83bf5c commit fadf648
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 8 deletions.
85 changes: 85 additions & 0 deletions jupyter/TrainingUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import ai.djl.Model;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;

public class TrainingUtils {

public static void fit(
Trainer trainer,
int numEpoch,
Dataset trainingDataset,
Dataset validateDataset,
String outputDir,
String modelName)
throws IOException {
for (int epoch = 0; epoch < numEpoch; epoch++) {
for (Batch batch : trainer.iterateDataset(trainingDataset)) {
trainer.trainBatch(batch);
trainer.step();
batch.close();
}

if (validateDataset != null) {
for (Batch batch : trainer.iterateDataset(validateDataset)) {
trainer.validateBatch(batch);
batch.close();
}
}
// reset training and validation metric at end of epoch
trainer.resetTrainingMetrics();
// save model at end of each epoch
if (outputDir != null) {
Model model = trainer.getModel();
model.setProperty("Epoch", String.valueOf(epoch));
model.save(Paths.get(outputDir), modelName);
}
}
}

public static TrainingListener getTrainingListener(
ProgressBar trainingProgressBar, ProgressBar validateProgressBar) {
return new SimpleTrainingListener(trainingProgressBar, validateProgressBar);
}

private static final class SimpleTrainingListener implements TrainingListener {

private ProgressBar trainingProgressBar;
private ProgressBar validateProgressBar;
private int trainingProgress;
private int validateProgress;

public SimpleTrainingListener(
ProgressBar trainingProgressBar, ProgressBar validateProgressBar) {
this.trainingProgressBar = trainingProgressBar;
this.validateProgressBar = validateProgressBar;
}

/** {@inheritDoc} */
@Override
public void onTrainingBatch() {
if (trainingProgressBar != null) {
trainingProgressBar.update(trainingProgress++);
}
}

/** {@inheritDoc} */
@Override
public void onValidationBatch() {
if (validateProgressBar != null) {
validateProgressBar.update(validateProgress++);
}
}

/** {@inheritDoc} */
@Override
public void onEpoch() {
trainingProgress = 0;
validateProgress = 0;
}
}
}
23 changes: 15 additions & 8 deletions jupyter/transfer_learning_on_cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,11 @@
"%maven ai.djl:api:0.2.0\n",
"%maven ai.djl:basicdataset:0.2.0\n",
"%maven ai.djl:model-zoo:0.2.0\n",
"%maven ai.djl:examples:0.2.0\n",
"%maven ai.djl:repository:0.2.0\n",
"%maven ai.djl.mxnet:mxnet-engine:0.2.0\n",
"%maven ai.djl.mxnet:mxnet-model-zoo:0.2.0\n",
"%maven org.apache.logging.log4j:log4j-slf4j-impl:2.12.1\n",
"%maven org.apache.logging.log4j:log4j-core:2.12.1\n",
"%maven org.apache.logging.log4j:log4j-api:2.12.1 \n",
"%maven org.slf4j:slf4j-api:1.7.26\n",
"%maven org.slf4j:slf4j-simple:1.7.26\n",
"%maven net.java.dev.jna:jna:5.3.0"
]
},
Expand Down Expand Up @@ -83,7 +80,6 @@
"source": [
"import ai.djl.*;\n",
"import ai.djl.basicdataset.*;\n",
"import ai.djl.examples.training.util.*;\n",
"import ai.djl.modality.cv.transform.*;\n",
"import ai.djl.mxnet.zoo.*;\n",
"import ai.djl.ndarray.*;\n",
Expand Down Expand Up @@ -238,8 +234,15 @@
"int epoch = 10;\n",
"Trainer trainer = model.newTrainer(config);\n",
"Shape inputShape = new Shape(1, 3, 32, 32);\n",
"trainer.initialize(inputShape);\n",
"\n",
"trainer.initialize(inputShape);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"float trainingAccuracy = 0f;\n",
"for (int i = 0; i < epoch; ++i) {\n",
" int index = 0;\n",
Expand Down Expand Up @@ -284,7 +287,7 @@
"metadata": {},
"source": [
"## Use the `fit` method\n",
"Instead of writing the two `for` loops, you can use the `fit` method in `TrainingUtils`, which will handle everything automatically. Just pass your `trainer`, number of epochs to train, training dataset, validation dataset (if any), model output path, and model name. It will save your model checkpoint at the end of each epoch."
"Instead of writing the two `for` loops, you can use the `fit` method in [TrainingUtils](TrainingUtils.java), which will handle everything automatically. Just pass your `trainer`, number of epochs to train, training dataset, validation dataset (if any), model output path, and model name. It will save your model checkpoint at the end of each epoch."
]
},
{
Expand All @@ -293,6 +296,10 @@
"metadata": {},
"outputs": [],
"source": [
"%load TrainingUtils.java\n",
"\n",
"trainer.setTrainingListener(TrainingUtils.getTrainingListener(progressBar, null));\n",
"\n",
"TrainingUtils.fit(trainer, epoch, trainDataset, null, \"build/resnet\", \"resnetv1\");"
]
},
Expand Down

0 comments on commit fadf648

Please sign in to comment.