Endpoints and Web Servers
Pre-Loading Models
This guide shows how you can optimize performance by pre-loading models when your container first starts.
Beam includes an optional on_start
lifecycle hook which you can add to your functions. The on_start
function will be run exactly once when your container first starts.
app.py
from beam import endpoint
def download_models():
# Do something that only needs to happen once
return {}
# The on_start function runs once when the container starts
@endpoint(on_start=download_models)
def handler():
return {}
Anything returned from on_start
can be retrieved in the context
variable that is automatically passed to your handler:
from beam import endpoint
def download_models():
# Do something that only needs to happen once
x = 10
return {"x": x}
# The on_start function runs once when the container starts
@endpoint(on_start=download_models)
def handler(context):
# Retrieve cached values from on_start
on_start_value = context.on_start_value
return {}
Example: Downloading Model Weights
from beam import Image, endpoint, Volume
CACHE_PATH = "./weights"
def download_models():
from transformers import AutoTokenizer, OPTForCausalLM
model = OPTForCausalLM.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", cache_dir=CACHE_PATH)
return model, tokenizer
@endpoint(
on_start=download_models,
volumes=[Volume(name="weights", mount_path=CACHE_PATH)],
cpu=1,
memory="16Gi",
gpu="T4",
image=Image(
python_version="python3.8",
python_packages=[
"transformers",
"torch",
],
),
)
def predict(context, prompt):
# Retrieve cached model fron on_start function
model, tokenizer = context.on_start_value
# Generate
inputs = tokenizer(prompt, return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=30)
result = tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(result)
return {"prediction": result}
Using Loaders with Multiple Workers
If you are scaling out vertically with workers, the loader function will run once for each worker that starts up.
Was this page helpful?