Saving Model Checkpoints
Using Persistent Volumes to save and load model weights.
When training an ML model, you might want to periodically save a checkpoint with the model weights. You can use a Persistent Volume to save checkpoints and retrieve them in subsequent runs.
In this example, we’ll demonstrate saving and loading a model checkpoint using PyTorch.
Mounting a Persistent Volume
The first thing you’ll do is mount a Persistent Volume to the Beam runtime in order to save your model weights.
Add the following line to the file with your app.py
:
Saving model checkpoints
During your training loop, you can call torch.save()
to dump the model weights to a file path. In this case, the file path is a Persistent Volume you’ve defined in app.py
.
Loading model checkpoints
You can load model weights using the load_state_dict()
method, by passing in the Persistent Volume path from the previous step.
Was this page helpful?