Fasten your seat belt, and let’s go together over the main techniques to speedup the training of your LLM. This blog post is a follow-up of a previous blog post on the basic concepts behind LLM training. Check it out before continuing.
Precision?
One way to significantly speedup compute is to reduce the precision of the numbers we are dealing with. What do we mean by precision? Let’s do a quick recap on number representation in a computer memory.
Computers use bits (binary digit) to represent data as zeros and ones. Positive integers can be represented as sum of powers of 2. For example, with 16 bits we can represent positive integers in [0, 65535] since 2^0 + 2^1 + … + 2^15 = 2^16 - 1 = 65535. Storing negative integers requires using an extra bit for the sign, thus with the same number of bits as before we can represent integers in [-32768, 32767].
But what about real numbers? Computers use floating-point numbers, which is a subset of real numbers that are represented using an integer scaled by an exponent. Under the IEEE 754 floating point standard, a floating-point number can be represent using a sign bit, an exponent width, and a significand precision. Single-precision fp32 uses for example an exponent width of 8 bits and a significand precision of 24 bits (23 explicitly stored, for a total of 1 + 8 + 23 = 32 bits), and thus can represent numbers in roughly a +/- 3.4 10^38 range. Half-precision fp16 uses an exponent width of 5 bits and a significand precision of 11 bits (10 explicitly stored), and thus can represent numbers in roughly a +/- 65504 range. Half-precision has a much lower range compared to single-precision, and in deep learning applications smaller ranges can produce overflow (trying to represent a number that is very large) or underflow (trying to represent a very small number) errors.
More recent GPU architectures, such as NVIDIA Ampere, supports additional floating-point formats like bf16. bf16 still uses 16 bits but with an exponent width of 8 bits. Thus, compared to fp16, it has a much bigger dynamic range, roughly the same as fp32, at the cost of a worse precision. In practice, bf16 prevents overflow and underflow from happening most of the time.
Reducing precision, when it is possible, reduces memory requirements (i.e. using larger models), reduces memory bandwidth requirements (i.e. speedup data transfer operations), and speeds up math operations (e.g. using NVIDIA Tensor Cores).
For more practical details: bitsandbytes.
Mixed precision
Let’s do a quick run-through of the memory requirements for training a model in full-precision. You need to store the model weights (4 bytes per parameter), the associated gradients (4 bytes per parameter), and the optimizer states (2*4=8 bytes per parameter for Adam because you need to store the first and second moments). You also need to store the forward activations for gradient computation during backpropagation, and the required memory typically depends on the sequence length, the hidden size, and the batch size. There can be additional sources of memory consumption, e.g. beam search at inference time.
Mixed precision combines different numerical precisions, by identifying the operations where single-precision is not required and using lower-precision arithmetic for these operations. There are three main types of operations in deep learning: element-wise, reduction, and dot-product. The first two tend to be memory limited while the last one is typically compute limited if the corresponding matrices are large enough. By reducing the amount of memory used, mixed precision reduces the time spent in memory limited layers.
When working with half-precision, it is necessary to scale the loss (to preserve gradient values of small magnitude), and to keep full-precision master weights that accumulate the gradients. Thus, storing the model weights requires 6 bytes per parameter. Wait, haven’t we increased our memory consumption? In practice, training memory is usually dominated by the forward activations (large batch size), and since activations are stored in half-precision, the overall memory consumption is roughly halved.
In PyTorch:
scaler = GradScaler()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
For more technical background: Mixed Precision Training.
For more practical details: NVIDIA’s Train With Mixed Precision, PyTorch’s Automatic Mixed Precision (AMP).
Gradient Checkpointing
I previously mentioned that forward activations need to be stored for gradient computation during backpropagation. But one might ask, why not recomputing them during backpropagation instead of storing them? This is the motivation behind gradient checkpointing. Rather than storing inputs from every layer for computing upstream gradients, gradient checkpointing stores inputs from a few strategically chosen layers and recomputes the rest. This saves memory at the cost of a slower training (usually by 20%), thus gradient checkpointing should only be used when you have saturated your GPU memory.
For more technical background: Fitting larger networks into memory.
Parameter Efficient Fine-Tuning
Fine-tuning refers to a subsequent training of a pre-trained model. But do we need to re-train, i.e. update, the entire model? The core idea behind parameter efficient fine-tuning is to train only a small subset of the parameters, keeping the rest unchanged, i.e. frozen. The compute cost is thus reduced, and since the updates are more constrained it can also help prevent catastrophic forgetting (the model forgets what it learnt during pre-training).
For example, if you are doing sentiment analysis, you could use a pre-trained BERT model and train just a classification head (linear layer with input size hidden_size and output size num_labels) on top of it.
A more sophisticated approach you may have heard of is Low Rank Adaptation (LoRA). LoRA decomposes weight updates into a product of 2 smaller matrices: W + ∆W = W + BA with B of dimension (d, r), A of dimension (r, k), r << d, k (the smaller r the higher is the risk of underfitting), and A and B initialized to zero. This type of decomposition, called low rank approximation, can be applied to any matrix, with more or less error. For large language models, it is typically applied to the projection matrices W_query, W_key and W_value in the attention blocks of the Transformer. LoRA is reported to have performance comparable to full fine-tuning, i.e. training all parameters, but it drastically reduces the number of trainable parameters (usually by 10,000x) and GPU memory (usually by 3x). There isn’t any inference latency after merging, i.e. adding, the adapter weights. Since the original model weights are kept frozen, multiple LoRA weights can be trained and used for different tasks. Note that since most of the memory footprint during training comes from storing the forward activations for backpropagation, aggressively reducing the amount of LoRA parameters (e.g. very small rank) only yields minor benefits after some point.
For more technical background: low rank approximation, LoRA: Low-Rank Adaptation of Large Language Models, Prefix-Tuning: Optimizing Continuous Prompts for Generation (additional learnable weights for keys and values at every attention layer), The Power of Scale for Parameter-Efficient Prompt Tuning (prepends k learnable token embeddings to the input), Parameter-Efficient Transfer Learning for NLP (add a learned bottleneck layer between each frozen layer), QLORA: Efficient Finetuning of Quantized LLMs (4-bit normal float data type, double quantization and paged optimizers), S-LoRA: Serving Thousands of Concurrent LoRA Adapters (to serve up to thousands to concurrent LoRA adapters).
I hope you enjoyed this blog post! For practical tips specific to PyTorch, check out their excellent Performance Tuning Guide.
I have left out parallelism techniques and distributed training for a later blog post. Subscribe for free, and stay tuned!