Yekun's Note

Machine learning notes and writeup.

Fork me on GitHub

Inductive Positions in Transformers

We summarize the positional encoding approaches in transformers.

Summary

PE Relative Trainable Each Layer Extrapolation
Sinusoidal
T5 bias
RoPE
ALiBi
KERPLE
Sandwich
xPos

Position Encoding

Sinusoidal Position Embeddings

Sinusoidal position embeddings[1] are constantly encoded vectors to be added on token embeddings of the first transformer layer.

where $\text{pos}$ is the position in the sentence and $i$ is the order along the embedding vector dimension. Assume this allows to learn to attend by relative positions, since for and fixed offset $k$, can be represented as the linear function of .

Position encoding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Positional encoding layer in PyTorch
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)

pe = torch.zeros(max_len, d_model)
position = torch.arange(0., max_len).unsqueeze(1) # generate with maximum length
div_term = torch.exp(torch.arange(0., d_model, 2) * - (math.log(1e4) / d_model))
pe[:, ::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
seq_len = x.size(1) # take the sequence length
x = x + Variable(self.pe[:, :seq_len], requires_grad=False)
return self.dropout(x)

Rotary Position Embedding (RoPE)

Rotary Position Embedding (RoPE) [3][4] proposes to use complex numbers as the base field of encoding space. Instead of working in $\mathbb{R}^d$, it uses consecutive pairs of elements of the query and key vectors in $\mathbb{C}^{d/2}$ to be a single complex number.

Specifically, instead of viewing as a $d$-dimensional real vector, RoPE views it as . If $d$ is odd, RoPE pads it with a dummy coordinate to ensure things line up correctly. Alternatives, it simply increases $d$ by one.

Derivation

The complex number format of RoPE is written as:

It is convenient to convert into matrix equation:

where , $\mathbf{\Theta_m}$ is the block diagonal rotation matrix, $\mathbf{W_q}$ is learned query weights, and $\mathbf{X_m}$ is the embedding of the $m$-th token.

Due to the high computation cost of sparse matrix, it is implemented as:

where $\odot$ denotes the element-wise product (*).

RoPE

  • Extension to Multiple Dimensions

Difference from sinusoidal embedding[4]

  1. Sinusoidal embeddings apply to each coordinate individually, while RoPE mixes pairs of coordinates.
  2. Sinusoidal embeddings add a $\cos(m\theta)$ or $\sin(m\theta)$ term, while RoPE uses a multiplicative factor.

Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

# source: https://huggingface.co/transformers/v4.8.0/_modules/transformers/models/roformer/modeling_roformer.html


# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""

def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)

@staticmethod
def _init_weight(out: nn.Parameter):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out

@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)


def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
# https://kexue.fm/archives/8265
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = torch.repeat_interleave(sin, 2, dim=-1)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = torch.repeat_interleave(cos, 2, dim=-1)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer
)
query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
if value_layer is not None:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
value_layer
)
value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
return query_layer, key_layer, value_layer
return query_layer, key_layer


sinusoidal_pos = RoFormerSinusoidalPositionalEmbedding(
max_position_embeddings,
config.hidden_size // config.num_attention_heads
)

q, k = apply_rotary_position_embeddings(sinusoidal_pos, q, k)

GPT-NeoX (PyTorch) implementation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None

def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached


# rotary pos emb helpers:

def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=x1.ndim - 1
) # dim=-1 triggers a bug in torch < 1.8.0


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

RoPE with Bias

[6] finds that RoPE w/ Bias can increase the capability of length extrapolation.

where are rotation matrix, $\boldsymbol{a}, \boldsymbol{b}$ are learnable bias.

NB: Pure self-attention softmax gets equivalent results with or without bias term, since it can be cancelled by the softmax normalization.

But reducing the bias term for self-attention with RoPE cannot obtain the same results.[6]

T5 Bias

T5[7] adds no position encoding to word embeddings. Instead, it add a learned, shared bias to each query-key self-attention score that is dependent on just the distance between the query and key. In which multiple different distances share the same learned bias, which might be beneficial for length interpolation. Specifically, a fixed number of embeddings are learned, each corresponding to a range of possible key-query offsets.

T5 uses a bucket of 32 learnable parameters and assign the relative position bias with a log-binning strategy:

[8] finds that T5 bias enables length extrapolation.

T5 uses 32 embeddings for all models with ranges that increase in size logarithmically up to an offset of 128 beyond which it assigns all relative positions to the same embedding. All position embeddings are shared across all layers in T5, though within a given layer each attention head uses a different learned position embedding.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance

...

if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)


def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on

Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer

Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact

# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)

relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets

def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

ALiBi (Attention with Linear Biases)

Length extrapolation allows transformers training on short sequences while testing on substantially long sequences, by means of relative positional encoding.

ALiBi[8] adds a static, non-learnable bias to the query-key dot product. As is done in T5 bias and RoPE, it adds position information to keys and querys at each layer.

where $\alpha$ is a head-specific slope (fixed). For $i$-th heads, the value of slope takes .

ALiBi

ALiBi bias is not multiplied by the scaling factor as in the original transformer.

Extrapolation: (validation-set’s) input sequence length (x-axis), versus perplexity (y-axis, lower is better).

It is observed that ALiBi and T5 bias show length extrapolation ability[8], while RoPE and sinusoidal position do not have.

KERPLE

KErnelize Relative Positional Embedding for Length Extrapolation (KERPLE)[9] proposes kernelized positional embeddings as follows:

where are learnable parameters.

Triangle kernel: . It reduces to ALiBi.

xPos

Extrapolatable Position Embedding (XPOS)[10] proposes:

xPos

Sandwich

The self-attention is calculated as:

The temporal bias terms is:

Sandwich[11]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

base = 1e4
heads = 12
seq_len = 8192
positions = np.arange(seq_len)[..., None]
bar_d = 128 # This is the hyperparameter of Sandwich
i = np.arange(bar_d // 2)

pos_embs = np.concatenate([np.sin(positions / base ** (2 * i / bar_d)),
np.cos(positions / base ** (2 * i / bar_d))],
axis=-1)
sandwich = np.matmul(pos_embs, pos_embs.T)
compression_ratio = np.arange(1, heads + 1) * 8 / heads
multi_head_sandwich = sandwich[None, ...] / compression_ratio[..., None, None]

Randomized Position

[13] introduce randomized position encoding that simulates the positions of longer sequences and randomly selects an ordered subset to fit the sequence’s length.

Randomized Position

1
source code: https://github.com/deepmind/randomized_positional_encodings/blob/main/models/positional_encodings.py#L160

It allows transformers to generalize to sequences of unseen length (increasing test accuracy by 12.0% on average) across 15 algorithmic reasoning tasks.

Randomized position results.

No Position

[12] observe that LMs without any explicit position encoding (NoPos) are still competitive with standard transformers across datasets, model sizes, and sequence length. It shows that causal LMs might derive positional awareness not only from the explicit positioning mechanism, but also from the causal mask effects.

No position

Causal transformer LM can achieve competitive results with original LMs, while the bidirectional masked LMs fail to converge. This may be because that causal LMs can learn positions from the autoregressive nature (left-to-right) but masked LMs are order-invariant.

Comparison results across different parameter sizes.

NB: LMs without explicit positional encodings (NoPos) are always slightly worse, suggesting the importance of inductive positional bias.

Position Interpolation

Instead of extrapolation, [15][16][17][18] presents position interpolation (PI) that directly downscales the non-integer position indices (RoPE-based) so that the maximum position index matches the previous context window limit in pre-training.

Position Interpolation

It simply adds two lines of code.[18]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class ScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

max_position_embeddings = 8192

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)

# These two lines:
self.scale = 1 / 4
t *= self.scale

SuperHOT-13B[17] uptrained on scaling factor of 0.25, compared to base LLaMa 13B and a test LoRA trained on 6K sequence length with no scaling.

PPL evaluation

Note that PI requires further fine-tuning to take effects for length extrapolation.

NTK-Aware Scaled RoPE

Background

  1. “Simply interpolating the RoPE’s fourier space “linearly” is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by.”[19]
  2. “Scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta[15] that suggests an upper bound of ~600x)”[19]

NTK-Aware Scaled RoPE[19][20] designs a nonlinear interpolation scheme using Neural Tangent Kernel (NTK) theory. It changes the base of the RoPE instead of the scale, which intuitively changes the “spinning” speed from which each of the RoPE’s dimension vectors shifts to the next.

Implementation [21]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import transformers

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):

#The method is just these three lines
max_position_embeddings = 16384
a = 8 # Alpha value
base = base * a ** (dim / (dim-2)) #Base change formula

old_init(self, dim, max_position_embeddings, base, device)

# apply ntk-sclaed init patch
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init

Average perplexity of LLaMA-7B on a set of 40 very long prompts (12k+ context size).

PI-comparison

Dynamic Linear RoPE

Dynamic linear RoPE: set the scale to max_seq_len/current position length, which can slowly increase the scale.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t /= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t /= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

Dynamic NTK-Aware Scaled RoPE

Cons:

  • Compared to dynamic linear scaling, NTK-Aware has higher perplexity for shorter sequences, but better perplexity at the tail end of the sequence lengths.
  • NTK-aware RoPE suffers from catastrophic perplexity blowup, like regular RoPE and static linear scaling.

[22] introduces dynamic NTK-aware scaling. The scaling of $\alpha$ is set to:

This dynamically scales the $\alpha$ as the sequence length increases.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import math
import torch

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk: ### dynamic NTK
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

Dynamic NTK RoPE<sup id="fnref:22"><a href="#fn:22" rel="footnote"><span class="hint--top hint--error hint--medium hint--rounded hint--bounce" aria-label="[Dynamic NTK-aware RoPE](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/)
">[22]</span></a></sup>

Partial NTK Scaled RoPE

Combine RoPE, Linear, NTK.[23]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import math

def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations

def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) #Clamp values just in case

def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 #Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def find_newbase_ntk(dim, base=10000, scale=1):
return base * scale ** (dim / (dim-2))

class LlamaPartNTKScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, ntk_factor=1, extrapolation_factor=1, original_max_position_embeddings=2048, device=None):
super().__init__()

#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2

#Three RoPE extrapolation/interpolation methods
inv_freq_base = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
inv_freq_linear = 1.0 / (scale * (base ** (torch.arange(0, dim, 2).float().to(device) / dim)))
inv_freq_ntk = 1.0 / (find_newbase_ntk(dim, base, scale) ** (torch.arange(0, dim, 2).float().to(device) / dim))

current_dtype = inv_freq_ntk.dtype
current_device = inv_freq_ntk.device

#Combine NTK and Linear
low, high = find_correction_range(beta_0, beta_1, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * ntk_factor
inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask

#Combine Extrapolation and NTK and Linear
low, high = find_correction_range(gamma_0, gamma_1, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * extrapolation_factor
inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask

self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

$\beta$-ary RoPE

[24]

ReRoPE

[25]

References


  1. 1.Vaswani, Ashish, Noam M. Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. “Attention is All you Need.” NeurIPS (2017).
  2. 2.Press, Ofir, Noah A. Smith and Mike Lewis. “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.” ICLR 2022.
  3. 3.Su, Jianlin, Yu Lu, Shengfeng Pan, Bo Wen and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” ArXiv abs/2104.09864 (2021).
  4. 4.Rotary Embeddings: A Relative Revolution
  5. 5.RoFormer blog (Chinese)- Transformer升级之路:2、博采众长的旋转式位置编码
  6. 6.Blog- Bias项的神奇作用:RoPE + Bias = 更好的长度外推性
  7. 7.Raffel, C., Shazeer, N.M., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P.J. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. JMLR.
  8. 8.Press, O., Smith, N.A., & Lewis, M. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. ICLR.
  9. 9.Chi, T., Fan, T., Ramadge, P.J., & Rudnicky, A.I. (2022). KERPLE: Kernelized Relative Positional Embedding for Length Extrapolation. NeurIPS.
  10. 10.Sun, Y., Dong, L., Patra, B., Ma, S., Huang, S., Benhaim, A., Chaudhary, V., Song, X., & Wei, F. (2022). A Length-Extrapolatable Transformer. ArXiv, abs/2212.10554.
  11. 11.Chi, T., Fan, T., Rudnicky, A., & Ramadge, P.J. (2022). Dissecting Transformer Length Extrapolation via the Lens of Receptive Field Analysis.
  12. 12.Haviv, Adi et al. “Transformer Language Models without Positional Encodings Still Learn Positional Information.” Conference on Empirical Methods in Natural Language Processing (2022).
  13. 13.Ruoss, A., Del'etang, G., Genewein, T., Grau-Moya, J., Csordás, R., Abbana Bennani, M., Legg, S., & Veness, J. (2023). Randomized Positional Encodings Boost Length Generalization of Transformers. ACL.
  14. 14.Blog: Transformer升级之路:8、长度外推性与位置鲁棒性
  15. 15.Chen, Shouyuan, Sherman Wong, Liangjian Chen and Yuandong Tian. Extending Context Window of Large Language Models via Positional Interpolation. ArXiv abs/2306.15595 (2023).
  16. 16.Github discussion: Position Interpolation
  17. 17.Reddit: A simple way to "Extending Context to 8K"
  18. 18.Things I’m Learning While Training SuperHOT
  19. 19.NTK-Aware Scaled RoPE
  20. 20.RoPE is a β-ary encoding (Chinese)
  21. 21.NTK-aware RoPE colab
  22. 22.Dynamic NTK-aware RoPE
  23. 23.GitHub: Dynamic RoPE.
  24. 24.β-ary RoPE (in Chinese)
  25. 25.ReROPE (in Chinese)