Link to original paper: https://arxiv.org/abs/2106.09685

What is LoRA

The Problem

In Natural Language Processing, we often have to fine-tune pre-trained language models to perform better at specific tasks (e.g. train a GPT-3 model to be better at solving math problems).

During fine-tuning, all the parameters of the pre-trained model is updated to produce the new model; the new model has as many parameters as the original. This makes fine-tuning challenging in terms of computational costs and deployment, especially for large pre-trained models (e.g. GPT-3 has 175 billion trainable parameters)

Prior to LoRA, the existing techniques to mitigate the problem include

  1. Limiting the number of parameters that can be updated
  2. Incorporating external modules tailored to new tasks

However, these solutions often introduce inference latency (time taken for model to generate output once it is fed its input) because they

  1. extend the depth of the model (more layers = more operations)
  2. reduce the model’s usable sequence (need to process input in smaller segments)

Moreover, these methods often compromise the model quality.

Hypothesis of LoRA

The authors hypothesise that the change in weights during fine-tuning has a low “intrinsic rank”, i.e. the weights changes during adaptation might reside in a lower-dimensional subspace compared to the full parameter space, allowing for efficient low-rank approximations

The rank of a matrix refers to the number of linearly independent rows/columns in that matrix. Let’s use an example, pretend you have a 4x5 matrix $\Delta W$ which represents the weight updates during a single instance of backpropagation

$$ \Delta W = \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 2 & 4 & 6 & 8 & 10 \\ 2 & 5 & 7 & 4 & 8 \\ 4 & 10 & 14 & 8 & 16 \\ \end{bmatrix} $$

We can reduce $\Delta W$ to its reduced row echelon form $B$

$$ \Delta W = \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 2 & 4 & 6 & 8 & 10 \\ 2 & 5 & 7 & 4 & 8 \\ 4 & 10 & 14 & 8 & 16 \\ \end{bmatrix} \sim \begin{bmatrix} 1 & 0 & 1 & 12 & 9 \\ 0 & 1 & 1 & -4 & -2 \\ 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 \\ \end{bmatrix} = B $$