Saving the model during the training process is done using the mx.do_checkpoint callback. A few important parameters are as follows:
- prefix: This defines the prefix of the filenames to save the model
- frequency: The frequency is measured in epochs to save checkpoints
Let's move back to the MNIST example we have been working on and adjust the mx.fit function to include mx.do_checkpoint:
mx.fit(nnet, mx.ADAM(), train_data_provider, eval_data = validation_data_provider, n_epoch = 50, callbacks = [mx.speedometer()]);
You can see that in the original version we have already configured the network to call the mx.speedometer callback. The new version will include a call to mx.do_checkpoint to save the model on every 5th epoch with ...