async function retrain() {
if (controllerDataset.xs == null) {
throw new Error('Add some examples before retraining!');
}
// //Creates the optimizers which drives training of the model.
// const optimizer = tf.train.adam(ui.getLearningRate());
// // We use categoricalCrossentropy which is the loss function we use for
// // categorical classification which measures the error between our predicted
// // probability distribution over classes (probability that an input is of each
// // class), versus the label (100% probability in the true class)>
// model.compile({
// optimizer: optimizer,
// loss: 'categoricalCrossentropy'
// });
// We parameterize batch size as a fraction of the entire dataset because the
// number of examples that are collected depends on how many examples the user
// collects. This allows us to have a flexible batch size.
const batchSize =
Math.floor(controllerDataset.xs.shape[0] * ui.getBatchSizeFraction());
if (!(batchSize > 0)) {
throw new Error(
`Batch size is 0 or NaN. Please choose a non-zero fraction.`);
}
// Train the model! Model.fit() will shuffle xs & ys so we don't have to.
model.fit(controllerDataset.xs, controllerDataset.ys, {
batchSize,
epochs: ui.getEpochs(),
//initialEpoch: ui.getInitialEpochs(),
callbacks: {
onBatchEnd: async(batch, logs) => {
ui.trainStatus('Loss: ' + logs.loss.toFixed(5));
}
}
}
);
//const saveResult = await model.save('downloads://my-model-1');
}