Improving Symmetric Power Transformers with Conformal Transformations
\[ \newcommand{\R}{\mathbb{R}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\N}{\mathbb{N}} \newcommand{\sft}{\text{softmax}} \newcommand{\List}{\text{List}} \newcommand{\Seq}{\text{Seq}} \newcommand{\SeqT}{\text{SeqT}} \newcommand{\CSeqT}{\text{CSeqT}} \newcommand{\Dist}{\text{Dist}} \newcommand{\SM}{\text{SM}} \newcommand{\Fn}{\text{Fn}} \newcommand{\Tok}{\text{Tok}} \newcommand{\Aij}{ A_{[i,j]}} \newcommand{\ten}{\small\text{tensor}} \newcommand{\sym}{\small\text{symmetric}} \]
In a previous article, we introduced symmetric power transformers, which are a variant of linear transformers with an embedding function based on the theory of symmetric tensors. We demonstrated that when using a context size of \(4096\) on the LongCrawl64 dataset, the symmetric power transformer successfully closes the performance gap to classic softmax transformers using a tractably-small state. Specifically, at context size \(4096\) a symmetric power transformer with \(p=4\) satisfies two desired properties: (1) performance matches that of a softmax transformer baseline and (2) has a state size small enough to be tractable (\(\leq\) \(80\) GB).
While having a recurrent formulation allows for efficient training and inference (training cost is linear instead of quadratic and inference cost is constant instead of linear), the capacity constraint on the recurrent state poses a fundamental limitation. A symmetric power transformer with a finite dimensional state size can only store a small fraction of the information that a softmax transformer stores. Since symmetric power transformers don’t have any easy way to erase information in their states, one might worry that they might suffer from poor performance when the context size is large. Indeed, we see that this issue emerges in two ways: (1) at training time and (2) at inference time.
1. At training time. Symmetric power transformers suffer degraded performance compared to the softmax baseline as the context size grows. We see this trend begin to emerge when the symmetric power \(p = 4\), and it occurs at smaller context sizes when \(p=2\).
It is evident that the effect is more pronounced for the \(p=2\) (\(39\) MB state size) model than the \(p=4\) (\(14\) GB state size). But solving this poor scaling of the training context size will be important even for the large models if we want to push them to millions of tokens in the context.
In the linear transformer literature, gating has been shown to help with this problem [1], [2], [3], [4]. Gating allows the network to erase some of the information that is placed on the state. A common approach is to set a different rate of forgetting per head [5].
2. At inference time. If we evaluate the ability of a trained model to make predictions at different context lengths (including ones longer than the ones used during training), we see that there is hardly any ability to generalize beyond the training context. The following contextwise plot shows just that. Following our previous article, it shows the average loss at different context lengths. The training context size is \(16384\) which is indicated by the dashed red line.
In the transformer literature, one approach has become widely adopted to remediate this problem. ALiBi positional embeddings [6] have been shown to make transformers generalize well beyond the training context size. Surprisingly, it turns out that ALiBi and gating are mathematically equivalent (Section 5 gives more details). Further, when we think of standard transformers as linear transformers using an infinite dimensional embedding function, the linear bias of ALiBi corresponds exactly to multiplicative gating on this infinite dimensional state.
One important thing to keep in mind is that any architectural modification to a linear transformer variant must be compatible with both the recurrent and attention formulations. We need to ensure this compatibility both when introducing gating and when integrating rotary embeddings with symmetric power transformers.
In our previous article we briefly discussed the fact symmetric power transformers are compatible with rotary embeddings, but we didn’t expand on the recurrent implementation of rotary embeddings.
In this article, we make the following contributions:
- We apply learned (and data dependent) multiplicative gating developed in prior work [3] to symmetric power transformers. We describe how to implement gating in both the attention and recurrent forms of the symmetric power transformer architecture.
- We discuss how to implement rotary embeddings with symmetric power transformers in both the attention and recurrent formulations.
- We introduce learned and data dependent rotary embeddings by allowing the model to learn adjustments to the standard rotations applied at each time step.
- We highlight that the ideas of gating and rotary embedings can be conceptually unified into a single object: a conformal transformation of the state. This unification does not only hold for our sympow transformers, we’ll see that any classic transformer with rotary embeddings and Alibi can be thought of as an RNN with an infinite dimensional state that has a conformal transformation.
All of these ideas combine to form the conformal symmetric power tranformer, which we call conformal-sympow. Our experiments on the Longcrawl64 dataset demonstrate that conformal-sympow further improves the performance of sympow, especially for longer context lengths. You can see that it doesn’t suffer from the degraded scaling of training context:
And, at inference time generalizes well beyond the training context:
We are happy with this architecture. Our next post will bring everything together by implementing conformal-sympow efficiently with its recurrent and chunked formulations.
1. Symmetric Power Transformers
We begin with a high level overview of symmetric power transformers. The inputs to the layer are sequences of \(Q_i, K_i, V_i \in \R^d\) of queries, keys, and values, where \(i\) ranges from \(1\) to the sequence length \(t\). The outputs are a sequence \(Y_i\in \R^d\). In the attention formulation, the formula for the output vectors is: \[ Y_i = \sum_{j=1}^i A_{ij} V_j \qquad A_{ij} = \frac{B_{ij}}{\sum_{k=1}^i B_{ik}} \qquad B_{ij} = (Q_i^T K_j)^p \qquad \text{(sympow)} \] We refer to \(A_{ij}\) as the attention scores and \(B_{ij}\) as the preattention scores (mirroring the preactivation/activation lanugage often use to desrive the hidden values of an MLP before and after the nonlinearity).
It is important that the power \(p\) is even because that guarantees the denominator is positive, which makes \(A_{i1}, \cdots, A_{ii}\) a valid probability distribution. In turn, this makes the outputs \(Y_i\) a convex combinatoin of \(V_1, \cdots, V_i\).
The exact same outputs \(Y_i\) can be computed via a recurrent formulation. Doing so invovles an embedding function \(\phi^p : \R^d \to \R^D\). The vector \(\phi^p(k)\) contains the same information as \({k\otimes \cdots \otimes k}\), repeatedly taking tensor product \(p\) times. But it does so much more efficiently because it removes a lot of symmetry in the tensor product. Thus \(D << d^p\). Using this embedding function, we can write the recurrent equations: \[ Y_{i} = \frac{S_i \phi^p(Q_i)}{Z_i \phi^p(Q_i)} \qquad Z_i = Z_{i-1} + \phi^p(K_i)^T \qquad S_i = S_{i-1} + V_i \phi^p(K_i)^T \] where \(Z_0\) and \(S_0\) are \(\mathcal 0\) vectors in their respective spaces. Since \(S_i \in \R^{d \times D}\) and \(Z_i \in \R^{D}\), the size of the state is \(D(d+1)\).
These two forms give rise to a variety of algorithms for training linear transformers, with differing computational properties. Read our earlier article on linear transformers for a detailed explanation.
1.1 Rotary embeddings (RoPE)
In our previous post on symmetric power transformers, we briefly discussed that even symmetric power transformers were compatible with RoPE. In this section, we will give the issue the proper discussion it deserves by deriving the attention and recurrent implementations.
Rotary embeddings [7] encode time information by rotating the keys and queries by an amount proportional to their corresponding timestep. The rotation matrix \(R\in \R^{d\times d}\) tells us how much we want to rotate every timestep, so that: \[ Q'_{i} = R^i Q_i \qquad K'_j = R^j K_j \] Then the preattention is changed to: \[ B_{ij} = \left({Q'_{i}}^T K'_j \right)^p = \left({Q_{i}}^T (R^{i-j})^T K_j \right)^p \qquad \text{(sympow rotary)} \] It is evident that the effect of rotation of the embeddings is relative because it modulates interaction between \(Q_i\) and \(K_j\) depending only on the time difference \(i-j\).
The rotation matrix \(R\) is constructed in a particular way. We start with some range of rotation rates \(\theta_1, \theta_2, \cdots, \theta_{\frac d 2}\) defined by the formula \(\theta_i = \frac{2\pi}{N^{\frac{2(i-1)}{d}}}\), where \(N\) is the maximum document size. The vector \(\theta\) contains these rotation rates. Then, the rotation matrix is \[\small R(\theta) = \begin{pmatrix} \cos(\theta_1) & -\sin(\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(\theta_1) & \cos(\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(\theta_2) & -\sin(\theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(\theta_2) & \cos(\theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(\theta_{d/2}) & -\sin(\theta_{d/2}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(\theta_{d/2}) & \cos(\theta_{d/2}) \end{pmatrix} \]
When multiplying a query or key vector by this rotation matrix, each pair of dimensions indexed by \(2j - 1\) and \(2j\) for \(j \in \{1, 2, ..., \frac{d}{2} \}\) is rotated by a different amount \(\theta_j\). Rotating each pair by a different angle helps break symmetry and increases the expressiveness of the positional encodings.
A computational advantage of using rotation matrices of this form is that \(R(\theta)^k = R(k \theta)\), which massively simplifies the cost of computing all the \(Q'_i\) and \(K'_j\).
Expand to see a proof of this fact
We are given a block diagonal rotation matrix \(R(\theta)\), where each \(2 \times 2\) block corresponds to a rotation by some angle \(\theta_i\):
\[ R(\theta) = \begin{pmatrix} \cos(\theta_1) & -\sin(\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ \sin(\theta_1) & \cos(\theta_1) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(\theta_2) & -\sin(\theta_2) & \cdots & 0 & 0 \\ 0 & 0 & \sin(\theta_2) & \cos(\theta_2) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(\theta_{d/2}) & -\sin(\theta_{d/2}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(\theta_{d/2}) & \cos(\theta_{d/2}) \end{pmatrix}. \]
We aim to prove that \(R(\theta)^k = R(k \theta)\), where \(k\) is a positive integer, and \(k \theta = (k \theta_1, k \theta_2, \dots, k \theta_{d/2})\).
We prove this statement by induction on \(k\) for a single \(2 \times 2\) rotation matrix \(R(\theta_i)\), and then extend it to the full block diagonal matrix.
For the base case (k = 1), we have:
\[R(\theta_i)^1 = R(\theta_i)\],
which is equivalent to \(R(1 \cdot \theta_i) = R(\theta_i)\). Thus, the base case holds.
Assume that for some positive integer \(k\), the property holds:
\[R(\theta_i)^k = R(k \theta_i).\]
We need to show that \(R(\theta_i)^{k+1} = R((k+1)\theta_i)\). Using the definition of matrix exponentiation:
\[R(\theta_i)^{k+1} = R(\theta_i) R(\theta_i)^k.\]
By the inductive hypothesis, \(R(\theta_i)^k = R(k\theta_i)\). Substituting this:
\[R(\theta_i)^{k+1} = R(\theta_i) R(k \theta_i).\]
The product of two rotation matrices corresponds to a rotation by the sum of their angles. Therefore:
\[R(\theta_i) R(k \theta_i) = R(\theta_i + (k \theta_i)) = R((k+1)\theta_i).\]
Thus, \(R(\theta_i)^{k+1} = R((k+1)\theta_i)\), completing the inductive step. By induction, the property holds for all ( k ).
Consider the block diagonal matrix ( R() ), where: \[ R(\theta) = \begin{pmatrix} R(\theta_1) & 0 & \cdots & 0 \\ 0 & R(\theta_2) & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R(\theta_{d/2}) \end{pmatrix}. \]
Since each block \(R(\theta_i)\) is independent of the others, the \(k\)-th power of \(R(\theta)\) is the block diagonal matrix with each block raised to the \(k\)-th power: \[ R(\theta)^k = \begin{pmatrix} R(\theta_1)^k & 0 & \cdots & 0 \\ 0 & R(\theta_2)^k & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R(\theta_{d/2})^k \end{pmatrix}. \]
Using the result for a single rotation matrix, \(R(\theta_i)^k = R(k\theta_i)\), we get: \[ R(\theta)^k = \begin{pmatrix} R(k \theta_1) & 0 & \cdots & 0 \\ 0 & R(k \theta_2) & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R(k \theta_{d/2}) \end{pmatrix}. \]
This is equivalent to the block diagonal matrix \(R(k\theta)\), where \(k\theta = (k\theta_1, k\theta_2, \dots, k\theta_{d/2})\).
Thus, by induction and the block diagonal structure, \(R(\theta)^k = R(k\theta)\) for any positive integer \(k\).
Now we want to find the recurrent formulation of rotary embeddings with symmetric power transformers. A simple way we can do that is by including one extra vector in the recurrent state which is now a tuple \((S, Z, \mu)\), where \(\mu \in \R^{\frac d 2}\). The recurrent equations are given by
\[ Z_i = Z_{i-1} + \phi^p(R(\mu_i) K_i)^T \qquad S_i = S_{i-1} + V_i \phi^p(R(\mu_i) K_i)^T \qquad \mu_i = \mu_{i-1} + \theta \]
Note we rotate the keys by \(R(\mu_i)\) before using them.
Given \(S_i\) and \(Z_i\), the outputs are the same as before, except that we rotate the queries by \(R(\mu)\) before using them: \[ Y_{i} = \frac{S_i \phi^p( R(\mu_i) Q_i)}{Z_i \phi^p( R(\mu_i) Q_i)} \]
Expand for a proof of the equivalence between the state and recurrent formulations
We begin by writing the output \(Y_i\) at time step \(i\) in the attention formulation. For notational simplicity, let \(C_i = \sum_{k=1}^i (Q_i'^T K_k')^p = \sum_{k=1}^i \phi^p(Q_i')^T \phi^p(K_k')\).
\[ \begin{align} Y_i &= \sum_{j=1}^i \frac{\left( Q_i^T (R^{i-j})^T K_j\right) ^p V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\left( Q_i^T (R^i)^T R^j K_j\right)^p V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\left(Q_i^T R(\mu_i)^T R(\mu_j) K_j\right)^p V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\left((R(\mu_i) Q_i)^T R(\mu_j) K_j\right)^p V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\left(\phi^p(R(\mu_i) Q_i)^T \phi^p(R(\mu_j) K_j)\right) V_j}{C_i} \\ &= \sum_{j=1}^i \frac{V_j \phi^p(R(\mu_j) K_j)^T \phi^p(R(\mu_i) Q_i)}{C_i} \\ &= \frac{\left( \sum_{j=1}^i V_j \phi^p(R(\mu_j) K_j)^T \right) \phi^p(R(\mu_i) Q_i) }{C_i} \\ &= \frac{S_i \phi^p(R(\mu_i) Q_i) }{Z_i \phi^p(R(\mu_i) Q_i)} \end{align} \]
which is the recurrent formulation of the output \(Y_i\). The last line above uses the fact that
\[ \sum_{j=1}^i V_j \phi^p(R(\mu_j) K_j)^T = S_i \]
and
\[ \begin{align} C_i &= \sum_{k=1}^i \phi^p(Q_i')^T \phi^p(K_k') \\ &= \sum_{k=1}^i \phi^p(R(\mu_i) Q_i)^T \phi^p(R(\mu_k) K_k) \\ &= \sum_{k=1}^i \phi^p(R(\mu_k) K_k)^T \phi^p(R(\mu_i) Q_i) \\ &= Z_i \phi^p(R(\mu_i) Q_i) \end{align} \]2. Gating
The basic idea of gating is that at each time step the state matrix \(S\in \R^{d \times D}\) will be discounted by a scalar \(\gamma \in [0, 1]\). Discounting the state ``erases” past information stored in the state. This technique has been used extensively throughout the linear transformer literature [8], [9], [10]. One common approach to implement gating is to manually pick a gating value for each head, usually using a range of large and small \(\gamma\) for different heads to allow the model to keep track of short and long term interactions. But, naturally, the gating values can also be learnable parameters or even data dependent values, as has been thoroughly explored in prior work [1], [2], [3], [4].
After exploring with a few variations of gating, including fixed, learnable and data dependent versions, we ended up converging to the technique used in [3]. The discount value at timestep \(i\) is \(\gamma_i = \sigma(W_\gamma X_i)\) where \(\sigma\) refers to the sigmoid function, \(W_\gamma \in \mathbb{R}^{d \times 1}\) and \(X_i\) are the input sequence we used to compute the keys, queries, and values (e.g. \(K_i = W_K X_i\)). When using symmetric power attention with power \(p\), the recurrent state update is simply \[ Z_i = \gamma_i Z_{i-1} + \phi^p(K'_i)^T \qquad S_i = \gamma_i S_{i-1} + V_i \phi^p(K'_i)^T \qquad \mu_i = \mu_{i-1} + \theta \] Recall that \(K'_i = R(\mu_i) K_i\).
To write the attention formulation we define \(b_{ij} = \Pi_{k=j+1}^i \gamma_m\). Then, in the attention formulation, the preattention becomes \[ B_{ij} = b_{ij} \; ( {Q'_i}^T K'_j)^p \qquad \text{(sympow gated)} \]
Expand for a derivation of the equivalence between the state and attention formulations
We begin by writing the output \(Y_i\) at time step \(i\) in the attention formulation. For notational simplicity, let \(C_i = \sum_{k=1}^i \beta_{ik} (Q_i'^T K_k')^p = \sum_{k=1}^i \beta_{ik} \phi^p(Q_i')^T \phi^p(K_k')\).
\[ \begin{align} Y_i &= \sum_{j=1}^i \frac{\beta_{ij} \left(Q_i'^T K'_j\right)^p V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\beta_{ij} \left(\phi^p(Q'_i)^T \phi^p(K'_j)\right) V_j}{C_i} \\ &= \sum_{j=1}^i \frac{\beta_{ij} V_j \phi^p(K'_j)^T \phi^p(Q'_i)}{C_i} \\ &= \frac{\left( \sum_{j=1}^i \beta_{ij} V_j \phi^p(K'_j)^T \right) \phi^p(Q'_i) }{C_i} \\ &= \frac{S_i \phi^p(Q'_i) }{Z_i \phi^p(Q'_i)} \end{align} \]
which is the recurrent formulation of the output \(Y_i\). The last line above uses the fact that
\[ \sum_{j=1}^i \beta_{ij} V_j \phi^p(K'_j)^T = S_i \]
and
\[ \begin{align} C_i &= \sum_{k=1}^i \beta_{ik} \phi^p(Q_i')^T \phi^p(K_k') \\ &= \sum_{k=1}^i \beta_{ik} \phi^p(K_k')^T \phi^p(Q_i') \\ &= Z_i \phi^p(Q'_i) \end{align} \]
We prove that \(S_i = \sum_{j=1}^i \beta_{ij} V_j \phi^p(K'_j)^T\) by induction.
As the base case, note that \(S_1 = V_1 \phi^p(K'_1)^T\)
For the inductive step, suppose that for \(k > 1\), \(S_k = \sum_{j=1}^k \beta_{kj} V_j \phi^p(K'_j)^T\). Then \[ \begin{align} S_{k+1} &= \gamma_{k+1} S_k + V_{k+1}\phi^p(K'_{k+1})^T \\ &= \gamma_{k+1} \left( \sum_{j=1}^k \beta_{kj} V_j \phi^p(K'_j)^T \right) + V_{k+1}\phi^p(K'_{k+1})^T \\ &= \left( \sum_{j=1}^k \beta_{(k+1)j} V_j \phi^p(K'_j)^T \right) + V_{k+1}\phi^p(K'_{k+1})^T \\ &= \sum_{j=1}^{k+1} \beta_{(k+1)j} V_j \phi^p(K'_j)^T \end{align} \] This completes the inductive step.Let’s see what difference gating makes, and whether it solves the issues we encountered in the intro. We can look at the loss at the end of training on 400k documents. As we grow the train context size, gated sympow stays better than the baseline as far as we were able to test it!
We also see that after adding learned gating, the symmetric power transformer is able to successfully generalize past the train context size.
3. Learned Rotary Embeddings
We now explore an intuitive idea of learning the rotation rates in rotary positional embeddings. There are many ways we can approach this. For example, the network could independently decide how much to rotate each of the 2D subspaces. We found that an efficient way to implement learning rotation rates that results in performance improvements is for the network to scale the fixed \(\theta\) vector that RoPE usually applies. Similar to gating, we add parameters \(W_\beta\) to each attention head. The network outputs a scalar \(\beta_i = 1 + \text{tanh}(W_\beta X_i)\) and multiplies this value with the original fixed vector \(\theta\). In the recurrent formulation, this produces the equation \[ \mu_i = \mu_{i-1} + \beta_i \theta \] and in the attention formulation \[ B_{ij} = \left(Q_i R(c_{ij}\theta) K_j \right)^p \qquad c_{ij} = \sum_{k=j+1}^i \beta_i \qquad \text{(conformal-sympow)} \] For reasons discussed in Section 4, we refer to this approach as conformal-sympow.
We can see that learning the rotary embeddings in addition to learning gating values further improves performance over sympow+gating with fixed rotary embeddings:
and generalizes past the train context size:
Full Training Curves
Here we display all the training curves for all methods using symmetric power \(p=2\) and \(p=4\) at different context lengths. We can see that sympow+gating and conformal-sympow improves optimization throughout training.
4. Conformal State Transformations (Optional Reading)
*Note: this section is conceptual and has no practical implications. In this section, we show that gating and rotations can be unified into a single mathematical idea: a conformal state transformation.
We refer to the combination of gating and rotary embeddings as conformal-sympow. The reason stems from the fact that the combination of gating and rotary embeddings can be interpreted as applying a conformal linear transformation to the state in the recurrent formulation. A conformal linear transformation is a type of linear transformation that preserves angles between vectors while allowing uniform scaling of lengths. Mathematically, a conformal linear transformation in \(n\)-dimensional Euclidean space can be expressed as: \[ \mathbf{T}(\mathbf{x}) = s \mathbf{R} \mathbf{x}, \] where \(s > 0\) is a scalar representing the scaling factor, \(\mathbf{R}\) is an orthogonal matrix \(\mathbf{R}^T \mathbf{R} = \mathbf{I}\), and \(\mathbf{x}\) is the input vector.
We will show that updating the recurrent state of a sympow transformer by applying gating and rotary embeddings is equivalent to right multiplying the state with a conformal linear transformation before adding new information:
\[ S_{i} = S_{i-1} (s\mathbf{R}) + V_i \phi^p(K_i)^T \]
When applying gating, we can see that the discount value \(\gamma\) plays the role of \(s\) in the above equation.
Now, recall that the recurrent state update when applying rotary embeddings is \[ Z_i = Z_{i-1} + \phi^p(K'_i)^T \qquad S_i = S_{i-1} + V_i \phi^p(K'_i)^T \qquad \mu_i = \mu_{i-1} + \theta \] where \(K'_i = R(\mu_i) K_i\).
We will show that the update equation \(S_i = S_{i-1} + V_i \phi^p(R(\mu_i)K_i)^T\) is equivalent to the following update equation: \[ S_i = S_{i-1} \bar{R}(\theta) + V_i \phi^p(K_i)^T \] for some rotation matrix \(\bar{R}(\theta)\).
This equivalence stems from the following result. If \(P\in \R^{d\times d}\) is a rotation matrix, then there exists another rotation matrix \(\bar P \in \R^{D \times D}\) s.t. \[ \phi^p( P k) = \bar P \phi^p(k) \]
Expand for a proof of this fact
Note that the symmetric power embedding function is equivalent to applying the tensor product and removing redundant information resulting from symmetry. For mathematical simplicity, we prove the corresponding result for which the embedding function is the repeated tensor product \(\otimes^p\). The corresponding proposition is stated below.
Let \(V\) be a vector space with dimension \(d\) and basis vectors \(\{ v_1, v_2, \dots, v_d \}\), and let \(P \in \mathbb{R}^{d \times d}\) be a rotation matrix. Define the linear map \(\bar{P} : V^{\otimes p} \to V^{\otimes p}\) (the tensor product of \(p\) copies of \(V\)) by its action on the basis elements as \[ \bar{P}(v_{i_1} \otimes v_{i_2} \otimes \dots \otimes v_{i_p}) = (P v_{i_1}) \otimes (P v_{i_2}) \otimes \dots \otimes (P v_{i_p}) \quad \text{for all } i_1, i_2, \dots, i_p. \] Then \(\bar{P} \in \mathbb{R}^{d^p \times d^p}\) is a rotation matrix.
We need to show that \(\bar{P}\) satisfies the properties of a rotation matrix, namely: 1. \(\bar{P}\) is an orthogonal matrix, i.e., \(\bar{P}^T \bar{P} = I\). 2. \(\det(\bar{P}) = 1\), so that \(\bar{P}\) represents a proper rotation.
Step 1: Orthogonality of \(\bar{P}\).
Since \(\bar{P}\) is defined by its action on the basis elements of \(V^{\otimes k}\) as \[ \bar{P}(v_{i_1} \otimes v_{i_2} \otimes \dots \otimes v_{i_p}) = (P v_{i_1}) \otimes (P v_{i_2}) \otimes \dots \otimes (P v_{i_p}), \] and \(P\) is an orthogonal matrix, i.e., \(P^T P = I_d\), where \(I_d\) is the identity matrix in \(\mathbb{R}^{d}\), we need to verify that \(\bar{P}\) preserves the inner product in the tensor product space. The inner product of two basis elements \(v_{i_1} \otimes v_{i_2} \otimes \dots \otimes v_{i_p}\) and \(v_{j_1} \otimes v_{j_2} \otimes \dots \otimes v_{j_p}\) in \(V^{\otimes p}\) is given by: \[ \langle v_{i_1} \otimes v_{i_2} \otimes \dots \otimes v_{i_p}, v_{j_1} \otimes v_{j_2} \otimes \dots \otimes v_{j_k} \rangle = \prod_{p=1}^{p} \langle v_{i_p}, v_{j_p} \rangle. \] Applying \(\bar{P}\) to this inner product, we get: \[ \langle \bar{P}(v_{i_1} \otimes \dots \otimes v_{i_p}), \bar{P}(v_{j_1} \otimes \dots \otimes v_{j_p}) \rangle = \prod_{p=1}^{p} \langle P v_{i_p}, P v_{j_p} \rangle. \] Since \(P\) is orthogonal, we have \(\langle P v_{i_p}, P v_{j_p} \rangle = \langle v_{i_p}, v_{j_p} \rangle\) for each \(p\). Therefore, \(\bar{P}\) preserves the inner product, meaning that \(\bar{P}\) is an orthogonal matrix, i.e., \(\bar{P}^T \bar{P} = I_{d^p}\).
Step 2: Determinant of \(\bar{P}\).
Next, we show that \(\det(\bar{P}) = 1\). Since \(\bar{P} = P \otimes P \otimes \dots \otimes P\) (a \(p\)-fold tensor product of \(P\) with itself), we can use the property of the determinant for tensor products of matrices. Specifically, if \(A\) and \(B\) are square matrices, then: \[ \det(A \otimes B) = \det(A)^{\dim(B)} \det(B)^{\dim(A)}. \] In our case, since \(\bar{P} = P \otimes P \otimes \dots \otimes P\), we have: \[ \det(\bar{P}) = \det(P)^{p \cdot d}. \] Since \(P\) is a rotation matrix in \(\mathbb{R}^d\), we know that \(\det(P) = 1\). Therefore: \[ \det(\bar{P}) = 1^{p \cdot d} = 1. \] Thus, \(\bar{P}\) is a proper rotation matrix.Using the above result, the conformal-sympow recurrent state update can be written as follows:
\[ Z_i = Z_{i-1} (\gamma_i \bar{R}(\theta, \beta_i)) + \phi^p(K_i)^T \qquad S_i = S_{i-1} (\gamma_i \bar{R}(\theta, \beta_i)) + V_i \phi^p(K_i)^T \] where \(\bar{R}(\theta, \beta_i)\) is a rotation matrix that depends on the fixed rotation rates \(\theta\) and the scalar \(\beta_i\).
5. Equivalance between Gating and ALiBi (Optional Reading)
Attention with Linear Biases (ALiBi) [6] is a type of positional encoding that significantly improves the ability of softmax transformers to extrapolate to evaluation contexts longer than the training context size. ALiBi biases query-key attention scores with a penalty that is proportional to the distance between the query and key. We now show that ALiBi is equivalent to applying scalar gating.
In a softmax transformer, the attention scores are computed as \[ A_{ij} = \frac{B_{ij}}{\sum_{k=1}^i B_{ik}} \qquad B_{ij} = \text{exp}(Q_i^T K_j) \qquad \text{(softmax)} \] Recall that we refer to \(A_{ij}\) as the attention scores and \(B_{ij}\) as the pre-attention scores.
The pre-attention scores after applying ALiBi are \[ B_{ij} = \text{exp}(Q_i^T K_j + m(j -i)) \qquad \text{(softmax + ALiBi)} \] where \(0 < m < 1\) is a head-specific value that is fixed before training.
Note that \[\text{exp}(Q_i^T K_j + m(j - i)) = \gamma^{(i - j)} \text{exp}(Q_i^T K_j)\] where \(\gamma = \text{exp}(-m)\). Since \(-m < 0\), \(0 < \gamma < 1\). Thus, the application of ALiBi is equivalent to applying scalar gating.
References
Citation
@misc{buckman2024,
author = {Buckman, Jacob and Gelada, Carles and Kumar, Saurabh and
Zhang, Sean},
publisher = {Manifest AI},
title = {Improving {Symmetric} {Power} {Transformers} with {Conformal}
{Transformations}},
date = {2024-12-10},
langid = {en}
}