📚 本文是《Transformer 由浅入深》系列第 4 篇 · 阶段二「核心机制」
上一篇结尾我们留了个疑问:一组 \( W_Q / W_K / W_V \) 只能学到一种“什么算相关”。可语言里的关系五花八门——有句法的(谁是主语)、有指代的(“它"指谁)、有语义的(近义、搭配)……指望一种视角全包,太勉强了。
这一篇的主角 多头注意力(Multi-Head Attention) 就是来解决这件事的。它的思路一句话:与其用一个大注意力看全部,不如拆成好几个小注意力,各看一个角度,最后汇总。
打个最直观的比方:你让一个班级讨论一篇课文。如果只叫一个学生发言,他再聪明,关注点也总是有限的——可能他特别在意修辞,却忽略了逻辑结构。但如果你让几个各有所长的学生同时读、再把发言汇总,你得到的理解会立体得多。多头注意力就是把"一个学生"换成"一个学习小组”。
本篇目标
读完你应该能:
- 说清多头的动机和计算流程;
- 算明白"维度怎么对得上"——为什么多头几乎不增加计算量;
- 知道不同的"头"大致在学什么。
阅读前提
第 3 篇的单头注意力公式 \( \text{Attention}(Q,K,V) = \text{softmax}(QK^{\top}/\sqrt{d_k})V \)。
1. 单头的局限
回想单头注意力:输入经过一组 \( W_Q/W_K/W_V \),投影到一个 \( d_k \) 维空间里去比相关性。问题是,这一组矩阵一旦训练好,它衡量"相关"的标准就固定了。
打个比方:这就像只让一个人去读整句话,而这个人只擅长一种分析(比如只看语法)。他能给出一种解读,但"这个词指代谁"“这个词和哪个词是固定搭配"这些别的角度,就被这单一视角压扁、混在了一起,损失了信息。
再换一个镜头的比喻可能更贴切:单头就像你只戴一副滤镜去看照片。一副偏色的滤镜也许能突出冷暖,但同一张照片里的"轮廓"“纹理"“光影"这些维度,被这副滤镜一压,就要么被强化、要么被抹平,你没法同时看清。摄影师真正想要的,是手里有一叠滤镜,需要看什么就叠什么。
考虑这样一句话:“小明把书还给了图书馆,因为它到期了。“要正确理解"它”,模型至少要同时处理几件事:从句法上知道"它"是个代词、需要找先行词;从语义上排除"图书馆”(图书馆不会"到期”)、锁定"书”;从位置上知道"它"和"书"隔了一段距离但仍相关。这三件事如果挤在同一个 \( d_k \) 维空间里争抢,彼此就会互相干扰。我们真正想要的,是一组各有所长的分析师同时上,每人专心干一件事,最后再碰头。
更进一步说,单头的问题不只是"视角少”,还在于冲突。同一组投影矩阵被要求既擅长抓"主谓宾"、又擅长抓"近义搭配",这两种目标在训练时会互相拉扯:为了把语法学好,某些维度需要朝一个方向调;为了把语义学好,同一批维度又被拽向另一个方向。最后的结果往往是两头都学了个半吊子——这就是所谓的"表征瓶颈"。多头的高明之处,正是把这场拔河拆成了几场互不打架的独立比赛:每个头只对自己那一件事负责,梯度信号干净、目标明确,自然学得更专、更深。
2. 多头的核心思路
多头的做法非常直接:
把 \( d \) 维空间切成 \( h \) 份,每一份叫一个"头"(head)。每个头有自己独立的 \( W_Q/W_K/W_V \),在自己那 \( d_k = d/h \) 维的小子空间里,独立地做一遍完整的注意力。最后把 \( h \) 个头的结果拼起来。
\( h \) 个头 = \( h \) 个互不干扰的视角,并行学习。有的头可以专门盯语法,有的盯指代,有的盯搭配——各司其职,互不挤占。
这里有个容易被忽略的精妙之处:每个头之所以能学到不同的东西,根源就在于它们各自拥有独立的、随机初始化的 \( W_Q/W_K/W_V \)。训练一开始,这几组矩阵长得都不一样;在梯度下降的过程中,它们会自发地"分工"——因为如果两个头学得一模一样,那就是浪费,损失函数没有任何理由奖励这种冗余,反而会推着它们朝不同方向走。这有点像一个团队招人:如果两个人能力完全重叠,老板自然会希望其中一个去补别的短板。分工不是我们手写规则强加的,而是优化过程"逼"出来的。
也可以换个更生活化的比方:把一句话想成一桌菜,多头就是请来几位口味不同的食客同时品尝。一位专挑咸淡(对应某种语义关系),一位专评火候(对应句法结构),一位专看摆盘(对应位置邻接)。如果你只请一位"全能美食家",他给的评价必然是个含糊的综合分,你听不出到底哪里好哪里差;但几位专精食客各打各的分,你拿到的就是一张细分到各个维度的"评测表"。后面的 \( W_O \) 再把这张表汇总成一句"这道菜值不值得点",信息既丰富又清晰。这正是"拆分—专精—汇总"这套流程的价值所在。
3. 计算流程拆解
写成公式。对第 \( i \) 个头(\( i = 1, \dots, h \)):
$$ \text{head}_i = \text{Attention}(Q W_Q^{i},\ K W_K^{i},\ V W_V^{i}) $$其中每个头的投影矩阵把 \( d \) 维压到 \( d_k = d/h \) 维:\( W_Q^{i}, W_K^{i}, W_V^{i} \in \mathbb{R}^{d \times d_k} \)。
每个 \( \text{head}_i \) 的形状是 \( n \times d_k \)。把 \( h \) 个头沿特征维拼接(concat),正好拼回 \( n \times d \):
$$ \text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{n \times d} $$最后,再过一个输出投影 \( W_O \in \mathbb{R}^{d \times d} \),把各头的结果融合成最终输出:
$$ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\, W_O $$这个 \( W_O \) 很重要:拼接只是把各头结果排在一起,\( W_O \) 才让它们互相交流、整合成一个统一的表示。
为什么单靠 Concat 还不够?可以这样想:拼接好比把几位专家的报告装订成一本册子——纸是放在一起了,但每页还是各说各话,后面那一层(下一个注意力块、前馈网络)直接读这本"拼盘"会很别扭,因为它不知道"第 164 维是指代头说的、第 65128 维是语法头说的"。\( W_O \) 干的事,就是当那个主编:它把各头的结论按权重重新组合、提炼,写成一份前后连贯的综述。具体到数学上,\( W_O \) 的每一行都会从所有头的输出里取一个加权组合,于是"指代头发现的信息"可以和"语法头发现的信息"在这一步真正混合起来,而不是各占一段、老死不相往来。没有 \( W_O \),多头就只是 \( h \) 个并行的小注意力硬凑在一起,缺了最后那道"把大家的意见揉成一个意见"的工序。
整个数据流的形状变化:
| 阶段 | 形状 |
|---|---|
| 输入 | \( n \times d \) |
| 每个头的 Q/K/V | \( n \times d_k \)(共 \( h \) 组) |
| 每个头的输出 | \( n \times d_k \) |
| 拼接后 | \( n \times d \) |
| 过 \( W_O \) 后(最终输出) | \( n \times d \) |
输入输出形状一致,这很关键——意味着多头注意力可以像积木一样反复堆叠(第 6 篇会用到)。
4. 维度账本:参数量没爆炸
一个常见的误解是:“开了 \( h \) 个头,计算量是不是翻了 \( h \) 倍?”
并没有。 关键在 \( d_k = d/h \) 这个设定。单头时,投影在 \( d \times d \) 的尺度上;多头时,每个头只在 \( d \times (d/h) \) 上,\( h \) 个头加起来:
$$ h \times \big(d \times \tfrac{d}{h}\big) = d \times d $$正好和单头一样大。换句话说,多头是把同样大小的"注意力预算",切成几份分给不同视角,而不是凭空多花算力。这是个非常划算的设计:近乎免费地换来了多视角能力。
这里的直觉,可以用"分蛋糕"来理解。假设你总共只有一块固定大小的蛋糕(总维度 \( d \)),你可以一个人独吞(单头,一个 \( d \) 维空间),也可以切成 12 小块分给 12 个人(12 个头,每个 64 维)。蛋糕的总量没变,变的只是怎么分。单头那个人占着整块蛋糕,但他只有一张嘴、只能尝出一种味道;切成小块后,12 个人各尝一块,你反而收集到了 12 种风味的反馈。总成本不变(同一块蛋糕、同样多的参数),收益却从"一种视角"变成了"多种视角"——这就是多头几乎"白嫖"来的好处。
也许你会追问:既然总维度没变,凭什么"切开"就比"不切"强?毕竟数据还是那些数据。关键在于 softmax 是按整个向量来归一化的。单头时,一次 softmax 要在整个 \( d \) 维上同时权衡所有线索,最终只能输出一套注意力权重;它要么重点关注"前一个词",要么重点关注"指代对象",很难两者都顾全——因为这一套权重是大家共用的。切成多头后,每个头有独立的一次 softmax,于是可以各自给出一套互不妥协的权重:邻座头放心地把火力全压在相邻词上,指代头放心地把火力全压在先行词上,谁也不用迁就谁。所以"切开"换来的不是数据变多,而是决策的自由度变多——这正是表达力的来源。
需要稍微补一句严谨的:说"计算量几乎不变",指的是投影和注意力打分这两块的主体开销与单头同量级;真正额外多出来的,只有那个 \( W_O \)(一个 \( d \times d \) 的矩阵)。但相对整个网络,它微不足道,所以工程上我们就当多头是"免费午餐"。
举个具体数字:GPT 类模型常见 \( d = 768 \)、\( h = 12 \),于是每个头 \( d_k = 64 \)。12 个 64 维的头,合起来还是 768 维。
5. 不同头学到了什么(可视化)
训练好的模型,把各个头的注意力热力图画出来,常能看到分工(下面是示意,行=当前词,列=被关注词,█ 越深关注越多):
头 A:关注「前一个词」 头 B:关注「指代对象」
the cat sat on The cat ... it
the [ ░ ░ ░ ░ ] The [ █ ░ ░ ░ ]
cat [ █ ░ ░ ░ ] cat [ ░ █ ░ ░ ]
sat [ ░ █ ░ ░ ] ... [ ░ ░ █ ░ ]
on [ ░ ░ █ ░ ] it [ ░ █ ░ ░ ] ← it 看向 cat
研究者发现真实模型里确实存在这类"专家头":有的稳定地看相邻词,有的负责句法依存(动词找它的宾语),有的做指代消解。当然也有不少头作用模糊、甚至可被剪掉——但整体而言,多头让模型能同时持有多种关于句子的假设,这正是它表达力的来源之一。
打个收尾的比方:多头就像给模型配了一套不同焦段的镜头。广角头负责把整句话的轮廓收进画面(全局语境),长焦头负责死死锁定远处某个特定的词(长距离指代),微距头负责看清紧挨着的两个词之间的细节(局部搭配)。拍同一个场景,摄影师不会只带一只镜头;模型理解同一句话,也不该只用一种视角。多头注意力做的,就是把这一整套镜头同时架好,让模型"咔嚓"一下,把各个层次的信息一次性都拍下来,再交给 \( W_O \) 去后期合成一张信息量最大的成片。
为了让"不同头学到不同东西"更有画面感,我们不妨把几类常见的头拟人化一下:
- “邻座头”:几乎只盯着自己左边或右边的那个词,像个总忍不住偷看同桌答案的学生。它在还原局部语序、捕捉固定搭配(“纽约"“人工智能”)时特别有用。
- “句法头”:专门把动词连到它的宾语、把介词连到它的宾语,像语文老师在句子上画成分结构图。
- “指代头”:看到代词(“它"“他"“这”)就去全句里找它到底指谁,像侦探顺着线索回溯。
- “全局头”:注意力摊得很平,谁都看一点,像个负责"通读全文、把握大意"的组长,给整句话一个背景底色。
要强调的是,没有人给这些头贴标签、派任务。我们从不告诉模型"你是第 3 个头,你去学指代”。这些角色是训练完之后,研究者拿着放大镜(注意力可视化)反过来"发现"的。换句话说,分工是涌现出来的结果,而不是设计出来的指令——这恰恰是多头机制最迷人的地方之一。
6. 代码视角
实现上不会真的开 \( h \) 个 for 循环。常见做法是一次性投影到 \( d \) 维,再 reshape 成多头,用一次批量矩阵乘把所有头并行算完:
# x: (batch, seq, d)
Q = x @ W_Q # (batch, seq, d),一次性投影,后面再切头
K = x @ W_K
V = x @ W_V
# 切成 h 个头:(batch, seq, d) -> (batch, h, seq, d_k)
def split_heads(t, h):
b, s, d = t.shape
return t.reshape(b, s, h, d // h).transpose(0, 2, 1, 3)
Qh, Kh, Vh = split_heads(Q, h), split_heads(K, h), split_heads(V, h)
# 在最后两维 (seq, d_k) 上做缩放点积注意力,h 个头天然并行
scores = Qh @ Kh.transpose(0, 1, 3, 2) / (d_k ** 0.5)
A = softmax(scores) # (batch, h, seq, seq)
heads = A @ Vh # (batch, h, seq, d_k)
# 合并回去:(batch, h, seq, d_k) -> (batch, seq, d),再过 W_O
out = heads.transpose(0, 2, 1, 3).reshape(b, s, d) @ W_O
要点就两个:reshape 把特征维切成 (h, d_k),然后让矩阵乘在最后两维上对每个头并行计算。整段代码从头到尾没有一处显式的 for i in range(h),所谓"多头"完全靠张量维度的巧妙安排实现,这也是它能跑得又快又省的根本原因。
值得回味的是:为什么明明是"多个头”,代码里却看不到一个循环?因为在张量眼里,“头"不过是多出来的一个维度(那个 h)。GPU 最擅长的就是对一堆维度同时做同样的运算。所以"切成 12 个头"在硬件上几乎不带来额外的时间——它们是真正并排同时算的,而不是排队一个接一个算。这也再次印证了第 4 节的结论:多头在算力上是笔划算的买卖。
顺带提一句容易踩的实现坑:split_heads 里那一步 transpose(把头那一维换到序列维前面)不是可有可无的。如果省掉它,矩阵乘就会错误地把"不同头之间"也算进注意力里,等于让本该各管各的食客互相串味。正确的摆法是让每个头的 (seq, d_k) 落在张量的最后两维,这样批量矩阵乘才会对每个头各算各的、彼此井水不犯河水。理解了"头只是一个并行维度”,这些 reshape / transpose 的来龙去脉也就顺理成章了。
常见疑问与易错点
问:头越多越好吗?是不是头数越多模型就越强? 答:不是。在总维度 \( d \) 固定的前提下,头数 \( h \) 越多,每个头分到的 \( d_k = d/h \) 就越小。头太多会导致每个子空间维度过低,小到放不下有意义的"相关性度量”,反而学不好;头太少又回到了视角单一的老问题。所以 \( h \) 是个需要权衡的超参数,常见取值如 8、12、16,要和 \( d \) 配套着调,而不是一味往大开。
问:多头是不是比单头慢很多? 答:几乎不会。如第 4 节所说,因为 \( d_k = d/h \),所有头的投影与打分加起来和单头同量级,额外开销只有一个 \( d \times d \) 的 \( W_O \)。而且各头在 GPU 上是并行计算的(见第 6 节),不存在"头多就要排队"的问题。真正影响速度的是序列长度 \( n \)(注意力是 \( n^2 \) 量级),而非头数。
问:各个头之间是完全独立、互不相干的吗? 答:在"算注意力"这一步是独立的——每个头有自己的 \( W_Q/W_K/W_V \),各算各的。但到了最后过 \( W_O \) 时,所有头的输出被重新混合,信息在这里汇合了。所以更准确的说法是:计算时分头并行,融合时合而为一。 不能说它们从头到尾老死不相往来。
问:我能手动指定"第 3 个头去学指代"吗? 答:不能,也不需要。头的分工是训练过程中自发涌现的,我们既不给头贴标签,也无法强制某个头学某种关系。第 5 节里那些"指代头"“句法头"的名字,都是研究者事后通过可视化观察、反向命名的,而非预先设定的角色。
问:既然有些头作用模糊、还能被剪掉,那它们是不是没用、是冗余? 答:不能简单地这么下结论。确实有研究表明训练好的模型里部分头剪掉后性能损失不大,但这往往是训练完成之后的现象。在训练过程中,这种"冗余"提供了更大的搜索空间和容错余地,有助于模型找到好的解。可以类比团队里的"替补”:平时看着不起眼,但它们参与了"把阵容磨合好"的过程,不等于一开始就可以不要。
问:多头注意力里的"头",和神经网络的"层"是一回事吗? 答:不是,别混淆。头是同一层内部的并行拆分(横向:把一个注意力切成 \( h \) 份同时算),层是纵向的堆叠(一个多头注意力块叠在另一个之上,逐级加深)。一个典型模型可能有 12 层,每层各有 12 个头——它们是两个不同维度上的"多",一个管广度,一个管深度。
问:每个头算出来的注意力权重(那张热力图),会因为头不同而不一样吗? 答:会,而且这正是多头的意义所在。每个头有自己的 \( W_Q/W_K \),所以同一句话在不同头里算出的 \( QK^\top \) 打分矩阵是不同的——也就是说,“谁该关注谁"在每个头里都有一套自己的答案。第 5 节那两张热力图(邻座头 vs 指代头)就是同一句话在两个头里截然不同的注意力分布。正因如此,我们才说多头让模型能同时持有多套关于句子结构的假设。
问:既然各头独立、最后又被 \( W_O \) 混合,那 \( W_O \) 会不会把好不容易分出来的视角又"搅成一锅粥”? 答:不会,恰恰相反。\( W_O \) 不是简单地把各头平均掉,而是一个可学习的矩阵——训练会教它有选择地取用各头的信息:对当前任务有用的头权重大、噪声大的头权重小。所以它扮演的是"主编挑稿"的角色,而非"把所有稿子揉碎重写"。分出来的多视角不但没被毁掉,反而通过 \( W_O \) 被恰当地编排进了最终表示里。
问:\( d_k = d/h \) 是必须的硬性规定吗?如果不整除怎么办? 答:\( d_k = d/h \) 是最常见、最省事的约定,它保证拼接后正好回到 \( d \) 维、参数量与单头持平。但它并非数学上的铁律——原则上每个头的维度可以自由设定,拼接后再用 \( W_O \) 投影回 \( d \) 维即可。只不过工程上为了让 reshape 干净利落、维度对齐方便,我们通常直接要求 \( h \) 能整除 \( d \),于是 \( d_k = d/h \) 就成了默认配置。这也是为什么实践中 \( d \) 常取 \( 64 \) 的倍数(如 768、1024)。
小结 & 下一篇预告
这一篇:
- 单头只有一种"相关"标准,视角太窄;
- 多头 = 把 \( d \) 维切成 \( h \) 个 \( d_k=d/h \) 的子空间,各自独立做注意力,再 Concat + \( W_O \) 融合;
- 因为 \( d_k=d/h \),多头几乎不增加计算量,却换来多视角表达;
- 真实模型里能观察到分工明确的"专家头"。
但到此为止,我们的注意力还有一个致命盲区:它根本不在乎词的顺序!“猫追狗"和"狗追猫”,在纯注意力眼里几乎一样。下一篇,我们就来补上这块——位置编码,看模型怎么把"语序"重新塞回去。
📖 系列目录
见 第 1 篇 文末完整目录。