GPT From Scratch
After being motivated to learn about transformer in more depth, I spent a good chunk of time to implement the GPT model from scratch following a tutorial by Andrej Karpathy. I typed every line of code by myself to understand the algorithm. And I was actually able to reproduce the same performance as the one in the tutorial.
However, my goals are not just that. I wanted to find out the relationship between the hyperparameters and the complexity of the transformer, which is measured by (1) the number of trainable parameters and (2) the time complexity to generate one token. Therefore, during the subsequent sections, we will pay close attention to two things – the shape of the data and the shape of the trainable matrices.
This post is long, and should be used as a reference when reasoning about transformers. If you want new insights, skip to the last section – I have some seemingly mind-blowing answers to the above questions!
Huge thanks to Andrej Karpathy for the tutorial on decoder-only transformer from scratch.
Preprocessing
First thing first, we are given a corpus, which is text. Therefore, we need two functions to convert between text and vectors:
encode(str) -> list[int]
decode(list[int]) -> str or list[str]
Now we only work with vectors. The input shape is:
- For training, each training step needs
xb
andyb
, both with shape(BATCH_SIZE, BLOCK_SIZE, VOCAB_SIZE)
yb
is the result by moving the window one step to the future fromxb
.- (There will be a way to prevent data leak from the future in the self-attention mechanism.)
- For inference (or text generation in particular)
- start with a text – can be a whole prompt or just a character (like
\n
) - encode it to a batch with basically the same general shape as above. But now batch size is 1 and block size depends on the given prompt.
- start with a text – can be a whole prompt or just a character (like
Given this similarity in input structure, we can consider one unifying input shape moving forward, which is BxTxE
.
Model architecture (decoder-only)
Hyperparams
Important hyperparameters, with values from the tutorial – probably matched with the 2017 Attention paper:
BATCH_SIZE
(\(B\)) = 64BLOCK_SIZE
(\(T\)) = 256; context window lengthVOCAB_SIZE
(\(C\)) = depends on the corpusEMBED_SIZE
(\(E\)) = 384N_BLOCKS
(\(L\)) = 6; num decoder blocksN_HEADS
= 6; num attention heads inside a multi-head attention component (which is inside a decoder block)DROPOUT
= 0.2; dropout probability
With this setting, a 5000-step training loop takes about 15-20 minutes on a GPU.
Computations
For training and ‘primitive’ inference, i.e., forward()
- Start with input of size
BxTxC
- Embedding: sum of token embedding and positional embedding
- Token embedding: simply a trainable matrix of size
CxE
. It acts as a lookup table for each token. This is actually static embedding, like word2vec. Implementation-wise, the tokens can be represented as a matrix of one-hot vectors, then multiplied with this token embedding matrix. - Positional embedding: a fixed matrix of size
TxE
. It maps each position slot in the input sequence with an (ideally unique) number that contains the positional info. - Output shape is
BxTxE
– the data is now officially in the embedding space!
- Token embedding: simply a trainable matrix of size
- Decoder block x
N_BLOCKS
. Each Decoder block is laid out as follows:- A LayerNorm, with 2 trainable params (\(\gamma\) and \(\beta\)).
- Multi-head Attention: Attention head x
N_HEADS
.- Each Attention head has:
head_size = EMBED_SIZE // N_HEADS
(\(H\)) (related to the upcoming concatenation of output)- Three trainable weight matrices – key, query, and value – each has size
ExH
.- Each of these matrix transform the data into \(K\), \(Q\), and \(V\), each of size
BxTxH
- Attention formula: \(\text{Attention}(K, Q, V) = \text{softmax}\left(\text{tril}_{t,t}\times\frac{Q\times K^T}{\sqrt{E}}\right)\times V\)
tril
is at
\(\times\)t
(t <= T
) lower triangular matrix to bring the averaged(?) information of previous tokens to the next token- Softmax-ed data is actually dropped out before begin multiplied with \(V\).
- Data is merged from the three, still has shape
BxTxH
- Each of these matrix transform the data into \(K\), \(Q\), and \(V\), each of size
- This is where the tokens talk to each other, thus ‘self-attention’.
- Output of all Attention unit is concatenated along the 2nd (i.e., last) dimension. Because
H x N_HEADS = E
, data shape is now back to the good oldBxTxE
! - Then it is multiplied with a trainable
ExE
matrix called the projection layer. It is to prepare data for the skip connection later. - There is a dropout layer here (parametrized by
DROPOUT
). - There is also a skip connection, where input is added directly to the output. This is poetically called the residual pathway.
- Each Attention head has:
- FFNN
- A LayerNorm with 2 trainable params.
- A Multi Layer Perceptron with two trainable matrices of size
Ex(4E)
and transposed, with an activation in the middle and Dropout at the end. - Also has a skip connection.
- This is where the tokens ‘think for themselves’.
- A language modeling head (code name:
lm_head
)- A LayerNorm with 2 trainable params.
- Then a trainable
ExC
matrix - Data is now mapped back to the ‘language space’
BxTxC
!
Training loop
- Loss: cross entropy loss between the logits (
BxTxC
) with the actual next tokens (BxT
) AdamW
optimizer should be used. It seems to be the best.- Doing eval once in a while by doing with-label inference on both train and val set. Then get the average loss. Remember to set
torch.no_grad()
andmodel.eval()
before; andmodel.train()
after.
Inference
(This is how ChatGPT actually spit out text!!!) So we were left off at forward()
, which returns a logits of shape BxTxC
. From here, we do auto-regressive generation by repeatedly:
- first, obtaining a probability distribution from
logits
by applying softmax, - then, sampling a new token in the sequence (so
B
tokens for the whole batch). (This is how and why ChatGPT produces slightly different outputs given the same input!) - finally, from this new token, prepare the next batch of input ids to feed to
forward()
(by concatenation); get the newlogits
.
Do this until some stop condition, like max_new_tokens
met. (Beam search may also be used?)
Observations
I roughly estimated the complexity of this algorithm:
- Number of trainable params (proxy for model size) = \(O[E(B+C+E)]\).
- See something? The model size seems to be independent from the context length?!! I originally did not believe this so I re-checked mine and Andrej’s code – the
block_size
variable is indeed not playing any role in shaping any matrices! - If this is true, why is there a context length limit on published models? I suspect that is just because the internal buffer for the triangular matrix (to prevent info flowing backwards) has a fixed size of \(T \times T\). I think there is a way to change that after loading the model weights. Consequently, any standard transformer can theoretically take arbitrarily long input. It will just run for longer.
- It actually still makes sense for this to be true. We usually think that, for transformer to attend to the past tokens, there should be some
TxT
matrix to represent the mutual importance between tokens. However, the transformer authors have decomposed that into \(Q \times K^T\), which only requires two matrices of sizeExH
, independent ofT
. That is definitely an assumption (which I think is highly intentional) that the mutual importance between two tokens can be achieved by dot-products. That removes the needs for an extraTxT
interaction matrix in the middle.
- See something? The model size seems to be independent from the context length?!! I originally did not believe this so I re-checked mine and Andrej’s code – the
- The (sequential) time complexity for predicting/training one token is roughly \(O[TCE+L(T^3+T^2H+TE^2)]\).
- Here we actually see \(T\) playing a big role.
- Given these variables (except \(L\)) are usually in the order of at least \(1000\), the cost of predicting on token is currently about \(1B\), times a constant (I guess about 20). Not a small cost!