When training an ML model, you might want to periodically save a checkpoint with the model weights. You can use a Persistent Volume to save checkpoints and retrieve them in subsequent runs.

In this example, we’ll demonstrate saving and loading a model checkpoint using PyTorch.

Mounting a Persistent Volume

The first thing you’ll do is mount a Persistent Volume to the Beam runtime in order to save your model weights.

Add the following line to the file with your app.py:

app.Mount.PersistentVolume(name="saved_models", path="./saved_models")

Saving model checkpoints

During your training loop, you can call torch.save() to dump the model weights to a file path. In this case, the file path is a Persistent Volume you’ve defined in app.py.

def save_model_weights():
    model = MyModel()

    PERSISTENT_VOLUME_PATH = "./saved_models/cifar_net.pth"
    # Save model to persistent volume
    torch.save(model.state_dict(), PERSISTENT_VOLUME_PATH)

Loading model checkpoints

You can load model weights using the load_state_dict() method, by passing in the Persistent Volume path from the previous step.

def load_model_weights():
    model = MyModel()

    PERSISTENT_VOLUME_PATH = "./saved_models/cifar_net.pth"
    # Load model from persistent volume
    model.load_state_dict(torch.load(PERSISTENT_VOLUME_PATH))

    saved_model = model.eval()
    print(saved_model)

Was this page helpful?