Saving Model Checkpoints
Using Persistent Volumes to save and load model weights.
When training an ML model, you might want to periodically save a checkpoint with the model weights. You can use a Persistent Volume to save checkpoints and retrieve them in subsequent runs.
In this example, we’ll demonstrate saving and loading a model checkpoint using PyTorch.
Mounting a Persistent Volume
The first thing you’ll do is mount a Persistent Volume to the Beam runtime in order to save your model weights.
Add the following line to the file with your app.py
:
app.Mount.PersistentVolume(name="saved_models", path="./saved_models")
Saving model checkpoints
During your training loop, you can call torch.save()
to dump the model weights to a file path. In this case, the file path is a Persistent Volume you’ve defined in app.py
.
def save_model_weights():
model = MyModel()
PERSISTENT_VOLUME_PATH = "./saved_models/cifar_net.pth"
# Save model to persistent volume
torch.save(model.state_dict(), PERSISTENT_VOLUME_PATH)
Loading model checkpoints
You can load model weights using the load_state_dict()
method, by passing in the Persistent Volume path from the previous step.
def load_model_weights():
model = MyModel()
PERSISTENT_VOLUME_PATH = "./saved_models/cifar_net.pth"
# Load model from persistent volume
model.load_state_dict(torch.load(PERSISTENT_VOLUME_PATH))
saved_model = model.eval()
print(saved_model)
Was this page helpful?