After being motivated to learn about transformer in more depth, I spent a good chunk of time to implement the GPT model from scratch following a tutorial by Andrej Karpathy. I typed every line of code by myself to understand the algorithm. And I was actually able to reproduce the same performance as the one in the tutorial.

However, my goals are not just that. I wanted to find out the relationship between the hyperparameters and the complexity of the transformer, which is measured by (1) the number of trainable parameters and (2) the time complexity to generate one token. Therefore, during the subsequent sections, we will pay close attention to two things – the shape of the data and the shape of the trainable matrices.

This post is long, and should be used as a reference when reasoning about transformers. If you want new insights, skip to the last section – I have some seemingly mind-blowing answers to the above questions!

Huge thanks to Andrej Karpathy for the tutorial on decoder-only transformer from scratch.

Preprocessing

First thing first, we are given a corpus, which is text. Therefore, we need two functions to convert between text and vectors:

  • encode(str) -> list[int]
  • decode(list[int]) -> str or list[str]

Now we only work with vectors. The input shape is:

  • For training, each training step needs xb and yb, both with shape (BATCH_SIZE, BLOCK_SIZE, VOCAB_SIZE)
    • yb is the result by moving the window one step to the future from xb.
    • (There will be a way to prevent data leak from the future in the self-attention mechanism.)
  • For inference (or text generation in particular)
    • start with a text – can be a whole prompt or just a character (like \n)
    • encode it to a batch with basically the same general shape as above. But now batch size is 1 and block size depends on the given prompt.

Given this similarity in input structure, we can consider one unifying input shape moving forward, which is BxTxE.

Model architecture (decoder-only)

Hyperparams

Important hyperparameters, with values from the tutorial – probably matched with the 2017 Attention paper:

  • BATCH_SIZE (\(B\)) = 64
  • BLOCK_SIZE (\(T\)) = 256; context window length
  • VOCAB_SIZE (\(C\)) = depends on the corpus
  • EMBED_SIZE (\(E\)) = 384
  • N_BLOCKS (\(L\)) = 6; num decoder blocks
  • N_HEADS = 6; num attention heads inside a multi-head attention component (which is inside a decoder block)
  • DROPOUT = 0.2; dropout probability

With this setting, a 5000-step training loop takes about 15-20 minutes on a GPU.

Computations

For training and ‘primitive’ inference, i.e., forward()

  • Start with input of size BxTxC
  • Embedding: sum of token embedding and positional embedding
    • Token embedding: simply a trainable matrix of size CxE. It acts as a lookup table for each token. This is actually static embedding, like word2vec. Implementation-wise, the tokens can be represented as a matrix of one-hot vectors, then multiplied with this token embedding matrix.
    • Positional embedding: a fixed matrix of size TxE. It maps each position slot in the input sequence with an (ideally unique) number that contains the positional info.
    • Output shape is BxTxE – the data is now officially in the embedding space!
  • Decoder block x N_BLOCKS. Each Decoder block is laid out as follows:
    • A LayerNorm, with 2 trainable params (\(\gamma\) and \(\beta\)).
    • Multi-head Attention: Attention head x N_HEADS.
      • Each Attention head has:
        • head_size = EMBED_SIZE // N_HEADS (\(H\)) (related to the upcoming concatenation of output)
        • Three trainable weight matrices – key, query, and value – each has size ExH.
          • Each of these matrix transform the data into \(K\), \(Q\), and \(V\), each of size BxTxH
          • Attention formula: \(\text{Attention}(K, Q, V) = \text{softmax}\left(\text{tril}_{t,t}\times\frac{Q\times K^T}{\sqrt{E}}\right)\times V\)
            • tril is a t \(\times\) t (t <= T) lower triangular matrix to bring the averaged(?) information of previous tokens to the next token
            • Softmax-ed data is actually dropped out before begin multiplied with \(V\).
          • Data is merged from the three, still has shape BxTxH
        • This is where the tokens talk to each other, thus ‘self-attention’.
      • Output of all Attention unit is concatenated along the 2nd (i.e., last) dimension. Because H x N_HEADS = E, data shape is now back to the good old BxTxE!
      • Then it is multiplied with a trainable ExE matrix called the projection layer. It is to prepare data for the skip connection later.
      • There is a dropout layer here (parametrized by DROPOUT).
      • There is also a skip connection, where input is added directly to the output. This is poetically called the residual pathway.
    • FFNN
      • A LayerNorm with 2 trainable params.
      • A Multi Layer Perceptron with two trainable matrices of size Ex(4E) and transposed, with an activation in the middle and Dropout at the end.
      • Also has a skip connection.
      • This is where the tokens ‘think for themselves’.
  • A language modeling head (code name: lm_head)
    • A LayerNorm with 2 trainable params.
    • Then a trainable ExC matrix
    • Data is now mapped back to the ‘language space’ BxTxC!

Training loop

  • Loss: cross entropy loss between the logits (BxTxC) with the actual next tokens (BxT)
  • AdamW optimizer should be used. It seems to be the best.
  • Doing eval once in a while by doing with-label inference on both train and val set. Then get the average loss. Remember to set torch.no_grad() and model.eval() before; and model.train() after.

Inference

(This is how ChatGPT actually spit out text!!!) So we were left off at forward(), which returns a logits of shape BxTxC. From here, we do auto-regressive generation by repeatedly:

  • first, obtaining a probability distribution from logits by applying softmax,
  • then, sampling a new token in the sequence (so B tokens for the whole batch). (This is how and why ChatGPT produces slightly different outputs given the same input!)
  • finally, from this new token, prepare the next batch of input ids to feed to forward() (by concatenation); get the new logits.

Do this until some stop condition, like max_new_tokens met. (Beam search may also be used?)

Observations

I roughly estimated the complexity of this algorithm:

  • Number of trainable params (proxy for model size) = \(O[E(B+C+E)]\).
    • See something? The model size seems to be independent from the context length?!! I originally did not believe this so I re-checked mine and Andrej’s code – the block_size variable is indeed not playing any role in shaping any matrices!
    • If this is true, why is there a context length limit on published models? I suspect that is just because the internal buffer for the triangular matrix (to prevent info flowing backwards) has a fixed size of \(T \times T\). I think there is a way to change that after loading the model weights. Consequently, any standard transformer can theoretically take arbitrarily long input. It will just run for longer.
    • It actually still makes sense for this to be true. We usually think that, for transformer to attend to the past tokens, there should be some TxT matrix to represent the mutual importance between tokens. However, the transformer authors have decomposed that into \(Q \times K^T\), which only requires two matrices of size ExH, independent of T. That is definitely an assumption (which I think is highly intentional) that the mutual importance between two tokens can be achieved by dot-products. That removes the needs for an extra TxT interaction matrix in the middle.
  • The (sequential) time complexity for predicting/training one token is roughly \(O[TCE+L(T^3+T^2H+TE^2)]\).
    • Here we actually see \(T\) playing a big role.
    • Given these variables (except \(L\)) are usually in the order of at least \(1000\), the cost of predicting on token is currently about \(1B\), times a constant (I guess about 20). Not a small cost!