Positional Encoding in Transformers: From Sinusoidal to RoPE
Transformers are permutation-equivariant by design—without explicit positional information, a self-attention layer cannot distinguish “the cat sat on the mat” from any permutation of those tokens. How we inject position has turned out to matter far more than the original sinusoidal encoding suggested: it directly governs length generalization, relative-distance sensitivity, and even training stability at scale.
The field has converged on a key insight: encoding relative distances within the attention logits (rather than adding absolute embeddings to token representations) yields architectures that generalize more gracefully beyond their training context length. But the design space is rich. RoPE achieves relative encoding through rotation matrices in complex space; ALiBi takes the minimalist route of a fixed linear bias; T5 bias learns the entire distance-to-bias mapping; KERPLE generalizes to learnable kernel functions.
This post is a comprehensive survey: we derive each method from first principles, compare their properties (trainability, per-layer application, extrapolation capability), and cover the increasingly important topic of context extension—position interpolation, NTK-aware scaling, and dynamic RoPE—that enables pretrained models to handle sequences far longer than they were trained on. Reference implementations are included throughout.
Summary
| PE | Relative | Trainable | Each Layer | Extrapolation |
|---|---|---|---|---|
| Sinusoidal | ✘ | ✘ | ✘ | ✘ |
| T5 bias | ✔ | ✔ | ✔ | ✔ |
| RoPE | ✔ | ✔ | ✔ | ✘ |
| ALiBi | ✔ | ✘ | ✔ | ✔ |
| KERPLE | ✔ | ✔ | ✔ | ✔ |
| Sandwich | ✔ | ✘ | ✔ | ✔ |
| xPos | ✔ | ✘ | ✔ | ✔ |
Position Encoding
Sinusoidal Position Embeddings
Sinusoidal position embeddings1 are fixed vectors added to the token embeddings at the first transformer layer.
\(\text{PE}_{(\text{pos}, 2i)} = \sin(\frac{\text{pos}}{10000^{2i/d_\text{model}}})\) \(\text{PE}_{(\text{pos}, 2i+1)} = \cos(\frac{\text{pos}}{10000^{2i/d_\text{model}}})\) where $\text{pos}$ is the position in the sentence and $i$ is the dimension index along the embedding vector. This design allows the model to learn to attend by relative positions, since for any fixed offset $k$, \(\text{PE}_{\text{pos}+k}\) can be represented as a linear function of \(\text{PE}_{\text{pos}}\).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Positional encoding layer in PyTorch
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0., max_len).unsqueeze(1) # generate with maximum length
div_term = torch.exp(torch.arange(0., d_model, 2) * - (math.log(1e4) / d_model))
pe[:, ::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
seq_len = x.size(1) # take the sequence length
x = x + Variable(self.pe[:, :seq_len], requires_grad=False)
return self.dropout(x)
Note:
- Refer to Attention Mechanisms and the Transformer.
- See also: BERTology: From XLNet to ELECTRA for Transformer-XL’s relative positional encoding.
Rotary Position Embedding (RoPE)
Rotary Position Embedding (RoPE)23 encodes positional information by operating in complex space. Instead of working in $\mathbb{R}^d$, it treats consecutive pairs of query and key dimensions as complex numbers in $\mathbb{C}^{d/2}$.
Specifically, instead of viewing \(\mathbf{q}=(q_1,q_2,q_3,q_4,\ldots,q_{d})\) as a $d$-dimensional real vector, RoPE views it as \(\mathbf{q}=(q_1+iq_2, q_3+iq_4,\ldots, q_{d-1} + iq_{d})\in\mathbb{C}^{d/2}\). If $d$ is odd, RoPE pads with a dummy coordinate (or simply increments $d$ by one) to ensure dimensional alignment.
Derivation
The complex number format of RoPE is written as:
\[\begin{align} f(\mathbf{q}, m) = R_f(\mathbf{q}, m)e^{i\Theta_f(\mathbf{q}, m)}=\mathbf{q}e^{i(\Theta(\mathbf{q})+m\mathbf{\theta})} = \sum_{j=1}^{d/2} q_je^{im\theta_j} \vec{e_j} \end{align}\]Here, $m$ indicates the $m$-th position of the sequence.
It is convenient to convert into matrix equation:
\[\begin{align} f(\mathbf{q}, m) = \begin{pmatrix} M_1 & & & \\ & M_2 & & \\ & & \ddots & \\ & & & M_{d/2} \end{pmatrix} \begin{pmatrix} q_1\\ q_2\\ \vdots\\ q_d \end{pmatrix} = \mathbf{\Theta_m Q_m} = \mathbf{\Theta_m W_q X_m} \end{align}\]where \(M_j=\begin{pmatrix}\cos m\theta_j & -\sin m\theta_j \\ \sin m\theta_j & \cos m\theta_j\end{pmatrix}\), \(\mathbf{\Theta_m}\) is the block-diagonal rotation matrix, \(\mathbf{W_q}\) is the learned query weight, and \(\mathbf{X_m}\) is the embedding of the $m$-th token.
Due to the high computation cost of sparse matrix, it is implemented as:
\[\begin{align} f(\mathbf{q}, m) = \begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{pmatrix}\odot\begin{pmatrix}\cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix}-q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{pmatrix}\odot\begin{pmatrix}\sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{pmatrix} \end{align}\]where $\odot$ denotes the element-wise (Hadamard) product.

- Extension to Multiple Dimensions
Note: > Difference from sinusoidal embedding3
- Sinusoidal embeddings apply to each coordinate individually, while RoPE mixes pairs of coordinates.
- Sinusoidal embeddings add a $\cos(m\theta)$ or $\sin(m\theta)$ term, while RoPE uses a multiplicative factor.
Implementation
Huggingface version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Transformers implementation: https://github.com/huggingface/transformers/blob/e42587f596181396e1c4b63660abf0c736b10dae/src/transformers/models/llama/modeling_llama.py#L180
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embedding=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embedding = max_position_embedding
self.base = base
# $$\theta_i = 10000^{-2i/d}$$
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer('inv_freq', inv_freq)
self._set_cos_sin_cache(
seq_len=max_position_embedding, device=device, dtype=self.inv_freq.dtype)
def _set_cos_sin_cache(self, seq_len, device=None, dtype=None):
self.max_seq_len_cached = seq_len
# $m$
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
# $m \theta$
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# $$R_{m\theta}$$
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None,:,:].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None,:,:].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bsz, num_heads, seq_len, d_head]
if seq_len > self.max_position_embedding:
self._set_cos_sin_cache(seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# the first two dim of cos and sin are always 1, so we can squeeze them
cos = cos.squeeze(1).squeeze(0) # seq_len, dim
sin = sin.squeeze(1).squeeze(0) # seq_len, dim
cos = cos[position_ids].unsqueeze(1) # bsz, 1, seq_len, dim
sin = sin[position_ids].unsqueeze(1) # bsz, 1, seq_len, dim
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def rotate_half(x):
""" rotate half the hiddim dims of inputs """
x1 = x[..., :x.shape[-1]//2]
x2 = x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
"""
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
"""
rotary_emb = LlamaRotaryEmbedding(head_dim, max_position_embeddings)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) # get cos/sin cache
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Llama version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Args:
dim (int): dimension of the frequency tensor.
"""
freqs = 1.0 / (theta ** (torch.arange(0,dim,2)[:(dim//2)].float()/dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex 64
return freqs_cis
def apply_rotary_emb(xq, xk, freqs_cis):
"""
# Apply rotary embeddings to input tensors using the given frequency tensor.
The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broad_cast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def reshape_for_broad_cast(freqs_cis, x):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
freqs_cis = precompute_freqs_cis(d_model//num_heads, max_seq_len*2)
# start_pos: starting position for attention caching
freqs_cis = freqs_cis[start_pos: start_pos+seqlen]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
Roformer version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# source: https://huggingface.co/transformers/v4.8.0/_modules/transformers/models/roformer/modeling_roformer.html
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
# https://kexue.fm/archives/8265
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer
)
query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
if value_layer is not None:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
value_layer
)
value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
return query_layer, key_layer, value_layer
return query_layer, key_layer
sinusoidal_pos = RoFormerSinusoidalPositionalEmbedding(
max_position_embeddings,
config.hidden_size // config.num_attention_heads
)
q, k = apply_rotary_position_embeddings(sinusoidal_pos, q, k)
GPT-NeoX (PyTorch) implementation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=x1.ndim - 1
) # dim=-1 triggers a bug in torch < 1.8.0
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
rotary_emb = RotaryEmbedding(d_head)
cos, sin = rotary_emb(value_states)
q, k = apply_rotary_pos_emb(q,k, cos, sin)
Notes:
| Equation | Code |
|---|---|
| \(\theta_i = 10000^{-2i/d}\) | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) |
| \(m\theta_i\) | freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| \(R_{\theta} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix}\) | emb.cos(), emb.sin() |
| \(x_m' = R_{m\theta} \cdot x_m\) | (q * cos) + (rotate_half(q) * sin)(k * cos) + (rotate_half(k) * sin) |
RoPE with Bias
Su4 finds that adding a learnable bias to RoPE can increase the capability of length extrapolation.
\[\begin{equation}\boldsymbol{q}_m^{\top}\boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n\boldsymbol{k}_n \quad\to\quad (\boldsymbol{q}_m + \boldsymbol{a})^{\top}\boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n(\boldsymbol{k}_n + \boldsymbol{b})\end{equation}\]where \(\boldsymbol{\mathcal{R}}_m, \boldsymbol{\mathcal{R}}_n\) are rotation matrices and $\boldsymbol{a}, \boldsymbol{b}$ are learnable biases.
Note: > NB: Pure self-attention softmax gets equivalent results with or without bias term, since it can be cancelled by the softmax normalization.
\[\begin{equation}\frac{e^{\boldsymbol{q}\cdot(\boldsymbol{k}_n + \boldsymbol{b})}}{\sum\limits_n e^{\boldsymbol{q}\cdot(\boldsymbol{k}_n + \boldsymbol{b})}} = \frac{e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}e^{\boldsymbol{q}\cdot\boldsymbol{b}}}{\sum\limits_n e^{\boldsymbol{q}\cdot\boldsymbol{k}_n} e^{\boldsymbol{q}\cdot\boldsymbol{b}}}= \frac{e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}}{\sum\limits_n e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}}\end{equation}\]But reducing the bias term for self-attention with RoPE cannot obtain the same results.4
T5 Bias
T55 adds no position encoding to word embeddings. Instead, it adds a learned, shared bias to each query-key attention score that depends only on the distance between query and key. Multiple distances share the same learned bias through bucketing, which may benefit length interpolation. Specifically, a fixed number of embeddings are learned, each corresponding to a range of possible key-query offsets.
T5 uses a bucket of 32 learnable parameters and assigns the relative position bias with a log-binning strategy:
\[\begin{align} b_{m-n} = \left\{ \begin{array}{ll} \text{bucket}[0] &{} \text{if } m-n<0\\ \text{bucket}[m-n] &{} \text{if } 0 \leq m-n < 16\\ \text{bucket}[\min(31, \lfloor \frac{\log \frac{m-n}{16}}{\frac{128}{16}} \cdot 16 \rfloor)] &{} \text{if } m-n \geq 16 \end{array} \right. \end{align}\]Press et al.6 find that T5 bias enables length extrapolation.
Note: T5 uses 32 embeddings for all models with ranges that increase in size logarithmically up to an offset of 128 beyond which it assigns all relative positions to the same embedding. All position embeddings are shared across all layers in T5, though within a given layer each attention head uses a different learned position embedding.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
...
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
ALiBi (Attention with Linear Biases)
Warning: Length extrapolation allows transformers training on short sequences while testing on substantially long sequences, by means of relative positional encoding.
ALiBi6 adds a static, non-learnable bias to the query-key dot product. Like T5 bias and RoPE, it injects positional information at every layer.
\[\mathbf{q}_m^T \mathbf{k}_n + \alpha \vert m-n \vert\]where $\alpha$ is a head-specific slope (fixed). For the $i$-th head, the slope is \(\alpha_i = 2^{-8/i}\).

Note: ALiBi bias is not multiplied by the \(\sqrt{d_k}\) scaling factor as in the original transformer.

Empirically, ALiBi and T5 bias exhibit strong length extrapolation6, whereas RoPE and sinusoidal PE do not.
KERPLE
KErnelize Relative Positional Embedding for Length Extrapolation (KERPLE)7 proposes kernelized positional embeddings as follows:
\[\begin{align} a_{m,n} = \frac{\exp(\frac{\mathbf{q}_m^T \mathbf{k}_n + \tilde{k}_{r_1,\cdots, r_{\mathcal{l}}(m, n)}}{\sqrt{d}})}{\sum_{i=1}^L \exp(\frac{\mathbf{q}_m^T \mathbf{k}_i + \tilde{k}_{r_1,\cdots, r_{\mathcal{l}}(m, i)}}{\sqrt{d}})} \end{align}\]where \(r_1,\cdots, r_{\mathcal{l}}\) are learnable parameters.
\[\begin{align} a_{m,n} = \left\{ \begin{array}{ll} \text{(power)}&{} \mathbf{q}_m^{\top}\mathbf{k}_n - r_1\vert m - n\vert^{r_2} ,&{} r_1 >0, 0 < r_2 \leq 2\\ \text{(logarithmic)}&{} \mathbf{q}_m^{\top}\mathbf{k}_n - r_1\log(1+r_2\vert m - n\vert),&{} r_1, r_2 > 0 \end{array} \right.\label{eq:kerple}\end{align}\]Note: Triangle kernel: $c-\vert m-n\vert$. This reduces to ALiBi.
xPos
Extrapolatable Position Embedding (XPOS)8 proposes:
\[\begin{align} f_q (q,n) &{}= A_q qe^{\lambda n} &{} \text{let }\lambda = \xi+i\theta \in \mathbb{C}^{d/2} \\&{}:= qe^{\xi n + i\theta n} &{} \text{Remove linear factor $A_q$} \\ \nonumber \\ f_k (k,n) &{}= A_k ke^{-\lambda n} &{} \text{let }\lambda = \xi+i\theta \in \mathbb{C}^{d/2} \\&{}:= ke^{\xi n + i\theta n} &{} \text{Remove linear factor $A_k$} \end{align}\]
Sandwich
The self-attention is calculated as:
\[\begin{align} (\mathbf{W}_q(\mathbf{e}_m + \mathbf{p}_m))^T (\mathbf{W}_k(\mathbf{e}_n + \mathbf{p}_n)) &{} \approx \underbrace{\mathbf{e}_m^T\mathbf{W}_q^T\mathbf{W}_k\mathbf{e}_n^T}_{\text{semantic}} + \underbrace{\mathbf{p}_m^T\mathbf{p}_n}_{\text{position}} \end{align}\]The temporal bias term is:
\[\begin{align} \mathbf{p}_m^T\mathbf{p}_n =&{} \sum_{i=1}^{\bar{d}/2} \sin \big( \frac{m}{10000^{2i/\bar{d}}} \big) \sin \big( \frac{n}{10000^{2i/\bar{d}}} \big) + \cos \big( \frac{m}{10000^{2i/\bar{d}}} \big) \cos \big( \frac{n}{10000^{2i/\bar{d}}} \big)\\ =&{} \sum_{i=1}^{\bar{d}/2} \cos \big( \frac{m-n}{10000^{2i/\bar{d}}} \big) \end{align}\]Sandwich9 compresses the temporal bias into a low-rank form, scaling the positional dot product per head.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
base = 1e4
heads = 12
seq_len = 8192
positions = np.arange(seq_len)[..., None]
bar_d = 128 # This is the hyperparameter of Sandwich
i = np.arange(bar_d // 2)
pos_embs = np.concatenate([np.sin(positions / base ** (2 * i / bar_d)),
np.cos(positions / base ** (2 * i / bar_d))],
axis=-1)
sandwich = np.matmul(pos_embs, pos_embs.T)
compression_ratio = np.arange(1, heads + 1) * 8 / heads
multi_head_sandwich = sandwich[None, ...] / compression_ratio[..., None, None]
Randomized Position
Ruoss et al.10 introduce randomized positional encodings that simulate positions of longer sequences by randomly selecting an ordered subset of indices to match the input length.

1
source code: https://github.com/deepmind/randomized_positional_encodings/blob/main/models/positional_encodings.py#L160
It allows transformers to generalize to sequences of unseen length (increasing test accuracy by 12.0% on average) across 15 algorithmic reasoning tasks.

No Position
Haviv et al.11 observe that LMs without any explicit position encoding (NoPE) are still competitive with standard transformers across datasets, model sizes, and sequence lengths. This suggests that causal LMs may derive positional awareness not only from explicit positioning mechanisms but also from the causal attention mask itself.

Note: Causal transformer LM can achieve competitive results with original LMs, while the bidirectional masked LMs fail to converge. This may be because that causal LMs can learn positions from the autoregressive nature (left-to-right) but masked LMs are order-invariant.

Warning: > NB: LMs without explicit positional encodings (NoPE) are consistently slightly worse, underscoring the importance of inductive positional bias.
Position Interpolation
Instead of extrapolation, Chen et al.12131415 present position interpolation (PI) that directly downscales the non-integer position indices (RoPE-based) so that the maximum position index matches the previous context window limit in pre-training.

It simply adds two lines of code.15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class ScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
max_position_embeddings = 8192
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
# These two lines:
self.scale = 1 / 4
t *= self.scale
SuperHOT-13B14 uptrained on scaling factor of 0.25, compared to base LLaMa 13B and a test LoRA trained on 6K sequence length with no scaling.

Warning: > Note that PI requires further fine-tuning to take effects for length extrapolation.
NTK-Aware Scaled RoPE
Important: > Background
- “Simply interpolating the RoPE’s fourier space “linearly” is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by.”16
- “Scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta12 that suggests an upper bound of ~600x)”16
NTK-Aware Scaled RoPE1617 designs a nonlinear interpolation scheme using Neural Tangent Kernel (NTK) theory. It changes the base of the RoPE instead of the scale, which intuitively changes the “spinning” speed from which each of the RoPE’s dimension vectors shifts to the next.
Implementation 18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
#The method is just these three lines
max_position_embeddings = 16384
a = 8 # Alpha value
base = base * a ** (dim / (dim-2)) #Base change formula
old_init(self, dim, max_position_embeddings, base, device)
# apply ntk-scaled init patch
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
Average perplexity of LLaMA-7B on a set of 40 very long prompts (12k+ context size).

Dynamic Linear RoPE
Dynamic linear RoPE: set the scale to max_seq_len / current_position_length, which can slowly increase the scale.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t /= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t /= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
Dynamic NTK-Aware Scaled RoPE
Warning: > Cons:
- Compared to dynamic linear scaling, NTK-Aware has higher perplexity for shorter sequences, but better perplexity at the tail end of the sequence lengths.
- NTK-aware RoPE suffers from catastrophic perplexity blowup, like regular RoPE and static linear scaling.
Emozilla19 introduces dynamic NTK-aware scaling. The scaling of $\alpha$ is set to:
\[\alpha' = \alpha \cdot \frac{L_{\text{cur}}}{L_{\text{orig}}} - (\alpha - 1)\]where \(L_{\text{cur}}\) is the current sequence length and \(L_{\text{orig}}\) is the original model context length.
This dynamically scales the $\alpha$ as the sequence length increases.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import math
import torch
class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk: ### dynamic NTK
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

Partial NTK Scaled RoPE
Combine RoPE, Linear, NTK.20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import math
def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) #Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 #Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def find_newbase_ntk(dim, base=10000, scale=1):
return base * scale ** (dim / (dim-2))
class LlamaPartNTKScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, ntk_factor=1, extrapolation_factor=1, original_max_position_embeddings=2048, device=None):
super().__init__()
#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2
#Three RoPE extrapolation/interpolation methods
inv_freq_base = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
inv_freq_linear = 1.0 / (scale * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)))
inv_freq_ntk = 1.0 / (find_newbase_ntk(dim, base, scale) ** (torch.arange(0, dim, 2).float().to(device) / dim))
current_dtype = inv_freq_ntk.dtype
current_device = inv_freq_ntk.device
#Combine NTK and Linear
low, high = find_correction_range(beta_0, beta_1, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * ntk_factor
inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask
#Combine Extrapolation and NTK and Linear
low, high = find_correction_range(gamma_0, gamma_1, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * extrapolation_factor
inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
$\beta$-ary RoPE
$\beta$-ary RoPE21 reinterprets the RoPE base frequency through the lens of $\beta$-ary positional encoding, providing theoretical insight into the relationship between the base and the effective context window.
ReRoPE
ReRoPE22 modifies the RoPE computation by truncating or redistributing position indices during attention, achieving better length generalization without modifying the base frequency.
References
-
Vaswani, Ashish, Noam M. Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. “Attention is All you Need.” NeurIPS (2017). ↩
-
Su, Jianlin, Yu Lu, Shengfeng Pan, Bo Wen and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” ArXiv abs/2104.09864 (2021). ↩
-
Raffel, C., Shazeer, N.M., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P.J. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR. ↩
-
Press, O., Smith, N.A., & Lewis, M. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. ICLR. ↩ ↩2 ↩3
-
Chi, T., Fan, T., Ramadge, P.J., & Rudnicky, A.I. (2022). KERPLE: Kernelized Relative Positional Embedding for Length Extrapolation. NeurIPS. ↩
-
Sun, Y., Dong, L., Patra, B., Ma, S., Huang, S., Benhaim, A., Chaudhary, V., Song, X., & Wei, F. (2022). A Length-Extrapolatable Transformer. ArXiv, abs/2212.10554. ↩
-
Chi, T., Fan, T., Rudnicky, A., & Ramadge, P.J. (2022). Dissecting Transformer Length Extrapolation via the Lens of Receptive Field Analysis. ↩
-
Ruoss, A., Del’etang, G., Genewein, T., Grau-Moya, J., Csordás, R., Abbana Bennani, M., Legg, S., & Veness, J. (2023). Randomized Positional Encodings Boost Length Generalization of Transformers. ACL. ↩
-
Haviv, Adi et al. “Transformer Language Models without Positional Encodings Still Learn Positional Information.” Conference on Empirical Methods in Natural Language Processing (2022). ↩
-
Chen, Shouyuan, Sherman Wong, Liangjian Chen and Yuandong Tian. Extending Context Window of Large Language Models via Positional Interpolation. ArXiv abs/2306.15595 (2023). ↩ ↩2
Related Posts