Detecting Dementia Using Instruction-Tuned LLaMA
As part of a project at the University of Stavanger, my fellow Student Kinnan Al Amir and I developed multiple AI models to detect dementia from speech transcriptions. I was tasked with creating two deep learning models: A fine-tuned RoBERTa model and an Instruction tuned LLaMA model. This article covers the LLaMA model.
Why this matters
Having a loved one seemingly lose all memory of oneself is a hurtful experience that the friends and family of over 50 million people have to live with [7]. Dementia affects not only memory but also thinking and hinders patients from living a happy life. But Dementia is not a disease in itself.
Dementia is a syndrome, caused by a variety of diseases, with 60-70% of the cases attributed to Alzheimer's disease [7]. This makes Alzheimer's disease the most common cause of Dementia. One of the early signs of Alzheimer's is a language impairment which is even noticeable in the early stages of the disease [4]. Patients have difficulty finding the right words and are often frustrated with themselves which can lead to anxiety and depression. But these difficulties in expression also give hope for early diagnosis by language analysis.
To help the development of tools that can diagnose Alzheimer's disease early, Saturnino Luz et al. [6]. created the Alzheimer’s Dementia Recognition through Spontaneous Speech (ADReSS) challenge. Part of the challenge is to predict if a patient has Alzheimer's disease based on a speech sample. The challenge provides a dataset with transcriptions of speech samples from patients with Alzheimer's disease and healthy controls.
Figure 1: The Cookie Theft picture
The speech samples were taken from patients describing the Cookie Theft picture shown in Figure 1. It is part of the Boston Diagnostic Aphasia Exam [3]. and is used to assess the language capabilities of a patient. We used this dataset to train multiple machine learning and deep learning models. This article will cover how I instruction-tuned the LLaMA 7B model to achieve an accuracy of 75%.
Instruction-tuning LLaMA
LLaMA is a large language model developed by Meta. To fine-tune LLaMA, we need to do the following steps:
- Load the model and Tokenizer
- Load the dataset
- Create Prompts
- Tokenize the data
- Fine-tune the model using PEFT!
We will import all needed libraries with the following code:
import os
from random import randrange
from functools import partial
import torch
from datasets import load_dataset
from transformers import (AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
pipeline,
logging,
set_seed)
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, AutoPeftModelForCausalLM
from trl import SFTTrainer
Loading Model and Tokenizer
Before loading the model, we set up the model with a BitsandBytes
configuration, enhancing computational efficiency.
def create_bnb_config(load_in_4bit, bnb_4bit_use_double_quant, bnb_4bit_quant_type, bnb_4bit_compute_dtype):
bnb_config = BitsAndBytesConfig(
load_in_4bit = load_in_4bit,
bnb_4bit_use_double_quant = bnb_4bit_use_double_quant,
bnb_4bit_quant_type = bnb_4bit_quant_type,
bnb_4bit_compute_dtype = bnb_4bit_compute_dtype,
)
return bnb_config
We load the model from Hugging Face using a special function. It configures the model with our bnb settings and optimizes it for the available number of GPUs.
def load_model(model_name, bnb_config):
# Get number of GPU device and set maximum memory
n_gpus = torch.cuda.device_count()
max_memory = f'{40960}MB'
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config = bnb_config,
device_map = "auto", # dispatch the model efficiently on the available resources
max_memory = {i: max_memory for i in range(n_gpus)},
)
# Load model tokenizer with the user authentication token
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token = False)
# Set padding token as EOS token
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
Let us now use the create_bnb_config
and load_model
functions to load the models and the appropriate tokenizer.
model_name = "beomi/llama-2-ko-7b"
# Activate 4-bit precision base model loading
load_in_4bit = True
# Activate nested quantization for 4-bit base models (double quantization)
bnb_4bit_use_double_quant = True
# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"
# Compute data type for 4-bit base models
bnb_4bit_compute_dtype = torch.bfloat16
bnb_config = create_bnb_config(load_in_4bit, bnb_4bit_use_double_quant, bnb_4bit_quant_type, bnb_4bit_compute_dtype)
model, tokenizer = load_model(model_name, bnb_config)
Loading the Dataset
We want to load the dataset as objects of the dataset
class. Hugging Face develops this class which optimizes for Memory Efficiency, and has built-in preprocessing functions amongst other upsides (compared to pandas).
# The instruction dataset to use
dataset_name = ["./Control_db.csv","./Dementia_db.csv", ]
# Load dataset
dataset = load_dataset("csv", data_files = dataset_name, split='train')
print(f'Number of prompts: {len(dataset)}')
print(f'Column names are: {dataset.column_names}')
Creating the tuning Prompts
To fine-tune the LLaMA model, we want to turn all our data-points into long prompts including an instruction, the transcription, and our label mapped to a string. Data with the label 0
will be mapped to 'healthy' and data with the label 1
will be mapped to 'alzheimers'.
def create_prompt_formats(sample, instruction= None):
label_map = {0: "healthy", 1: "alzheimers"}
instruction = "The input is a transcription of a patient who could have the alzheimers disease. Based on the transcription respond with 'healthy' or 'alzheimers' according to the patients diagnosis."
# Initialize static strings for the prompt template
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
# Combine a prompt with the static strings
blurb = f"{INTRO_BLURB}"
instruction = f"{INSTRUCTION_KEY}\n{instruction}"
input_context = f"{INPUT_KEY}\n{sample['Transcript']}" if sample['Transcript'] else None
response = f"{RESPONSE_KEY}\n{label_map[sample['Category']]}"
end = f"{END_KEY}"
# Create a list of prompt template elements
parts = [part for part in [blurb, instruction, input_context, response, end] if part]
# Join prompt template elements into a single string to create the prompt template
formatted_prompt = "\n\n".join(parts)
# Store the formatted prompt template in a new key “text"
sample["text"] = formatted_prompt
return sample
We can try out the function by running the following snipped:
This will give us the output:
{'Language': 'eng',
'Data': 'Pitt',
'Participant': 'PAR',
'Age': 61,
'Gender': 'female',
'Diagnosis': 'Control',
'Category': 0,
'mmse': 30.0,
'Filename': 'S033',
'Transcript': " mhm . well the water's running over on the floor . &uh the chair [: stool] [* s:r] is tilting . the boy is into the cookie jar . and his sister is reaching for a cookie . the mother's drying dishes . &um do you want action or just want anything I see ? okay . mhm .",
'text': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nThe input is a transcription of a patient who could have the alzheimers disease. Based on the transcription respond with 'healthy' or 'alzheimers' according to the patients diagnosis.\n\nInput:\n mhm . well the water's running over on the floor . &uh the chair [: stool] [* s:r] is tilting . the boy is into the cookie jar . and his sister is reaching for a cookie . the mother's drying dishes . &um do you want action or just want anything I see ? okay . mhm .\n\n### Response:\nhealthy\n\n### End"}
We can see that a new column containing our newly constructed prompt was added.
Tokenizing the data
Before tokenizing, we have to find the maximum token length from the model configuration.
def get_max_length(model):
# Pull model configuration
conf = model.config
# Initialize a "max_length" variable to store maximum sequence length as null
max_length = None
# Find maximum sequence length in the model configuration and save it in "max_length" if found
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
max_length = getattr(model.config, length_setting, None)
if max_length:
print(f"Found max lenth: {max_length}")
break
# Set "max_length" to 1024 (default value) if maximum sequence length is not found in the model configuration
if not max_length:
max_length = 1024
print(f"Using default max length: {max_length}")
return max_length
We also prepare a function for tokenizing.
def preprocess_batch(batch, tokenizer, max_length):
return tokenizer(
batch["text"],
max_length = max_length,
truncation = True,
)
At this point, we have prepared our data and functions for promt generation and tokenization. We are now ready to preprocess our dataset.
The following function takes in our dataset and returns the tokenized data.
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str):
# Add prompt to each sample
print("Preprocessing dataset…")
dataset = dataset.map(create_prompt_formats)
# Apply preprocessing to each batch of the dataset & and remove "instruction", "input", "output", and "text" fields
_preprocessing_function = partial(preprocess_batch, max_length = max_length, tokenizer = tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched = True,
remove_columns = ['Language', 'Data','Participant','Age','Gender','Diagnosis','Category','mmse','Filename','Transcript','text'],
)
# Filter out samples that have "input_ids" exceeding "max_length"
dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)
# Shuffle dataset
dataset = dataset.shuffle(seed = seed)
return dataset
We call it like this:
# Random seed
seed = 33
max_length = get_max_length(model)
preprocessed_dataset = preprocess_dataset(tokenizer, max_length, seed, dataset)
Fine-Tuning the model using PEFT
For fine-tuning our LLaMA model efficiently, we have implemented Parameter-Efficient Fine-Tuning (PEFT) techniques, focusing on Low-Rank Adaptation (LoRA). The fine-tuning process is aimed at adjusting a small subset of the model's parameters to achieve significant improvements without the need for extensive computational resources.
Initializing PEFT Configuration
We configure our model for LoRA by initializing the necessary parameters. This step ensures that only a specific portion of the model parameters are adjusted during training, making the process resource-efficient.
def create_peft_config(r, lora_alpha, target_modules, lora_dropout, bias, task_type):
config = LoraConfig(
r = r,
lora_alpha = lora_alpha,
target_modules = target_modules,
lora_dropout = lora_dropout,
bias = bias,
task_type = task_type,
)
return config
def find_all_linear_names(model):
"""
Find modules to apply LoRA to.
:param model: PEFT model
"""
cls = bnb.nn.Linear4bit
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
print(f"LoRA module names: {list(lora_module_names)}")
return list(lora_module_names)
def print_trainable_parameters(model, use_4bit = False):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
num_params = param.numel()
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
if use_4bit:
trainable_params /= 2
print(f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}")
Fine-tuning Pre-trained Model
With our LoRA settings configured, we begin the fine-tuning process. The training arguments specify the number of epochs, learning rate, and other hyperparameters crucial for effective learning.
def fine_tune(model, tokenizer, dataset, lora_r, lora_alpha, lora_dropout, bias, task_type,
per_device_train_batch_size, gradient_accumulation_steps, warmup_steps, max_steps,
learning_rate, fp16, logging_steps, output_dir, optim):
# Enable gradient checkpointing to reduce memory usage during fine-tuning
model.gradient_checkpointing_enable()
# Prepare the model for training
model = prepare_model_for_kbit_training(model)
# Get LoRA module names
target_modules = find_all_linear_names(model)
# Create PEFT configuration for these modules and wrap the model to PEFT
peft_config = create_peft_config(lora_r, lora_alpha, target_modules, lora_dropout, bias, task_type)
model = get_peft_model(model, peft_config)
# Print information about the percentage of trainable parameters
print_trainable_parameters(model)
# Training parameters
trainer = Trainer(
model = model,
train_dataset = dataset,
args = TrainingArguments(
per_device_train_batch_size = per_device_train_batch_size,
gradient_accumulation_steps = gradient_accumulation_steps,
warmup_steps = warmup_steps,
max_steps = max_steps,
learning_rate = learning_rate,
fp16 = fp16,
logging_steps = logging_steps,
output_dir = output_dir,
optim = optim,
),
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
)
model.config.use_cache = False
do_train = True
# Launch training and log metrics
print("Training…")
# if do_train:
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
print(metrics)
# Save model
print("Saving last checkpoint of the model…")
os.makedirs(output_dir, exist_ok = True)
trainer.model.save_pretrained(output_dir)
# Free memory for merging weights
del model
del trainer
torch.cuda.empty_cache()
Training the Model
Finally, we can train the model. You can see our parameters in the following code snippet:
# LoRA attention dimension
lora_r = 16
# Alpha parameter for LoRA scaling
lora_alpha = 64
# Dropout probability for LoRA layers
lora_dropout = 0.1
# Bias
bias = "none"
# Task type
task_type = "CAUSAL_LM"
# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"
# Batch size per GPU for training
per_device_train_batch_size = 1
# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 4
# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4
# Optimizer to use
optim = "paged_adamw_32bit"
# Number of training steps (overrides num_train_epochs)
max_steps = 1000
# Linear warmup steps from 0 to learning_rate
warmup_steps = 2
# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = True
# Log every X updates steps
logging_steps = 1
fine_tune(model,
tokenizer,
preprocessed_dataset,
lora_r,
lora_alpha,
lora_dropout,
bias,
task_type,
per_device_train_batch_size,
gradient_accumulation_steps,
warmup_steps,
max_steps,
learning_rate,
fp16,
logging_steps,
output_dir,
optim)
Hooray! We have now successfully fine-tuned the model! All that is left is to load the model and use it to predict our test data.
Testing the model
We can load the pre-trained model which was saved in the previous step.
# Load fine-tuned weights
output_dir = "./results"
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map = "auto", torch_dtype = torch.bfloat16)
We used a similar technique to create the test prompts but this time the string stops after "### Response:"
.
sequences = pipeline(
queries,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=10,
early_stopping=True,
# do_sample=True,
)
To extract the response, we use the following function:
def extract_responses(sequences):
responses = []
for sequence in sequences:
for item in sequence:
# Split the text to find the part after "### Response:\n"
parts = item['generated_text'].split("### Response:\n")
if len(parts) > 1:
# Further split to isolate the response before "\n\n### End"
response_part = parts[1].split("\n\n### End")[0]
responses.append(response_part.strip())
return responses
responses = extract_responses(sequences)
print(responses)
Finally, with some more computations, we can create a confusion matrix:
Accuracy matrix for the test data.
Closing thoughts
In this project, we were tasked to predict dementia using Large Language Models. We chose to try instruction tuning to show an alternative method to top-layer tuning. When we first started tuning this model, it wasn't even able to only generate 'healthy' and 'dementia' as the output. After increasing the training time to 1000 steps, we finally achieved the above output. What will happen if we increase it even more?