Beam includes an optional on_start lifecycle hook which you can add to your functions. The on_start function will be run exactly once when your container first starts.

app.py
from beam import endpoint


def download_models():
    # Do something that only needs to happen once
    return {}


# The on_start function runs once when the container starts
@endpoint(on_start=download_models)
def handler():
    return {}

Anything returned from on_start can be retrieved in the context variable that is automatically passed to your handler:

from beam import endpoint


def download_models():
    # Do something that only needs to happen once
    x = 10
    return {"x": x}


# The on_start function runs once when the container starts
@endpoint(on_start=download_models)
def handler(context):
    # Retrieve cached values from on_start
    on_start_value = context.on_start_value
    return {}

Example: Downloading Model Weights

from beam import Image, endpoint, Volume


CACHE_PATH = "./weights"


def download_models():
    from transformers import AutoTokenizer, OPTForCausalLM

    model = OPTForCausalLM.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)

    return model, tokenizer


@endpoint(
    on_start=download_models,
    volumes=[Volume(name="weights", mount_path=CACHE_PATH)],
    cpu=1,
    memory="16Gi",
    gpu="T4",
    image=Image(
        python_version="python3.8",
        python_packages=[
            "transformers",
            "torch",
        ],
    ),
)
def predict(context, prompt):
    # Retrieve cached model fron on_start function
    model, tokenizer = context.on_start_value

    # Generate
    inputs = tokenizer(prompt, return_tensors="pt")
    generate_ids = model.generate(inputs.input_ids, max_length=30)
    result = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    print(result)

    return {"prediction": result}

Using Loaders with Multiple Workers

If you are scaling out vertically with workers, the loader function will run once for each worker that starts up.

Was this page helpful?