Fine-tuning Gemma with LoRA
In this example we are fine-tuning Gemma 2B, an open source model from Google.
View the Code
See the code for this example on Github.
Fine-Tuning
In this example, we are using Low-Rank Adaption (LoRA) to fine-tune the Gemma language model using the Open Assistant dataset.
The goal is to use this dataset to improve Gemma’s ability to engage in helpful conversations, making it more suitable for assistant-like apps.
LoRA
You can read more about LoRA here. However, let’s briefly discuss what exactly it does and why we chose to use it here.
At a high level, LoRA introduces a new small set of weights to the model that we will be training. By limiting our training to these additional weights, we can fine-tune the model much quicker. Additionally, since we are not touching the original weights, the model’s initial knowledge base should intact.
Initial Setup
In this example, we are using an A100-40 GPU. We are using mixed precision (FP16) to optimize for speed and memory usage. In this example, we are only training for one epoch. In practice, you can probably train longer and continue to see improved results.
No surprise here, but we are getting our compute via Beam. We are using the function
decorator so that we can run our fine-tuning application as if it were on our local machine.
from beam import Volume, Image, function
# The mount path is the location on the beam volume with the model weights
MOUNT_PATH = "./gemma-ft"
@function(
volumes=[Volume(name="gemma-ft", mount_path=MOUNT_PATH)],
image=Image(
python_packages=["transformers", "torch", "datasets", "peft", "bitsandbytes"]
),
gpu="A100-40",
cpu=4,
)
Mounting Storage Volumes
One interesting thing to note above is that we are mounting a storage volume to our container. This volume is where we have uploaded our intial weights from Hugging Face and our training dataset.
It is also where we will store our additional fine-tuned weights.
Start Training
We can start our training by running python finetune.py
. After beginning training, you should see something like the following in your terminal:
=> Building image
=> Syncing files
...
=> Running function: <finetune:gemma_fine_tune>
Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]
...
Generating train split: 12947 examples [00:00, 114393.80 examples/s]
...
Map: 93%|#########2| 12000/12947 [00:13<00:01, 921.12 examples/s]
...
1%| | 6/809 [00:08<16:35, 1.24s/it]
...
{'loss': 1.617, 'grad_norm': 0.4805833399295807, 'learning_rate': 0.00019752781211372064, 'epoch': 0.01}
...
Once it is finished, we can use the beam CLI to look at the resulting files. You should see something like this:
❯ beam ls gemma-ft/gemma-2b-finetuned
Name Size Modified Time IsDir
──────────────────────────────────────────────────────────────────────────────────
gemma-2b-finetuned/README.md 4.97 KiB Aug 10 2024 No
gemma-2b-finetuned/adapter_config.json 644.00 B Aug 10 2024 No
gemma-2b-finetuned/adapter_model.safetensors 12.20 MiB Aug 10 2024 No
gemma-2b-finetuned/checkpoint-700 36.70 MiB Aug 01 2024 Yes
gemma-2b-finetuned/checkpoint-800 36.70 MiB Aug 01 2024 Yes
gemma-2b-finetuned/checkpoint-809 36.70 MiB Aug 01 2024 Yes
gemma-2b-finetuned/special_tokens_map.json 555.00 B Aug 10 2024 No
gemma-2b-finetuned/tokenizer.json 16.71 MiB Aug 10 2024 No
gemma-2b-finetuned/tokenizer_config.json 45.21 KiB Aug 10 2024 No
9 items | 139.06 MiB used
Inference
In inference.py
, we are loading up our model with the additional fine-tuned weights and setting up an endpoint to send it requests.
Here, we make use of the Beam’s on_start
functionality so that we only load the model when the container starts instead of every time we receive a request. Let’s explore the endpoint
decorator below.
from beam import Volume, Image, endpoint
# The mount path is the location on the beam volume with the model weights
MOUNT_PATH = "./gemma-ft"
@endpoint(
name="gemma-inference",
on_start=load_finetuned_model,
volumes=[Volume(name="gemma-ft", mount_path=MOUNT_PATH)],
cpu=1,
memory="16Gi",
gpu="T4",
image=Image(
python_version="python3.9",
python_packages=["transformers==4.42.0", "torch", "peft"],
),
)
Once again, we are mounting our storage volume named “gemma-ft”. Since we have already run training, this volume will now contain our fine-tuned weights alongside the base weights we got from Hugging Face.
Choosing a GPU For Inference
Now that we’ve trained the model, we can run it on a machine with a weaker GPU.
Training requires more memory than inference because it must store gradients and optimizer states for all parameters, in addition to activations, whereas inference only needs to maintain the current layer’s activations during a forward pass. Be sure to keep this in mind as you work on your own applications.
You can use the Beam dashboard to get a sense of GPU utilization in real-time. With this information, you can make a more informed choice about how much compute you require. For this example, we use a T4 GPU. It has 16GB of VRAM and is a good choice for inference with a model this small.
Monitoring compute usage in real-time in the dashboard
Using Signals to Reload Model Weights Automatically
We use a Signal
abstraction to fire an event to the inference app when the model has finished training.
This allows us to communicate between apps on Beam. In this example, we have it setup to re-run our on-start method when a signal is received. This way, if we re-train our model, we can load the newest weights without restarting the container.
# Register a signal
s = experimental.Signal(
name="reload-model",
handler=load_finetuned_model,
)
Deploying The Endpoint
Let’s deploy our endpoint! We can do this with the beam
CLI.
beam deploy inference.py:predict --name gemma-ft
The output will look something like this:
=> Building image
=> Syncing files
=> Deploying
=> Deployed 🎉
=> Invocation details
curl -X POST 'https://app.beam.cloud/endpoint/gemma-ft/v2' \
-H 'Connection: keep-alive' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer {YOUR_AUTH_TOKEN}' \
-d '{}'
When calling our inference endpoint, we’ll need to include a prompt. For example, we can call the deployed endpoint with -d '{"prompt": "hi"}
. The response we get back will be in the following format:
{"text":"Hello! How can I help you today?<|im_end|>"}
Note that the returned response includes the stop tokens <|im_end|>
. You could strip this token in the endpoint logic if you would like, but it is worth keeping around if you will be appending this response to a longer running conversation.
Was this page helpful?