Multi-Head Attention Dimension
对于每一个“注意力头”来说,Q、K、V 的维度通常远小于 Embedding 的维度。 通常遵循的公式是:
其中:
- 是 Embedding 的维度(比如 512)。
- 是“头”的数量(Heads,比如 8)。
- 是每个头内部 Q、K、V 的维度(比如 64)。
1. 为什么要变小?(为了“多头”)
你可以把 Embedding 想象成一份包含全方位信息的档案。如果只有一个巨大的注意力机制(一个大头)来处理整个 512 维的向量,它可能很难同时关注到所有细节。
Transformer 的做法是:把大任务拆给多个“专家小组”。
- Embedding (512维): 原始档案。
- 头 1 (64维): 负责关注“语法关系”(比如主谓一致)。
- 头 2 (64维): 负责关注“代词指代”(比如 it 指的是谁)。
- 头 3 (64维): 负责关注“时态信息”。
- …
- 头 8 (64维): 负责关注其他特征。
这就是为什么单个 Q、K、V 的维度要变小。因为我们将原始的“宽”向量切分到了不同的子空间里去并行处理。
2. 经典模型中的实际数字
业界最著名的模型是怎么设置的:
| 模型 | Embedding 维度 (dmodel) | 头数 (h) | 单个头的 QKV 维度 (dk) | 计算关系 |
|---|---|---|---|---|
| Transformer (论文原版) | 512 | 8 | 64 | |
| BERT-Base | 768 | 12 | 64 | |
| GPT-3 (175B) | 12288 | 96 | 128 |
你可以看到,无论模型多大,单个头的维度通常保持在 64 或 128 左右,不会特别大。
3. 最后会变回来吗?(拼接)
这也是 Transformer 巧妙的地方。虽然中间变小了,但在这一层结束时,它们会合体。
- 拆分计算: 8 个头分别计算,每个头输出一个 64 维的 向量。
- 拼接 (Concat): 把这 8 个 64 维的向量首尾相接拼起来。
- 融合 (Linear): 最后再乘一个巨大的输出矩阵 ,让不同头的信息混合一下,输出结果依然是 512 维。
结论:
- 局部看: Q、K、V 变小了(为了专注特定特征)。
- 整体看: 所有头的维度加起来,通常正好等于 Embedding 的维度(为了保证输入输出大小一致,方便做残差连接 ResNet)。