In this post we review LongNet, a new research paper by Microsoft titled “LongeNet: Scaling Transformers to 1,000,000,000 Tokens”. The paper starts with the above amusing chart that shows the trend of transformer sequence lengths over time in a non-logarithmic y axis and we can see LongNet is way above with its one billion tokens. But how does it work? In this post we’ll cover important points from the paper to help you keep track on this new advancement.
If you prefer a video format then check out the following video:
Long Context Importance
Let’s start with the importance of long context. Modeling long sequences has become a critical demand in the era of large language models and there are many reasons for that, but as an example of why it is important, say we have a LongNet model that can process a huge input sequence, then we can feed it with books as inputs or a stack of books on the same topic and get a summary on a topic that takes into account many different information sources at once. Another intriguing example is being able to provide the entire internet as a sequence and have the model provide insights with this context.
Additionally, many times we want to have a model that is adapted to a specific task we have, in such case a common approach is to take a dataset that fits this specific task and take a foundational large language model, like LLaMA, and fine-tune the LLaMA model on that dataset to be able to yield adapted results. With LongNet and its huge context, we can provide the dataset as context to the model, and have it yield adapted results without fine-tuning.
Until LongNet, existing methods struggled with long sequences. They either have a high computational complexity or a limited model expressivity, which made it difficult to scale up the context length.
Improving the Attention Mechanism
LongNet can scale sequence length to one billion tokens and is doing so without scarifying the performance on shorter sequences and the key component that allows that is dilated attention which was introduced as part of the LongNet research paper.
In the core of the transformer model we have self-attention, where each word is attending every other word in the input sequence.
So for an input sequence for example, like the sequence of 12 tokens in the image below, we would get a 12 X 12 matrix which is called the attention matrix. This matrix has a quadratic dependency on the sequence length, which makes it very difficult to scale up the context length.
With dilated attention, we take the input sequence and do input segmentation to split it into equal segments, so in the image below we get three segment of size 4 for a 12 tokens input. The attention is calculated separately for each segment, so here we would get three attention matrices of size 4 by 4, rather than 12 by 12 as before.
But dilated attention has another step which is sparsification, where we remove rows from each segment based on a hyperparameter r which controls the distance between each removed row. So, in the example image below r is 2, which means we keep rows 1 and 3 in each segment attention matrix.
If we remove the second and fourth rows we also cannot attend to them from the remaining rows so the matching cells are also removed.
Finally, we get from each segment attention matrix a sparse attention matrix of 2 X 2 which is way smaller than the original 12 X 12.
A very important note is that the attention on each segment can be calculated in parallel which allows distributed training on multiple GPUs.
Additionally, this approach is significantly faster than the original self-attention on long sequences as the researchers show in the following chart, where we can see that the runtime of dilated attention stays very fast even when sequence length is increased significantly, while the original attention grows exponentially.
Mixture of Dilated Attentions
You may be asking yourself by now if this approach doesn’t lose a lot of information from the context, which brings us to talk about mixture of dilated attentions. Consider the following example from the paper where we see an attention matrix for sequence of size 16, divided into two segments of size 8 with dilation rate of 2, so every second row is removed.
To capture long-range information, they add another dilated attention block with larger segment size and higher dilation rate, as in the example below on the left, with segment size 16 and dilation rate of 4. And to capture short-range information they include another dilated attention block with smaller segments and lower dilation rate, like in the example below on the right, where the segment is of size 4 and dilation rate is 1.
Then, all of the different dilated attentions can be computed in parallel and results from all attention blocks are used to provide the model with diverse and full information that captures both short-range and long-range information.
Multi-Head Dilated Attention
To further diverse the captured information in addition to the mixture of dilate attentions, we can use multi-head dilated attention blocks, where for a given segment length, 8 in the example below, and a dilation rate, 2 in this example below, we choose different rows to remove in each block, as we can see in the multiple heads in the image below, so each head looks at different information.
References & Links
- Paper page – https://arxiv.org/abs/2307.02486
- Video – https://youtu.be/VMu0goeii3g
- We use ChatPDF to help us analyze research papers – https://www.chatpdf.com/?via=ai-papers (affiliate)
All credit for the research goes to the researchers who wrote the paper we covered in this post.