Saving/loading models and checkpointing
After you have trained your model, obviously you will want to save it to use it thereafter. You will then need to load it in any session in which you want to use it.
In addition to saving or loading a fully trained model, it is important to know how to create regular checkpoints: training ML models takes a long time and a cluster crash or countless other issues might interrupt the training process. You don’t want to have to restart from scratch if that happens.
Saving/loading models
Saving models
You can save a model by serializing its internal state dictionary. The state dictionary is a Python dictionary that contains the learnable parameters of your model (weights and biases).
"model.pt") torch.save(model.state_dict(),
Loading models
To recreate your model, you first need to recreate its structure:
= TheModelClass(*args, **kwargs) model
Then you can load the state dictionary containing the parameters values into it:
"model.pt")) model.load_state_dict(torch.load(
Assuming you want to use your model for inference, you also must run:
eval() model.
If instead you want to do more training on your model, you would of course run model.train()
instead.
Checkpointing
Creating a checkpoint
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
..."model.pt") },
Resuming training from a checkpoint
Recreate the state of your model from the checkpoint:
= TheModelClass(*args, **kwargs)
model = TheOptimizerClass(*args, **kwargs)
optimizer
= torch.load("model.pt")
checkpoint 'model_state_dict'])
model.load_state_dict(checkpoint['optimizer_state_dict'])
optimizer.load_state_dict(checkpoint[= checkpoint['epoch']
epoch = checkpoint['loss'] loss
Resume training:
model.train()