Skip to main content
This example demonstrates an AI Avatar app, built using DreamBooth and Stable Diffusion v1.5.

Overview

This app has two APIs. The first API is used to start a fine-tuning job on a batch of image URLs. The second API is used to generate an image using the fine-tuned model.

Training

This endpoint will take a list of input images as URLs, and fine-tune Stable Diffusion on those images. It also takes a user ID, so that you can reference the specific fine-tuned model later on when you generate customized images.
app-training.py
from beam import App, Runtime, Image, Output, Volume

import pathlib
import requests
import subprocess
import hashlib
import os

"""
This function:
- takes a list of image URLs
- saves them to a storage volume
- trains Dreambooth on the images
- saves them in a dedicated partition based on their user ID
"""

BASE_ROUTE = "./dreambooth"
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"


app = App(
    name="dreambooth-training",
    runtime=Runtime(
        gpu="A10G",
        cpu=4,
        memory="32Gi",
        image=Image(
            python_version="python3.8",
            python_packages="requirements.txt",
        ),
    ),
    # Shared Volume to store the trained models
    volumes=[Volume(path="./dreambooth", name="dreambooth")]
)

# Deploys function as async task queue
@app.task_queue()
def train_dreambooth(**inputs):

    user_id = inputs["user_id"]
    urls = inputs["image_urls"]
    instance_prompt = inputs["instance_prompt"]
    class_prompt = inputs["class_prompt"]

    # Create directories in storage volume
    pathlib.Path(BASE_ROUTE).mkdir(parents=True, exist_ok=True)
    pathlib.Path(f"{BASE_ROUTE}/images/{user_id}").mkdir(parents=True, exist_ok=True)

    training_images_path = f"{BASE_ROUTE}/images/{user_id}"

    # Loop through the list of URLs provided and download each to a volume
    for url in urls:
        response = requests.get(url)
        image_url_hash = hashlib.md5(url.encode("utf-8")).hexdigest()

        if response.status_code == 200:
            with open(
                os.path.join(training_images_path, image_url_hash + ".png"), "wb"
            ) as f:
                f.write(response.content)
        else:
            print(f"Failed to save image from URL: {url}")

    # Dreambooth commands
    subprocess.run(
        [
            "python3.8",
            "-m",
            "accelerate.commands.accelerate_cli",
            "launch",
            f"--config_file=/workspace/default-config.yaml",
            "train_dreambooth.py",
            # Path to the pre-trained model
            f"--pretrained_model_name_or_path={pretrained_model_name_or_path}",
            # Path to the training data
            f"--instance_data_dir={training_images_path}",
            # Save trained model in the volume, based on the user UUID
            f"--output_dir={BASE_ROUTE}/trained_models/{user_id}",
            "--prior_loss_weight=1.0",
            # Instance Prompt -- the specific instance of the image being fine-tuned, e.g. a [sks] man wearing sunglasses
            f"--instance_prompt={instance_prompt}",
            # Class Prompt -- the general category of the image being fine-tuned e.g. a man wearing sunglasses
            f"--class_prompt={class_prompt}",
            "--mixed_precision=no",
            "--resolution=512",
            "--train_batch_size=1",
            "--gradient_accumulation_steps=1",
            "--use_8bit_adam",
            "--gradient_checkpointing",
            "--set_grads_to_none",
            "--lr_scheduler=constant",
            "--lr_warmup_steps=0",
            # The two most useful levers in the training process
            # If the generated images don't match your prompt, you should consider increasing or decreasing the training steps and learning rate
            "--learning_rate=2e-6",
            "--max_train_steps=400",
        ],
        stdin=subprocess.PIPE,
        cwd="/workspace",
        env={**os.environ, "PYTHONPATH": "/workspace/__pypackages__:/workspace"},
    )


if __name__ == "__main__":
    user_id = "111111"
    instance_prompt = "a photo of a sks toy"
    class_prompt = "a photo of a toy"
    urls = [
        "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
    ]
    train_dreambooth(
        user_id=user_id,
        image_urls=urls,
        instance_prompt=instance_prompt,
        class_prompt=class_prompt,
    )
We’ll deploy the training API by running:
beam deploy app-training.py
Once the app spins up, you can find the API URL in the web dashboard and send a request to start a training job.

Starting a fine-tuning task

After deploying the app, you can kick-off a fine-tuning job by calling the API with a JSON payload like this:
{
  "user_id": "111111",
  "instance_prompt": "a photo of a sks toy",
  "class_prompt": "a photo of a toy",
  "image_urls": [
    "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
    "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
    "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
    "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg"
  ]
}
We’ll pass in a bunch of images of cat toys:
Here’s what the complete cURL request will look like:
curl -X POST --compressed "https://api.beam.cloud/lnmfd" \
    -H 'Accept: */*' \
    -H 'Accept-Encoding: gzip, deflate' \
    -H 'Authorization: Basic [YOUR_AUTH_TOKEN]' \
    -H 'Connection: keep-alive' \
    -H 'Content-Type: application/json' \
    -d '{"user_id": "111111", "image_urls": "[\"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg\", \"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg\", \"https://huggingface.co/datasets/valhalla/images/resolve/main/4.jpeg\"]", "class_prompt": "a photo of a toy", "instance_prompt": "a photo of a sks toy"}'
This code runs asynchronously, so a task ID is returned from the request:
{ "task_id": "403f3a8e-503c-427a-8085-7d59384a2566" }
We can view the status of the training job by querying the task API:
curl -X POST --compressed "https://api.beam.cloud/task" \
  -H 'Accept: */*' \
  -H 'Accept-Encoding: gzip, deflate' \
  -H 'Authorization: Basic [YOUR_AUTH_TOKEN]' \
  -H 'Content-Type: application/json' \
  -d '{"action": "retrieve", "task_id": "403f3a8e-503c-427a-8085-7d59384a2566"}'
This returns the task status. If the task is completed, we can call the inference API to use our newly fine-tuned model.
{
  "outputs": {},
  "outputs_list": [],
  "started_at": "2023-02-15T22:26:11.941531Z",
  "ended_at": "2023-02-15T22:30:20.875621Z",
  "status": "COMPLETE",
  "task_id": "403f3a8e-503c-427a-8085-7d59384a2566"
}

Inference

Now that we’ve setup our fine-tuning API, we’ll move onto the code that runs inference with the fine-tuned model:
app-inference.py
from beam import App, Runtime, Image, Output, Volume

import os
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image

model_id = "runwayml/stable-diffusion-v1-5"


# The environment your code will run on
app = App(
    name="dreambooth-inference",
    runtime=Runtime(
        cpu=4,
        memory="32Gi",
        gpu="A10G",
        image=Image(
            python_version="python3.8",
            python_packages="requirements.txt",
        ),
    ),
    volumes=[Volume(path="./dreambooth", name="dreambooth")],
)


# TaskQueue API will take two inputs:
# - user_id, to identify the user training their custom model
# - image_urls, a list of image URLs
@app.task_queue(outputs=[Output(path="./dreambooth")])
def generate_images(**inputs):
    # Takes in a prompt and userID from the API request
    prompt = inputs["prompt"]
    user_id = inputs["user_id"]

    # Path to the unique model trained for this userID
    model_path = f"./dreambooth/trained_models/{user_id}"

    # Special torch method to improve performance
    torch.backends.cuda.matmul.allow_tf32 = True

    pipe = StableDiffusionPipeline.from_pretrained(
        # Run inference on the specific model trained for this user ID
        model_path,
        revision="fp16",
        torch_dtype=torch.float16,
        # The `cache_dir` arg is used to cache the model in between requests
        cache_dir=model_path,
    ).to("cuda")

    pipe.enable_xformers_memory_efficient_attention()

    # Image generation
    with torch.inference_mode():
        with torch.autocast("cuda"):
            image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

    print(f"Generated Image: {image}")
    image.save("output.png")


if __name__ == "__main__":
    user_id = "111111"
    generate_images(
        user_id=user_id,
        prompt=f"a photo of a sks toy riding the subway",
    )

Deployment

You can deploy this by running beam deploy app-inference.py. Once it’s deployed, you can find the web URL in the dashboard.
Here’s what a request will look like:
curl -X POST --compressed "https://api.beam.cloud/lnmfd" \
    -H 'Accept: */*' \
    -H 'Accept-Encoding: gzip, deflate' \
    -H 'Authorization: Basic [YOUR_AUTH_TOKEN]' \
    -H 'Connection: keep-alive' \
    -H 'Content-Type: application/json' \
    -d '{"prompt": "photo of a sks riding the subway", "user_id": "111111"}'
This function also runs asynchronously, so a task ID is returned:
{ "task_id": "403f3a8e-503c-427a-8085-7d59384a2566" }

Querying task status

We can view the status of the inference request by querying the task API:
curl -X POST --compressed "https://api.beam.cloud/task" \
  -H 'Accept: */*' \
  -H 'Accept-Encoding: gzip, deflate' \
  -H 'Authorization: Basic [YOUR_AUTH_TOKEN]' \
  -H 'Content-Type: application/json' \
  -d '{"action": "retrieve", "task_id": "403f3a8e-503c-427a-8085-7d59384a2566"}'
If the request is completed, you’ll see an image-output field in the response.
{
  "outputs": {
    "image-output": "https://beam.cloud/data/f2c8760c63d6e403729a212f1c19b597692b1c26c1c65"
  },
  "outputs_list": [
    {
      "id": "63ed62d4a6b28b22fbfd58bf",
      "created": "2023-02-15T22:55:16.347656Z",
      "name": "image-output",
      "updated": "2023-02-15T22:55:16.347674Z",
      "output_type": "file",
      "task": "403f3a8e-503c-427a-8085-7d59384a2566"
    }
  ],
  "started_at": "2023-02-15T22:54:43.156854Z",
  "ended_at": "2023-02-15T22:55:16.438379Z",
  "status": "COMPLETE",
  "task_id": "403f3a8e-503c-427a-8085-7d59384a2566"
}

Retrieving image outputs

Enter this the image-output URL in the browser. It will download a zip file with the image generated from the model. And there you go — a cat toy riding the subway:
I