Multi-head Attention Detail Example
单头的 Self-Attention 就像是你自己在读一本书,一次只能关注一个重点(比如语法)。而 Multi-Head Attention 就像是请了 8 个专家同时读这句话:
- 专家 A (Head 1) 专门负责看 “谁是主语”。
- 专家 B (Head 2) 专门负责看 “发生动作的时间”。
- 专家 C (Head 3) 专门负责看 “代词指的是谁”。
- …
最后,把大家看到的信息汇总起来,理解就更全面。
具体的数字案例:双头行动
为了简化计算,我们设定一个极简场景:
- 输入单词: “AI” (向量 )
- Embedding 维度 (): 4 (假设向量是
[1.0, 2.0, 3.0, 4.0]) - 头数 (Heads): 2 (我们要把它劈成两个头)
- 每个头的维度 ():
我们来看看数据是怎么“分家”又“团聚”的。
第一步:分头行动 (Linear Projections)
我们不再只用一套 ,而是准备 两套 不同的权重矩阵。Head 1 的任务(假设它关注前半部分特征):它有自己的 。
假设 是一个 的矩阵(把 4 维压到 2 维)。
(Head 1 提取出了前两个特征)
Head 2 的任务(假设它关注后半部分特征):它有自己的 。
(Head 2 提取出了后两个特征)
第二步:各自为战 (Scaled Dot-Product Attention)
现在,两个头在平行的宇宙里分别做你刚才学会的那套 Self-Attention 流程。
Head 1 的宇宙:
它拿着 去和它的 算点积、Softmax、乘 。假设经过一系列计算,Head 1 认为“AI”这个词在语法上很重要,得出的结论向量是:
Head 2 的宇宙:
它拿着 去和它的 算。因为它的关注点不同(比如它关注语义),权重分配完全不同。假设它得出的结论向量是:
第三步:破镜重圆 (Concatenation)
现在两个专家都算完了,我们需要把结果拼起来。操作非常简单,就是直接 首尾相接 (Concat)。
注意: 此时向量的维度变回了 4 (2 + 2),和最开始的 Embedding 维度一样了!
第四步:最后融合 (Linear Output)
虽然拼起来了,但 部分只懂语法, 部分只懂语义,它们之间还没“交流”过。所以,最后一步是乘以一个巨大的输出权重矩阵 (Output Weights)。
假设 也是 的矩阵:
这一步相当于把“语法专家”和“语义专家”的意见进行了一次加权混合,生成了一个既包含语法又包含语义的最终向量。
总结:为什么要这么折腾?
你可能会问:“直接用一个 4 维的大头算,和分成两个 2 维的小头算,有啥区别?”
具体的数字区别:
-
单头 (Single Head):
如果只有一个头,Softmax 只有一次机会。 假设句子是 “I gave the dog a bone because it was hungry”。 在算 “it” 的时候,单头注意力可能发现 “dog” 很重要,权重大。 那 “bone” 呢?“hungry” 呢?
因为 Softmax 归一化了,如果你把 80% 的注意力给了 “dog”,剩下能给 “hungry” 的就很少了。你很难在一次计算中同时关注多个不同的重点。
-
多头 (Multi-Head):
- Head 1: 此时 Softmax 高度关注 “dog” (因为它负责找指代对象)。
- Head 2: 此时 Softmax 高度关注 “hungry” (因为它负责找原因)。
- Head 3: 此时 Softmax 高度关注 “because” (因为它负责找逻辑词)。
结论:
多头注意力允许模型在不同的子空间 (Subspaces) 里,学习到不同的关注模式。它打破了“一次只能看一个重点”的限制,让模型变得更加“博学”和“多维”。
Multi-Head Attention Dimension
对于每一个“注意力头”来说,Q、K、V 的维度通常远小于 Embedding 的维度。 通常遵循的公式是:
其中:
- 是 Embedding 的维度(比如 512)。
- 是“头”的数量(Heads,比如 8)。
- 是每个头内部 Q、K、V 的维度(比如 64)。
1. 为什么要变小?(为了“多头”)
你可以把 Embedding 想象成一份包含全方位信息的档案。如果只有一个巨大的注意力机制(一个大头)来处理整个 512 维的向量,它可能很难同时关注到所有细节。
Transformer 的做法是:把大任务拆给多个“专家小组”。
- Embedding (512维): 原始档案。
- 头 1 (64维): 负责关注“语法关系”(比如主谓一致)。
- 头 2 (64维): 负责关注“代词指代”(比如 it 指的是谁)。
- 头 3 (64维): 负责关注“时态信息”。
- …
- 头 8 (64维): 负责关注其他特征。
这就是为什么单个 Q、K、V 的维度要变小。因为我们将原始的“宽”向量切分到了不同的子空间里去并行处理。
2. 经典模型中的实际数字
业界最著名的模型是怎么设置的:
| 模型 | Embedding 维度 (dmodel) | 头数 (h) | 单个头的 QKV 维度 (dk) | 计算关系 |
|---|---|---|---|---|
| Transformer (论文原版) | 512 | 8 | 64 | |
| BERT-Base | 768 | 12 | 64 | |
| GPT-3 (175B) | 12288 | 96 | 128 |
你可以看到,无论模型多大,单个头的维度通常保持在 64 或 128 左右,不会特别大。
3. 最后会变回来吗?(拼接)
这也是 Transformer 巧妙的地方。虽然中间变小了,但在这一层结束时,它们会合体。
- 拆分计算: 8 个头分别计算,每个头输出一个 64 维的 向量。
- 拼接 (Concat): 把这 8 个 64 维的向量首尾相接拼起来。
- 融合 (Linear): 最后再乘一个巨大的输出矩阵 ,让不同头的信息混合一下,输出结果依然是 512 维。
结论:
- 局部看: Q、K、V 变小了(为了专注特定特征)。
- 整体看: 所有头的维度加起来,通常正好等于 Embedding 的维度(为了保证输入输出大小一致,方便做残差连接 ResNet)。