Multi-Head Attention,简称 MHA,是 Transformer 中最经典的 attention 结构。它把 hidden state 投影到多个 attention heads,让模型在多个子空间中并行计算信息路由,再把各个 head 的输出拼接并融合。相比单头 attention,MHA 允许不同 head 学习不同类型的 token 关系。

为什么需要多个 Head

单个 attention head 只能用一组 投影计算相关性。语言和代码中的关系却有很多种:局部语法、长距离依赖、实体指代、格式标记、括号匹配、对话轮次等。如果只用一个投影空间,所有关系都要压到同一组 attention pattern 中。

MHA 的直觉是:

把表示分成多个子空间,让不同 head 并行学习不同的信息路由方式。

这并不意味着每个 head 都有清晰的人类语义标签,但多头结构为模型提供了更丰富的关系建模能力。

结构与形状

给定 hidden states:

设 head 数量为 ,每个 head 的维度为 ,通常:

MHA 先投影:

然后 reshape 为多头形式:

每个 head 独立计算 attention:

最后拼接所有 head,并通过输出投影融合:

其中 把拼接后的表示重新映射回

Output Projection 的作用

如果只是拼接多个 head,每个 head 的信息仍然相对分离。输出投影 负责在 head 之间重新混合信息,让模型学习不同 head 输出如何组合。

因此 MHA 不只是“多个 attention 并排算完”:它包含 head-specific attention 和跨 head 的线性融合。

与 Self-Attention 的关系

Self-Attention 描述的是 Q/K/V 来自同一序列;Multi-Head Attention 描述的是 attention 的多头组织方式。二者可以组合:

  • encoder self-attention 可以是 multi-head;
  • decoder causal self-attention 可以是 multi-head;
  • encoder-decoder cross-attention 也可以是 multi-head。

在 decoder-only LLM 中,常见的是 causal multi-head self-attention。

参数量与计算

标准 MHA 的 Q/K/V/O 投影参数量大致为:

其中 Q、K、V 三个投影各约 ,输出投影 。在固定 下,增加 head 数通常会减少每个 head 的 ,不一定显著改变投影参数总量;但 head 数会影响 attention pattern、KV Cache 组织和 kernel 实现。

标准 MHA 的 attention score 计算仍然与序列长度二次相关:

KV Cache 成本

自回归推理时,每层需要缓存历史 token 的 K/V。标准 MHA 中,每个 query head 都有自己的 K/V head,因此 KV Cache 大小随 增长:

其中:

  • 表示 key 和 value;
  • 是层数;
  • 是上下文长度;
  • 是 heads 数;
  • 是 head dimension。

由于 ,标准 MHA 的每 token KV 表示维度接近 。长上下文和大 batch 下,这会成为 serving 的核心瓶颈。

这也是 MQAGQAMLA 出现的重要原因:它们都试图降低 K/V 表示或缓存成本。

Head 的可解释性边界

一些 head 可能呈现可解释模式,例如关注分隔符、局部邻居、复制位置或实体引用。但不能把每个 head 都理解成独立、稳定、可命名的“功能模块”。原因包括:

  • 多个 head 可能冗余;
  • head 功能可能随层数变化;
  • FFN 和 residual 会重组 head 输出;
  • 相同任务能力可能分布在多个 head 和多层之间。

因此,head visualization 可以作为诊断工具,但不应直接等同于模型解释。

设计取舍

设计点优势代价
多个 query heads多子空间建模关系KV Cache 和实现复杂度增加
较大 head dimension单 head 表达更强head 数减少或总宽度增加
较多 headsattention pattern 更丰富每 head 维度过小时可能受限
输出投影融合不同 head 信息增加参数和计算
MHA表达能力强,结构标准推理 KV Cache 成本最高

与 MQA / GQA / MLA 的关系

MHA 是这些变体的基准:

  • MQA:保留多个 Q heads,但所有 heads 共享一组 K/V。
  • GQA:多个 Q heads 分组共享 K/V。
  • MLA:把 K/V 压缩到 latent representation,再恢复用于 attention。

这些方法主要针对推理时 KV Cache 和 memory bandwidth 成本,而不是否定 MHA 的表达能力。

常见误解

  • 误解:head 越多一定越好。 head 数、head dimension、模型宽度和训练规模要一起考虑。
  • 误解:每个 head 都有固定语义。 head 可以有可解释倾向,但功能通常是分布式的。
  • 误解:MHA 只影响训练。 推理时 MHA 直接决定 KV Cache 规模。
  • 误解:MQA/GQA 只是实现优化。 它们改变了 K/V 参数共享方式,也可能影响表达能力。

相关概念