from huggingface_hub import notebook_login
notebook_login()
Single GPU Fine-tuning
Fine-tuning a Code LLM on Custom Code on a single GPU
Authored by: Maria Khalusova
Publicly available code LLMs such as Codex, StarCoder, and Code Llama are great at generating code that adheres to general programming principles and syntax, but they may not align with an organization’s internal conventions, or be aware of proprietary libraries.
In this notebook, we’ll see show how you can fine-tune a code LLM on private code bases to enhance its contextual awareness and improve a model’s usefulness to your organization’s needs. Since the code LLMs are quite large, fine-tuning them in a traditional manner can be resource-draining. Worry not! We will show how you can optimize fine-tuning to fit on a single GPU.
Dataset
For this example, we picked the top 10 Hugging Face public repositories on GitHub. We have excluded non-code files from the data, such as images, audio files, presentations, and so on. For Jupyter notebooks, we’ve kept only cells containing code. The resulting code is stored as a dataset that you can find on the Hugging Face Hub under smangrul/hf-stack-v1
. It contains repo id, file path, and file content.
Model
We’ll finetune bigcode/starcoderbase-1b
, which is a 1B parameter model trained on 80+ programming languages. This is a gated model, so if you plan to run this notebook with this exact model, you’ll need to gain access to it on the model’s page. Log in to your Hugging Face account to do so:
To get started, let’s install all the necessary libraries. As you can see, in addition to transformers
and datasets
, we’ll be using peft
, bitsandbytes
, and flash-attn
to optimize the training.
By employing parameter-efficient training techniques, we can run this notebook on a single A100 High-RAM GPU.
!pip install -q transformers datasets peft bitsandbytes flash-attn
Let’s define some variables now. Feel free to play with these.
="bigcode/starcoderbase-1b" # Model checkpoint on the Hugging Face Hub
MODEL="smangrul/hf-stack-v1" # Dataset on the Hugging Face Hub
DATASET="content" # Column name containing the code content
DATA_COLUMN
=2048 # Sequence length
SEQ_LENGTH
# Training arguments
=2000 # max_steps
MAX_STEPS=16 # batch_size
BATCH_SIZE=1 # gradient_accumulation_steps
GR_ACC_STEPS=5e-4 # learning_rate
LR="cosine" # lr_scheduler_type
LR_SCHEDULER_TYPE=0.01 # weight_decay
WEIGHT_DECAY=30 # num_warmup_steps
NUM_WARMUP_STEPS=100 # eval_freq
EVAL_FREQ=100 # save_freq
SAVE_FREQ=25 # log_freq
LOG_FREQ="peft-starcoder-lora-a100" # output_dir
OUTPUT_DIR=True # bf16
BF16=False # no_fp16
FP16
# FIM trasformations arguments
=0.5 # fim_rate
FIM_RATE=0.5 # fim_spm_rate
FIM_SPM_RATE
# LORA
=8 # lora_r
LORA_R=32 # lora_alpha
LORA_ALPHA=0.0 # lora_dropout
LORA_DROPOUT="c_proj,c_attn,q_attn,c_fc,c_proj" # lora_target_modules
LORA_TARGET_MODULES
# bitsandbytes config
=True # use_nested_quant
USE_NESTED_QUANT="bfloat16"# bnb_4bit_compute_dtype
BNB_4BIT_COMPUTE_DTYPE
=0 SEED
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
logging,
set_seed,
BitsAndBytesConfig,
)
set_seed(SEED)
Prepare the data
Begin by loading the data. As the dataset is likely to be quite large, make sure to enable the streaming mode. Streaming allows us to load the data progressively as we iterate over the dataset instead of downloading the whole dataset at once.
We’ll reserve the first 4000 examples as the validation set, and everything else will be the training data.
from datasets import load_dataset
import torch
from tqdm import tqdm
= load_dataset(
dataset
DATASET,="data",
data_dir="train",
split=True,
streaming
)
= dataset.take(4000)
valid_data = dataset.skip(4000)
train_data = train_data.shuffle(buffer_size=5000, seed=SEED) train_data
At this step, the dataset still contains raw data with code of arbitraty length. For training, we need inputs of fixed length. Let’s create an Iterable dataset that would return constant-length chunks of tokens from a stream of text files.
First, let’s estimate the average number of characters per token in the dataset, which will help us later estimate the number of tokens in the text buffer later. By default, we’ll only take 400 examples (nb_examples
) from the dataset. Using only a subset of the entire dataset will reduce computational cost while still providing a reasonable estimate of the overall character-to-token ratio.
= AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
tokenizer
def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
= 0, 0
total_characters, total_tokens for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
+= len(example[data_column])
total_characters += len(tokenizer(example[data_column]).tokens())
total_tokens
return total_characters / total_tokens
= chars_token_ratio(train_data, tokenizer, DATA_COLUMN)
chars_per_token print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
100%|██████████| 400/400 [00:10<00:00, 39.87it/s]
The character to token ratio of the dataset is: 2.43
The character-to-token ratio can also be used as an indicator of the quality of text tokenization. For instance, a character-to-token ratio of 1.0 would mean that each character is represented with a token, which is not very meaningful. This would indicate poor tokenization. In standard English text, one token is typically equivalent to approximately four characters, meaning the character-to-token ratio is around 4.0. We can expect a lower ratio in the code dataset, but generally speaking, a number between 2.0 and 3.5 can be considered good enough.
Optional FIM transformations
Autoregressive language models typically generate sequences from left to right. By applying the FIM transformations, the model can also learn to infill text. Check out “Efficient Training of Language Models to Fill in the Middle” paper to learn more about the technique. We’ll define the FIM transformations here and will use them when creating the Iterable Dataset. However, if you want to omit transformations, feel free to set fim_rate
to 0.
import functools
import numpy as np
# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.
@functools.lru_cache(maxsize=None)
def get_fim_token_ids(tokenizer):
try:
= tokenizer.special_tokens_map["additional_special_tokens"][1:5]
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = (
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
tokenizer.vocab[tok]
)except KeyError:
= None, None, None, None
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id
## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
def permute(
sample,
np_rng,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
pad_tok_id,=0.5,
fim_rate=0.5,
fim_spm_rate=False,
truncate_or_pad
):"""
Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:
PSM and SPM (with a probability of fim_spm_rate).
"""
# The if condition will trigger with the probability of fim_rate
# This means FIM transformations will apply to samples with a probability of fim_rate
if np_rng.binomial(1, fim_rate):
# Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.
= list(np_rng.randint(low=0, high=len(sample) + 1, size=2))
boundaries
boundaries.sort()
= np.array(sample[: boundaries[0]], dtype=np.int64)
prefix = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)
middle = np.array(sample[boundaries[1] :], dtype=np.int64)
suffix
if truncate_or_pad:
# calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix
= suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
new_length = new_length - len(sample)
diff
# trancate or pad if there's a difference in length between the new length and the original
if diff > 0:
if suffix.shape[0] <= diff:
return sample, np_rng
= suffix[: suffix.shape[0] - diff]
suffix elif diff < 0:
= np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])
suffix
# With the probability of fim_spm_rateapply SPM variant of FIM transformations
# SPM: suffix, prefix, middle
if np_rng.binomial(1, fim_spm_rate):
= np.concatenate(
new_sample
[
[prefix_tok_id, suffix_tok_id],
suffix,
[middle_tok_id],
prefix,
middle,
]
)# Otherwise, apply the PSM variant of FIM transformations
# PSM: prefix, suffix, middle
else:
= np.concatenate(
new_sample
[
[prefix_tok_id],
prefix,
[suffix_tok_id],
suffix,
[middle_tok_id],
middle,
]
)else:
# don't apply FIM transformations
= sample
new_sample
return list(new_sample), np_rng
Let’s define the ConstantLengthDataset
, an Iterable dataset that will return constant-length chunks of tokens. To do so, we’ll read a buffer of text from the original dataset until we hit the size limits and then apply tokenizer to convert the raw text into tokenized inputs. Optionally, we’ll perform FIM transformations on some sequences (the proportion of sequences affected is controlled by fim_rate
).
Once defined, we can create instances of the ConstantLengthDataset
from both training and validation data.
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
import random
# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
seed (int): Seed for random number generator.
"""
def __init__(
self,
tokenizer,
dataset,=False,
infinite=1024,
seq_length=1024,
num_of_sequences=3.6,
chars_per_token="content",
content_field=0.5,
fim_rate=0.5,
fim_spm_rate=0,
seed
):self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
self.fim_rate = fim_rate
self.fim_spm_rate = fim_spm_rate
self.seed = seed
(self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
= get_fim_token_ids(self.tokenizer)
) if not self.suffix_tok_id and self.fim_rate > 0:
print("FIM is not supported by tokenizer, disabling FIM")
self.fim_rate = 0
def __iter__(self):
= iter(self.dataset)
iterator = True
more_examples = np.random.RandomState(seed=self.seed)
np_rng while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
+= len(buffer[-1])
buffer_len except StopIteration:
if self.infinite:
= iter(self.dataset)
iterator else:
= False
more_examples break
= self.tokenizer(buffer, truncation=False)["input_ids"]
tokenized_inputs = []
all_token_ids
for tokenized_input in tokenized_inputs:
# optionally do FIM permutations
if self.fim_rate > 0:
= permute(
tokenized_input, np_rng
tokenized_input,
np_rng,self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
=self.fim_rate,
fim_rate=self.fim_spm_rate,
fim_spm_rate=False,
truncate_or_pad
)
+ [self.concat_token_id])
all_token_ids.extend(tokenized_input = []
examples for i in range(0, len(all_token_ids), self.seq_length):
= all_token_ids[i : i + self.seq_length]
input_ids if len(input_ids) == self.seq_length:
examples.append(input_ids)
random.shuffle(examples)for example in examples:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(example),
"labels": torch.LongTensor(example),
}
= ConstantLengthDataset(
train_dataset
tokenizer,
train_data,=True,
infinite=SEQ_LENGTH,
seq_length=chars_per_token,
chars_per_token=DATA_COLUMN,
content_field=FIM_RATE,
fim_rate=FIM_SPM_RATE,
fim_spm_rate=SEED,
seed
)= ConstantLengthDataset(
eval_dataset
tokenizer,
valid_data,=False,
infinite=SEQ_LENGTH,
seq_length=chars_per_token,
chars_per_token=DATA_COLUMN,
content_field=FIM_RATE,
fim_rate=FIM_SPM_RATE,
fim_spm_rate=SEED,
seed )
Prepare the model
Now that the data is prepared, it’s time to load the model! We’re going to load the quantized version of the model.
This will allow us to reduce memory usage, as quantization represents data with fewer bits. We’ll use the bitsandbytes
library to quantize the model, as it has a nice integration with transformers
. All we need to do is define a bitsandbytes
config, and then use it when loading the model.
There are different variants of 4bit quantization, but generally, we recommend using NF4 quantization for better performance (bnb_4bit_quant_type="nf4"
).
The bnb_4bit_use_double_quant
option adds a second quantization after the first one to save an additional 0.4 bits per parameter.
To learn more about quantization, check out the “Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA” blog post.
Once defined, pass the config to the from_pretrained
method to load the quantized version of the model.
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer
= False
load_in_8bit
# 4-bit quantization
= getattr(torch, BNB_4BIT_COMPUTE_DTYPE)
compute_dtype
= BitsAndBytesConfig(
bnb_config =True,
load_in_4bit="nf4",
bnb_4bit_quant_type=compute_dtype,
bnb_4bit_compute_dtype=USE_NESTED_QUANT,
bnb_4bit_use_double_quant
)
= {"": 0}
device_map
= AutoModelForCausalLM.from_pretrained(
model
MODEL,=load_in_8bit,
load_in_8bit=bnb_config,
quantization_config=device_map,
device_map=False, # We will be using gradient checkpointing
use_cache=True,
trust_remote_code=True,
use_flash_attention_2 )
When using a quantized model for training, you need to call the prepare_model_for_kbit_training()
function to preprocess the quantized model for training.
= prepare_model_for_kbit_training(model) model
Now that the quantized model is ready, we can set up a LoRA configuration. LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.
To train a model using LoRA technique, we need to wrap the base model as a PeftModel
. This involves definign LoRA configuration with LoraConfig
, and wrapping the original model with get_peft_model()
using the LoraConfig
.
To learn more about LoRA and its parameters, refer to PEFT documentation.
# Set up lora
= LoraConfig(
peft_config =LORA_ALPHA,
lora_alpha=LORA_DROPOUT,
lora_dropout=LORA_R,
r="none",
bias="CAUSAL_LM",
task_type=LORA_TARGET_MODULES.split(","),
target_modules
)
= get_peft_model(model, peft_config)
model model.print_trainable_parameters()
trainable params: 5,554,176 || all params: 1,142,761,472 || trainable%: 0.4860310866343243
As you can see, by applying LoRA technique we will now need to train less than 1% of the parameters.
Train the model
Now that we have prepared the data, and optimized the model, we are ready to bring everything together to start the training.
To instantiate a Trainer
, you need to define the training configuration. The most important is the TrainingArguments
, which is a class that contains all the attributes to configure the training.
These are similar to any other kind of model training you may run, so we won’t go into detail here.
= 0
train_data.start_iteration
= TrainingArguments(
training_args =f"Your_HF_username/{OUTPUT_DIR}",
output_dir=True,
dataloader_drop_last="steps",
evaluation_strategy="steps",
save_strategy=MAX_STEPS,
max_steps=EVAL_FREQ,
eval_steps=SAVE_FREQ,
save_steps=LOG_FREQ,
logging_steps=BATCH_SIZE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=LR,
learning_rate=LR_SCHEDULER_TYPE,
lr_scheduler_type=NUM_WARMUP_STEPS,
warmup_steps=GR_ACC_STEPS,
gradient_accumulation_steps=True,
gradient_checkpointing=FP16,
fp16=BF16,
bf16=WEIGHT_DECAY,
weight_decay=True,
push_to_hub=True,
include_tokens_per_second )
As a final step, instantiate the Trainer
and call the train
method.
= Trainer(
trainer =model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset
model
)
print("Training...")
trainer.train()
Training...
Step | Training Loss | Validation Loss |
---|---|---|
100 | 5.524600 | 7.456872 |
200 | 5.617800 | 7.262190 |
300 | 5.129100 | 6.410039 |
400 | 5.052200 | 6.306774 |
500 | 5.202900 | 6.117062 |
600 | 4.654100 | 6.018349 |
700 | 5.100200 | 6.000355 |
800 | 5.049800 | 5.889457 |
900 | 4.541200 | 5.813823 |
1000 | 5.000700 | 5.834208 |
1100 | 5.026500 | 5.781939 |
1200 | 4.411800 | 5.720596 |
1300 | 4.782500 | 5.736376 |
1400 | 4.980200 | 5.712276 |
1500 | 4.368700 | 5.689637 |
1600 | 4.884700 | 5.675920 |
1700 | 4.914400 | 5.662421 |
1800 | 4.248700 | 5.660122 |
1900 | 4.798400 | 5.664026 |
2000 | 4.704200 | 5.655665 |
TrainOutput(global_step=2000, training_loss=4.885598585128784, metrics={'train_runtime': 15380.3075, 'train_samples_per_second': 2.081, 'train_steps_per_second': 0.13, 'train_tokens_per_second': 4261.033, 'total_flos': 4.0317260660736e+17, 'train_loss': 4.885598585128784, 'epoch': 1.0})
Finally, you can push the fine-tuned model to your Hub repository to share with your team.
trainer.push_to_hub()
Inference
Once the model is uploaded to Hub, we can use it for inference. To do so we first initialize the original base model and its tokenizer. Next, we need to merge the fine-duned weights with the base model.
from peft import PeftModel
import torch
# load the original model first
= AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
tokenizer = AutoModelForCausalLM.from_pretrained(
base_model
MODEL,=None,
quantization_config=None,
device_map=True,
trust_remote_code=torch.bfloat16,
torch_dtype
).cuda()
# merge fine-tuned weights with the base model
= f"Your_HF_username/{OUTPUT_DIR}"
peft_model_id = PeftModel.from_pretrained(base_model, peft_model_id)
model model.merge_and_unload()
Now we can use the merged model for inference. For convenience, we’ll define a get_code_completion
- feel free to experiment with text generation parameters!
def get_code_completion(prefix, suffix):
= prompt = f"""<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"""
text eval()
model.= model.generate(
outputs =tokenizer(text, return_tensors="pt").input_ids.cuda(),
input_ids=128,
max_new_tokens=0.2,
temperature=50,
top_k=0.95,
top_p=True,
do_sample=1.0,
repetition_penalty
)return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
Now all we need to do to get code completion is call the get_code_complete
function and pass the first few lines that we want to be completed as a prefix, and an empty string as a suffix.
= """from peft import LoraConfig, TaskType, get_peft_model
prefix from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
"""
=""""""
suffix
print(get_code_completion(prefix, suffix))
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
modules_to_save=["q_proj", "v_proj"],
inference_mode=False,
)
model = AutoModelForCausalLM.from_pretrained("gpt2")
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
As someone who has just used the PEFT library earlier in this notebook, you can see that the generated result for creating a LoraConfig
is rather good!
If you go back to the cell where we instantiate the model for inference, and comment out the lines where we merge the fine-tuned weights, you can see what the original model would’ve generated for the exact same prefix:
= """from peft import LoraConfig, TaskType, get_peft_model
prefix from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
"""
=""""""
suffix
print(get_code_completion(prefix, suffix))
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
peft_config = LoraConfig(
model_name_or_path="facebook/wav2vec2-base-960h",
num_labels=1,
num_features=1,
num_hidden_layers=1,
num_attention_heads=1,
num_hidden_layers_per_attention_head=1,
num_attention_heads_per_hidden_layer=1,
hidden_size=1024,
hidden_dropout_prob=0.1,
hidden_act="gelu",
hidden_act_dropout_prob=0.1,
hidden
While it is Python syntax, you can see that the original model has no understanding of what a LoraConfig
should be doing.
To learn how this kind of fine-tuning compares to full fine-tuning, and how to use a model like this as your copilot in VS Code via Inference Endpoints, or locally, check out the “Personal Copilot: Train Your Own Coding Assistant” blog post. This notebook complements the original blog post.