The Gradient

Language is not just words.

Fork me on GitHub

Memory-Efficient Attention: MHA vs. MQA vs. GQA vs. MLA

Efficient attention mechanisms are crucial for scaling transformers in large-scale applications. Here we explore different attention variants of Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA), analyzing their trade-offs in memory, speed, and expressivity, and how they enhance transformer scalability. 🚀

Self-attention architecture

Image source: DeepSeek-V2

Multi-Head Attention (MHA)

Standard MHA[1] uses independent queries, keys, and values per head.

Formulation

  1. Linear Projections (Per Head)

    where:

    • is the input sequence,
    • are head-specific projections.
  2. Scaled Dot-Product Attention (Per Head)

  3. Concatenation and Output Projection

    where is the output projection matrix.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)

def forward(self, x):
# projection x to Q,K,V
Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
# reshape to (bsz, num_heads, seq_len, d_k)
Q, K, V = (
rearrgange(Q, 'b t (h d) -> b h t d', h=self.num_heads),
rearrgange(K, 'b t (h d) -> b h t d', h=self.num_heads),
rearrgange(V, 'b t (h d) -> b h t d', h=self.num_heads),
)

# scaled dot-product attention
attn = torch.einsum("bhtd,bhTd->bhtT", Q,K) / self.d_k**.5
attn = F.softmax(attn, dim=-1)
output = torch.einsum('bhtT,bhTd->bhtd', attn, V)

# merge & project
output = rearrange(output, "b h t d -> b t (h d)")
return self.W_o(output)

Multi-Query Attention (MQA)

MQA[2] optimizes MHA by sharing keys and values across all heads.

Formulation

  1. Query Projection (Per Head)

  2. Shared Key and Value Projections

    where:

    • are shared across all heads.
  3. Attention Computation

  4. Concatenation and Output Projection

Key Difference from MHA: All heads share the same , reducing memory usage.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_k, bias=False) # shared K
self.W_v = nn.Linear(d_model, d_k, bias=False) # shared V
self.W_o = nn.Linear(d_model, d_model, bias=False)

def forward(self, x):
# projection x to Q,K,V
Q, = self.W_q(x)
K, V = self.W_k(x), self.W_v(x) # shared K,V
# reshape queries, not on K/V (all shared)
Q = rearrgange(Q, 'b t (h d) -> b h t d', h=self.num_heads))

# scaled dot-product attention
attn = torch.einsum("bhtd,bTd->bhtT", Q,K) / self.d_k**.5
attn = F.softmax(attn, dim=-1)
output = torch.einsum('bhtT,bTd->bhtd', attn, V)

# merge & project
output = rearrange(output, "b h t d -> b t (h d)")
return self.W_o(output)

Grouped-Query Attention (GQA)

GQA[3] is a middle ground between MHA and MQA: queries are grouped, with each group sharing key-value projections.

Formulation

  1. Query Projection (Per Head)

  2. Grouped Key-Value Projections

    • Queries are divided into groups (), each with its own shared key-value pair:where .
  3. Attention Computation (Per Group)

  4. Concatenation and Output Projection

Key Difference from MQA: Each group of heads has its own key-value pairs, offering better expressivity than fully shared MQA.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_groups):
super().__init__()
assert d_model % num_heads == 0
assert num_heads % num_groups == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.num_groups = num_groups

self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)

def forward(self, x):
# projection x to Q,K,V
Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)

# reshape Q into heads, K/V into groups
Q = rearrgange(Q, 'b t (h d) -> b h t d', h=self.num_heads))
K = rearrange(K, 'b t (g d) -> b g t d', g=self.num_groups)
V = rearrange(V, 'b t (g d) -> b g t d', g=self.num_groups)

# assign each head to a group
group_idx = self.num_heads // self.num_groups
   K = K.repeat_interleave(group_idx, dim=1)
   V = V.repeat_interleave(group_idx, dim=1)

# scaled dot-product
  attn = torch.einsum("bhtd,bhTd->bhtT", Q,K) / self.d_k**.5
attn = F.softmax(attn, dim=-1)
output = torch.einsum('bhtT,bhTd->bhtd', attn, V)

# merge & project
output = rearrange(output, "b h t d -> b t (h d)")
return self.W_o(output)

Multi-Head Latent Attention (MLA)

MLA[4] compresses queries, keys, and values using low-rank projections, reducing KV cache size and activation memory. For simplicity, we omit the RoPE embeddings.

Source: DeepSeek-V2

Formulation

  1. Latent Compression for Keys and Values

    where is the compressed KV dimension.

  2. Key-Value Reconstruction

    where are up-projection matrices.

  3. Latent Compression for Queries

    where is the compressed query dimension.

  4. Query Reconstruction

    where is an up-projection matrix.

  5. Attention Computation

  6. Concatenation and Output Projection

Key Difference: MLA stores compressed latent low-rank representations in the KV cache, reducing memory footprint while maintaining expressivity.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Here, we **omit the branch of RoPE position embeddings**.

class MultiHeadLatentAttention(nn.Module):
def __init__(self, d_model, num_heads, d_c, d_c1):
"""
Args:
d_c (int): compression dimension for K/V.
d_c1 (int): compression dimension for Q.
"""
super().__init__()
assert d_model % num_heads == 0
assert num_heads % num_groups == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # head dim
assert d_c < d_model and d_c1 < d_model
self.d_c = d_c # compressed dim for K/V
self.d_c1 = d_c1 # compressed dim for Q

# compression layer
self.W_dkv = nn.Linear(d_model, d_c, bias=False) # compress KV
self.W_dq = nn.Linear(d_model, d_c, bias=False) # compress Q

# decompression layers
self.W_uk = nn.Linear(d_c, d_model, bias=False) # expand K
self.W_uv = nn.Linear(d_c, d_model, bias=False) # expand V
self.W_uq = nn.Linear(d_c1, d_model, bias=False) # expand Q
# output projection
self.W_o = nn.Linear(d_model, d_model, bias=False)

def forward(self, x):
# latent compression
C_kv = self.W_dkv(x) # compress KV
C_q = self.W_dq(x) # compress Q

# reconstruct full dim Q/K/V
Q = self.W_uq(C_q)
K = self.W_uk(C_kv)
V = self.W_uv(C_kv)

# reshape into multiple heads
Q = rearrange(Q, 'bt(hd)->bhtd', h=self.num_heads)
K = rearrange(Q, 'bt(hd)->bhtd', h=self.num_heads)
V = rearrange(Q, 'bt(hd)->bhtd', h=self.num_heads)

# scaled dot-product attention
attn = torch.einsum('bhtd,bhTd=>bhtT', Q, K) / self.d_k**.5
attn = F.softmax(attn, dim=-1)
output = torch.einsum('bhtT,bhTd->bhtd', attn, V)

# merge & projection
output = rearrange(output, 'bhtd->bt(hd)')
return self.W_o(output)

Comparison

Attention Type Query Projection Key Projection Value Projection KV Cache Size Expressivity Use Case
MHA (Multi-Head) Per Head Per Head Per Head High High General Transformer models
MQA (Multi-Query) Per Head Single Single Low Low Large decoder models (e.g., LLMs)
GQA (Grouped-Query) Per Head Per Group Per Group Medium Medium Balanced tradeoff for large models
MLA (Multi-Head Latent) Compressed Compressed Compressed Very Low Moderate Efficient KV caching for long-sequence processing

References


  1. 1.Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  2. 2.Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
  3. 3.Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
  4. 4.Liu, Aixin, et al. "Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model." arXiv preprint arXiv:2405.04434 (2024).