Optimizing Cold Start

If you’ve followed these steps and want to further improve your cold start, reach out to us and we’ll help you optimize your app to load as fast as possible.

There are two things you’ll want to do to lower the serverless cold start.

Cache Models in Volumes

To avoid downloading your models from the internet on each request, you can cache them in Beam’s Volumes.

In the example below, the models are saved to the Volume by passing the cache_dir argument in the Huggingface Transformers method:

from beam import Image, endpoint, Volume

# Path to cache model weights
CACHE_PATH = "./weights"

@endpoint(
    volumes=[Volume(name="weights", mount_path=CACHE_PATH)],
    cpu=1,
    memory="16Gi",
    gpu="T4",
    image=Image(
        python_version="python3.9",
        python_packages=[
            "transformers",
            "torch",
        ],
    ),
)
def predict():
    from transformers import AutoTokenizer, OPTForCausalLM
    import torch

    model = OPTForCausalLM.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)

    # Run inference
    model.generate(...)
    return {"text": ""}

Load Models Using on_start

In addition to using a Volume, it’s best-practice to ensure models are only loaded once when the container first starts. Beam lets you define an on_start function that will run exactly once when the container first starts:

This example combines the on_start functionality with the Volume caching:

from beam import Image, endpoint, Volume

# Path to cache model weights
CACHE_PATH = "./weights"


# This runs once when the container first starts
def download_models():
    from transformers import AutoTokenizer, OPTForCausalLM
    import torch

    model = OPTForCausalLM.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)
    return model


@endpoint(
    on_start=download_models,
    volumes=[Volume(name="weights", mount_path=CACHE_PATH)],
    cpu=1,
    memory="16Gi",
    gpu="T4",
    image=Image(
        python_version="python3.9",
        python_packages=[
            "transformers",
            "torch",
        ],
    ),
)
def predict(context):
    # Retrieve cached model and tokenizer from on_start function
    model = context.on_start_value

    # Run inference
    model.generate(...)
    return {"text": ""}

Measuring Cold Start

We’ve made it easier to optimize your cold starts by adding a cold start profile to each task.

You can view the cold start profile of a task by clicking on any task in the tasks table.

This breakdown shows the entire lifecycle of your task: spinning up a container, running your on_start function, and running the task itself.

Here’s a breakdown of a serverless cold start:

  • Container Start Time. This is typically under 1s.
  • Image Load Time. Pulling your container image from our image cache. This varies based on the size of your model and the dependencies you’ve added.
  • Application Start Time. Running your code. This is the time running your on_start, and loading it on the GPU.

Was this page helpful?