Finetuning Llava
Llava is a better open-source vision-language model developed by Haotian Liu’s team. As with almost every other LLM, to make Llava really useful for your use case, it needs finetuning. But building a training loop is a complex task, involving data loading, gradient tracking, GPU management, logging, and so on. Therefore, we usually use libraries to make the process less painful.
What are the requirements of a library used for training Llava? It has to support:
- LoRA finetuning: the smallest Llava model,
llava-hf/llava-1.5-7b-hf
already has 7B params. According to NielsRogge (from HF itself), full finetuning of the model requires 18x GPU RAM of its size, which is 126GB – something that my two 24GB GPUs cannot contain. There is another way, which is to selectively finetune only the last layer of the model. This way was popular when finetuning models like BERT, but I don’t see anyone mentioning it for VLMs, so I guess it won’t work as well. For LoRA, there has been thepeft
library which works quitely seamlessly with transformers model. - Distributed training: Again, memory constraint. A single 24GB GPU cannot hold that much gradients of a 7B model, even in its LoRA form.
- Custom evaluation metrics for text generation model, like BLEU, ROUGE, BERTScore, or even NLI score. Eval loss is simply inadequate.
- Reliable saving of checkpoints
- Informative logging
Currently, there are broadly two popularly libraries for this task:
- Writing your own training loop, usually in Pytorch. Tensorflow will probably also work, but less popular now.
- Using
Trainer
, or one of its subclasses such asSeq2SeqTrainer
andSFTTrainer
.
Let’s first talk about the new guy – Trainer.
Trainer for Llava
Trainer is a popular utility inside the famous transformers
library by HuggingFace (the leading open-source library for LLMs currently). Its popularity is probably due to the fact that it abstracts away many tedious things that are also common when writing a training loop, enable you to write less code to do more. That is best illustrated when looking at the list of command-line arguments it takes (contained in a TrainingArguments
object):
- Gradient descent parameters:
learning_rate
,weight_decay
, Adam’s params, andlr_scheduler_type
. - Logging: they are particular good with logging. By default, there are many useful information being logged via
wandb
dashboard. Notable params arelogging_steps
andreport_to
. - Evaluation-related:
eval_steps
,metric_for_best_model
,greater is_better
, andload_best_model_at_end
(not sure why this has to be a param). - Memory management:
auto_find_batch_size
(not working for me yet) andtorch_empty_cache_steps
.
This suite of over 100 arguments is powerful but also overwhelming. It quickly feels like you are learning a new programming language that is not very well-documented. Even though you need to write little code, every line of code you write now takes much more time to consider.
And here’s some facts I have found after a day of digging through Trainer’s docs and forums and source code:
SFTTrainer
, even though being recommended in an official blog of HF (code), does not support custom eval metrics for text generation model. There is a place for you to plug in a customcompute_metrics
, but it (1) requires you to implement the autoregressive generation (e.g., greedy search or beam search) from scratch and, more disappointingly, (2) doesn’t seem to provide thepixel_values
for you to even do forward passes with Llava.Seq2SeqTrainer
support text generation during evaluation, but then the real you to inherit it to overwrite the evaluation loop.- Why? All tutorials show that the collator for training must have the ground truth response inside the prompt. That cannot happen during evaluation. But Seq2SeqTrainer takes only one
data_collator
parameter, which simply means it doesn’t not natively support two different collating behaviors for training and evaluation (weird!) - Even with inheritance in place, after quite some time trying to get it to work, I could not resolve an error in the
_merge_input_ids_with_image_features
method ofLlavaForConditionalGeneration
. At this point, I have already spent too much time reading transformers’ source code (which is quite exhausting), so I gave up for now. But this remains the most promising way to use Trainer for my use case.
- Why? All tutorials show that the collator for training must have the ground truth response inside the prompt. That cannot happen during evaluation. But Seq2SeqTrainer takes only one
After the entire process, I learned that Trainer and its subclasses sit on top of a pretty big code base. The nice thing is that it is relatively intuitive to read its source code and to understand how things work. But the fact that I have to go through the code already highlights that inadequacy of the documentation and support ecosystem for this new library.
Pytorch!
Conclusion? Go back to the traditional training loop! Now the challenge lies in distributed training. But surprisingly, when doing SomeModel.from_pretrain(..., device_map="auto")
, the model is loaded on multiple GPUs and trained smoothly! So, it guess I will stick with Pytorch for a while and observe how the Trainer class evolves for new use cases like VLM finetuning with text generation metrics.