In this post, I’ll show how a small set of reasonable assumptions can recover the Transformer attention mechanism. Some parts of attention are theoretically motivated, while others are arbitrary choices. I’ll explicitly call out which is which.
To see why attention exists, it helps to recall its predecessor: the recurrent neural network (RNN). Classic encoder-decoder RNNs process a sequence token by token. Each new hidden state incorporates the current token and the previous hidden state, producing a vector you can think of as an “accumulator” of everything seen so far. After ingesting the final token, that accumulated vector is repeatedly fed to the decoder, which predicts output tokens until it emits a STOP symbol.
The problem is long-range dependence. If an important token appeared far earlier in the sequence (say, the first of 10,000 tokens), its influence becomes diluted as the RNN processes additional tokens. The model simply forgets.
Ideally, the model should use all previously seen tokens to compute the information needed to predict the next token, weighting each earlier token by how relevant it is for that prediction. That suggests computing a relevance score between a position i and each other position j, and then combining some function of the embeddings accordingly.
Formally, define a scalar relevance function that takes the embeddings X along with indices i and j:
We work in embedding space rather than raw token IDs to avoid meaningless geometric assumptions (e.g., token 9 is not inherently “closer” to token 10). One-hot encodings would work, but are much more sparse.
Then the model’s output vector at position i can be written as:
where G aggregates some function of the embeddings xj alongside their relevance to position i. (During autoregressive generation, i corresponds to the most recently produced token.) We don’t yet know the form of G or u. Our goal is to characterize the simplest constraints that lead directly to Transformer-style attention.
Observation: Enforce permutation symmetry
We want to constrain the space of possible functions for G and u.
Once we have the relevance scores u(i, j), the output yi should not depend on the order in which the pairs (xj, u(i, j)) are provided. In other words, if we reorder the elements indexed by j, the result should remain the same. This requires G to be “permutation-invariant” over the set {(xj, u(i, j))}j.
The Deep Sets theorem (Zaheer et al., 2017) tells us that any such function can be written as:
Here ρ and φ are arbitrary differentiable functions. We fix the index i, since the invariance only applies over j. Differentiability ensures that the overall model can be trained with gradient-based methods.
At this point, ρ and φ are still completely general, and we also need to define u. We will impose further assumptions to narrow down their form.
Assumption 1: ρ is the identity function
ρ could output many different types of objects. For example:
It could output a scalar, but that would discard most of the information from the embeddings.
It could output an O(N)-dimensional vector, with one component per input element, but that would make the output scale with sequence length and defeat the purpose of summarizing information.
It could output a vector in some intermediate dimension, or even map into a different space/manifold entirely.
All of these are technically possible. In practice, Transformers set ρ to be the identity function, so:
This simplifies the structure of G and lets us focus on constraining φ and u.
Assumption 2: Relevance-contribution proportionality
Even with ρ set to the identity, φ could be any function of the embedding xj and the relevance score u(i, j). To simplify the form, we assume that if a token’s relevance is scaled by a constant k, its contribution scales by the same factor:
This is not the only possible relationship. For example, we could have chosen a quadratic or some other monotonic transformation in u(i, j). The key requirement is simply that φ should separate into:
A scalar measuring how important xj is
A vector capturing what xj contributes
Under the linear version of this assumption, we get:
Define v(xj) = φ(xj, 1), yielding:
This makes φ explicitly separable, where u(i, j) purely controls magnitude (relevance), and v(xj) determines the content being contributed.
Assumption 3: Linear change of coordinates
At this point, v(xj) could be any function of xj. To simplify the model and keep it efficient to compute, we assume v is a linear transformation of xj:
Substituting this into the previous expression gives:
This means each token contributes a linearly transformed version of its embedding, weighted by its relevance score u(i, j).
Observation: Constrain u for efficient parallel computation
We want u(i, j) to be computable efficiently on hardware like GPUs. Here, “efficient” refers to low sequential depth in the computational graph, not necessarily a low number of arithmetic operations. GPUs can execute many multiplications in parallel, but long chains of dependent operations create bottlenecks. For example, a recurrent computation with O(N) sequential steps is slow for long sequences, but a matrix multiply has O(1) sequential depth and is highly parallelizable.
If we allowed a fully general relevance function such as:
where context(X) examines all tokens at once, we would need to evaluate this network O(N2) times for a single layer, which is too slow.
Alternatively, we could define a single model:
that outputs all pairwise relevance values directly. But that would require storing and training parameters of size O(N2), which locks the model to a fixed input length and scales poorly.
To keep computation parallelizable and scalable, we restrict u to be built from tensor operations such as:
Linear projections
Element-wise functions
Inner products
Reductions like sums
and avoid control flow or long sequential recurrences.
Assumption 4: Dot product similarity for u
A simple way to score the interaction between xi and xj is with a dot product. However, we don’t necessarily want similarity in the embedding space - we want similarity in a space optimized for relevance.
So, as we did for v(xj), we first apply learned linear projections:
Then we define the relevance score as:
We denote this version as u’ because additional modifications will be applied later.
Assumption 5: Pick a normalization for u
Next, we want the relevance scores u’(i, j) to measure relative importance. If the same constant were added to all scores, or if they were scaled uniformly, the ranking of tokens should not change. This motivates applying a differentiable normalization function over j.
There are several possibilities (e.g., softmax, Gumbel-Softmax). In practice, Transformers use softmax.
One final issue: the dot product ⟨qi, kj⟩ tends to grow in magnitude with the key/query dimension dk. To prevent extremely large values from dominating the softmax, we scale the logits:
Applying softmax normalization over j then gives:
This is exactly the scaled dot-product attention used in Transformers.
Does a better attention mechanism exist?
So there you have it. If we impose the following assumptions:
ρ is the identity function
Each token’s contribution scales proportionally with its relevance score
A linear transformation maps embeddings to the value, key, and query vectors
Relevance is based on a dot product
Relevance scores are normalized with a softmax
we obtain the exact scaled dot-product attention used in Transformers.
While some of the choices were forced, they weren’t all theoretically required. There may be better options for ρ, for the similarity measure, or for the normalization function. Even more fundamentally, the Deep Sets form at the beginning was from imposing permutation-invariance, but we end up reinjecting positional encodings in practice. Exploring these variations could reveal new attention mechanisms with different computational or modeling advantages.

