In this blog post, I’ll explain how large language models are typically trained. This blog post is a follow-up of a previous blog post describing the fundamental concepts behind large language models; check-it out before reading this one.
Simply put, a large language model is a machine learning model that estimates the probability of words within sequences of words. Training a large language model means finding optimal model parameters to accomplish a given task, for example generating plausible language. The model predictions are quantified as “good” or “bad” according to a chosen training loss function, usually the cross-entropy loss. An optimization algorithm is applied to find a set of parameters that minimizes this training loss. Various optimization algorithms can be used; the most commonly used in deep-learning settings are Stochastic Gradient Descent (SGD) and Adam.
Did you say entropy?
The notion of entropy was introduced by Shannon in his 1948 paper "A Mathematical Theory of Communication". Shannon considered different approaches to communicate messages, and proved that entropy represents the minimum number of bits required to losslessly encode and transmit a message. He defined entropy as:
Entropy can be viewed as the expected information of all outcomes, where the information -log2(p) is the number of bits required to represent a possible outcome of probability p. Why -log2(p)? To give a rough intuition, say you have n possible outcomes each with probability p_i. For the least likely outcome j, 1 = p_1 + … + p_n >= n p_j. Thus, the number of bits required to encode this outcome is at most the number of bits required to encode n outcomes: log2(n) <= -log2(p_j).
Let’s consider a practical example for natural language: say I’ll choose an English letter at random, and you have to guess the letter by asking yes-or-no questions (analogous to binary states). Shannon’s source coding theorem tells us that you’ll need at least log2(1/26) = 4.70 yes-or-no questions on average. Put it differently, at least 5 bits are required on average to losslessly encode an English letter chosen at random. What if instead of guessing letters at random, you’re actually trying to guess letters in actual English words? This should be easier, since you can use letter frequency, language patterns, etc., to help you. Indeed, Shannon computed that the entropy of the English language is around 2.62 bits per letter.
Note: Actually achieving the Shannon limit is non trivial. Compressors are often tailored to a particular setting, say PNG for images.
For more advanced topics: Huffman coding.
… and cross entropy?
The cross entropy loss between 2 probability distributions p and q is defined as:
Applying the ideas from the previous section, the cross entropy CE(p, q) can be viewed as the expected information of a sample x ~ p (or expected number of bits to encode this sample) using the encoding given by q. It can be decomposed into the true entropy H(p) and the Kullback-Leibler divergence KL(p, q) between p and q.
Since KL(p, q) is always positive, the cross entropy CE(p, q) is bounded below by the true entropy H(p). The cross entropy CE(p, q) is thus an upper estimate of the true entropy, and better estimates can be obtained by constructing better approximations q.
The cross-entropy loss is a widely used training loss for large language models. It can be used during pre-training, during supervised fine-tuning on curated input-output pairs (possibly augmented with instructions - to help the model generalize better - thus called instruction fine-tuning), and even in reinforcement learning with human feedback RLHF using Direct Preference Optimization DPO.
For more advanced topics: Finetuned language models are zero-shot learners, Direct Preference Optimization: Your Language Model is Secretly a Reward Model.
Finding optimal parameters
You probably heard of gradient descent. Gradients provide the steepest direction to reach the local minimum. Gradient descent iteratively computes the partial derivatives of the loss function with respect to each model parameter, and adjust these parameters in the opposite direction of the gradients. Gradient descent is a first-order optimization algorithm, since only the first derivatives are used.
Batch gradient descent computes the loss and gradients for the entire training set before updating the parameters. This is very slow and intractable for datasets that do not fit in memory. Stochastic Gradient Descent (SGD) computes the loss and gradients for only a single training example (or a few training examples, in the case of mini-batch). SGD typically converges much faster than batch gradient descent; and it can also be used in an online setting to learn from new data. Intuitively, larger batch size results in better approximated gradients whereas smaller batch sizes produces more frequent updates.
Adam, i.e adaptive moment estimation, is a more recent optimizer that uses estimates of the first and second moments of the gradients to update the model parameters:
where
are debiased first and second moment estimates adapted from:
g_t being the gradient compute at time step t. Note that, since m_t and v_t are initialized to zero, they are biased towards zero. One can prove that:
AdamW is a variant of Adam with L2 regularization on weights. L2 regularization is a classic regularization method: since large weights produce instability (small changes in inputs can lead to large changes in outputs), L2 regularization penalizes the loss function by adding the sum of squares of the model weights, multiplied by an hyper parameter alpha. In practice, it can be implemented by adding alpha * weight to the gradients, rather than actually changing the loss function. This implementation is called weight decay: a little proportion of the weight is subtracted at each iteration. Note that, for vanilla SGD, L2 regularization and weight decay are identical, but that is not longer the case for more sophisticated optimizers like Adam.
For practical details: torch.optim.
For more advanced topics: Tikhonov regularization. Newton-Raphson method.
How to deal with conversations?
Given sequences of texts, you can train an autocomplete large language model to predict the next word given preceding words. But what about conversations? Indeed, conversational assistants are trained on dialog data, and this requires a special data formatting. ChatML is one of such formatting designed to make multi-turn conversations easier to manage. Each turn is wrapped with special tokens. These special tokens help the model differentiate for example between the user and the assistant. They can also be used to mask the loss from the labels associated with the user turns, if we are primarily interested in learning assistant behavior or have realized that user questions in the training data are very noisy.
As an example, Llama-2 uses the following syntax:
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
where the system system prompt is used during training to condition the character of the assistant.
You should always use the special tokens and formatting used during the training of the model you are planning to use. Also, note that adding additional special tokens changes the embedding matrix size. But in order to use Tensor Cores for fast matrix multiplication, the dimensions need to be multiples of 8, so you may need to resize using pad_to_multiple_of.
I hope you enjoyed this blog post! In a later blog post, I’ll go over different techniques for optimizing your fine-tuning (mixed precision, gradient checkpointing, peft, etc.). Stay tuned!
Fun follow-up exercise: try re-implementing your choice of optimizer in PyTorch or NumPy. For the most motivated, take a stab at the Annotated Transformer.