Saving/loading models and checkpointing

Author

Marie-Hélène Burle

After you have trained your model, obviously you will want to save it to use it thereafter. You will then need to load it in any session in which you want to use it.

In addition to saving or loading a fully trained model, it is important to know how to create regular checkpoints: training ML models takes a long time and a cluster crash or countless other issues might interrupt the training process. You don’t want to have to restart from scratch if that happens.

Saving/loading models

Saving models

You can save a model by serializing its internal state dictionary. The state dictionary is a Python dictionary that contains the learnable parameters of your model (weights and biases).

torch.save(model.state_dict(), "model.pt")

Loading models

To recreate your model, you first need to recreate its structure:

model = TheModelClass(*args, **kwargs)

Then you can load the state dictionary containing the parameters values into it:

model.load_state_dict(torch.load("model.pt"))

Assuming you want to use your model for inference, you also must run:

model.eval()

If instead you want to do more training on your model, you would of course run model.train() instead.

Checkpointing

Creating a checkpoint

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    ...
}, "model.pt")

Resuming training from a checkpoint

Recreate the state of your model from the checkpoint:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Resume training:

model.train()