"Attention is all you need." — Vaswani et al., NeurIPS 2017
2017 年 6 月的某个深夜,Ashish Vaswani 在 Google Brain 提交了一篇 8 位作者的论文,题目是当时看起来挑衅意味十足的 《Attention Is All You Need》。论文里既没有循环(recurrence)也没有卷积(convolution), 只用一种叫 self-attention 的几何运算,在 WMT 2014 英德翻译上把 BLEU 推到 28.4,训练时间却只是当时最强 GNMT 的零头。
在那之前,整个 NLP 领域有一个稳固的共识——序列必须由循环网络处理。LSTM 是 1997 年的产物, seq2seq + attention(Bahdanau 2014, Luong 2015)经过 3 年打磨已经统治 NMT。突然之间, 一个本质上"非时序"的并行架构,把这个共识彻底颠覆。这场颠覆的余波席卷了之后整个 LLM 浪潮:BERT (2018)、GPT-2 (2019)、 GPT-3 (2020)、ChatGPT (2022)、GPT-4 (2023)——所有这些模型的核心积木,都是 2017 年那篇论文里定义的 Transformer block。
每一节都按这个三段式组织:(1) 直觉 / 历史动机 → (2) 数学公式与推导 → (3) PyTorch 代码 / 思考题。 读者不需要事先了解 attention,但需要熟悉矩阵乘法、softmax、链式法则与反向传播。
要理解 Transformer 为什么这样设计,必须先理解它替代的 RNN 为何失败。本节我们把 Pascanu et al. (2013) 那篇被引超过 6000 次的论文 "On the difficulty of training recurrent neural networks" 的数学结论搬到台面上, 精确说明梯度消失与梯度爆炸是怎么从链式法则的几何结构里诞生的。
一个 vanilla(朴素)RNN 在时刻 $t$ 的更新方程是:
其中 $\sigma$ 通常是 $\tanh$ 或 sigmoid,$\boldsymbol{W}_h, \boldsymbol{W}_x \in \mathbb{R}^{d\times d}$ 是所有时刻共享的权重矩阵。 "所有时刻共享"这五个字是接下来一切问题的根源。
假设我们要计算第 4 步的损失 $J^{(4)}(\theta)$ 对第 1 步隐藏状态 $\boldsymbol{h}^{(1)}$ 的梯度。这是个"反向传播穿越时间"(BPTT)问题:
由链式法则:
其中每一个 Jacobian 矩阵 $A_t = \partial \boldsymbol{h}^{(t)}/\partial \boldsymbol{h}^{(t-1)}$ 都长成同一个形式:
重点是:所有 $A_t$ 共用同一个 $\boldsymbol{W}_h$。如果时间窗口长度为 $T$,那么 $\partial J^{(T)}/\partial \boldsymbol{h}^{(1)}$ 就是 $T-1$ 个几乎相同矩阵的乘积。线性代数告诉我们:同一个矩阵自乘 $k$ 次,结果由其最大奇异值(或最大特征值,谱半径)主导。
为了把直觉变成定理,我们做一个简化:假设激活函数 $\sigma$ 是恒等函数(identity,$\sigma(z)=z$),所以 $\sigma'=1$。 那么 $A_t = \boldsymbol{W}_h$,且:
对 $\boldsymbol{W}_h$ 做特征值分解(假设可对角化):$\boldsymbol{W}_h = Q\,\Lambda\,Q^{-1}$,其中 $\Lambda = \mathrm{diag}(\lambda_1,\dots,\lambda_d)$。那么:
当 $t-k$ 很大(也就是反传跨越很多步)时:
Pascanu, Mikolov & Bengio (2013) 给出的精确不等式是:
对 $\tanh$ 而言 $\sigma'(z) \in (0,1]$,所以即使 $\|\boldsymbol{W}_h\|$ 略大于 1,乘上对角项后仍可能整体收缩。 经验上 $\tanh$ RNN 的有效记忆长度通常只有 10-20 个时刻。
"When she tried to print her tickets, she found that the printer was out of toner. She went to the stationery store to buy more toner. It was very overpriced. After installing the toner into the printer, she finally printed her ______"要在末尾填 "tickets",模型必须将第 7 步的 "tickets" 的信息传到第 50+ 步。 如果梯度按 $0.5^{43} \approx 10^{-13}$ 衰减,反传时这条依赖关系的梯度小到浮点数都无法表示, 所以模型训练时根本学不到"她要打印的是 tickets"这个事实,测试时自然预测错。
反方向上,如果 $\|\boldsymbol{W}_h\|$ 谱半径大于 1(更准确说:动力系统进入混沌区域),梯度就会指数爆炸。 SGD 更新规则是:
当 $\|\nabla_\theta J(\theta)\|$ 极大时(比如 $10^6$),即使学习率 $\alpha=10^{-3}$ 看似合理,单步更新仍会把参数推到一个 非常远且非常糟的位置。Pascanu 描述这种现象的比喻很经典:
"You think you've found a hill to climb, but suddenly you're in Iowa."
(你以为找到了一座要爬的山,结果突然就掉到了爱荷华平原)
最严重的情况是产生 Inf 或 NaN,整个模型权重失效,必须从更早的 checkpoint 重新加载。
解决方案非常朴素:梯度裁剪(gradient clipping)。Pascanu 的伪代码是:
直觉:保持梯度方向不变,但把它的长度缩到阈值以内——"还是上那座山,只是步子小点"。 PyTorch 一行:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Hochreiter & Schmidhuber 在 1997 年发明 LSTM 的核心洞察是:给隐藏状态一条"线性高速公路"(cell state $\boldsymbol{c}^{(t)}$),让梯度可以无衰减地流过。LSTM 的关键方程是:
关键性质:当 forget gate $\boldsymbol{f}^{(t)} \approx 1$ 时, $\partial \boldsymbol{c}^{(t)}/\partial \boldsymbol{c}^{(t-1)} \approx I$(恒等阵),梯度可以无衰减传播。 这把有效记忆长度从 10-20 步推到 100-200 步。
正是 (2) 和 (3) 这两点,催生了下一节的 attention 机制:与其指望信息慢慢传递过来, 不如让解码端直接跳过去看编码端的所有隐藏状态。
机器翻译(Machine Translation, MT)是第一个被深度学习彻底打败传统方法的 NLP 大任务。 理解 MT 的演化史,相当于理解 NMT、attention、Transformer、乃至现代 LLM 的第一性原理。 本部分按时间顺序梳理:MT 的形式化定义 → 2014 年 seq2seq → 2015 年 attention → 2017 年 Transformer 一统江湖。
机器翻译是:给定源语言句子 $x = (x_1, \dots, x_S)$(比如法语),输出目标语言句子 $y = (y_1, \dots, y_T)$(比如英语)。 比如:
| 变量 | 例子 |
|---|---|
| 源语言 $x$ | "il m'a entarté"(法语) |
| 目标语言 $y$ | "he hit me with a pie"(英语) |
MT 的核心就是建模条件概率 $P(y\mid x)$。利用链式法则,可以拆成"一个一个 token 预测":
注意两个关键词:
Bayesian 视角下,这其实是 noisy-channel model 的简化版。SMT 时代曾用 $P(y\mid x) \propto P(x\mid y)\,P(y)$(翻译模型 × 目标语言模型),而 NMT 直接建模 $P(y\mid x)$,因此模型架构和训练流程都大大简化。
2013 年以前,Google Translate 用的是统计机器翻译(SMT): 基于短语对齐、n-gram 语言模型、log-linear feature combination,由数百名工程师维护多年的庞大代码库。 然后 2014 年 Sutskever, Vinyals & Le 在 Google Brain 提出 seq2seq,2016 年 Google Translate 全线切到 NMT。
| 对比维度 | SMT (2003-2015) | NMT (2014-) |
|---|---|---|
| 开发周期 | 数百工程师 · 多年 | 小团队 · 数月 |
| 组件数 | 10+ 子系统(aligner / language model / reordering / decoder ...) | 1 个端到端神经网络 |
| 需要语言学知识 | 极多(morphology, syntax, alignment) | 几乎为零 |
| WMT'14 EN-DE BLEU | ~20 | 28.4 (Transformer) |
| 训练数据要求 | 平行语料即可 | 同左,但要 10× 量 |
2018 年起,所有主流翻译系统(Google, Microsoft, Baidu, DeepL, Tencent, ByteDance)都迁移到了 NMT。 这是深度学习在 NLP 第一次"完胜并彻底取代"传统方法。
Seq2seq 是一个由两个 RNN组成的架构:
形式化地:
$E_x, E_y$ 是源/目标词嵌入矩阵,$W_o$ 是输出投影矩阵。整个系统的参数:编码器 RNN、解码器 RNN、两个 embedding 矩阵、输出层 —— 可联合训练,反向传播自然贯穿两端。这就是 seq2seq 优雅之处。
Seq2seq 训练目标是极大似然:给定 $N$ 对平行句子 $\{(x^{(i)}, y^{(i)})\}$,最大化:
训练时关键技巧叫 Teacher Forcing:第 $t$ 步预测 $y_t$ 时, decoder 的输入是真实的 $y_{t-1}$(来自 ground truth),而非上一步模型自己预测的 token。 这样训练稳定且高效;缺点是训练-推理分布偏差(exposure bias),后来 scheduled sampling, RL fine-tuning 等技术试图缓解,但 LM 标准做法仍是 teacher forcing。
回到图 2.1 那个金色框:encoder 把整个源句子压缩成一个固定维度向量。 这意味着:无论源句子是 5 个词还是 50 个词,所有的语义信息都必须装进 $\boldsymbol{c} \in \mathbb{R}^d$ 这一个向量里。
实验上这表现为:
更深层次地,瓶颈实际上叠加了三重问题:
解决方案有两条历史路径:
下一部分我们详细看 attention 怎么把这个瓶颈彻底拆掉。
Attention 不是 Vaswani 等人 2017 年的发明 — Bahdanau et al. 在 2014 年 9 月已经把它用在 NMT 上。 但 2014–2016 年的 attention 仍然附加在 RNN 之上,作为"补丁"使用。 2017 年的关键洞察是:既然 attention 这么有用,为什么还要 RNN? 本节先讲 attention 在 seq2seq 中的形式,第四部分再讲怎么"独立"出来变成 self-attention。
回想瓶颈问题:decoder 只能看到一个固定向量 $\boldsymbol{c}$。Attention 的直觉是:
解码每一步可能关注源句的不同部分。生成 "he" 时关注 "il",生成 "hit" 时关注 "entarté",生成 "pie" 时也关注 "entarté"。 让 decoder 主动选择看哪里,比把所有信息塞进一个向量更合理。
这种"主动选择"模仿了人类翻译的过程:你不会先把整个法语句子背下来再开始写英文,而是翻一段对应一段。 人类的视线在源文和译文之间来回扫,这就是 "attention" 一词的来源。
在 seq2seq + attention 模型中,每一步 decoder 都执行下面 4 步:
理解 attention 最有用的角度是把它看作"软化的字典查表"。在普通的 Python 字典里:
d = {"apple": 1, "banana": 2, "cherry": 3}
d["banana"] # → 2,精确匹配
普通查表是硬的(hard):查 "banana",直接返回对应的 value,其他 key 完全忽略。 但如果我们 query 的不是字典里的精确 key,比如 "banan"?只能报 KeyError。
Attention 是软查表(soft lookup):
不再精确匹配,而是 query 与每个 key 做相似度(点积),转成概率分布,再用这个分布对所有 values 加权求和:
为什么"软"很关键?因为软查表是处处可导的。如果你硬要 argmax(hard attention),梯度无法回传,必须用 REINFORCE 等技巧。 而 softmax 让整个流程端到端可微,可用普通反向传播训练。
$e_j = f(s, h_j)$ 这个相似度函数 $f$ 有几种历史选择:
| 名称 | 公式 | 来源 | 特点 |
|---|---|---|---|
| 加性 (additive) aka Bahdanau | $v^\top\tanh(W_1 s + W_2 h_j)$ | Bahdanau 2014 | 表达力强;需要额外参数 $W_1, W_2, v$;计算慢 |
| 点积 (dot product) aka Luong-dot | $s^\top h_j$ | Luong 2015 | 无参数;要求 $s, h$ 维度相同;快 |
| 双线性 (bilinear) aka Luong-general | $s^\top W h_j$ | Luong 2015 | 有参数;允许不同维度;中速 |
| 缩放点积 (scaled dot) | $\dfrac{s^\top h_j}{\sqrt{d}}$ | Vaswani 2017 | 无参数;防止 softmax 饱和(见 5.3 节) |
Transformer 选择缩放点积,因为:
但要注意一个新引入的代价:计算复杂度 $O(S \cdot T)$(源长 × 译长)。在 self-attention 里这会变成 $O(n^2)$, 是后面所有"高效 Transformer"研究的源头。
本节最后我们把 attention 抽象到最一般的形式。一个 attention 模块由三部分输入构成:
核心计算:
"软查找表"的解读:$n_q$ 个 query 同时查询一张含 $n_k$ 条记录的表,每条记录由 key(索引)和 value(内容)组成。
对应到 seq2seq + attention:
| 角色 | seq2seq + attention | 来源 RNN |
|---|---|---|
| Query | decoder 当前隐状态 $s^{(t)}$ | Decoder RNN |
| Key | encoder 隐状态 $h^{(j)}$ | Encoder RNN |
| Value | encoder 隐状态 $h^{(j)}$(与 key 相同) | Encoder RNN |
在 seq2seq + attention 里 key = value,这只是历史巧合。第四部分我们会看到 self-attention 把它们分开了: 通过不同的投影矩阵 $W_K, W_V$ 让"索引信息"和"内容信息"可以各自学习,互不干扰。
2017 年的关键飞跃是:把 attention 从 RNN 上"剥离"出来,让它独立成为一种序列建模单元。 新名字叫 self-attention(自注意力):一个序列对自己做 attention。 本部分逐步推导 self-attention 的数学定义、它无法独立工作的三大障碍、以及对应的工程补丁。
仔细审视一下:上一节的 attention 究竟在做什么?它接受一个 query,对一组 keys/values 做加权检索,输出一个向量。 而 RNN 在做什么?它接受当前输入 $x_t$ 与前一时刻状态 $h_{t-1}$,输出一个向量。
两者的本质都是"信息聚合"。Attention 是"从所有位置加权聚合",RNN 是"沿着时间链聚合"。 那能不能用 attention 完全替代 RNN?
答案是肯定的,而且收益是巨大的:
| 维度 | RNN | Self-Attention |
|---|---|---|
| 并行化 | 串行 $O(n)$ 步 | 完全并行,1 步矩阵乘 |
| 最长路径长度 | $O(n)$(信号要走 n 步) | $O(1)$(任意两位置直接相连) |
| 每层计算复杂度 | $O(n \cdot d^2)$ | $O(n^2 \cdot d)$ |
| 梯度通路 | 易消失/爆炸 | 直接通路,稳定 |
当 $n < d$(短序列、高维度,比如 $n=128, d=1024$)时 self-attention 的总 FLOPs 反而更少。 即使 $n > d$,并行化带来的 wall-clock 加速依然显著。
设输入为一个词序列 $\boldsymbol{w}_1, \dots, \boldsymbol{w}_n$(例如 "Zuko made his uncle tea")。每个 $\boldsymbol{w}_i$ 通过词嵌入矩阵 $E \in \mathbb{R}^{d \times |V|}$ 转成向量: $\boldsymbol{x}_i = E\boldsymbol{w}_i \in \mathbb{R}^d$。
Self-attention 的三步骤:
关键观察:同一个 $\boldsymbol{x}_i$ 同时充当 query、key、value 三个角色,但经过不同投影矩阵。 这就是 "self-attention" 名字的由来——序列对自己做 attention。
现在让我们诚实地问:能不能把 self-attention 一层接一层堆叠,直接做语言模型?答案是不能。 存在三个根本性障碍:
| 障碍 | 问题描述 | 方案 |
|---|---|---|
| 1. 无顺序信息 | Self-attention 对输入做的是置换不变(permutation invariant)的操作 — 把 "猫吃鱼" 重排成 "鱼吃猫" 输出相同! | 位置编码 (positional encoding) |
| 2. 无非线性 | Self-attention 本质是 value 的加权平均(线性),堆叠多层只是重新加权再加权,仍是线性变换 — 无深度学习"魔法" | 逐位置 FFN(feed-forward network) |
| 3. 偷看未来 | 语言模型要求时刻 $t$ 只能看 $1\dots t-1$ 的词。Self-attention 默认看全部位置 — 训练时变成作弊 | 因果掩码 (causal mask) |
解决第一个障碍:给每个位置 $i$ 分配一个位置向量 $\boldsymbol{p}_i \in \mathbb{R}^d$,然后把它加到词嵌入上:
之后所有计算用 $\widetilde{\boldsymbol{x}}_i$ 取代原始 $\boldsymbol{x}_i$。这一步只在第一层输入做(深层 self-attention 网络只在最底层注入位置信息)。
用不同频率的正弦/余弦构造位置向量。第 $i$ 位置的第 $2k$ / $2k+1$ 维:
优点:
缺点:
更简单粗暴的做法:直接学一个 $\boldsymbol{p} \in \mathbb{R}^{d \times n_{\max}}$ 矩阵,每列对应一个位置。BERT、GPT-1/2、RoBERTa 都用这种方式。
| Sinusoidal | Learned | |
|---|---|---|
| 参数量 | 0 | $d \times n_{\max}$(百万级) |
| 外推 | 理论可,实际差 | 不可(超出 $n_{\max}$ 没意义) |
| 灵活性 | 固定模式 | 每位置自由学 |
| 原始论文使用 | Vaswani 2017 | BERT, GPT-1/2 |
后续相对位置编码(relative, Shaw et al. 2018)、RoPE(Su et al. 2021,被 LLaMA / GPT-NeoX / Mistral 采用) 和 ALiBi(Press et al. 2022,BLOOM 采用)在外推性与建模能力上更强,详见第六部分。
第二个障碍:self-attention 是线性的。证明:
修复办法:在每个 self-attention 层之后加一个逐位置(position-wise)的两层 FFN,引入 ReLU 非线性:
"逐位置"意思是每个位置独立用同一个 FFN 处理——这两层 MLP 在所有位置间共享权重,但不混合位置信息。 混合信息的工作完全交给 self-attention。Transformer 因此是"attention 混位置,FFN 加非线性"的清晰分工。
第三个障碍:训练自回归语言模型时,预测 $y_t$ 只能看 $y_1, \dots, y_{t-1}$。但 self-attention 默认看整个序列,会"偷看" $y_t, y_{t+1}, \dots$。 直接做法是逐位置改变 keys/queries 的集合 — 但这无法并行(每个位置不同长度)。
聪明的并行解法:把"未来位置"的 attention score 设为 $-\infty$,softmax 之后这些位置的权重自动变成 0:
PyTorch 实现一行:
mask = torch.triu(torch.full((n, n), float('-inf')), diagonal=1)
scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)
scores = scores + mask # 上三角变 -inf
attn = scores.softmax(dim=-1) # softmax 后上三角自动 = 0
注意三种 attention 的 mask 配置:
| 模型类型 | Mask | 用途 |
|---|---|---|
| Encoder(双向) | 无 | BERT, T5 encoder:理解任务 |
| Decoder(因果) | 下三角 | GPT 系列:生成任务 |
| Encoder-Decoder Cross-Attention | 无(query 总能看完整 encoder) | 翻译、摘要 |
| 障碍 | 方案 |
|---|---|
| 无顺序信息 | 加位置编码 $\boldsymbol{p}_i$ 到输入 |
| 无非线性 | 逐位置 FFN(两层 + ReLU) |
| 偷看未来 | 上三角设 $-\infty$ 因果掩码 |
把这三块拼起来,就已经是一个最简版 Transformer block(minimal self-attention building block):
下一部分我们再加 4 个工程优化:multi-head、scaled dot-product、残差、layer-norm。这就完整还原 Vaswani 2017 论文里的 Transformer Decoder。
上一节我们得到了一个最简 self-attention building block。这一节加上 4 个让它"真的能 train 得起来"的工程优化: 矩阵化批量计算、multi-head、scaled dot product、残差连接 + Layer Norm。最终拼装成 Vaswani et al. 2017 的标准 Transformer。
之前我们写的 self-attention 是"逐位置"形式(公式里有 $\sum_j$)。为了 GPU 并行,必须改写成矩阵乘法。
令 $X = [\boldsymbol{x}_1; \cdots; \boldsymbol{x}_n] \in \mathbb{R}^{n \times d}$ 为所有输入向量堆叠。则:
所有两两 query-key 点积一次性算出来:
沿最后一维 softmax,再乘以 values:
比起 RNN 的 $n$ 步串行,这里只需两次大矩阵乘。在 GPU 上利用 cuBLAS 可以做到 $> 90\%$ peak FLOPs。
单一 self-attention 有一个根本局限:每个 query 只能产生一组权重分布。 但一个词在句子中可能同时和多个东西有不同性质的关系(语法、语义、共指、长距等)。让我们看一个 PDF 中的例子(p.56):
Vaswani 等人的解决方案:同时训 $h$ 套 Q/K/V 矩阵,每套独立做 attention,最后拼接:
关键设计:每个 head 的维度是 $d/h$(不是 $d$),所以所有 heads 拼起来正好是 $d$ 维。 计算量与单 head $d$-维 attention 相同,但表达力更强。
q = self.W_q(x).view(B, n, h, d_k).transpose(1, 2)
实证上,常用 head 数 $h$ 与维度 $d$:
| 模型 | $d$ | $h$ | $d/h$ |
|---|---|---|---|
| Transformer Base (2017) | 512 | 8 | 64 |
| BERT-Base | 768 | 12 | 64 |
| BERT-Large | 1024 | 16 | 64 |
| GPT-3 175B | 12288 | 96 | 128 |
| LLaMA-2 70B | 8192 | 64 | 128 |
注意 $d/h$ 几乎总是 64–128 —— 这不是巧合,与 GPU tensor core 的 16/32/64 对齐有关。
为什么要除以 $\sqrt{d_k}$?让我们做一个简单概率推导。
假设 $\boldsymbol{q}, \boldsymbol{k} \in \mathbb{R}^{d_k}$ 的各分量都是 $\mathcal{N}(0,1)$ 独立采样。那么它们的点积:
所以 $\boldsymbol{q}^\top \boldsymbol{k}$ 的标准差是 $\sqrt{d_k}$。当 $d_k = 64$ 时,分数典型量级是 $\pm 8$;当 $d_k = 128$ 时,是 $\pm 11.3$。
softmax 对输入尺度极其敏感:输入差 1 时概率比是 $e \approx 2.7$;差 10 时是 $e^{10} \approx 22000$;差 20 时是 $5 \times 10^8$!
解决方案就是缩放:
除以 $\sqrt{d_k}$ 后,分数的方差变回 1,softmax 输入处于"合理量级",梯度健康。
当我们把 Transformer 堆叠到 6 层、12 层、24 层、48 层、96 层(GPT-3)时,纯堆叠会再次遇到梯度消失(不是 RNN 那种沿时间的消失,而是沿深度的消失)。 He et al. 2016 在 ResNet 中提出的残差连接是普世解:
直觉:让恒等映射(identity mapping)成为基线,子层 $\mathrm{Layer}(\cdot)$ 只需要学习"残差"。
反向传播时:
即使 Layer 的 Jacobian 接近 0,梯度仍能通过 $I$(恒等)项无衰减回传。这是训练超深网络(>50 层)的必要条件。 Li et al. 2018 的 loss landscape 可视化证明:有残差的网络 loss landscape 平滑得多,更容易优化。
Layer Normalization (Ba et al. 2016) 解决另一个工程难题:不同层、不同位置的激活值尺度差异巨大,导致训练不稳定。
对每个位置的隐藏向量 $\boldsymbol{x} \in \mathbb{R}^d$,LayerNorm 沿特征维度归一化:
| BatchNorm | LayerNorm | |
|---|---|---|
| 归一化轴 | 沿 batch 维(同一特征在不同样本上) | 沿特征维(同一样本的所有特征) |
| 需要 batch | 是(小 batch 不稳) | 否(单样本也行) |
| 训练/推理一致 | 否(推理用 running stats) | 是 |
| 序列长度敏感 | 是 | 否 |
LayerNorm 在 Transformer 流行后逐渐成了序列模型的事实标准。最近的 LLaMA / GPT-NeoX 使用 RMSNorm(去掉均值减去)也工作得很好,速度更快。
把 5.1–5.5 节的所有组件拼起来:
这是"Post-LN"版本(先 add 再 norm),原始 Vaswani 2017 的形式。
三种 Transformer 变种:
| 变种 | Mask | 典型代表 | 用途 |
|---|---|---|---|
| Decoder-only | 因果(下三角) | GPT 系列, LLaMA, Mistral | 语言生成、对话 |
| Encoder-only | 无 mask(双向) | BERT, RoBERTa, DeBERTa | 分类、检索、NER |
| Encoder-Decoder | encoder 无 / decoder 有 | T5, BART, mT5, NMT | 翻译、摘要、SeqSeq |
Encoder 与 Decoder 唯一区别:去掉因果掩码,让每个位置看整个序列。BERT 因此称为"双向"Transformer。
Encoder-Decoder 则把两者组合:
Cross-attention 是 encoder-decoder Transformer 的关键。它让 decoder 在生成时"回看"源句信息。 形式上与 self-attention几乎一模一样,只是 Q 和 K/V 来源不同:
令:
则:
$$ \boldsymbol{k}_i = K\,\boldsymbol{h}_i,\quad \boldsymbol{v}_i = V\,\boldsymbol{h}_i,\quad \boldsymbol{q}_j = Q\,\boldsymbol{z}_j $$ $$ \mathrm{CrossAttn} = \mathrm{softmax}\!\Big(\frac{QK^\top}{\sqrt{d_k}}\Big)\,V $$2017 年那篇论文只是 Transformer 的起点。过去 8 年里,Transformer 经历了大量"非原始"改进, 其中很多是 PDF 课堂上没有展开的研究生必备内容:Pre-LN、RoPE、Flash Attention、KV cache、MQA/GQA。 本部分系统串讲这些演化。
Transformer 在 WMT 2014 翻译任务上一鸣惊人:
| 模型 | EN-DE BLEU | EN-FR BLEU | 训练 FLOPs |
|---|---|---|---|
| ByteNet (CNN) | 23.75 | — | — |
| GNMT + RL | 24.6 | 39.92 | $2.3 \times 10^{19}$ |
| ConvS2S | 25.16 | 40.46 | $9.6 \times 10^{18}$ |
| MoE | 26.03 | 40.56 | $2.0 \times 10^{19}$ |
| GNMT + RL Ensemble | 26.30 | 41.16 | $1.8 \times 10^{20}$ |
| ConvS2S Ensemble | 26.36 | 41.29 | $7.7 \times 10^{19}$ |
| Transformer (base) | 27.3 | 38.1 | $3.3 \times 10^{18}$ |
| Transformer (big) | 28.4 | 41.8 | $2.3 \times 10^{19}$ |
亮点:Transformer base 用少一个数量级的 FLOPs 击败所有 RNN/CNN baselines; Transformer big 在 EN-DE 上比之前最好高 1.5+ BLEU,在 EN-FR 创下新 SOTA。
更震撼的是文档生成(WikiSum, Liu et al. 2018):原来 seq2seq+attn 的 perplexity 是 5.05、ROUGE-L 12.7, Transformer 把 perplexity 拉到 1.90、ROUGE-L 38.8 — "Transformers all the way down"。
下面几节讲针对这些痛点的现代解。
原始 Transformer 是 Post-LN:先 add 残差,再 LayerNorm。 Xiong et al. 2020 ("On Layer Normalization in the Transformer Architecture") 证明,Pre-LN(先 norm 再 attention,残差直接相加)训练更稳定。
| Post-LN | Pre-LN | |
|---|---|---|
| 训练稳定性 | 需 warmup 学习率,容易发散 | 无需 warmup,可用大学习率 |
| 深度可堆叠性 | >12 层易梯度爆炸 | 可堆 100+ 层(GPT-3, LLaMA) |
| 峰值性能 | 调好后略优 | 略弱但稳定 |
| 采用模型 | 原始 Transformer, BERT | GPT-2/3, LLaMA, PaLM, T5, Mistral |
现代几乎所有 LLM 都用 Pre-LN。BERT 是少有的 Post-LN 例外(因为已工业部署,难以改)。
原始 sinusoidal / learned absolute position 有两个缺点:
不在输入加位置向量,而是在 attention 内部对 query/key 做"旋转"。把 $\boldsymbol{q}_i \in \mathbb{R}^d$ 拆成 $d/2$ 个 2D 子空间,每个子空间用旋转矩阵 $R_\theta$ 旋转 $i\theta$ 角度:
优点:
LLaMA 1/2/3、Mistral、PaLM、Qwen、GLM-4 等绝大多数主流 LLM 都用 RoPE。
更激进:完全不要位置编码,只在 attention 分数上加一个"距离惩罚"线性偏置:
ALiBi 的核心好处是极强外推:训练时 $n=1024$,推理时 $n=16384$ 仍可工作。BLOOM、MPT 等模型采用 ALiBi。
Dao et al. 2022 的 FlashAttention 是工程上最重要的 Transformer 优化之一。 它没有改变数学,只是改了 GPU 内存访问方式。
原始 attention 在 GPU 上的瓶颈不是 FLOPs,而是HBM ↔ SRAM 之间的 IO:
| HBM (High Bandwidth Memory) | SRAM (片上 cache) | |
|---|---|---|
| 容量 | 40-80 GB | ~ 20 MB |
| 带宽 | ~1.5 TB/s | ~19 TB/s(10×+) |
| 访问代价 | 慢 | 快 |
原始实现需要:
显存中实例化的 $n \times n$ 矩阵和反复读写 HBM 是慢的元凶。FlashAttention 用 tiling + online softmax, 把 Q, K, V 分块处理,每个块在 SRAM 里走完整流程,不实例化完整 attention 矩阵。
torch.nn.functional.scaled_dot_product_attention (PyTorch 2.0+) 默认调用注意:FlashAttention 的数学结果与原始 attention 完全相同(在浮点误差内),只是更快。这是工程优化的典范。
自回归生成一个长度 $n$ 序列时,第 $t$ 步的 attention 需要重新计算 $1, \dots, t-1$ 的 K, V 投影。 但这些 K, V 不依赖当前 query——是纯函数$f(x_{1:t-1})$,可以缓存!
推理时间从 $O(n^2)$ 降到 $O(n)$(生成 $n$ tokens 总共)。代价是显存——KV cache 大小 $= 2 \cdot L \cdot n \cdot d$($L$ 层)。 对于 LLaMA-2 70B 在 $n=4096$ 时 KV cache 已经超过 5GB / 用户,是当代 LLM 服务的关键瓶颈。
多头 attention 每 head 都有独立的 K, V,缓存巨大。Multi-Query Attention(MQA, Shazeer 2019)让 $h$ 个 head 共享同一组 K, V:
缺点:质量下降。Grouped-Query Attention(GQA, Ainslie et al. 2023)折中——把 $h$ 个 head 分成 $g$ 组(如 $g=8$),每组共享一组 K/V:
| 变种 | $K, V$ 数量 | 缓存大小 | 质量 | 采用 |
|---|---|---|---|---|
| MHA | $h$ 组 | ×$h$ | 最佳 | 原始 Transformer |
| MQA | 1 组 | ×1(极省) | 下降 | PaLM, Falcon |
| GQA | $g$ 组 ($g \ll h$) | ×$g$(中等) | 接近 MHA | LLaMA-2/3, Mistral, Qwen |
课堂 PDF 最后一页(p.71)提了一个有趣问题:"既然 attention 在大模型中只占少部分计算,为什么还要追求线性 attention?"
这一节给出一个完整可运行的 minimal Transformer 实现,约 150 行核心代码。 适合直接复制运行、与 Andrej Karpathy 的 nanoGPT 互相参照。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def scaled_dot_product_attention(q, k, v, mask=None):
"""
q, k, v: (..., n, d_k)
mask: (..., n, n) of 0/1,1 表示可见,0 表示屏蔽
返回: (..., n, d_k)
"""
d_k = q.size(-1)
scores = q @ k.transpose(-2, -1) / math.sqrt(d_k) # (..., n, n)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1) # (..., n, n)
return attn @ v # (..., n, d_v)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x_q, x_kv, mask=None):
B, n_q, _ = x_q.shape
_, n_kv, _ = x_kv.shape
# 1) project & reshape to (B, h, n, d_k)
q = self.W_q(x_q ).view(B, n_q, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x_kv).view(B, n_kv, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x_kv).view(B, n_kv, self.n_heads, self.d_k).transpose(1, 2)
# 2) scaled dot-product attention
out = scaled_dot_product_attention(q, k, v, mask) # (B, h, n_q, d_k)
# 3) concat heads and project out
out = out.transpose(1, 2).contiguous().view(B, n_q, self.d_model)
return self.dropout(self.W_o(out))
注:传 x_q == x_kv 就是 self-attention;传不同就是 cross-attention。
class FFN(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.relu(self.w1(x))))
# 现代版用 GELU 或 SwiGLU 替换 ReLU
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FFN(d_model, d_ff, dropout)
def forward(self, x, mask=None):
# Pre-LN: 先 norm, 再子层, 残差
x = x + self.attn(self.ln1(x), self.ln1(x), mask)
x = x + self.ffn(self.ln2(x))
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe) # 不当作参数
def forward(self, x):
# x: (B, n, d_model)
return x + self.pe[:x.size(1)]
class GPTMini(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6,
d_ff=2048, max_len=512, dropout=0.1):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model, max_len)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
# weight tying(GPT-2 trick)
self.head.weight = self.tok_emb.weight
def forward(self, idx, targets=None):
B, n = idx.shape
x = self.tok_emb(idx) # (B, n, d_model)
x = self.pos_emb(x)
causal = torch.tril(torch.ones(n, n, device=idx.device)).unsqueeze(0).unsqueeze(0)
for blk in self.blocks:
x = blk(x, mask=causal)
x = self.ln_f(x)
logits = self.head(x) # (B, n, vocab_size)
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100
)
return logits, loss
return logits, None
# 训练一个 batch 的 minimal 示例
model = GPTMini(vocab_size=50257)
x = torch.randint(0, 50257, (2, 64)) # B=2, n=64
y = torch.randint(0, 50257, (2, 64))
logits, loss = model(x, y)
loss.backward()
print(f"loss = {loss.item():.4f}, params = {sum(p.numel() for p in model.parameters()):,}")
q @ k.transpose(-2, -1)/ math.sqrt(d_k).view(B, n, h, d_k).transpose(1, 2)torch.tril(torch.ones(n, n))x = x + self.attn(...)nn.LayerNorm(d_model)w2(ReLU(w1(x)))PositionalEncoding 类x + self.attn(self.ln1(x), ...)MultiHeadAttention.forward 中对 $q, k$ 做旋转,不传 absolute position embedding。参考实现。forward 增加 past_kv 参数,递增生成时只算新 token 的 attention。scaled_dot_product_attention 替换为
F.scaled_dot_product_attention(q, k, v, is_causal=True)(PyTorch ≥ 2.0),测 forward+backward 时间。