AI
A simple guide to fine tuning Llama 2 on your own data
In this guide, I show you how it's actually quite easy to fine tune Llama 2 with your own dataset using HuggingFace's libraries!
All you need is your data in this format:
{ "text": "text-for-model-to-predict" }
and HuggingFace's dataloader will handle the rest.
This the equivalent notebook.
Requirements
To start, use a machine with a reasonably latest version of Python and Cuda. (I used Python 3.10 and Cuda 11.7).
On the GPU side, I recommend at least 24GB of VRam: A100, A10, A10G etc. If you want to fine-tune the larger 13b and 70b models, I'd go straight for the A100 :)
Brev sorts all this out for you.
Setup
First step is always to install those pesky dependencies:
pip install -q huggingface_hub
pip install -q -U trl transformers accelerate peft
pip install -q -U datasets bitsandbytes einops wandb
pip install -q ipywidgets
pip install -q scipy
And import them:
from huggingface_hub import notebook_login
notebook_login()
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
If the notebook_login step doesn't prompt you for your Huggingface credentials try running this in a terminal:
huggingface-cli login
Dataset
Preparing data
To prepare your dataset for loading, all you need is a .jsonl
file structured something like this:
{"input": "What color is the sky?", "output": "The sky is blue."}
{"input": "Where is the best place to get cloud GPUs?", "output": "Brev.dev"}
{"input": "Why do Americans love guns so much?", "output": "Because of the Spanish."}
You can have any combination of keys/features you want because Huggingface's SFFTrainer function will take care of formatting your prompts during training.
Loading data
Load your train and (optionally) evaluation datasets like this:
from datasets import load_dataset
train_dataset = load_dataset('json', data_files='notes.jsonl', split='train')
eval_dataset = load_dataset('json', data_files='notes_validation.jsonl', split='train')
(If you don't want to use an evaluation dataset, just comment out the above eval_dataset line and the other appearences of it in below in SFTTrainer.)
Formatting prompts
Then create a formatting_func to structure training examples as prompts:
def formatting_func(example):
text = f"### Question: {example['input']}\n ### Answer: {example['output']}"
return [text]
In my case, my data was just notes like this:
{"note": "note-for-model-to-predict"}
{"note": "note-for-model-to-predict-1"}
{"note": "note-for-model-to-predict-2"}
So the formatting_func I used was:
def formatting_func(example):
text = f"### The following is a note: {example['note']}"
return [text]
Model
Loading model
Load the Llama 2 non-chat model quantized to 4 bits:
base_model_name = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
use_auth_token=True
)
base_model.config.use_cache = False
# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
(If you see a token error, it probably means you haven't requested the model from Meta: go here, then agree to the terms on Huggingface and try again. Make sure the email you use for Huggingface is the same as you use to request Llama 2.)
Fine tuning model
And setup a train so that we log, save and evaluate every 50 steps:
output_dir = "./Llama-2-7b-hf-fine-tune-baby"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=50,
max_steps=1000,
logging_dir="./logs", # Directory for storing logs
save_strategy="steps", # Save the model checkpoint every logging step
save_steps=50, # Save checkpoints every 50 steps
evaluation_strategy="steps", # Evaluate the model every logging step
eval_steps=50, # Evaluate and save checkpoints every 50 steps
do_eval=True # Perform evaluation at the end of training
)
We set the config for the Lora adapter:
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
I experimented with higher alpha and r and got slightly worse results :(
Then leverage the beautiful SFTTrainer class from Huggingface to fine tune Llama 2:
max_seq_length = 512
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
formatting_func=formatting_func,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_args,
)
# pass in resume_from_checkpoint=True to resume from a checkpoint
trainer.train()
On a 40GB A100 training for 1000 steps took about 2 hours.
Running inference on a trained model
By default, the PEFT library will only save the Qlora adapters. So we need to load the base Llama 2 model from the Huggingface Hub:
base_model_name="meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
use_auth_token=True
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
and load the qlora adapter from a checkpoint directory:
model = PeftModel.from_pretrained(base_model, "/root/llama2sfft-testing/Llama-2-7b-hf-qlora-full-dataset/checkpoint-900")
then run some inference:
eval_prompt = """A note has the following\nTitle: \nLabels: \nContent: i love"""
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))
As always, if you have any questions shoot me a message on Discord.