Symmetric Power Transformers

Authors

Jacob Buckman

Carles Gelada

Sean Zhang

Published

August 15, 2024

Linear transformers [1] can be formulated as linear-cost RNNs, which have better theoretical context scaling than ordinary transformers. In our previous article [2], we presented an efficient chunked algorithm that turns this theoretical advantage into practical speedups when the context is long: 10x faster training for a 64k-token context. Unfortunately, we also found that vanilla linear transformers suffer from degraded performance, especially at long contexts, rendering any benefits from the speedup useless. This article advances our previous discussion by introducing a linear transformer variant that solves the degraded performance issue while still enabling an efficient linear-cost implementation.

Behind all the algorithms explored in this post there is a central idea: for RNNs, thinking about the size of the model purely in terms of the number of parameters misses something important. An RNN encodes all the information from the past inputs \(X_1,...,X_{t}\) into a finite-dimensional vector \(S_t\) called the state. If the states are too small, the model will struggle to store all the information it will later require. Could this be the cause of the poor performance of the GPT-2-style linear transformers we evaluated in our previous article? If we look at the state sizes, we notice they are many orders of magnitude smaller than the weights of the architecture:

A generic RNN equation would compute next states via \(S_t = g(S_{t-1}, X_t)\) and produce outputs via \(Y_t = f(S_t, X_t)\).

Weights State Size State-Weight Ratio
Small 124M 589K 0.0048
Medium 335M 1.5M 0.0045
Large 774M 2.9M 0.0037
XL 1.6B 4.9M 0.0031

The formula for the state size of a vanilla linear transformer is layer_n * head_count * key_size * value_size.

Fortunately, since the architecture is a linear transformer, this imbalance has a straightforward remedy. The size of the state of a linear transformer can be controlled by simply embedding the keys and queries in a higher-dimensional space. (The larger the space in which we embed the key, the larger the state becomes). Previous work [3][10] has already observed that this improves the performance of linear transformers, but the resulting architectures are still not competitive with standard transformers.

In this article we introduce symmetric power transformers, which are a variant of linear transformers with an embedding function based on the theory of symmetric tensors. They have a hyperparameter \(p\) that controls the state size. For \(p=4\) and above, they outperform the transformer baseline, and at \(p=4\) and below, they have a state size small enough to fit on a modern GPU. A second advantage is that, unlike other variants of linear transformers, one can combine symmetric power transformers with the commonly used rotary embeddings [11]. Below, we see a performance comparison (full experimental details are given in Section 2).

In this article, our experiments used the attention formulation of linear transformers (with \(O(t^2)\) cost) rather than the more efficient chunked formulation (with \(O(t)\) cost). This allowed us to validate the learning ability of the architecture without writing custom CUDA kernels (which an efficient implementation of chunked symmetric power transformers requires). We will release the efficient chunked implementation in an upcoming article.

1. Linear Transformers with Embeddings

We begin with a review of linear transformers. The inputs to a transformer layer are sequences of \(Q_i, K_i, V_i \in \R^d\) of queries, keys, and values, where \(i\) ranges from \(1\) to the sequence length \(t\). The outputs are a sequence \(Y_i\in \R^d\). The formula for the output vectors is: \[ Y_i = \sum_{j=1}^i A_{ij} V_j \qquad A_{ij} = \frac{ \phi(Q_i)^T \phi(K_j)}{\sum_{k=1}^i \phi(Q_i)^T \phi(K_k) } \] where \(\phi : \R^d \to \R^D\) is an embedding function that maps keys or queries into vectors of dimension \(D\). This formulation is what we call the attention formulation of a linear transformer, because it involves explicitly computing the attention scores used to weight the values \(V_j\).

This linear transformer is the same architecture as in our previous article, but here we present the complete formula in its full generality, including both the normalizing term and embedding function (previously suppressed for clarity).

The exact same outputs can be computed via a recurrent formulation: \[ Y_{i} = \frac{S_i \phi(Q_i)}{Z_i \phi(Q_i)} \qquad Z_i = Z_{i-1} + \phi(K_i)^T \qquad S_i = S_{i-1} + V_i \phi(K_i)^T \] where \(Z_0\) and \(S_0\) are \(\mathcal 0\) vectors in their respective spaces. Since \(S_i \in \R^{d \times D}\) and \(Z_i \in \R^{D}\), the size of the state is \(D(d+1)\).

Note that \(D(d+1)\) gives the size of the state for any one linear transformer head. To compute the state size for an entire multi-head architecture, one must multiply by the number of layers and number of heads per layer.

Expand for a derivation of the recurrent formulation. If we expand the recurrent the definitions of \(S_i\) and \(Z_i\) we get that \[ Z_i = \sum_{j=1}^i \phi(K_j)^T \qquad S_i = \sum_{j=1}^i V_j \phi(K_j)^T \] Then, starting with the attention formulation \[ \begin{aligned} Y_i &= \sum_{j=1}^i \frac{ \phi(Q_i)^T \phi(K_j)}{\sum_{k=1}^i \phi(Q_i)^T \phi(K_k) } V_j \\ &= \sum_{j=1}^i V_j \frac{ \phi(K_j)^T \phi(Q_i)}{\sum_{k=1}^i \phi(K_k)^T \phi(Q_i) } \\ &= \frac{ \left( \sum_{j=1}^i V_j \phi(K_j)^T \right) \phi(Q_i)}{ \left(\sum_{m=1}^i \phi(K_m)^T \right) \phi(Q_i) } \\ &= \frac{S_i\phi(Q_i)}{Z_i\phi(Q_i)} \\ \end{aligned} \]

These two forms give rise to a variety of algorithms for training linear transformers, with differing computational properties. Read our earlier article on linear transformers for a detailed explanation. In particular, there are two algorithms that are relevant to our present discussion: parallel attention and chunked. The parallel attention algorithm is the standard algorithm used to train transformers. It does not require using the state, but its cost is \(O(t^2d)\), where \(t\) is the sequence length. The chunked algorithm can be used to train only linear transformers. It has a cost of \(O(tdD)\), and it requires materializing the state. The chunked algorithm is primarily what makes linear transformers interesting, because it is what allows them to be trained much more efficiently than softmax transformers when \(t\) is large.

Materializing an object refers to storing it in the GPU main memory. Sometimes, a mathematical object is necessary to perform a computation, and yet one can avoid having to store the whole thing in RAM. The most prominent example of this is Flash Attention [12], which avoids having to materialize the [t,t] attention matrix.

With these computational considerations in mind, how should we choose \(\phi\)? Here are some attributes that we want:

  • Adjustable dimensionality. To balance the size of the state with the size of the weights, there should be some hyperparameter controlling the dimension \(D\).
  • Efficient dot product. In certain parts of the algorithm, \(\phi(Q_i)\) and \(\phi(K_j)\) appear as intermediate steps in the computation of \(\phi(Q_i)^T\phi(K_j)\). For some choices of \(\phi\), there is a more efficient formula for \(\phi(Q_i)^T\phi(K_j)\) that does not require computing these intermediate objects.1
  • Positive dot product. We want \(\phi(Q_i)^T \phi(K_j)\) to always be positive. This ensures that each normalized output \(Y_i\) is a convex combination of all preceding values \(V_1, \cdots, V_i\). We found this to be essential for stable and performant learning.
  • Compatible with rotary positional encoding. Rotary encodings take \(Q_i \to R Q_i\) and \(K_j \to R K_j\). To preserve their translational symmetry, we want \(\phi(R Q_i)^T \phi (R K_j) = \phi(Q_i)^T \phi (K_j)\).

Subject to these four constraints, we are searching for the embedding function with the best empirical performance at each state size. Since our ultimate goal is to replace transformers (trained with attention) with linear transformers (trained with the chunked algorithm), the bar for success is simple: an embedding whose performance matches that of a transformer baseline at a state size small enough to be tractable. We’ll use 80 GB as our limit for the state size because that is the entire memory capacity of A100 and H100 GPUs.2

Many possible embedding functions have already been investigated in the literature [3][6], but we are not aware of one that satisfies all of our requirements. In the next section, we’ll describe a variant of attention whose empirical performance is competitive with softmax attention. In the sections that follow, we will show that this can be implemented via a linear transformer that satisfies our desiderata.

2. Experiments

At first, it might not be evident that the following attention layer corresponds to a linear transformer, but it should be clear that it is a reasonable modification to the standard softmax transformer. Let \(p\in\N\) be even, \[ Y_i = \sum_{j=1}^i A_{ij} V_j \qquad A_{ij} = \frac{ (Q_i^T K_j)^p}{\sum_{k=1}^i (Q_i^T K_k)^p } \] A classic softmax transformer would guarantee that all the attention \(A_{i1}, \cdots, A_{ii}\) are positive and sum to one by applying the exponential \(e^{Q_i^T K_j}\). Instead, this variant does \((Q_i^T K_j)^p\) to achieve the same result. Raising each inner product to an even power makes the term positive, and dividing by the sum ensures each row of the attention is a distribution. We call this variant even-power attention, it has been studied before in the literature [5].

Like softmax attention, even-power attention is compatible with rotary embeddings. A key motivation for using rotary embeddings is that they are relative, meaning that only the difference in time \(i-j\) influences the attention score \(A_{ij}\). See [11] for the full details, but in short, the relative property of rotary embeddings is guaranteed when the attention scores \(A_{ij}\) are unchanged by the rotation of \(Q_i\) and \(K_j\) by the same rotation matrix \(R \in \R^{d \times d}\). This holds for even-power attention: \[ \left( (R Q_i)^T R K_j \right)^p = \left( Q_i^T R^T R K_j \right)^p = \left( Q_i^T K_j \right)^p \] Many other variants of linear transformers do not have this property [1].

Here is a simple JAX implementation of even-power attention:

def even_power_attention(Q, K, V, p):
    # even only
    assert p % 2 == 0
    # compute inner products
    C = Q @ K.T
    # raise to power
    B = D**p
    # apply causal mask
    B = where(tril(ones(B.shape)), B, 0)
    # project to simplex
    A = B / B.sum(-1, keepdims=True)
    # compute output
    Y = A @ V
    return Y

This implementation turns out to be more pedagogical than practical, because numerical stability is an important empirical consideration. Expand below for an implementation that addresses these issues.

Numerically-stable implementation of power attention.

Numerical instabilities come from numbers underflowing (too small) or overflowing (too large). Solutions typically fall into a few main categories:

  • Make sure a number is not too small or too large, e.g. turn log(x) into log(x + ε). This prevents overflow.
  • Accumulate in fp32. When accumulating a long list of small values in half-precision, it is sometimes the case that each addition will underflow and no accumulation will occur at all.
  • Manipulate an equation to cancel out some common factors algebraically, rather than letting them cancel out computationally. For example, to calculate \(\frac{x}{y}\) where \(x = Am\) and \(y = An\) for some large \(A\), compute m / n instead of x / y.
  • Separate magnitude and sign, and work with magnitude in log-space, i.e. manipulate sign(x) and log(abs(x)) instead of working with x directly. Convert back to linear-space with sign(x) * exp(f(log(abs(x)))), where f has internally completed the relevant cancellations, so f(log(abs(x))) is small enough to avoid overflow.

With these techniques in mind, we can implement a numerically-stable version of the attention algorithm. This is the code used to generate the experimental results in this article.

def even_power_attn(Q, K, V, p, ε):
     # even only
     assert p % 2 == 0

    # compute inner products
    D = Q @ K.T

    # raise to power, in log space for numerical stability
    log_C = p * log(abs(D) + ε)
    # apply causal mask
    log_C = where(tril(ones(log_C.shape)), log_C, -inf)
    # subtract rowmax for numerical stability
    log_C -= log_C.max(axis=-1, keepdims=True)
    
    # Return to linear space
    B = exp(log_C)
    # Compute the normalizing term, accumulating in float32 for numerical stability
    denom = B.sum(-1, keepdims=True, dtype=float32).astype(B.dtype)
    
    # project to simplex, adding ε for numerical stability
    A = B / (denom + ε)
    # compute output
    Y = A @ V
    return Y

Now we are ready to train some models. For this experiment, we used the LongCrawl64 dataset [13], a context length of 4096, a batch size of 524288 tokens. The architecture was similar to the 124M-parameter GPT-2 architecture, but with rotational positional encoding and an additional layernorm after input embeddings. The optimization was conducted in bf16 mixed-precision using Adam with learning rate .0006 and no scheduling. Each model was trained on a node of 8 H100s.

Performance seems to improve consistently as we increase \(p\). For large enough \(p\), the performance of even-power transformers matches or even surpasses that of the softmax transformer baseline. This architecture is looking promising!

In Section 3, we will show how to use the tensor product to implement even-power attention as a linear transformer, albeit one with a state size so large as to be impractical. In Section 4, we will see that the tensor product embedding is highly symmetric and contains a lot of redundant information. We will exploit this structure to construct an embedding function for even-power attention with tractable state size.

3. Tensor Product

In this section, we show how the tensor product, a deep and ubiquitous mathematical idea, can be used to construct embeddings for the even-power transformer.

3.1. Mathematical Background

The tensor product of vectors generalizes the outer product and formalizes the concepts of multi-dimensional arrays. Given two vectors \(v\in \R^{d_1}\) and \(w\in \R^{d_2}\) one can think of their tensor product \(v\otimes w\) as the matrix \(v w^T \in \R^{d_1\times d_2}\), \[ v w^T = \left[ \begin{array}{cccc} v_1w_1 & v_1w_2 & \cdots & v_1w_m \\ v_2w_1 & v_2w_2 & \cdots & v_2w_m \\ \vdots & \vdots & \ddots & \vdots \\ v_nw_1 & v_nw_2 & \cdots & v_nw_m \\ \end{array} \right] \] Intuitively, \(v \otimes w \otimes u\) would be a 3-dimensional table containing all possible entries of the sort \(v_i w_j u_k\). But let’s make the intuition of multi-dimensional tables more rigorous.

Multi-indices. A multi-index \(\alpha\) specifies a location in a \(p\) dimensional table with dimension sizes \(d_1, \cdots, d_p \in \N\). Let \(\N_d\) denote the set \(\{1,2,\cdots,d\}\). Then, the space of multi-indices is \(\N_{d_1} \times \cdots \times \N_{d_p}\) and we refer to a generic multi-index as \(\alpha = [\alpha_1, \cdots, \alpha_p] \in \N_{d_1} \times \cdots \times \N_{d_p}\).

Tensors. A tensor \(T\in \R^{d_1 \times \cdots d_p}\) corresponds to a high-dimensional table where every location has a numerical value assigned to it. In other words, it is a map from multi-indices to the reals. \[ T: \N_{d_1} \times \cdots \times \N_{d_p} \to \R \] By convention, we index tensors with subscript notation \(T_\alpha\) instead of functional notation \(T(\alpha)\), but the meaning is the same.

Tensor product of vectors. Given a list of \(p\) vectors \(v_i \in \R^{n_i}\), we denote by \(v_1 \otimes \cdots \otimes v_p\) (or alternatively \(\bigotimes_{i=1}^p v_i\)) as the tensor in \(\R^{n_1 \times \cdots \times n_p}\) with entries given by the following formula: \[ \left[\bigotimes_{i=1}^p v_i\right]_\alpha = \prod_{i=1}^p v_{i, \alpha_i} \] Where \(v_{i,j}\in \R\) denotes the \(j\)th entry of the \(i\)th vector.

Flattening. To build embedding functions \(\phi\), we are going to use the tensor product to embed lists of vectors into \(\R^{d_1\times \cdots d_p}\) tensors. But once we’ve done that, we will no longer care about the tensor structure and we will prefer to think of them as vectors in \(\R^D\), where \(D=\prod_{i=1}^p d_i\). The map \(\text{flat}: \R^{d_1\times \cdots d_p} \to \R^D\) implements this transformation by writing every entry of the array into a flat vector. This can be done with any bijective function \(\sigma: D \to \N_{d_1} \times \cdots \times \N_{d_p}\) which effectively imposes an (arbitrary) ordering on the multi-indices.3 The flattening is defined as: \[ \flat{T}_i = T_{\sigma(i)} \] The dot product of flattened tensors satisfies the following property: \[ \flat{\bigotimes_{i=1}^p v_i}^T \flat{\bigotimes_{i=1}^p w_i} = \prod_{i=1}^p v_i^T w_i \qquad \text{(Result 1)} \]

Expand to see a proof. We can just check the both sides of the equation match. First, \[\begin{align} \flat{\bigotimes_{i=1}^p v_i}^T \flat{\bigotimes_{i=1}^p w_i} &= \sum_{l=1}^D \left[ \bigotimes_{i=1}^p v_i \right]_{\sigma(l)} \left [ \bigotimes_{i=1}^p w_i \right]_{\sigma(l)} \\ &= \sum_{l=1}^D \prod_{i=1}^p v_{i, \sigma(l)_i} w_{i, \sigma(l)_i} \\ &= \sum_{j_1=1}^{d_1} \cdots \sum_{j_p=1}^{d_p} \prod_{i=1}^p v_{i, j_i} w_{i, j_i} \\ \end{align}\] Where we used the assumption that \(\sigma(l)\) ranges over every possible combination of \([j_1, \cdots, j_p]\). On the other hand, \[\begin{align} \prod_{i=1}^p v_i^T w_i &= \prod_{i=1}^p \sum_{j=1}^{d_i} v_{i, j} w_{i, j} \\ &= \sum_{j_1=1}^{d_1} \cdots \sum_{j_p=1}^{d_p} \prod_{i=1}^p v_{i, j_i} w_{i, j_i} \end{align}\] Where the last step used a generalization of the distributive property: \(\prod_{i=1}^p \sum_{j=1}^{d_i} v_{i, j} = \sum_{j_1=1}^{d_1} \cdots \sum_{j_p=1}^{d_p} \prod_{i=1}^p v_{i, j_i}\)

3.2. Implementation

Armed with the tensor product, we are ready to define an embedding \(\phi^p_{\text{TP}}\), and in doing so define a linear transformer architecture. The definition is simple: embed a key by taking its tensor product with itself, \(p\) times.

In this section, we focus on the effect of \(\phi\) on keys \(k\), but wlog all discussion applies equally to queries.

\[ \phi^p_{\text{TP}}(k) = \text{flat}\left(\bigotimes_{i=1}^p k\right) \in \mathbb{R}^{d^p} \]

If this embedding is used with even \(p\), the resulting architecture is an even-power transformer:

\[ \phi^p_{\text{TP}}(q)^T \phi^p_{\text{TP}}(k) = \text{flat}\left(\bigotimes_{i=1}^p q\right)^T \text{flat}\left(\bigotimes_{i=1}^p k\right) = \prod_{i=1}^p q^T k = (q^T k)^p \]

where in the second step we used Result 1.

The implementation is straightforward.

def tensor_power_embedding(k, p):
    expanded_k = k
    for _ in range(p-1):
        expanded_k = expanded_k[...,None] @ k[None,:]
    return expanded_k.flatten()
def even_power(q, k, p):
  return np.inner(q, k) ** p

def tensor_power_inner_product(q, k, p):
  embedded_q = tensor_power_embedding(q, p)
  expanded_k = tensor_power_embedding(k, p)
  return (embedded_q * expanded_k).sum()

d = 8
for p in [2, 4, 6, 8]:
  q = np.random.random(d)
  k = np.random.random(d)
  assert np.allclose(
    even_power(q, k, p), 
    tensor_power_inner_product(q, k, p)
  )

Embedding in hand, we can return to our main objective. Have we found an linear transformer whose performance is competitive with that of a strong transformer baseline, while, at the same time, having a state size small enough to fit on a GPU?

The table below shows the size of a single state, as measured in bytes (assuming fp16/bf16 precision), for a 124M-parameter GPT-2 tensor power transformer at various \(p\).

The formula for the state size of a linear transformer with the tensor power embedding is layer_n * head_count * key_size**p * value_size.

p State Size Memory ≤ 80 GB? Relative Loss at 100K Steps Loss ≤ baseline?
2 77 MB 1.03x
4 314 GB .98x
6 1.3 PB .97x

The settings of \(p\) that improve upon the baseline have states that are far too large. So the embedding \(\phi^p_{\text{TP}}\) still does not satisfy all the properties we are after. But we are close.

4. Symmetric Power Transformers

The missing piece is to realize the huge embeddings we’ve been working with are highly symmetric. The theory of symmetric powers will help us compress the same information into much smaller objects. We will begin with an introduction to the relevant mathematical ideas. Then, we will put them to use by proposing symmetric power transformers, whose embedding is \(\phi^p_{\text{SYM}}\).

To build intuition, observe that the embedding \(\phi^2_{\text{TP}}(v)= \flat {v v^T}\) is somewhat wasteful. The matrix \(v v^T\) is symmetric, so all the information we need can be found in the upper triangular part of the matrix. \[ v v^T = \left[ \begin{array}{cccc} v_1v_1 & v_1v_2 & \cdots & v_1v_m \\ v_2v_1 & v_2v_2 & \cdots & v_2v_m \\ \vdots & \vdots & \ddots & \vdots \\ v_nv_1 & v_nv_2 & \cdots & v_nv_m \\ \end{array} \right] \] Entries at indices \((i,i)\) appear a single time, but due to the commutativity of scalar multiplication (i.e. \(v_i v_j = v_j v_i\)), the entries at indices \((i,j)\) each appear twice (if \(i\neq j\)).

Noticing this symmetry in the matrix \(v v^T\) allows us to create an alternative embedding, \(\phi^2_\text{SYM}: \R^d \to \R^{\frac{d^2 +d} 2}\), which can be implemented as:

def sym_2_embedding(v):
  x, d = [], v.size
  for i in range(d):
    for j in range(i, d):
      count = 1 if i==j else 2
      x.append(sqrt(count) * v[i] * v[j])
  return x

This construction of \(\phi^2_\text{SYM}\) guarantees that \(\phi^2_\text{SYM}(v)^T \phi^2_\text{SYM}(w) = \phi^2_\text {TP} (v)^T \phi^2_\text{TP}(w)= (v^T w)^2\). Recall that in the attention formulation of the linear transformer the embedding \(\phi\) only influences the outputs via the attention scores, which were defined as \[ A_{ij} = \frac{ \phi(Q_i)^T \phi(K_j)}{\sum_{k=1}^i \phi(Q_i)^T \phi(K_k) } \] Then two linear transformers with embeddings \(\phi^2_\text{TP}(v)\) and \(\phi^2_\text{SYM}(v)\) will have exactly the same outputs, since they have the same inner products \(\phi(Q_i)^T \phi(K_j)\) (namely, \((Q_i^T K_j)^p\)). We’ve been able to exploit the symmetry of \(v v^T\) to construct an equivalent embedding function with approximately half the dimensionality!

In this section, we will generalize this idea to arbitrary powers.

4.1. Mathematical Background

We begin by introducing some key tools.

Permutation group. The first thing we need is the permutation group of \(p\) elements, which is defined as the set of all functions \(\rho: \N_p \to \N_p\) that are invertible and denoted by \(G_p\). We also overload the notation slightly. For a multi-index \(\alpha = [\alpha_1, \cdots, \alpha_p]\) define the permutation of the multi-index as \(\rho(\alpha) = [\alpha_{\rho(1)}, \cdots, \alpha_{\rho(p)}]\). This is useful to define symmetric tensors.

Symmetric tensors. A tensor \(T \in \R^{\underbrace{d\times \cdots \times d}_p}\) is symmetric if for all multi-indices \(\alpha \in \N_d \times \cdots \times \N_d\) and permutations \(\rho \in G_p\) we have that: \[ T_\alpha = T_{\rho(\alpha)} \]

Symmetric power of vectors. We use the notation \(v^{\otimes p}\) to refer to \(\otimes^p_{i=1} v\) (the tensor product of \(p\) copies of \(v\)), and call it the \(p\)th symmetric power of \(v\). Due to the commutativity of multiplication, all symmetric powers of vectors are symmetric tensors. For example, for a multi-index \([1, 2, 3]\), the entrie \([v^{\otimes 3}]_{[1, 2, 3]} = v_1 v_2 v_3\) will equal \([v^{\otimes 3}]_{[3, 2, 1]} = v_3 v_2 v_1\). Showing that a general tensor \(v^{\otimes p}\) is symmetric is simple: \[ \left[ v^{\otimes p} \right ]_{\rho(\alpha)} = \prod_{i=1}^p v_{\alpha_{\rho(i)}} = \prod_{i=1}^p v_{\alpha_i} = \left[ v^{\otimes p} \right ]_{\alpha} \]

To construct embeddings that exploit the symmetries of \(T=v^{\otimes p}\) we will need some key properties about symmetric tensors:

  • Duplication counts: If \(T\) is symmetric, the entry \(T_\alpha\) might have duplicate entries. To know how many duplicates a multi-index \(\alpha\) has, we first need to count how many times each number \(i\in \{1, \cdots, d\}\) occurs in \(\alpha\). Define the counts \(c_i = \sum_{j=1}^p \delta(\alpha_j, i)\). Then, the number of multi-indices containing the same data as \(\alpha\) is given by the formula \(\frac{d!}{c_1 ! \; \cdots \; c_p!}\).
  • Unique multi-indices: No data is lost if we restrict ourselves to only looking at entries \(T_\alpha\) for multi-indices \(\alpha\) that are non-decreasing (i.e. \(\alpha_i \le \alpha_{i+1}\)). The intuition is that an arbitrary multi-index \(\beta\) can always be transformed into a non-decreasing multi-index \(\alpha\) by applying some permutation \(\rho\). Using the defining property of symmetric tensors, \(T_\beta = T_{\rho(\beta)} = T_\alpha\). Thus, we lose no information by excluding every multi-index that isn’t non-decreasing.
  • Dimension: The space of symmetric tensors has dimension \(\binom{d+p-1}{p}\). This can be derived via a classic combinatorial argument counting the number of non-decreasing sequences.
Expand to see a complete derivation of these properties and a few other relevant facts about symmetric tensors.

Duplicate counts. By definition, the only constraint a symmetric tensor has, is that all the entries \(T_{\rho(\alpha)}\) must be the same for all permutations \(\rho \in G_p\). Now we want to understand the amount of duplication that that any specific \(\alpha\) has. Since the number of permutations of the multi-indices is \(|G_p| = p!\), a naive estimate would be that every entrie \(T_\alpha\) appears \(p!\) times in the tensor. And indeed, that is the case for some multi-indices. For example, every permutation \(\rho \in G_3\) sends the multi-index \([1,4,6]\) to a different multi-index, so there are \(3!\) entries with the same value. But, on the other hand, for the multi-index \([1,1,1]\) it doesn’t matter what permutation \(\rho\) we apply, we always have that \(\rho \alpha = \alpha\). So the entrie at \([1,1,1]\) has no duplicates.

To count the number of duplicates for a generic mulit-index \(\alpha\) we are going to use the orbit stabilizer theorem. This theorem tells us that the number of elements in the set \(\orbit(\alpha) = \{ \rho(\alpha) \; | \; \rho \in G_p \}\) is given by the formula: \[ |\orbit(\alpha) | =\frac {|G_p|} {|\stab(\alpha)|} \] where the stabilizer \(\stab(\alpha) = \{ \rho \in G_p \; | \; \rho(\alpha) = \alpha \}\) is the set of permutations that fix \(\alpha\). Working out the size of the stabilizer is not hard. For a permutation \(\rho \in G_p\) to leave \(\alpha\) fixed it must satisfy that \(\alpha_{\rho(i)} = \alpha_i\) for all \(i\in \N_p\). In other words, \(\rho\) must only interchange entreis of \(\alpha\) that hold the same index. Say \(\alpha = [1,1,2]\), then we can only exchange the fist and second element. Generically, if index \(i\in\N_d\) appears \(c_i\) times in \(\alpha\), then there are \(c_i!\) permutations that move around entries of \(\alpha\) with value \(i\) while keeping the rest fixed. From this, it is clear that: \[ |\stab(\alpha)| = \prod_{i=1}^d c_i ! \] For the example, the counts for \([1,1,2]\) are \(c_1=2, \; c_2 = 1\) so \(|\stab([1,1,2])| = 2! \: 1! =2\). With this, we can get the formula for the number of replicated entries of \(\alpha\) by applying the orbit stabilizer theorem: 4 \[ |\orbit(\alpha) | = \frac {p!} {\prod_{i=1}^d c_i!} \]

Basis of symmetric tensors. We know that a lot of multi-indices of a symmetric tensor are redundant. To understand the true structure (like the dimensionality) of symmetric tensors we need to find a way to select an instance of each, non redundant, multi-indices. One way to do that is to restrict oursleves to non decreasing multi-indices. Denote them by \(P = \{\alpha \in \N_d^{\times p} \; | \; \alpha_i \le \alpha_{i+1} \}\). Then we can construct a basis for the space of symmetric tensors out of \(\alpha \in P\) like \[ S^\alpha = \sum_{\rho \in G_p} E^{\rho(\alpha)} \] where \(E^\alpha = \bigotimes^p_{i=1} e_{\alpha_i}\). (the tensors \(E^\alpha\) are a the natural way to construct a basis for the non-symmetric tensors out of the basis \(e_i \in \R^d\)) To convince ourselves that \(\{S^\alpha \; | \; \alpha \in P \}\) forms a basis of the symmetric tensors we need to check that the set is linearly independent and that it spans all symmetirc tensors. Let’s check linear independence first. Assume that we have some coefficients \(x_\alpha \in \R\) s.t. \[ \sum_{\alpha \in P} x_\alpha S^\alpha = 0 \] Then, for any \(\beta \in P\) \[\begin{align} \left[\sum_{\alpha \in P} x_\alpha S^\alpha\right]_\beta &= \sum_{\alpha \in P} x_\alpha S^\alpha_\beta = \sum_{\alpha \in P} x_\alpha \sum_{\rho \in G_p} E^{\rho(\alpha)}_\beta \\ &= \sum_{\alpha \in P} x_\alpha \sum_{\rho \in G_p} \delta(\rho(\alpha) = \beta) \\ \end{align}\] Since \(\alpha, \beta \in P\) the only way there can exist a \(\rho \in G_p\) such that \(\rho(\alpha) = \beta\) is when \(\alpha = \beta\). So \[\begin{align} 0 &= \left[\sum_{\alpha \in P} x_\alpha S^\alpha\right]_\beta \\ &= x_\beta \sum_{\rho \in G_p} \delta(\rho(\beta) = \beta) \\ &= x_\beta \; | \stab (\alpha) | \\ \end{align}\] And, since \(| \stab (\alpha) | \ge 1\) that implies that \(x_\alpha = 0\) and we have that the set of \(S^\alpha\) is linearly independent. To show \(S^\alpha\) span all symmetric tensors it we can just show that, for any symmetric tensor \(T\), if we define \[ Q = \sum_{\alpha \in P} \frac {T_\alpha} {\stab(\alpha)} S^\alpha \] Then \(T = Q\). That can be easily seen by noticing that \(Q\) is a symmetric tensor and that, evaluating \(Q\) at \(\beta \in P\) \[\begin{align} Q_\beta &= \left[\sum_{\alpha \in P} \frac {T_\alpha} {\stab(\alpha)} S^\alpha \right]_\beta = \sum_{\alpha \in P} \frac {T_\alpha} {\stab(\alpha)} \sum_{\rho \in G_p} E^{\rho(\alpha)}_\beta \\ &= \sum_{\alpha \in P} \frac {T_\alpha} {\stab(\alpha)} \sum_{\rho \in G_p} \delta(\rho(\alpha) = \beta) \\ &= \frac {T_\beta} {\stab(\beta)} \sum_{\rho \in G_p} \delta(\rho(\beta) = \beta) \\ &= T_\beta \end{align}\]

Dimension of symmetric tensors. Since we’ve created a basis for the space of symmetric tensors out of non-decreasing sequences we can establish the dimension of the space by counting all such sequences. This is a standard combinatorial problem solved via the method of stars and bars. Which tells us that the dimension is \[\binom{d+p-1}{p}\]

The only thing we must note to apply the standard combinatorial results is that there is a 1-1 correspondance between non-decreasing sequences and multisets of \(\N_d\) with cardinality \(p\). This is because there is a unique way to lay out a multi-set into a non-decreasing sequence. However many \(1\)s there are in the multi-set, they will all come first, then all the \(2\) etc…

Symmetric powers of vectors span the symmetric tensors Showing that all tensors of the form \(v^{\otimes p}\) are symmetric was tivial, but there is a harder question we might ask ourselves. Do we actually need all the \(\binom {d + p - 1} p\) dimensions of the symmetric tensors? The way to formalize this question is to ask whether the space of all symmetric tensors is spanned by rank-1 symmetric tensors \(v^{\otimes p}\). We will prove that the answer is yes by building every basis vector \(S^\alpha\) that way. Concretely, if we define \[ Z^\alpha = \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \left ( \sum_{j=1}^p b_j e_{\alpha_j} \right)^{\otimes p} \] Turns out that \(Z^\alpha = S^\alpha\). It is evident that both, the \(Z^\alpha\) and \(S^\alpha\) are symmetric tensors so, to convince ourselves that they are equivalent, we just need to index them at a non-decreasing multi-index \(\beta \in P\). First, see that \[\begin{align} S^\alpha_\beta &= \sum_{\rho \in G_p} \left [ \bigotimes^p_{i=1} e_{\rho(\alpha)_i} \right]_\beta \\ &= \sum_{\rho \in G_p} \prod^p_{i=1} \delta(\rho(\alpha)_i = \beta_i) \\ &= \sum_{\rho \in G_p} \delta(\rho(\alpha) = \beta) \\ &= |\stab(\alpha) | \; \delta(\alpha = \beta) \\ \end{align}\] And, on the other hand, \[\begin{align} Z^\alpha_\beta &= \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \left [ \left ( \sum_{j=1}^p b_j e_{\alpha_j} \right)^{\otimes p} \; \right ]_\beta \\ &= \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \prod_{i=1}^p \left ( \sum_{j=1}^p b_j \delta(\alpha_j = \beta_i) \right) \\ &= \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \sum_{j_1, \cdots, j_p = 1}^d \prod_{i=1}^p b_{j_i} \delta(\alpha_{j_i} = \beta_i) \\ &= \sum_{j_1, \cdots, j_p = 1}^d \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \prod_{i=1}^p b_{j_i} \delta(\alpha_{j_i} = \beta_i) \\ \end{align}\] During the sum over all combinations of \(j_1, \cdots, j_p\), if there is any \(l \in \N_p\) that does not appear in the set \(j_1, \cdots, j_p\), then that term of the sum will drop out. This is because the only place where \(b_l\) will appear is in the term \((-1)^{b_l-1}\), so since we are summing over \(b_l \in \{0, 1\}\), the terms will cancel out. Thus, we can restrict ourselves to summing over \(j_1, \cdots, j_p\) that contain every element in \(\N_p\). In other words, the \(j\) terms must be a permutation \([j_1, \cdots, j_p] = \rho([1, \cdots, p])\) for some \(\rho \in G_p\). So we can continue \[\begin{align} Z^\alpha_\beta &= \sum_{\rho \in G_p} \sum_{b_1, \cdots, b_p = 0}^1 \prod_{i=1}^p (-1)^{b_i -1} \prod_{i=1}^p b_{\rho(i)} \delta(\alpha_{\rho(i)} = \beta_i) \\ &= \sum_{\rho \in G_p} \prod_{i=1}^p \delta(\alpha_{\rho(i)} = \beta_i) \\ \end{align}\] Where we used the fact that, since the inner term is multiplied by every \(b_i\), the only way the term can be non \(0\) is if every single \(b_i =1\). Finally, we can wrap up the proof, \[\begin{align} Z^\alpha_\beta &= \sum_{\rho \in G_p} \prod_{i=1}^p \delta(\alpha_{\rho(i)} = \beta_i) \\ &= \sum_{\rho \in G_p} \delta(\rho(\alpha) = \beta) \\ &= |\stab(\alpha) | \; \delta(\alpha = \beta) \\ &= S^\alpha_\beta \end{align}\]

4.2 Implementation

The symmetric power embedding \(\phi^p_\text{SYM}(v)\) will give a list of \(\binom{d+p-1}{p}\) numbers, each corresponding to \([v^{\otimes p}]_\alpha\) for a non-decreasing \(\alpha\). Just as we did in the example of \(\phi^2_\text{SYM}\), we also need to apply a correction that is the square root of the duplicate count of that particular \(\alpha\). The inner product of two vectors embedded in this way is identical to the tensor power embedding, \(\phi^2_\text{SYM}(v)^T \phi^2_\text{SYM}(w) = \phi^2_\text {TP}(v)^T \phi^2_\text {TP}(w)\). The following is an example implementation of this embedding:

def symmetric_power_embedding(k, p):
    d = len(k)
    x = []
    for midx in non_decreasing_multiindices(p, d):
        c = count(midx, d)
        xi = np.sqrt(multinomial(c))
        for j in range(p):
            xi *= k[midx[j]]
        x.append(xi)
    return np.array(x)

# -- helper functions --
# generates list of non-decreasing multiindices
def non_decreasing_multiindices(n, max_idx, starting_from=0):
    if n == 1:
        return [[i] for i in range(starting_from, max_idx)]
    seqs = []
    for i in range(starting_from, max_idx):
        seqs += [[i, *remainder] for remainder in
                    non_decreasing_multiindices(n-1, max_idx, starting_from=i)]
    return seqs

# computes multinomial coefficient
def multinomial(lst):
    res, i = 1, 1
    for a in lst:
        for j in range(1, a + 1):
            res *= i
            res //= j
            i += 1
    return res

# given a multiindex, counts how many times each index appears
def count(midx, d):
    c = [0] * d
    for i in midx:
      c[i] += 1
    return c
def even_power(q, k, p):
  return np.inner(q, k) ** p

def symmetric_power_inner_product(q, k, p):
  embedded_q = symmetric_power_embedding(q, p)
  expanded_k = symmetric_power_embedding(k, p)
  return (embedded_q * expanded_k).sum()

d = 8
for p in [2, 4, 6, 8]:
  q = np.random.random(d)
  k = np.random.random(d)
  assert np.allclose(
    even_power(q, k, p), 
    symmetric_power_inner_product(q, k, p)
  )

Using this embedding produces a massive dimensionality reduction compared to the dimensionality of \(\phi_\text{TP} ^p\). The table below compares the size of the state between repeated tensor products and symmetric powers, as measured in bytes (assuming half-precision), for a 124M-parameter GPT-2 transformer at various \(p\).

p Tensor Power Symmetric Power Savings
2 77 MB 39 MB 49%
4 314 GB 14 GB 96%
6 1.3 PB 2.2 TB 99.8%
8 5.3 EB 199 TB 99.996%

We can evaluate each symmetric power architecture against our two metrics, state size (under 80 GB) and performance (loss below baseline).

p State Size Memory ≤ 80 GB? Relative Loss at 100K Steps Loss ≤ baseline?
2 39 MB 1.03x
4 14 GB 0.98x
6 2.2 TB 0.97x

The symmetric power transformer with \(p=4\) passes our bar.

5. Conclusion

In this article, we have introduced the symmetric power transformer, a linear transformer which closes the performance gap to classic softmax transformers using a tractably-small state. We replace the exponentiation in a traditional softmax transformer with an even power, and then show that this is equivalent to a linear transformer with the symmetric power embedding. We expect this approach will provide transformer-level performance at greatly reduced training costs when combined with the chunked algorithm. It will also enjoy cheaper inference, thanks to the constant-time inference costs common to all RNNs. In an upcoming article, we plan to release an open-source model that uses a symmetric power transformer at its core, together with an efficient CUDA kernel implementation. Stay tuned!

Acknowledgments

We would like to thank Warfa Jibril, Jono Ridgway, Saurabh Kumar, Justin Dieter, Fabrice Normandin, and Imanol Schlag for their feedback on an earlier draft of this post, and Txus Bach for correcting the state size calculations.

References

[1]
A. Katharopoulos, A. Vyas, N. Pappas, and F. Fleuret, “Transformers are rnns: Fast autoregressive transformers with linear attention,” in International conference on machine learning, PMLR, 2020, pp. 5156–5165.
[2]
J. Buckman and C. Gelada, “Linear Transformers Are Faster.” Manifest AI, Jan. 05, 2024.
[3]
S. Wang, B. Z. Li, M. Khabsa, H. Fang, and H. Ma, “Linformer: Self-attention with linear complexity,” arXiv preprint arXiv:2006.04768, 2020.
[4]
I. Schlag, K. Irie, and J. Schmidhuber, “Linear transformers are secretly fast weight programmers,” in International conference on machine learning, PMLR, 2021, pp. 9355–9366.
[5]
P. Kacham, V. Mirrokni, and P. Zhong, “Polysketchformer: Fast transformers via sketches for polynomial kernels,” arXiv preprint arXiv:2310.01655, 2023.
[6]
M. Zhang, K. Bhatia, H. Kumbong, and C. Ré, “The hedgehog & the porcupine: Expressive linear attentions with softmax mimicry,” arXiv preprint arXiv:2402.04347, 2024.
[7]
Z. Qin et al., “Hgrn2: Gated linear rnns with state expansion,” arXiv preprint arXiv:2404.07904, 2024.
[8]
H. Peng, N. Pappas, D. Yogatama, R. Schwartz, N. A. Smith, and L. Kong, “Random feature attention,” arXiv preprint arXiv:2103.02143, 2021.
[9]
K. Choromanski et al., “Rethinking attention with performers,” arXiv preprint arXiv:2009.14794, 2020.
[10]
S. Arora et al., “Simple linear attention language models balance the recall-throughput tradeoff,” arXiv preprint arXiv:2402.18668, 2024.
[11]
J. Su, M. Ahmed, Y. Lu, S. Pan, W. Bo, and Y. Liu, “Roformer: Enhanced transformer with rotary position embedding,” Neurocomputing, vol. 568, p. 127063, 2024.
[12]
T. Dao, D. Fu, S. Ermon, A. Rudra, and C. Ré, “Flashattention: Fast and memory-efficient exact attention with io-awareness,” Advances in Neural Information Processing Systems, vol. 35, pp. 16344–16359, 2022.
[13]
J. Buckman, “LongCrawl64: A Long-Context Natural-Language Dataset.” Manifest AI, May 16, 2024.

Footnotes

  1. This is also known as the kernel method. It is essential when one works with infinite-dimensional embeddings, and it’s also useful in this case to avoid materializing large-but-finite embeddings.↩︎

  2. This state-size threshold is admittedly somewhat arbitrary. In principle, larger states are possible with clever sharding; but for excessively large states, which must be sharded across a huge number of GPUs, the hardware cost of sharding becomes completely prohibitive. In this article, we are training models whose parameters fit on a single GPU, so it seems reasonable to use the memory of a single GPU as the threshold for tractability.↩︎

  3. A natural choice for the ordering \(\sigma\) is row major ordering.↩︎

  4. You might have previously seen this expression in the multinomial theorem. This connection is no coincidence. Symmetric powers are highly related to polynomails.↩︎

Citation

BibTeX citation:
@misc{buckman2024,
  author = {Buckman, Jacob and Gelada, Carles and Zhang, Sean},
  publisher = {Manifest AI},
  title = {Symmetric {Power} {Transformers}},
  date = {2024-08-15},
  langid = {en}
}
For attribution, please cite this work as:
J. Buckman, C. Gelada, and S. Zhang, “Symmetric Power Transformers.” Manifest AI, Aug. 15, 2024.