Positional Encoding in Transformers: From Sinusoidal to RoPE

Transformers are permutation-equivariant by design—without explicit positional information, a self-attention layer cannot distinguish “the cat sat on the mat” from any permutation of those tokens. How we inject position has turned out to matter far more than the original sinusoidal encoding suggested: it directly governs length generalization, relative-distance sensitivity, and even training stability at scale.

The field has converged on a key insight: encoding relative distances within the attention logits (rather than adding absolute embeddings to token representations) yields architectures that generalize more gracefully beyond their training context length. But the design space is rich. RoPE achieves relative encoding through rotation matrices in complex space; ALiBi takes the minimalist route of a fixed linear bias; T5 bias learns the entire distance-to-bias mapping; KERPLE generalizes to learnable kernel functions.

This post is a comprehensive survey: we derive each method from first principles, compare their properties (trainability, per-layer application, extrapolation capability), and cover the increasingly important topic of context extension—position interpolation, NTK-aware scaling, and dynamic RoPE—that enables pretrained models to handle sequences far longer than they were trained on. Reference implementations are included throughout.

Summary

PE Relative Trainable Each Layer Extrapolation
Sinusoidal
T5 bias
RoPE
ALiBi
KERPLE
Sandwich
xPos

Position Encoding

Sinusoidal Position Embeddings

Sinusoidal position embeddings1 are fixed vectors added to the token embeddings at the first transformer layer.

\(\text{PE}_{(\text{pos}, 2i)} = \sin(\frac{\text{pos}}{10000^{2i/d_\text{model}}})\) \(\text{PE}_{(\text{pos}, 2i+1)} = \cos(\frac{\text{pos}}{10000^{2i/d_\text{model}}})\) where $\text{pos}$ is the position in the sentence and $i$ is the dimension index along the embedding vector. This design allows the model to learn to attend by relative positions, since for any fixed offset $k$, \(\text{PE}_{\text{pos}+k}\) can be represented as a linear function of \(\text{PE}_{\text{pos}}\).

Position encoding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Positional encoding layer in PyTorch
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1) # generate with maximum length
        div_term = torch.exp(torch.arange(0., d_model, 2) * - (math.log(1e4) / d_model))
        pe[:, ::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1) # take the sequence length
        x = x + Variable(self.pe[:, :seq_len], requires_grad=False)
        return self.dropout(x)

Note:

Rotary Position Embedding (RoPE)

Rotary Position Embedding (RoPE)23 encodes positional information by operating in complex space. Instead of working in $\mathbb{R}^d$, it treats consecutive pairs of query and key dimensions as complex numbers in $\mathbb{C}^{d/2}$.

Specifically, instead of viewing \(\mathbf{q}=(q_1,q_2,q_3,q_4,\ldots,q_{d})\) as a $d$-dimensional real vector, RoPE views it as \(\mathbf{q}=(q_1+iq_2, q_3+iq_4,\ldots, q_{d-1} + iq_{d})\in\mathbb{C}^{d/2}\). If $d$ is odd, RoPE pads with a dummy coordinate (or simply increments $d$ by one) to ensure dimensional alignment.

Derivation

The complex number format of RoPE is written as:

\[\begin{align} f(\mathbf{q}, m) = R_f(\mathbf{q}, m)e^{i\Theta_f(\mathbf{q}, m)}=\mathbf{q}e^{i(\Theta(\mathbf{q})+m\mathbf{\theta})} = \sum_{j=1}^{d/2} q_je^{im\theta_j} \vec{e_j} \end{align}\]

Here, $m$ indicates the $m$-th position of the sequence.

It is convenient to convert into matrix equation:

\[\begin{align} f(\mathbf{q}, m) = \begin{pmatrix} M_1 & & & \\ & M_2 & & \\ & & \ddots & \\ & & & M_{d/2} \end{pmatrix} \begin{pmatrix} q_1\\ q_2\\ \vdots\\ q_d \end{pmatrix} = \mathbf{\Theta_m Q_m} = \mathbf{\Theta_m W_q X_m} \end{align}\]

where \(M_j=\begin{pmatrix}\cos m\theta_j & -\sin m\theta_j \\ \sin m\theta_j & \cos m\theta_j\end{pmatrix}\), \(\mathbf{\Theta_m}\) is the block-diagonal rotation matrix, \(\mathbf{W_q}\) is the learned query weight, and \(\mathbf{X_m}\) is the embedding of the $m$-th token.

Due to the high computation cost of sparse matrix, it is implemented as:

\[\begin{align} f(\mathbf{q}, m) = \begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{pmatrix}\odot\begin{pmatrix}\cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix}-q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{pmatrix}\odot\begin{pmatrix}\sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{pmatrix} \end{align}\]

where $\odot$ denotes the element-wise (Hadamard) product.

RoPE

  • Extension to Multiple Dimensions
\[\begin{align*} \langle f(\mathbf{q}, m, i), f(\mathbf{k}, n, j) \rangle &= \langle f_1(\mathbf{q}_{:d/2}, m), f_1(\mathbf{k}_{:d/2}, n) \rangle + \langle f_2(\mathbf{q}_{d/2:}, i), f_2(\mathbf{k}_{d/2:}, j) \rangle \\ &= g_1(\mathbf{q}_{:d/2}, \mathbf{k}_{:d/2}, m - n) + g_2(\mathbf{q}_{d/2:}, \mathbf{k}_{d/2:}, i - j) \\ &= g(\mathbf{q}, \mathbf{k}, m - n, i - j) \end{align*}\]

Note: > Difference from sinusoidal embedding3

  1. Sinusoidal embeddings apply to each coordinate individually, while RoPE mixes pairs of coordinates.
  2. Sinusoidal embeddings add a $\cos(m\theta)$ or $\sin(m\theta)$ term, while RoPE uses a multiplicative factor.

Implementation

Huggingface version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Transformers implementation: https://github.com/huggingface/transformers/blob/e42587f596181396e1c4b63660abf0c736b10dae/src/transformers/models/llama/modeling_llama.py#L180
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embedding=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embedding = max_position_embedding
        self.base = base
        # $$\theta_i = 10000^{-2i/d}$$
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer('inv_freq', inv_freq)

        self._set_cos_sin_cache(
            seq_len=max_position_embedding, device=device, dtype=self.inv_freq.dtype)

    def _set_cos_sin_cache(self, seq_len, device=None, dtype=None):
        self.max_seq_len_cached = seq_len
        # $m$
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        # $m \theta$
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        # $$R_{m\theta}$$
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None,:,:].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None,:,:].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bsz, num_heads, seq_len, d_head]
        if seq_len > self.max_position_embedding:
            self._set_cos_sin_cache(seq_len, device=x.device, dtype=x.dtype)
        return (
            self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
            self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
        )

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # the first two dim of cos and sin are always 1, so we can squeeze them
    cos = cos.squeeze(1).squeeze(0) # seq_len, dim
    sin = sin.squeeze(1).squeeze(0) # seq_len, dim
    cos = cos[position_ids].unsqueeze(1) # bsz, 1, seq_len, dim
    sin = sin[position_ids].unsqueeze(1) # bsz, 1, seq_len, dim
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def rotate_half(x):
    """ rotate half the hiddim dims of inputs """
    x1 = x[..., :x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

"""
position_ids = torch.arange(past_key_values_length, seq_length +  past_key_values_length)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
"""

rotary_emb = LlamaRotaryEmbedding(head_dim, max_position_embeddings)

cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) # get cos/sin cache

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Llama version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Args:
        dim (int): dimension of the frequency tensor.
    """
    freqs = 1.0 / (theta ** (torch.arange(0,dim,2)[:(dim//2)].float()/dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex 64
    return freqs_cis


def apply_rotary_emb(xq, xk, freqs_cis):
    """
    # Apply rotary embeddings to input tensors using the given frequency tensor.
    The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broad_cast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def reshape_for_broad_cast(freqs_cis, x):
    """
    Reshape frequency tensor for broadcasting it with another tensor.
    This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
freqs_cis = precompute_freqs_cis(d_model//num_heads, max_seq_len*2)

# start_pos: starting position for attention caching
freqs_cis = freqs_cis[start_pos: start_pos+seqlen]

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)


Roformer version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# source: https://huggingface.co/transformers/v4.8.0/_modules/transformers/models/roformer/modeling_roformer.html


# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
        super().__init__(num_positions, embedding_dim)
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
        """`input_ids_shape` is expected to be [bsz x seqlen]."""
        bsz, seq_len = input_ids_shape[:2]
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        )
        return super().forward(positions)


def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
    # https://kexue.fm/archives/8265
    # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
    # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
    sin, cos = sinusoidal_pos.chunk(2, dim=-1)
    # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
    sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
    # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
    cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
    # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
    rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
        query_layer
    )
    query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
    # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
    rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
    key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
    if value_layer is not None:
        # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
        rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
            value_layer
        )
        value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
        return query_layer, key_layer, value_layer
    return query_layer, key_layer


sinusoidal_pos = RoFormerSinusoidalPositionalEmbedding(
    max_position_embeddings,
    config.hidden_size // config.num_attention_heads
)

q, k = apply_rotary_position_embeddings(sinusoidal_pos, q, k)

GPT-NeoX (PyTorch) implementation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch

class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        return self.cos_cached, self.sin_cached


# rotary pos emb helpers:
def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat(
        (-x2, x1), dim=x1.ndim - 1
    )  # dim=-1 triggers a bug in torch < 1.8.0


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


rotary_emb = RotaryEmbedding(d_head)
cos, sin = rotary_emb(value_states)
q, k = apply_rotary_pos_emb(q,k, cos, sin)

Notes:

Equation Code
\(\theta_i = 10000^{-2i/d}\) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
\(m\theta_i\) freqs = torch.einsum('i,j->ij', t, self.inv_freq)
\(R_{\theta} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix}\) emb.cos(), emb.sin()
\(x_m' = R_{m\theta} \cdot x_m\) (q * cos) + (rotate_half(q) * sin)
(k * cos) + (rotate_half(k) * sin)

RoPE with Bias

Su4 finds that adding a learnable bias to RoPE can increase the capability of length extrapolation.

\[\begin{equation}\boldsymbol{q}_m^{\top}\boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n\boldsymbol{k}_n \quad\to\quad (\boldsymbol{q}_m + \boldsymbol{a})^{\top}\boldsymbol{\mathcal{R}}_m^{\top}\boldsymbol{\mathcal{R}}_n(\boldsymbol{k}_n + \boldsymbol{b})\end{equation}\]

where \(\boldsymbol{\mathcal{R}}_m, \boldsymbol{\mathcal{R}}_n\) are rotation matrices and $\boldsymbol{a}, \boldsymbol{b}$ are learnable biases.

Note: > NB: Pure self-attention softmax gets equivalent results with or without bias term, since it can be cancelled by the softmax normalization.

\[\begin{equation}\frac{e^{\boldsymbol{q}\cdot(\boldsymbol{k}_n + \boldsymbol{b})}}{\sum\limits_n e^{\boldsymbol{q}\cdot(\boldsymbol{k}_n + \boldsymbol{b})}} = \frac{e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}e^{\boldsymbol{q}\cdot\boldsymbol{b}}}{\sum\limits_n e^{\boldsymbol{q}\cdot\boldsymbol{k}_n} e^{\boldsymbol{q}\cdot\boldsymbol{b}}}= \frac{e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}}{\sum\limits_n e^{\boldsymbol{q}\cdot\boldsymbol{k}_n}}\end{equation}\]

But reducing the bias term for self-attention with RoPE cannot obtain the same results.4

T5 Bias

T55 adds no position encoding to word embeddings. Instead, it adds a learned, shared bias to each query-key attention score that depends only on the distance between query and key. Multiple distances share the same learned bias through bucketing, which may benefit length interpolation. Specifically, a fixed number of embeddings are learned, each corresponding to a range of possible key-query offsets.

T5 uses a bucket of 32 learnable parameters and assigns the relative position bias with a log-binning strategy:

\[\begin{align} b_{m-n} = \left\{ \begin{array}{ll} \text{bucket}[0] &{} \text{if } m-n<0\\ \text{bucket}[m-n] &{} \text{if } 0 \leq m-n < 16\\ \text{bucket}[\min(31, \lfloor \frac{\log \frac{m-n}{16}}{\frac{128}{16}} \cdot 16 \rfloor)] &{} \text{if } m-n \geq 16 \end{array} \right. \end{align}\]

Press et al.6 find that T5 bias enables length extrapolation.

Note: T5 uses 32 embeddings for all models with ranges that increase in size logarithmically up to an offset of 128 beyond which it assigns all relative positions to the same embedding. All position embeddings are shared across all layers in T5, though within a given layer each attention head uses a different learned position embedding.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class T5Attention(nn.Module):
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.has_relative_attention_bias = has_relative_attention_bias
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
        self.relative_attention_max_distance = config.relative_attention_max_distance

        ...

        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)


    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        """
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

    def compute_bias(self, query_length, key_length, device=None):
        """Compute binned relative position bias"""
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values

ALiBi (Attention with Linear Biases)

Warning: Length extrapolation allows transformers training on short sequences while testing on substantially long sequences, by means of relative positional encoding.

ALiBi6 adds a static, non-learnable bias to the query-key dot product. Like T5 bias and RoPE, it injects positional information at every layer.

\[\mathbf{q}_m^T \mathbf{k}_n + \alpha \vert m-n \vert\]

where $\alpha$ is a head-specific slope (fixed). For the $i$-th head, the slope is \(\alpha_i = 2^{-8/i}\).

ALiBi

Note: ALiBi bias is not multiplied by the \(\sqrt{d_k}\) scaling factor as in the original transformer.

Extrapolation: (validation-set’s) input sequence length (x-axis), versus perplexity (y-axis, lower is better).

Empirically, ALiBi and T5 bias exhibit strong length extrapolation6, whereas RoPE and sinusoidal PE do not.

KERPLE

KErnelize Relative Positional Embedding for Length Extrapolation (KERPLE)7 proposes kernelized positional embeddings as follows:

\[\begin{align} a_{m,n} = \frac{\exp(\frac{\mathbf{q}_m^T \mathbf{k}_n + \tilde{k}_{r_1,\cdots, r_{\mathcal{l}}(m, n)}}{\sqrt{d}})}{\sum_{i=1}^L \exp(\frac{\mathbf{q}_m^T \mathbf{k}_i + \tilde{k}_{r_1,\cdots, r_{\mathcal{l}}(m, i)}}{\sqrt{d}})} \end{align}\]

where \(r_1,\cdots, r_{\mathcal{l}}\) are learnable parameters.

\[\begin{align} a_{m,n} = \left\{ \begin{array}{ll} \text{(power)}&{} \mathbf{q}_m^{\top}\mathbf{k}_n - r_1\vert m - n\vert^{r_2} ,&{} r_1 >0, 0 < r_2 \leq 2\\ \text{(logarithmic)}&{} \mathbf{q}_m^{\top}\mathbf{k}_n - r_1\log(1+r_2\vert m - n\vert),&{} r_1, r_2 > 0 \end{array} \right.\label{eq:kerple}\end{align}\]

Note: Triangle kernel: $c-\vert m-n\vert$. This reduces to ALiBi.

xPos

Extrapolatable Position Embedding (XPOS)8 proposes:

\[\begin{align} f_q (q,n) &{}= A_q qe^{\lambda n} &{} \text{let }\lambda = \xi+i\theta \in \mathbb{C}^{d/2} \\&{}:= qe^{\xi n + i\theta n} &{} \text{Remove linear factor $A_q$} \\ \nonumber \\ f_k (k,n) &{}= A_k ke^{-\lambda n} &{} \text{let }\lambda = \xi+i\theta \in \mathbb{C}^{d/2} \\&{}:= ke^{\xi n + i\theta n} &{} \text{Remove linear factor $A_k$} \end{align}\]

xPos

Sandwich

The self-attention is calculated as:

\[\begin{align} (\mathbf{W}_q(\mathbf{e}_m + \mathbf{p}_m))^T (\mathbf{W}_k(\mathbf{e}_n + \mathbf{p}_n)) &{} \approx \underbrace{\mathbf{e}_m^T\mathbf{W}_q^T\mathbf{W}_k\mathbf{e}_n^T}_{\text{semantic}} + \underbrace{\mathbf{p}_m^T\mathbf{p}_n}_{\text{position}} \end{align}\]

The temporal bias term is:

\[\begin{align} \mathbf{p}_m^T\mathbf{p}_n =&{} \sum_{i=1}^{\bar{d}/2} \sin \big( \frac{m}{10000^{2i/\bar{d}}} \big) \sin \big( \frac{n}{10000^{2i/\bar{d}}} \big) + \cos \big( \frac{m}{10000^{2i/\bar{d}}} \big) \cos \big( \frac{n}{10000^{2i/\bar{d}}} \big)\\ =&{} \sum_{i=1}^{\bar{d}/2} \cos \big( \frac{m-n}{10000^{2i/\bar{d}}} \big) \end{align}\]

Sandwich9 compresses the temporal bias into a low-rank form, scaling the positional dot product per head.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

base = 1e4
heads = 12
seq_len = 8192
positions = np.arange(seq_len)[..., None]
bar_d = 128 # This is the hyperparameter of Sandwich
i = np.arange(bar_d // 2)

pos_embs = np.concatenate([np.sin(positions / base ** (2 * i / bar_d)),
                           np.cos(positions / base ** (2 * i / bar_d))],
                           axis=-1)
sandwich = np.matmul(pos_embs, pos_embs.T)
compression_ratio = np.arange(1, heads + 1) * 8 / heads
multi_head_sandwich = sandwich[None, ...] / compression_ratio[..., None, None]

Randomized Position

Ruoss et al.10 introduce randomized positional encodings that simulate positions of longer sequences by randomly selecting an ordered subset of indices to match the input length.

Randomized Position

1
source code: https://github.com/deepmind/randomized_positional_encodings/blob/main/models/positional_encodings.py#L160

It allows transformers to generalize to sequences of unseen length (increasing test accuracy by 12.0% on average) across 15 algorithmic reasoning tasks.

Randomized position results.

No Position

Haviv et al.11 observe that LMs without any explicit position encoding (NoPE) are still competitive with standard transformers across datasets, model sizes, and sequence lengths. This suggests that causal LMs may derive positional awareness not only from explicit positioning mechanisms but also from the causal attention mask itself.

No position

Note: Causal transformer LM can achieve competitive results with original LMs, while the bidirectional masked LMs fail to converge. This may be because that causal LMs can learn positions from the autoregressive nature (left-to-right) but masked LMs are order-invariant.

Comparison results across different parameter sizes.

Warning: > NB: LMs without explicit positional encodings (NoPE) are consistently slightly worse, underscoring the importance of inductive positional bias.

Position Interpolation

Instead of extrapolation, Chen et al.12131415 present position interpolation (PI) that directly downscales the non-integer position indices (RoPE-based) so that the maximum position index matches the previous context window limit in pre-training.

Position Interpolation

It simply adds two lines of code.15

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class ScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        max_position_embeddings = 8192

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )

        # These two lines:
        self.scale = 1 / 4
        t *= self.scale

SuperHOT-13B14 uptrained on scaling factor of 0.25, compared to base LLaMa 13B and a test LoRA trained on 6K sequence length with no scaling.

PPL evaluation

Warning: > Note that PI requires further fine-tuning to take effects for length extrapolation.

NTK-Aware Scaled RoPE

Important: > Background

  1. “Simply interpolating the RoPE’s fourier space “linearly” is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by.”16
  2. “Scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta12 that suggests an upper bound of ~600x)”16

NTK-Aware Scaled RoPE1617 designs a nonlinear interpolation scheme using Neural Tangent Kernel (NTK) theory. It changes the base of the RoPE instead of the scale, which intuitively changes the “spinning” speed from which each of the RoPE’s dimension vectors shifts to the next.

Implementation 18

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

    #The method is just these three lines
    max_position_embeddings = 16384
    a = 8 # Alpha value
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)

# apply ntk-scaled init patch
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

Average perplexity of LLaMA-7B on a set of 40 very long prompts (12k+ context size).

PI-comparison

Dynamic Linear RoPE

Dynamic linear RoPE: set the scale to max_seq_len / current_position_length, which can slowly increase the scale.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
        super().__init__()
        self.scale = scale
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        t /= self.scale
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            t /= self.scale
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Dynamic NTK-Aware Scaled RoPE

Warning: > Cons:

  • Compared to dynamic linear scaling, NTK-Aware has higher perplexity for shorter sequences, but better perplexity at the tail end of the sequence lengths.
  • NTK-aware RoPE suffers from catastrophic perplexity blowup, like regular RoPE and static linear scaling.

Emozilla19 introduces dynamic NTK-aware scaling. The scaling of $\alpha$ is set to:

\[\alpha' = \alpha \cdot \frac{L_{\text{cur}}}{L_{\text{orig}}} - (\alpha - 1)\]

where \(L_{\text{cur}}\) is the current sequence length and \(L_{\text{orig}}\) is the original model context length.

This dynamically scales the $\alpha$ as the sequence length increases.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import math
import torch

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
        super().__init__()
        self.ntk = ntk
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            if self.ntk: ### dynamic NTK
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                self.register_buffer("inv_freq", inv_freq)
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            if not self.ntk:
                t *= self.max_position_embeddings / seq_len
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Dynamic NTK RoPE

Partial NTK Scaled RoPE

Combine RoPE, Linear, NTK.20

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import math

def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations

def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
    low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1) #Clamp values just in case

def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001 #Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func

def find_newbase_ntk(dim, base=10000, scale=1):
    return base * scale ** (dim / (dim-2))

class LlamaPartNTKScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, ntk_factor=1, extrapolation_factor=1, original_max_position_embeddings=2048, device=None):
        super().__init__()

        #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
        #Do not change unless there is a good reason for doing so!
        beta_0 = 1.25
        beta_1 = 0.75
        gamma_0 = 16
        gamma_1 = 2

        #Three RoPE extrapolation/interpolation methods
        inv_freq_base = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        inv_freq_linear = 1.0 / (scale * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)))
        inv_freq_ntk = 1.0 / (find_newbase_ntk(dim, base, scale) ** (torch.arange(0, dim, 2).float().to(device) / dim))

        current_dtype = inv_freq_ntk.dtype
        current_device = inv_freq_ntk.device

        #Combine NTK and Linear
        low, high = find_correction_range(beta_0, beta_1, dim, base, original_max_position_embeddings)
        inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * ntk_factor
        inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask

        #Combine Extrapolation and NTK and Linear
        low, high = find_correction_range(gamma_0, gamma_1, dim, base, original_max_position_embeddings)
        inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * extrapolation_factor
        inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask

        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

$\beta$-ary RoPE

$\beta$-ary RoPE21 reinterprets the RoPE base frequency through the lens of $\beta$-ary positional encoding, providing theoretical insight into the relationship between the base and the effective context window.

ReRoPE

ReRoPE22 modifies the RoPE computation by truncating or redistributing position indices during attention, achieving better length generalization without modifying the base frequency.

References

  1. Vaswani, Ashish, Noam M. Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. “Attention is All you Need.” NeurIPS (2017). 

  2. Su, Jianlin, Yu Lu, Shengfeng Pan, Bo Wen and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” ArXiv abs/2104.09864 (2021). 

  3. Rotary Embeddings: A Relative Revolution  2

  4. Blog- Bias项的神奇作用:RoPE + Bias = 更好的长度外推性  2

  5. Raffel, C., Shazeer, N.M., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P.J. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR. 

  6. Press, O., Smith, N.A., & Lewis, M. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. ICLR.  2 3

  7. Chi, T., Fan, T., Ramadge, P.J., & Rudnicky, A.I. (2022). KERPLE: Kernelized Relative Positional Embedding for Length Extrapolation. NeurIPS. 

  8. Sun, Y., Dong, L., Patra, B., Ma, S., Huang, S., Benhaim, A., Chaudhary, V., Song, X., & Wei, F. (2022). A Length-Extrapolatable Transformer. ArXiv, abs/2212.10554. 

  9. Chi, T., Fan, T., Rudnicky, A., & Ramadge, P.J. (2022). Dissecting Transformer Length Extrapolation via the Lens of Receptive Field Analysis

  10. Ruoss, A., Del’etang, G., Genewein, T., Grau-Moya, J., Csordás, R., Abbana Bennani, M., Legg, S., & Veness, J. (2023). Randomized Positional Encodings Boost Length Generalization of Transformers. ACL. 

  11. Haviv, Adi et al. “Transformer Language Models without Positional Encodings Still Learn Positional Information.” Conference on Empirical Methods in Natural Language Processing (2022). 

  12. Chen, Shouyuan, Sherman Wong, Liangjian Chen and Yuandong Tian. Extending Context Window of Large Language Models via Positional Interpolation. ArXiv abs/2306.15595 (2023).  2

  13. Github discussion: Position Interpolation 

  14. Reddit: A simple way to “Extending Context to 8K”  2

  15. Things I’m Learning While Training SuperHOT  2

  16. NTK-Aware Scaled RoPE  2 3

  17. RoPE is a β-ary encoding (Chinese) 

  18. NTK-aware RoPE colab 

  19. Dynamic NTK-aware RoPE 

  20. GitHub: Dynamic RoPE. 

  21. β-ary RoPE (in Chinese) 

  22. ReROPE (in Chinese) 




    Related Posts

  • Multimodal Tokenization with Vector Quantization: A Review
  • Memory-Efficient Attention: MHA vs. MQA vs. GQA vs. MLA
  • Diffusion Models: A Mathematical Guide from Scratch
  • Efficient Distributed Training: From DP to ZeRO and FlashAttention
  • Masking Strategies for Pre-trained Language Models: From MLM to T5