Linear Transformers Are Faster After All
\[ \newcommand{\R}{\mathbb{R}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\N}{\mathbb{N}} \newcommand{\sft}{\text{softmax}} \newcommand{\List}{\text{List}} \newcommand{\Seq}{\text{Seq}} \newcommand{\SeqT}{\text{SeqT}} \newcommand{\CSeqT}{\text{CSeqT}} \newcommand{\Dist}{\text{Dist}} \newcommand{\SM}{\text{SM}} \newcommand{\Fn}{\text{Fn}} \newcommand{\Tok}{\text{Tok}} \newcommand{\Aij}{ A_{[i,j]}} \]
It is well-known that removing the exponential from the attention layer of a transformer allows for a recurrent reformulation, with computational cost that is linear instead of quadratic on context length [1]. One would expect that such an architecture would be far faster to train, especially when the context size is large. However, when training deep neural networks on hardware accelerators (e.g. GPUs), reductions in FLOPs do not always straightforwardly translate into practical speedups. Initial empirical experiments showed that large language models based on linear transformers train more slowly than classic transformers, and led many to dismiss the approach as nice-in-theory but unhelpful-in-practice [2].
At the moment, the conventional wisdom in the field is that a quadratic-cost algorithm with highly optimized hardware utilization (e.g. FlashAttention) gives the best training throughput [3]. But this is mistaken. In this post, we explain several different ways of implementing linear transformers and we show that, in fact, they can produce massive speed-ups.
Experimental setup. Each timing experiment was run on an H100 with batch size 1 and vocabulary size 50k. We start with an exact JAX replication of GPT2 (numerically tested against nanoGPT), and modify only the self-attention layers. Our implementation also includes modern bells-and-whistles such as mixed-precision training.
The experiment below showcases the central takeaway. We compare the speed of an optimized transformer, a straightforward recurrent implementation of the linear transformer (as described in [1]), and an optimized linear transformer (described in Section 3). We see that, despite exhibiting better growth with respect to context size, the linear transformer trains much more slowly than the transformer in wall-clock time. In contrast, our algorithm is strictly faster than even the highly-optimized FlashAttention baseline, at all context and model sizes. At moderately large context sizes, this speedup has an larger impact than FlashAttention itself.
As we increase the context size, we utilize more and more GPU DRAM, until eventually we encounter an out of memory (OOM) error and cannot run a training step anymore. Different algorithms have different memory requirements (for example, FlashAttention utilizes dramatically less memory than traditional attention), so the various algorithms we tested have different ranges on this plot.
But speed is only one part of the picture. We also need to know whether linear transformers actually minimize the loss as well as standard transformers do. Unfortunately, as the next plot shows, directly converting GPT2 to use linear transformer layers hurts learning significantly.
Experimental setup. These experiments are each run on an 8xH100 node for 24 hours. For all runs, we use flash-attention. We use the chunked algorithm with selecting optimal chunk size (see Section 3) for linear transformer runs. The dataset used was c4 [4], tokenized by tiktoken’s GPT2 encoder. We plot the train loss because the dataset is large enough that all training was done in the single-epoch setting, and so there is no difference between train and heldout loss.
These results are discussed in detail in Section 5, but in short: linear transformers seem to become increasingly unstable as the sequence length grows, negatively affecting learning. This means that our linear transformers are somewhat useless,1 as the positive impact from the speedup seen in long contexts is undermined by the negative impact of degraded learning.
Other variants of linear transformers have been proposed that claim resolve these learning issues [5]–[11], but we do not survey them today. Since our main motivation in this post is to share our insights on how to implement efficient linear transformers we stick with the most natural variant, which is most closely analogous to standard transformers. In a future post, we will explain how to improve the learning of linear transformers, allowing us to efficiently train unprecedentedly-long-context language models with super-fast sampling.
1. Linear Transformers
The inputs to a transformer layer are sequences of \(Q_i, K_i, V_i \in \R^d\) of query, key and values, where \(i\) ranges from \(1\) to the sequence length \(t\). The outputs are a sequence \(Y_i\in \R^d\). The well-known formula for the transformer layer, first popularized by Vaswani et al [12], is: \[ Y_i^\text{Transformer} = \sum_{j=1}^i e^{Q^T_i K_j} V_j \] Here we are excluding the denominator of the softmax for simplicity of notation. Although the normalization provided by the denominator is important for good learning performance, it doesn’t have a major impact on the computational cost or implementation speed, which are our main focus here.
Even though we omit the normalization for the mathematical exposition, it is included in our implementation for all experiments.
The formula for the linear transformer (LT) layer is quite similar: just change the term \(e^{Q^T_i K_j} \to Q^T_i K_j\) yielding \[ Y_i^\text{LinearTransformer} = \sum_{j=1}^i Q^T_i K_j V_j \]
All our experiments with linear transformers also include a normalization factor, which we empirically found to be important for learning performance. Drawing on [1], we divide each \(Y_i\) by \(\sum_{j=1}^i Q^T_i K_j\) after eunsuring the sum is positive by making keys and queries live in the positive quadrant using softplus
.
This layer is “linear” in that the outputs \(Y\) are linearly related to all of \(Q\), \(K\), and \(V\).2 From now on, we will omit the superscript of \(Y_i^\text{LinearTransformer}\) and just write \(Y_i\). To begin our exploration of the computational cost of linear transformers, consider the following implementation.
def LT_attention(Q, K, V):
"""
Shapes of inputs are
Q: [t, d] K: [t, d] V: [t, d]
Shapes of outputs are
Y: [t, d]
"""
= Q.shape
t, d = []
Y_list for i in range(t): # loop cost: O(t^2 d)
= zeros(d)
Y_i = Q[i]
Q_i for j in range(i+1): # loop cost: O(id)
= inner(K[j], Q_i) # cost: O(d)
A_ij += A_ij * V[j] # cost: O(d)
Y_i
Y_list.append(Y_i)return stack(Y_list)
Anyone who has worked with self-attention will recognize that this is a terrible implementation, since nested for-loops are not GPU-friendly. Don’t worry: in the next section, we will discuss implementations that parallelize computation, and thus run much faster on modern hardware. For now, the main point of this pseudocode is to highlight that first mode of computation, which we call attention formulation, has a FLOP cost of \(O(t^2 d)\).
The key to the massive speedups we saw in the introduction comes from leveraging linearity to restructure the computation. Consider the following factorization: \[ Y_i = \sum_{j=1}^i Q^T_i K_j V_j = \underbrace{ \left ( \sum_{j=1}^i V_j K_j^T\right )}_{S_i} \; \; Q_i \] Written in this form, we notice that the term labeled \(S_i \in \R^{d\times d}\) can be thought of as a state summarizing all the relevant information up to time \(i\). It’s easy to rewrite into the following recurrent equations \[ Y_{i} = S_i Q_i \;\;\;\;\;\;\;\;\;\;\;\; S_i = S_{i-1} + V_i K_i^T \] where we assume \(S_{0} = 0\in \R^{d\times d}\). Written in this form, we realize that a linear transformer is an RNN. We can also write pseudocode for this approach, which we call the state formulation, and analyze the cost:
def LT_state(Q, K, V):
"""
Shapes of inputs are
Q: [t, d] K: [t, d] V: [t, d]
Shapes of outputs are
Y: [t, d]
"""
= Q.shape
t, d = zeros(d, d) # shape [d,d]
S_i = []
Y_list for i in range(t): # loop cost: O(t d^2)
+= outer(K[i], V[i]) # cost: O(d^2)
S_i = S_i @ Q[i] # cost: O(d^2)
Y_i
Y_list.append(Y_i)return stack(Y_list)
We see that the cost here is \(O(t d^2)\).
So, while a standard transformer layer always has cost \(O(t^2 d)\), linear transformers have two formulations with different costs. By switching from the attention formulation to the state formulation, we can change the cost from \(O(t^2 d)\) to \(O(t d^2)\), trading a \(t\) term for a \(d\) term.
2. Parallel Implementations
In general, for-loops are a terrible way to implement anything that will run on a GPU. When using high-level frameworks like PyTorch or JAX, the easiest way to get high GPU utilization is to only use primitives that are already highly optimized for GPU: matrix multiplication, elementwise operations, etc. We can rewrite our attention and state algorithms in this style to make them efficient.
First, let’s do this for attention. Our main technique is to compute the attention matrix \(A\), which contains all the terms outer(Q[i], K[j])
that appeared inside the for-loops of LT_attention
, using a single heavyweight matrix multiply.
def LT_attention_parallel_no_flash(Q, K, V):
"""
Shapes of inputs are
Q: [t, d] K: [t, d] V: [t, d]
Shapes of outputs are
Y: [t, d]
"""
= Q.shape[0]
t = causal_mask(t)
M = Q @ K.T # cost O(t^2 d)
A_raw = A_raw * M # cost O(t^2)
A = A @ V # cost O(t^2 d)
Y return Y
This implementation lies at the core of nearly every transformer of the past five years, powering all of the AI models that have recently exploded in popularity. Since it is composed of such a small number of ultra-parallelizable ops, it obtains high GPU utilization. Recently, specialized flash attention kernels [3] have been used to get even further speedups by avoiding explicitly storing the attention matrix \(A\), and thereby saving both memory and time spent on memory-transfers. Algorithmically, though, flash attention is a variant of parallel attention, and has the same computational cost (as measured in FLOPs). We use LT_attention_parallel
to refer to the flash attention implementation.
Next, let’s parallelize the recurrent formulation. This optimization is less well-known. The key is to compute all the terms \(V_i K^T_i\) in parallel, and then use a cumulative-sum, which can be parallelized, to combine them.
def LT_state_parallel(Q, K, V):
"""
Shapes of inputs are
Q: [t, d] K: [t, d] V: [t, d]
Shapes of outputs are
Y: [t, d]
"""
= V[:,:,None] @ K[:,None,:] # cost: O(t d^2)
P = cumsum(P, axis=0) # cost: O(log_2(t) t d^2)
S = S @ Q[:,:,None] # cost: O(t d^2)
Y return Y[:,:0]
The cost in FLOPs of this algorithm is \(O(\log_2(t) t d^2)\).3
Now that we have four mathematically-equivalent implementations of a linear transformer, let’s time the training step of our GPT2 models and see how big of a difference it makes. For our LT_attention_parallel
implementation, we use a custom linear self-attention flash kernel we implemented in Triton [13] based on OpenAI’s FlashAttention2 implementation.
Here are some takeaways:
- As expected, the
attention
variants all have a quadratic asymptotic cost (slope of 2 on a log-log plot4). Thestate
variants all have linear asymptotic cost (slope 1). 5 LT_state_parallel
is an order-of-magnitude faster thanLT_state
.LT_attention_parallel_no_flash
is two orders-of-magnitude faster thanLT_attention
.LT_attention_parallel
seems to asymptotically stabilize into being an order-of-magnitude faster thanLT_attention_parallel_no_flash
.- For the majority of settings,
LT_attention_parallel
is the fastest. (This is the linear version of the algorithm used by the standard transformer.) - Parallel attention is the fastest algorithm for small context sizes. However,
LT_state_parallel
overcomesLT_attention_parallel_no_flash
at around 13k context size, and overcomesLT_attention_parallel
at around 100k.
Overall, these results paint a clear picture: use the attention algorithm for small contexts and the state algorithm for large contexts. But do we really face a binary choice? Can we combine the state and attention ideas and get the best of both words?
3. Chunked Formulation
It’s evident that, for small context sizes, computing the \(t\) by \(t\) attention matrix is much more efficient than computing many \(d\) by \(d\) state matrices. But as \(t\) grows, there is a point where the quadratic cost of the attention matrix ends up dominating. Noticing that attention is extremely efficient for small \(t\) and that states are necessary for large \(t\) motivates doing one last reworking of the LT equation.
Let \(c \in \N\) be a positive integer that we’ll call the chunk size. For any \(i\in \N\) find the unique \(n\in \Z\) s.t. \(cn < i \le c(n+1)\). We can easily see that the following equations are equivalent to the previous ones. \[ Y_{i} = S_{cn}Q_i + \sum_{j=cn+1}^i Q_i^T K_j V_j \;\;\;\;\;\;\;\;\;\;\;\; S_{c(n+1)} = S_{cn} + \sum_{j=cn+1}^{c(n+1)} V_j K_j^T \] The key idea is that we are only going to compute a subset of all states: \(S_0, S_c, S_{2c}, \cdots\). Then, to compute each output \(Y_i\), we need only to take into account the contribution via the most recent state \(S_{cn}\), as well as the contribution (computed via attention) of all moments in time \(j\) in the range \(cn < j \le i\).
As pseudocode, this looks like:
def LT_attention_with_initial_state(S, Q, K, V):
"""
Shapes of inputs are
S: [d, d] Q: [c, d] K: [c, d] V: [c, d]
Shapes of outputs are
Y: [c, d]
"""
= Q @ S # cost O(c d^2)
Y_state = LT_attention_parallel(Q, K, V) # cost O(c^2 d)
Y_attention = Y_state + Y_attention # cost O(cd)
Y return Y
def LT_chunked(Q, K, V, c):
"""
Shapes of inputs are
Q: [t, d] K: [t, d] V: [t, d], c: int
Shapes of outputs are
Y: [t, d]
"""
= Q.shape
t, d assert t % c == 0
= [arr.reshape(t//c, c, d)
Q_, K_, V_ for arr in [Q,K,V]]
` = K_.transpose([0,2,1]) @ V_ # cost O(t d^2)
P_ = cumsum(P_, axis=0) - P_ # cost O(log_2(t/c)(t/c)d^2)
S_ = vmap(LT_attention_with_initial_state, axis=0)(
Y_ # cost O(td^2 + tcd)
S_, Q_, K_, V_) return Y_.reshape(t, d)
The cost is \(O\left(td^2 + tcd + \log_2(t/c)(t/c)d^2\right)\), once again avoiding a quadratic dependency on \(t\). Also, note that this algorithm makes an inner call to LT_attention_parallel
, so we can use a flash-attention kernel to do that part of the computation.
This algorithm has a hyperparameter, the chunk size, which must be set correctly for optimal performance. Fortunately, it is inexpensive to identify the best chunk size empirically, since measuring just a few steps of training at each chunk size is sufficient to identify which is the fastest. In the plot below, we plot the speed of the optimally-chunked algorithm in each setting.
We see LT_chunked
gives the desired best-of-both-worlds behavior, as it is equal-or-better than all other approaches in all settings. And we now see the massive (& rapidly growing) speedup relative to standard self-attention, finally unlocking the true power of linear transformers.
4. Sampling
When working with language models, efficient training is not the only performance that deserves consideration. Once the model is trained, how expensive is it to utilize? When we are sampling we have a sequence of tokens, \(z_1 \cdots z_t\), and we want to sample the next token, \(z_{t+1}\).
The most efficient algorithm to sample from traditional transformers is called the KV-cache algorithm [14]. This algorithm assumes that when we generate token \(z_{t+1}\), we will have already computed and cached all the \(K_i, V_i\) for all \(0 \le i \le t\). In order to compute the output of the attention layer at time \(t+1\) given this cached information, we can use \[ Y_{t+1}^\text{Transformer} = \sum_{j=1}^{t+1} e^{Q^T_i K_j} V_j \] It is easy to see that this is an \(O(td)\) operation. In other words, as the sequence length grows, sampling each subsequent token becomes more computationally expensive.6 This is one of the major limitations of the classic transformer architecture.
With linear transformers, however, we have access to the recurrent formulation. A linear transformer is an RNN where the state has size \(O(d^2)\). \[ Y_{i} = S_i Q_i \;\;\;\;\;\;\;\;\;\;\;\; S_i = S_{i-1} + V_i K_i^T \] We can compare the time it takes to generate any particular token when sampling a sequence:
As expected, we see that the time to sample of the linear transformer is independent of context length, whereas that of the KV-cache transformer grows linearly. This leads to a large gap in inference speed at large contexts.7
5. Learning Performance
Until now, our focus has been on how to implement linear transformers efficiently. But an architecture that runs really fast is useless unless it also learns well. We will now compare the learning curves of the GPT2 baselines with their linear transformer counterparts, at various context sizes.
In order to control for the fact that longer contexts help learning by introducing more tokens per update, we hold the number of tokens-per-update constant across context sizes by decreasing batch size. At all context-lengths, each update sees \(2^{19}\) tokens.8 Importantly, for this set of experiments, we have used the dataset c4 [4], which does not have much long-term structure: it consists of a shuffled collection of short articles, of average length ~500 tokens.
First, let’s look at the parameter scaling law of GPT2 and our modified Linear-GPT2. The following plot shows the final performance after 24h of training on our 8xH100 server for each scale and each context size. Click the second tab if you want to see the full learning curves.
Both architectures scale similarly with respect to model size, but there is a gap in performance. This gap seems to grow as context size increases, together with more loss spikes and general instability. It’s possible that the gap can be explained by the linear transformer not effectively utilizing its context. To explore whether or not this is the case, we can use these same experiments, but visualize all context-lengths together.
We see a dramatically different trend between the two architectures. For GPT2, each increase in context length slows down the initial pace of learning, but ultimately all context-lengths (beyond at least 1024) converge to a similar final loss. Whereas for Linear-GPT2, not only does increasing the context slow down learning in the beginning, but it also causes convergence to a worse final performance and nasty-looking loss spikes.
The results for GPT2 are what we would expect from a healthy model, given the setting. Note that since our dataset consists of short articles, of average length ~500 tokens, we can assume that even a context window of 1024 would include nearly everything relevant. Thus, we wouldn’t expect that increasing the context length beyond this point would decrease the loss the model ultimately converges to. Longer-context models do however need to learn to ignore many more irrelevant tokens, explaining the slowed initial learning.9
In contrast, the results for Linear-GPT2 seem to indicate some underlying instability of the linear transformer architecture. It doesn’t even seem capable of ignoring irrelevant information, let alone exploiting useful long-term structure.
Remedying these learning issues will be the main focus of our future work in linear transformers. Our current architecture is essentially a one-to-one clone of GPT2, sticking as architecturally close to the original as possible; aspects such as initializations, normalizations and hyperparameters were directly copied. It is well-known that these decisions can have a large impact on scaling and stability and often need to be tuned in an architecture-specific way. In the literature, various other linear variants of self-attention have been proposed, that incorporate techniques such as gating mechanisms in order to improve stability and learning [5]–[11]. A future post will include a thorough study of the impact of all of these choices.
Ultimately, it may be impossible for any linear transformer to perfectly match the context scaling laws of the classic transformer baseline. Although removing the exponential seems like a minor change, it represents a meaningful decrease in expressivity.10 But as we’ve shown, linear transformers can be trained orders-of-magnitude more efficiently. On equivalent compute, linear transformers can be trained for many more steps; or, keeping steps constant as well, we can train a much larger model. If this additional scaling is sufficient to compensate for the reduced expressivity, linear transformers will be the more efficient architecture overall.
Acknowledgments
We would like to thank: Jono Ridgway for helping to prepare the release; Eric Alcaide for introducing us to the associative scan algorithm; Jannis Fengler for working on the tooling to confirm our JAX GPT2 numerically replicates NanoGPT; Joel Einbinder and Aaron Mondal for their assistance in setting up our engineering stack; Warfa Jibril, Tony Pezzullo, and Desh Raj for feedback on a draft of the blog post; Sander Dieleman for pointing us towards some early linear transformer literature. …
References
Footnotes
As discussed in Section 4, a second benefit of linear transformers is that the cost to sample a token does not grow with context size. Perhaps one could argue that this improvement in sampling speed could, on its own, justify using linear transformers for applications where the inference costs vastly exceed training costs. But it is evident to us that, for linear transformers to become actually useful, we need to address these instability issues.↩︎
It is not named for the fact that the computational cost is linear with respect to \(t\)! That is just a coincidence. (And is not even always true, as we will see.)↩︎
Interestingly,
LT_state_parallel
is actually more expensive thanLT_state
. (This is in contrast with the attention formulation, whereLT_attention
andLT_attention_parallel
share the same \(O(t^2 d)\) cost.) As we will see in a moment, this extra \(\log_2(t)\) factor is well worth the parallelization benefits.↩︎If \(y=x^2\), a log-log plot where \(y'=\log_a(y)\) and \(x'=\log_a(x)\) for any base \(a\), then \(y'=\log_a(y) = \log_a(x^2) = 2 \log_a(x) = 2 x'\). So the graph will be a line with slope 2.↩︎
The reason we see the expected slopes asymptotically is that we are timing a full GPT2 architecture which has many other components besideds the attention layer. If we were only timing the attention layer, the plots would all be straight lines.↩︎
An interesting connection is that the KV-cache can be understood as the state of an RNN with non-constant state size; namely, one whose state-size is \(O(td)\).↩︎
This comparison may not be completely fair. In these experiments, our implementation of neither sampling algorithm makes use of specialized kernels. A lot of the ideas of flash attention can be used to write a much faster KV cache sampling algorithm; on the other hand, it’s unclear if much improvement is possible on the recurrent sampling. Thus, it’s possible that with engineering effort the gap between the two algorithms could become smaller. However, the overall pattern will certainly remain the same.↩︎
e.g. runs with context-size 1024 would have batch-size of \(2^{19} / 2^{10} = 2^{9} = 512\).↩︎
Put another way: doubling the size of the input vastly increases the size of the function space over which gradient descent must search, and it’s intuitive that in a larger space it takes somewhat longer to find a good solution.↩︎
We plan to elaborate on this topic in a future blog post.↩︎
Citation
@misc{buckman2024,
author = {Buckman, Jacob and Gelada, Carles},
publisher = {Manifest AI},
title = {Linear {Transformers} {Are} {Faster} {After} {All}},
date = {2024-01-05},
langid = {en}
}