Attention 与 Transformer 完全指南

面向 NLP 研究生的深度教材 · 改编自 Stanford CS224N Winter 2026 Lecture 5 (Diyi Yang)
覆盖:RNN 梯度病理 · Seq2Seq 与 NMT · 加性/点积注意力 · Self-Attention · 位置编码 (sinusoidal/learned/RoPE) · Multi-Head · 残差与 LayerNorm · Transformer Encoder/Decoder/Cross-Attention · Pre-LN vs Post-LN · Flash Attention · KV cache
版本 1.0 · 2026 春 · 配套姊妹篇 Pretraining · Post-training · PEFT

引论:为什么 Transformer 是分水岭

"Attention is all you need." — Vaswani et al., NeurIPS 2017

2017 年 6 月的某个深夜,Ashish Vaswani 在 Google Brain 提交了一篇 8 位作者的论文,题目是当时看起来挑衅意味十足的 《Attention Is All You Need》。论文里既没有循环(recurrence)也没有卷积(convolution), 只用一种叫 self-attention 的几何运算,在 WMT 2014 英德翻译上把 BLEU 推到 28.4,训练时间却只是当时最强 GNMT 的零头。

在那之前,整个 NLP 领域有一个稳固的共识——序列必须由循环网络处理。LSTM 是 1997 年的产物, seq2seq + attention(Bahdanau 2014, Luong 2015)经过 3 年打磨已经统治 NMT。突然之间, 一个本质上"非时序"的并行架构,把这个共识彻底颠覆。这场颠覆的余波席卷了之后整个 LLM 浪潮:BERT (2018)、GPT-2 (2019)、 GPT-3 (2020)、ChatGPT (2022)、GPT-4 (2023)——所有这些模型的核心积木,都是 2017 年那篇论文里定义的 Transformer block。

本教材的写法
我们不会把 Transformer 当作天上掉下来的奇迹,而是沿着 2014–2017 历史路径,从 RNN 的两个数学病理(梯度消失、信息瓶颈)出发, 一步一步推导:为什么需要 attention → 为什么 attention 单独不够 → 为什么需要 position 信息、非线性、masking → 为什么需要 multi-head → 为什么需要残差和 LayerNorm。每一个组件都对应一个具体的失败模式,而非任意选择。

本课在 CS224N 中的位置

前置课程
Lecture 1-4:词向量、神经网络、依存解析、RNN 语言模型与反向传播
本课主线
Lecture 5:从 RNN 的失败 → Attention → Self-Attention → 完整 Transformer
后续课程
Lecture 7:预训练范式 (BERT/T5/GPT) — Lecture 8:后训练 (SFT/RLHF) — Lecture 9:PEFT — Lecture 10:RAG
作业映射
Assignment 4:Seq2Seq + Attention NMT(动手实现本课所有数学)

本教材的结构图

flowchart TD A["第一部分
RNN 的两个病理"] --> B["第二部分
Seq2Seq 与瓶颈"] B --> C["第三部分
Cross-Attention 救场"] C --> D["第四部分
Self-Attention 抛弃 RNN"] D --> E["第五部分
完整 Transformer
(Multi-Head + Add&Norm + FFN)"] E --> F["第六部分
现代演化
(Pre-LN/RoPE/Flash/KV)"] F --> G["附录
PyTorch 实现"] style A fill:#fbecec style B fill:#fff7e6 style C fill:#eaf3fb style D fill:#eafaf0 style E fill:#e8f1f4 style F fill:#fdf6e9 style G fill:#f4f1e8

每一节都按这个三段式组织:(1) 直觉 / 历史动机 → (2) 数学公式与推导 → (3) PyTorch 代码 / 思考题。 读者不需要事先了解 attention,但需要熟悉矩阵乘法、softmax、链式法则与反向传播。

第一部分:RNN 的两个梯度病理

要理解 Transformer 为什么这样设计,必须先理解它替代的 RNN 为何失败。本节我们把 Pascanu et al. (2013) 那篇被引超过 6000 次的论文 "On the difficulty of training recurrent neural networks" 的数学结论搬到台面上, 精确说明梯度消失梯度爆炸是怎么从链式法则的几何结构里诞生的。

1.1 RNN 反向传播链:直觉与矩阵化

一个 vanilla(朴素)RNN 在时刻 $t$ 的更新方程是:

RNN 隐藏状态更新
$$ \boldsymbol{h}^{(t)} = \sigma\!\big(\boldsymbol{W}_{\!h}\,\boldsymbol{h}^{(t-1)} + \boldsymbol{W}_{\!x}\,\boldsymbol{x}^{(t)} + \boldsymbol{b}\big) $$

其中 $\sigma$ 通常是 $\tanh$ 或 sigmoid,$\boldsymbol{W}_h, \boldsymbol{W}_x \in \mathbb{R}^{d\times d}$ 是所有时刻共享的权重矩阵。 "所有时刻共享"这五个字是接下来一切问题的根源。

假设我们要计算第 4 步的损失 $J^{(4)}(\theta)$ 对第 1 步隐藏状态 $\boldsymbol{h}^{(1)}$ 的梯度。这是个"反向传播穿越时间"(BPTT)问题:

h⁽¹⁾ h⁽²⁾ h⁽³⁾ h⁽⁴⁾ J⁽⁴⁾(θ) W W W ∂J⁽⁴⁾/∂h⁽¹⁾ = ∂h⁽²⁾/∂h⁽¹⁾ · ∂h⁽³⁾/∂h⁽²⁾ · ∂h⁽⁴⁾/∂h⁽³⁾ · ∂J⁽⁴⁾/∂h⁽⁴⁾
图 1.1:RNN 前向传播(黑色)与反向传播(蓝色)。损失从 J⁽⁴⁾ 沿时序反传,需要连乘 3 次 ∂h⁽t⁾/∂h⁽t-1⁾。

由链式法则:

$$ \frac{\partial J^{(4)}}{\partial \boldsymbol{h}^{(1)}} = \underbrace{\frac{\partial \boldsymbol{h}^{(2)}}{\partial \boldsymbol{h}^{(1)}}}_{A_2} \cdot \underbrace{\frac{\partial \boldsymbol{h}^{(3)}}{\partial \boldsymbol{h}^{(2)}}}_{A_3} \cdot \underbrace{\frac{\partial \boldsymbol{h}^{(4)}}{\partial \boldsymbol{h}^{(3)}}}_{A_4} \cdot \frac{\partial J^{(4)}}{\partial \boldsymbol{h}^{(4)}} $$

其中每一个 Jacobian 矩阵 $A_t = \partial \boldsymbol{h}^{(t)}/\partial \boldsymbol{h}^{(t-1)}$ 都长成同一个形式:

$$ A_t = \mathrm{diag}\!\big(\sigma'(\boldsymbol{z}^{(t)})\big)\,\boldsymbol{W}_{\!h} $$ 其中 $\boldsymbol{z}^{(t)} = \boldsymbol{W}_h \boldsymbol{h}^{(t-1)} + \boldsymbol{W}_x \boldsymbol{x}^{(t)} + \boldsymbol{b}$

重点是:所有 $A_t$ 共用同一个 $\boldsymbol{W}_h$。如果时间窗口长度为 $T$,那么 $\partial J^{(T)}/\partial \boldsymbol{h}^{(1)}$ 就是 $T-1$ 个几乎相同矩阵的乘积。线性代数告诉我们:同一个矩阵自乘 $k$ 次,结果由其最大奇异值(或最大特征值,谱半径)主导

1.2 梯度消失:Pascanu 不等式与谱半径

为了把直觉变成定理,我们做一个简化:假设激活函数 $\sigma$ 是恒等函数(identity,$\sigma(z)=z$),所以 $\sigma'=1$。 那么 $A_t = \boldsymbol{W}_h$,且:

$$ \frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{h}^{(k)}} = \boldsymbol{W}_h^{\,t-k} $$

对 $\boldsymbol{W}_h$ 做特征值分解(假设可对角化):$\boldsymbol{W}_h = Q\,\Lambda\,Q^{-1}$,其中 $\Lambda = \mathrm{diag}(\lambda_1,\dots,\lambda_d)$。那么:

$$ \boldsymbol{W}_h^{\,t-k} = Q\,\Lambda^{t-k}\,Q^{-1} = Q\,\mathrm{diag}\!\big(\lambda_1^{\,t-k},\dots,\lambda_d^{\,t-k}\big)\,Q^{-1} $$

当 $t-k$ 很大(也就是反传跨越很多步)时:

Pascanu, Mikolov & Bengio (2013) 给出的精确不等式是:

Pascanu 上界 (sufficient condition for vanishing)
$$ \left\|\frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{h}^{(k)}}\right\| \;\le\; \prod_{i=k+1}^{t}\,\big\|\,\mathrm{diag}(\sigma'(\boldsymbol{z}^{(i)}))\,\big\|\cdot\|\boldsymbol{W}_h\| \;\le\;\gamma^{t-k} $$ 若 $\gamma < 1$,梯度按 $\gamma^{t-k}$ 指数衰减。

对 $\tanh$ 而言 $\sigma'(z) \in (0,1]$,所以即使 $\|\boldsymbol{W}_h\|$ 略大于 1,乘上对角项后仍可能整体收缩。 经验上 $\tanh$ RNN 的有效记忆长度通常只有 10-20 个时刻

为什么梯度消失是真正的灾难
不是"训练慢"那么简单——而是远距离信号根本传不到。看这个语言模型例句(来自 PDF p.10):
"When she tried to print her tickets, she found that the printer was out of toner. She went to the stationery store to buy more toner. It was very overpriced. After installing the toner into the printer, she finally printed her ______"
要在末尾填 "tickets",模型必须将第 7 步的 "tickets" 的信息传到第 50+ 步。 如果梯度按 $0.5^{43} \approx 10^{-13}$ 衰减,反传时这条依赖关系的梯度小到浮点数都无法表示, 所以模型训练时根本学不到"她要打印的是 tickets"这个事实,测试时自然预测错。

1.3 梯度爆炸与梯度裁剪

反方向上,如果 $\|\boldsymbol{W}_h\|$ 谱半径大于 1(更准确说:动力系统进入混沌区域),梯度就会指数爆炸。 SGD 更新规则是:

$$ \theta^{\text{new}} = \theta^{\text{old}} - \alpha\,\nabla_\theta J(\theta) $$

当 $\|\nabla_\theta J(\theta)\|$ 极大时(比如 $10^6$),即使学习率 $\alpha=10^{-3}$ 看似合理,单步更新仍会把参数推到一个 非常远且非常糟的位置。Pascanu 描述这种现象的比喻很经典:

"You think you've found a hill to climb, but suddenly you're in Iowa."
(你以为找到了一座要爬的山,结果突然就掉到了爱荷华平原)

最严重的情况是产生 InfNaN,整个模型权重失效,必须从更早的 checkpoint 重新加载。

解决方案非常朴素:梯度裁剪(gradient clipping)。Pascanu 的伪代码是:

Algorithm 1: Norm clipping
$$ \hat{\boldsymbol{g}} \leftarrow \frac{\partial \mathcal{E}}{\partial \theta} $$ $$ \text{if } \|\hat{\boldsymbol{g}}\| \ge \text{threshold:} $$ $$ \quad \hat{\boldsymbol{g}} \leftarrow \frac{\text{threshold}}{\|\hat{\boldsymbol{g}}\|}\,\hat{\boldsymbol{g}} $$

直觉:保持梯度方向不变,但把它的长度缩到阈值以内——"还是上那座山,只是步子小点"。 PyTorch 一行:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
实务记忆点
梯度爆炸容易解决(一行 clip 即可),梯度消失难以解决(需要改架构)。 所以本课大半篇幅是在讲怎么用 LSTM/attention/Transformer 绕开梯度消失,而不是讨论裁剪。

1.4 LSTM/GRU 为何只是创可贴

Hochreiter & Schmidhuber 在 1997 年发明 LSTM 的核心洞察是:给隐藏状态一条"线性高速公路"(cell state $\boldsymbol{c}^{(t)}$),让梯度可以无衰减地流过。LSTM 的关键方程是:

$$ \boldsymbol{c}^{(t)} = \boldsymbol{f}^{(t)} \odot \boldsymbol{c}^{(t-1)} + \boldsymbol{i}^{(t)} \odot \tilde{\boldsymbol{c}}^{(t)} $$ 其中 $\boldsymbol{f}^{(t)}$ 是 forget gate (sigmoid,∈[0,1]),$\odot$ 是逐元素乘。

关键性质:当 forget gate $\boldsymbol{f}^{(t)} \approx 1$ 时, $\partial \boldsymbol{c}^{(t)}/\partial \boldsymbol{c}^{(t-1)} \approx I$(恒等阵),梯度可以无衰减传播。 这把有效记忆长度从 10-20 步推到 100-200 步。

但 LSTM 仍有三个根本缺陷
  1. 仍是串行:要算 $\boldsymbol{h}^{(t)}$ 必须先算完 $\boldsymbol{h}^{(t-1)}$。GPU 并行化效率极低。
  2. 信息瓶颈:最后一个隐藏状态是固定维度向量,必须承担整段输入的所有信息。后面会看到这是 seq2seq 的致命伤。
  3. 仍是"消息传递":第 1 步要影响第 50 步,必须把信息原样保留经过 49 次门控运算——即使数学上可能,统计上仍然脆弱。

正是 (2) 和 (3) 这两点,催生了下一节的 attention 机制:与其指望信息慢慢传递过来, 不如让解码端直接跳过去看编码端的所有隐藏状态。

本节思考题(研究生)

  1. 用初始化技巧(如 orthogonal initialization)让 $\boldsymbol{W}_h$ 的奇异值恰为 1,能避免梯度问题吗?工程上为何不流行?
  2. 为什么 Pascanu 等人后来又提出 echo state networks 和 IRNN,但 LSTM 反而胜出?
  3. 如果 RNN 完全用 ReLU 激活(IRNN),最坏情况下梯度消失/爆炸还会发生吗?给出一个反例。

第二部分:机器翻译与 Seq2Seq 瓶颈

机器翻译(Machine Translation, MT)是第一个被深度学习彻底打败传统方法的 NLP 大任务。 理解 MT 的演化史,相当于理解 NMT、attention、Transformer、乃至现代 LLM 的第一性原理。 本部分按时间顺序梳理:MT 的形式化定义 → 2014 年 seq2seq → 2015 年 attention → 2017 年 Transformer 一统江湖。

2.1 MT 任务的概率视角:条件语言模型

机器翻译是:给定源语言句子 $x = (x_1, \dots, x_S)$(比如法语),输出目标语言句子 $y = (y_1, \dots, y_T)$(比如英语)。 比如:

变量例子
源语言 $x$"il m'a entarté"(法语)
目标语言 $y$"he hit me with a pie"(英语)

MT 的核心就是建模条件概率 $P(y\mid x)$。利用链式法则,可以拆成"一个一个 token 预测":

条件语言模型(Conditional Language Model)
$$ P(y\mid x) = \prod_{t=1}^{T}\,P(y_t \mid y_1, y_2, \dots, y_{t-1},\,x) $$

注意两个关键词:

Bayesian 视角下,这其实是 noisy-channel model 的简化版。SMT 时代曾用 $P(y\mid x) \propto P(x\mid y)\,P(y)$(翻译模型 × 目标语言模型),而 NMT 直接建模 $P(y\mid x)$,因此模型架构和训练流程都大大简化。

2.2 SMT 到 NMT 的范式跳变 (2014–2016)

2013 年以前,Google Translate 用的是统计机器翻译(SMT): 基于短语对齐、n-gram 语言模型、log-linear feature combination,由数百名工程师维护多年的庞大代码库。 然后 2014 年 Sutskever, Vinyals & Le 在 Google Brain 提出 seq2seq,2016 年 Google Translate 全线切到 NMT。

对比维度SMT (2003-2015)NMT (2014-)
开发周期数百工程师 · 多年小团队 · 数月
组件数10+ 子系统(aligner / language model / reordering / decoder ...)1 个端到端神经网络
需要语言学知识极多(morphology, syntax, alignment)几乎为零
WMT'14 EN-DE BLEU~2028.4 (Transformer)
训练数据要求平行语料即可同左,但要 10× 量

2018 年起,所有主流翻译系统(Google, Microsoft, Baidu, DeepL, Tencent, ByteDance)都迁移到了 NMT。 这是深度学习在 NLP 第一次"完胜并彻底取代"传统方法。

2.3 Encoder-Decoder 数学定义

Seq2seq 是一个由两个 RNN组成的架构:

il m' a entarté x₁ x₂ x₃ x₄ Source 句子 (input) context vector c ⟨START⟩ he hit me with a pie he hit me with a pie ⟨END⟩ Target 句子 (output) Encoder RNN Decoder RNN(语言模型,条件于 c)
图 2.1:Seq2Seq 翻译模型。最后一个 encoder 隐藏状态被截断成唯一一条信息通道(金色框)传给 decoder。

形式化地:

Encoder
$$ \boldsymbol{h}^{(t)} = \mathrm{RNN}_{\text{enc}}(\boldsymbol{h}^{(t-1)}, E_x x_t),\quad t=1,\dots,S $$ $$ \boldsymbol{c} = \boldsymbol{h}^{(S)} $$
Decoder
$$ \boldsymbol{s}^{(0)} = \boldsymbol{c} $$ $$ \boldsymbol{s}^{(t)} = \mathrm{RNN}_{\text{dec}}(\boldsymbol{s}^{(t-1)}, E_y y_{t-1}) $$ $$ P(y_t \mid y_{<t},\,x) = \mathrm{softmax}(W_o\,\boldsymbol{s}^{(t)}) $$

$E_x, E_y$ 是源/目标词嵌入矩阵,$W_o$ 是输出投影矩阵。整个系统的参数:编码器 RNN、解码器 RNN、两个 embedding 矩阵、输出层 —— 可联合训练,反向传播自然贯穿两端。这就是 seq2seq 优雅之处。

2.4 训练:Teacher Forcing 与负对数似然

Seq2seq 训练目标是极大似然:给定 $N$ 对平行句子 $\{(x^{(i)}, y^{(i)})\}$,最大化:

$$ \mathcal{L}(\theta) = \sum_{i=1}^{N}\sum_{t=1}^{T_i}\,\log P_\theta\!\big(y^{(i)}_t \mid y^{(i)}_{<t},\,x^{(i)}\big) $$ 等价于最小化 token 级的负对数似然(即交叉熵): $$ J = -\frac{1}{T}\sum_{t=1}^{T}\,\log P_\theta(y_t \mid y_{<t}, x) $$

训练时关键技巧叫 Teacher Forcing:第 $t$ 步预测 $y_t$ 时, decoder 的输入是真实的 $y_{t-1}$(来自 ground truth),而非上一步模型自己预测的 token。 这样训练稳定且高效;缺点是训练-推理分布偏差(exposure bias),后来 scheduled sampling, RL fine-tuning 等技术试图缓解,但 LM 标准做法仍是 teacher forcing。

所有 NLP 任务都能写成 seq2seq
  • Summarization:长文本 → 短文本
  • Dialogue:之前的对话 → 下一句
  • Parsing:句子 → 依存/成分句法树(线性化后)
  • Code generation:自然语言 → Python
  • QA:问题+文档 → 答案
这种统一性是 T5 和 instruction tuning 的思想源头。

2.5 瓶颈问题:一切的起源

回到图 2.1 那个金色框:encoder 把整个源句子压缩成一个固定维度向量。 这意味着:无论源句子是 5 个词还是 50 个词,所有的语义信息都必须装进 $\boldsymbol{c} \in \mathbb{R}^d$ 这一个向量里。

The quick brown fox jumps over the lazy dog despite the warning signs ... A cat is on the mat which is in the corner of the room ... Once upon a time in a far away kingdom there lived a brave knight ... c ∈ ℝᵈ 固定维度 译文 token 1 译文 token 2 ... 所有源句信息必须穿过这个"针孔" → bottleneck problem
图 2.2:Seq2Seq 的瓶颈:长句子的所有语义压缩成单一固定维度向量。

实验上这表现为:

Seq2Seq 性能随句长崩塌
Bahdanau et al. (2014) 在 WMT'14 EN-FR 上发现:vanilla seq2seq 在句长 30 词以内表现尚可, 超过 30 词后 BLEU 急速下降,60 词以上几乎不可用。原因就是单一 $\boldsymbol{c}$ 容量不够。

更深层次地,瓶颈实际上叠加了三重问题:

  1. 容量瓶颈(capacity):固定维度 $d$ 装不下任意长度的语义。
  2. 位置遗忘(recency bias):由于梯度消失,远古的源 token 信息更难保存到 $\boldsymbol{c}$ 中。
  3. 翻译对齐丢失(alignment):MT 本质上是局部对齐的(一个目标词对应几个源词),但 $\boldsymbol{c}$ 抹掉了所有位置结构。

解决方案有两条历史路径:

  1. 加深 + 加宽:用 8 层 LSTM、4096 维隐藏状态(GNMT, 2016 年 Google 上线版)— 缓解但不根治。
  2. 砍掉瓶颈:让 decoder 直接看到 encoder 的每一个隐藏状态 — 这就是 attention。

下一部分我们详细看 attention 怎么把这个瓶颈彻底拆掉。

本节思考题(研究生)

  1. 如果把 encoder 的所有隐藏状态 $\{\boldsymbol{h}^{(1)},\dots,\boldsymbol{h}^{(S)}\}$ 直接拼接给 decoder,会有什么问题?为什么非要用 attention 这种加权和形式?
  2. 用 bidirectional encoder(双向 LSTM)替代单向,能缓解瓶颈到什么程度?请从信息论角度估计。
  3. 从概率视角看,seq2seq 隐含了什么 Markov 假设?为什么 attention 让这个假设变得无关紧要?

第三部分:Attention 机制

Attention 不是 Vaswani 等人 2017 年的发明 — Bahdanau et al. 在 2014 年 9 月已经把它用在 NMT 上。 但 2014–2016 年的 attention 仍然附加在 RNN 之上,作为"补丁"使用。 2017 年的关键洞察是:既然 attention 这么有用,为什么还要 RNN? 本节先讲 attention 在 seq2seq 中的形式,第四部分再讲怎么"独立"出来变成 self-attention。

3.1 直觉:从单一向量到加权检索

回想瓶颈问题:decoder 只能看到一个固定向量 $\boldsymbol{c}$。Attention 的直觉是:

解码每一步可能关注源句的不同部分。生成 "he" 时关注 "il",生成 "hit" 时关注 "entarté",生成 "pie" 时也关注 "entarté"。 让 decoder 主动选择看哪里,比把所有信息塞进一个向量更合理。

这种"主动选择"模仿了人类翻译的过程:你不会先把整个法语句子背下来再开始写英文,而是翻一段对应一段。 人类的视线在源文和译文之间来回扫,这就是 "attention" 一词的来源。

3.2 数学定义:4 步计算流程

在 seq2seq + attention 模型中,每一步 decoder 都执行下面 4 步:

1计算 attention 分数(score)。对当前 decoder 隐藏状态 $\boldsymbol{s}^{(t)}$, 和每一个 encoder 隐藏状态 $\boldsymbol{h}^{(j)}$ 计算相似度:
$$ e_j^{(t)} = \boldsymbol{s}^{(t)\top}\boldsymbol{h}^{(j)},\quad j=1,\dots,S $$ (这里用点积;其他打分函数见 3.4 节)
2softmax 归一化成概率分布
$$ \alpha_j^{(t)} = \frac{\exp(e_j^{(t)})}{\sum_{j'=1}^{S}\,\exp(e_{j'}^{(t)})} $$ $\boldsymbol{\alpha}^{(t)} = (\alpha_1^{(t)},\dots,\alpha_S^{(t)})$ 称为 attention distribution,是源句长度上的概率分布。
3加权求和得到 attention output(也叫 context vector):
$$ \boldsymbol{a}^{(t)} = \sum_{j=1}^{S}\,\alpha_j^{(t)}\,\boldsymbol{h}^{(j)} $$ $\boldsymbol{a}^{(t)} \in \mathbb{R}^d$ 主要由权重大的 encoder 隐藏状态构成。
4用 attention output 帮助预测:拼接 $\boldsymbol{a}^{(t)}$ 与 decoder 隐状态 $\boldsymbol{s}^{(t)}$,然后投影出词表分布:
$$ P(y_t \mid y_{<t}, x) = \mathrm{softmax}\!\big(W_o\,[\boldsymbol{a}^{(t)};\,\boldsymbol{s}^{(t)}]\big) $$
il a m' entarté h⁴ s⁽¹⁾ ⟨START⟩ scores eⱼ αⱼ 分布 attention output a ŷ₁ = "he" softmax → weighted sum (Σ αⱼ hⱼ) ① 点积分数 ② softmax ③ 加权求和 ④ 用于预测
图 3.1:Seq2Seq + Attention 在第 1 步的完整计算流程。Attention 选中了 "il" 对应到 "he"。

3.3 Attention 作为软查找表 (soft lookup)

理解 attention 最有用的角度是把它看作"软化的字典查表"。在普通的 Python 字典里:

d = {"apple": 1, "banana": 2, "cherry": 3}
d["banana"]  # → 2,精确匹配

普通查表是的(hard):查 "banana",直接返回对应的 value,其他 key 完全忽略。 但如果我们 query 的不是字典里的精确 key,比如 "banan"?只能报 KeyError。

Attention 是查表(soft lookup):

不再精确匹配,而是 query 与每个 key 做相似度(点积),转成概率分布,再用这个分布对所有 values 加权求和:

Attention as soft lookup
$$ \text{Attention}(q, K, V) = \sum_{j=1}^{n}\,\underbrace{\frac{\exp(q^\top k_j)}{\sum_{j'} \exp(q^\top k_{j'})}}_{\alpha_j\;\text{(soft index)}}\;v_j $$
Attention(软查找) q k₁ v₁ k₂ v₂ k₃ v₃ k₄ v₄ k₅ v₅ Σ output query 与所有 key 都匹配(粗细=权重) 普通字典(硬查找) d a v₁ b v₂ c v₃ d v₄ e v₅ → v₄ 仅匹配一个 key(完全选中)
图 3.2:左 — 软 attention(query 与所有 key 计算相似度,加权求和 values);右 — 普通字典(精确匹配单个 key)。

为什么"软"很关键?因为软查表是处处可导的。如果你硬要 argmax(hard attention),梯度无法回传,必须用 REINFORCE 等技巧。 而 softmax 让整个流程端到端可微,可用普通反向传播训练。

3.4 三种打分函数对比 (加性/点积/缩放)

$e_j = f(s, h_j)$ 这个相似度函数 $f$ 有几种历史选择:

名称公式来源特点
加性 (additive)
aka Bahdanau
$v^\top\tanh(W_1 s + W_2 h_j)$Bahdanau 2014表达力强;需要额外参数 $W_1, W_2, v$;计算慢
点积 (dot product)
aka Luong-dot
$s^\top h_j$Luong 2015无参数;要求 $s, h$ 维度相同;快
双线性 (bilinear)
aka Luong-general
$s^\top W h_j$Luong 2015有参数;允许不同维度;中速
缩放点积
(scaled dot)
$\dfrac{s^\top h_j}{\sqrt{d}}$Vaswani 2017无参数;防止 softmax 饱和(见 5.3 节)

Transformer 选择缩放点积,因为:

  1. 无额外参数 — 比加性少一组矩阵。
  2. 可用高度优化的矩阵乘法实现 — GPU 的 BLAS / cuBLAS 内核高效。
  3. 缩放因子 $\sqrt{d}$ 解决了点积在高维下方差爆炸的问题(详见 5.3 节推导)。

3.5 Attention 的五大收益

为什么 attention 被誉为深度学习史上最重要的 idea 之一?
  1. 显著提升 NMT 性能:Bahdanau 等人在 WMT'14 EN-FR 上将 BLEU 从 17.3 推到 28.5(+11 个点,巨大)。
  2. 解决瓶颈问题:decoder 可以直接看 encoder 的每个位置,不再受 $\boldsymbol{c}$ 容量限制。
  3. 缓解梯度消失:attention 在 encoder 与 decoder 之间架起直接捷径,反向传播梯度可以一步到达任何 encoder 位置,不必沿 $T$ 步链式衰减。
  4. 提供可解释性:attention distribution 可视化为热图,直观地看到"翻译某个词时模型在看源句哪里"——这是 NLP 史上第一次免费得到对齐信息(之前 SMT 要专门训 aligner)。
  5. 更"人类"的认知模型:人类翻译时也是回头看源文,不是把整段记忆死。

但要注意一个新引入的代价:计算复杂度 $O(S \cdot T)$(源长 × 译长)。在 self-attention 里这会变成 $O(n^2)$, 是后面所有"高效 Transformer"研究的源头。

3.6 通用框架:Query / Key / Value

本节最后我们把 attention 抽象到最一般的形式。一个 attention 模块由三部分输入构成:

Queries $Q \in \mathbb{R}^{n_q \times d_k}$
"我在找什么"
Keys $K \in \mathbb{R}^{n_k \times d_k}$
"我有什么可以匹配"
Values $V \in \mathbb{R}^{n_k \times d_v}$
"匹配上之后给你什么内容"

核心计算:

General Attention (Q, K, V 框架)
$$ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\Big(\frac{QK^\top}{\sqrt{d_k}}\Big)\,V $$ 其中 $QK^\top \in \mathbb{R}^{n_q \times n_k}$(每对 q-k 一个分数),softmax 沿 $n_k$ 维度做归一化。

"软查找表"的解读:$n_q$ 个 query 同时查询一张含 $n_k$ 条记录的表,每条记录由 key(索引)和 value(内容)组成。

对应到 seq2seq + attention:

角色seq2seq + attention来源 RNN
Querydecoder 当前隐状态 $s^{(t)}$Decoder RNN
Keyencoder 隐状态 $h^{(j)}$Encoder RNN
Valueencoder 隐状态 $h^{(j)}$(与 key 相同)Encoder RNN

在 seq2seq + attention 里 key = value,这只是历史巧合。第四部分我们会看到 self-attention 把它们分开了: 通过不同的投影矩阵 $W_K, W_V$ 让"索引信息"和"内容信息"可以各自学习,互不干扰。

Q / K / V 是同一个东西吗?
在数学上,它们可以是同一组向量经过不同投影,也可以是完全不同来源的向量。 两种用法:
  • Cross-attention(如 seq2seq):Q 来自 decoder,K/V 来自 encoder — 两端不同。
  • Self-attention:Q / K / V 都来自同一序列(如同一个句子)的不同投影 — 序列对自己做 attention。

本节思考题(研究生)

  1. 为什么 attention 缓解梯度消失?请用反向传播链分析:从 $J$ 到 $\boldsymbol{h}^{(1)}$ 在带 attention 的 seq2seq 中梯度长什么样?
  2. "Hard attention"(argmax 而非 softmax)需要怎么训练?为什么 soft attention 是默认选择?请用 REINFORCE 与 reparameterization trick 比较。
  3. Attention 分布常常被解读为"模型对齐",但 Jain & Wallace (2019) 论文标题就是 "Attention is not Explanation"。读完该论文后简要说明:为什么 attention 权重 ≠ 因果重要性?

第四部分:Self-Attention 的诞生

2017 年的关键飞跃是:把 attention 从 RNN 上"剥离"出来,让它独立成为一种序列建模单元。 新名字叫 self-attention(自注意力):一个序列对自己做 attention。 本部分逐步推导 self-attention 的数学定义、它无法独立工作的三大障碍、以及对应的工程补丁。

4.1 抛弃 RNN 的动机

仔细审视一下:上一节的 attention 究竟在做什么?它接受一个 query,对一组 keys/values 做加权检索,输出一个向量。 而 RNN 在做什么?它接受当前输入 $x_t$ 与前一时刻状态 $h_{t-1}$,输出一个向量。

两者的本质都是"信息聚合"。Attention 是"从所有位置加权聚合",RNN 是"沿着时间链聚合"。 那能不能用 attention 完全替代 RNN

答案是肯定的,而且收益是巨大的:

维度RNNSelf-Attention
并行化串行 $O(n)$ 步完全并行,1 步矩阵乘
最长路径长度$O(n)$(信号要走 n 步)$O(1)$(任意两位置直接相连)
每层计算复杂度$O(n \cdot d^2)$$O(n^2 \cdot d)$
梯度通路易消失/爆炸直接通路,稳定

当 $n < d$(短序列、高维度,比如 $n=128, d=1024$)时 self-attention 的总 FLOPs 反而更少。 即使 $n > d$,并行化带来的 wall-clock 加速依然显著。

4.2 Self-Attention 数学定义

设输入为一个词序列 $\boldsymbol{w}_1, \dots, \boldsymbol{w}_n$(例如 "Zuko made his uncle tea")。每个 $\boldsymbol{w}_i$ 通过词嵌入矩阵 $E \in \mathbb{R}^{d \times |V|}$ 转成向量: $\boldsymbol{x}_i = E\boldsymbol{w}_i \in \mathbb{R}^d$。

Self-attention 的三步骤

1三组线性投影得到 query / key / value。引入三个可学习矩阵 $Q, K, V \in \mathbb{R}^{d\times d}$(注意:这里复用了字母):
$$ \boldsymbol{q}_i = Q\boldsymbol{x}_i \quad \text{(query)}\qquad \boldsymbol{k}_i = K\boldsymbol{x}_i \quad \text{(key)}\qquad \boldsymbol{v}_i = V\boldsymbol{x}_i \quad \text{(value)} $$

关键观察:同一个 $\boldsymbol{x}_i$ 同时充当 query、key、value 三个角色,但经过不同投影矩阵。 这就是 "self-attention" 名字的由来——序列对自己做 attention。

2计算两两相似度并 softmax 归一化
$$ e_{ij} = \boldsymbol{q}_i^\top \boldsymbol{k}_j,\qquad \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{j'} \exp(e_{ij'})} $$ $\alpha_{ij}$ 表示"位置 $i$ 关注位置 $j$ 的程度"。注意它不对称:$\alpha_{ij} \ne \alpha_{ji}$。
3加权求和 values,得到每个位置的输出:
$$ \boldsymbol{o}_i = \sum_{j=1}^{n}\,\alpha_{ij}\,\boldsymbol{v}_j $$ $\boldsymbol{o}_i \in \mathbb{R}^d$ 是位置 $i$ 的新表示,融合了整个序列所有相关位置的信息。
I went to Stanford CS 224n and learned v v v v v v v v k k k k k k k k q attention weights for "learned"
图 4.1:Self-attention 假想示例(PDF p.42)。"learned" 作为 query 查询整句,所有词(包括自己)按相似度产生权重, 最终 "learned" 的新表示主要由 "CS"、"224n"、"learned" 自身的 value 加权而成。

4.3 三大障碍与对应方案

现在让我们诚实地问:能不能把 self-attention 一层接一层堆叠,直接做语言模型?答案是不能。 存在三个根本性障碍:

障碍问题描述方案
1. 无顺序信息Self-attention 对输入做的是置换不变(permutation invariant)的操作 — 把 "猫吃鱼" 重排成 "鱼吃猫" 输出相同!位置编码 (positional encoding)
2. 无非线性Self-attention 本质是 value 的加权平均(线性),堆叠多层只是重新加权再加权,仍是线性变换 — 无深度学习"魔法"逐位置 FFN(feed-forward network)
3. 偷看未来语言模型要求时刻 $t$ 只能看 $1\dots t-1$ 的词。Self-attention 默认看全部位置 — 训练时变成作弊因果掩码 (causal mask)
为什么 self-attention 置换不变?
设 $\pi$ 是输入位置的任意排列,输入变为 $\{x_{\pi(1)}, \dots, x_{\pi(n)}\}$。重新计算: $\boldsymbol{q}_{\pi(i)} = Q x_{\pi(i)}$ 等等。可以验证输出集合 $\{o_{\pi(1)}, \dots, o_{\pi(n)}\}$ 与原输出 $\{o_1, \dots, o_n\}$ 完全相同(只是位置同步重排)。所以模型无法区分 "I love you" 和 "you love I"。

4.4 位置编码:sinusoidal vs learned

解决第一个障碍:给每个位置 $i$ 分配一个位置向量 $\boldsymbol{p}_i \in \mathbb{R}^d$,然后把它加到词嵌入上:

注入位置信息
$$ \widetilde{\boldsymbol{x}}_i = \boldsymbol{x}_i + \boldsymbol{p}_i $$

之后所有计算用 $\widetilde{\boldsymbol{x}}_i$ 取代原始 $\boldsymbol{x}_i$。这一步只在第一层输入做(深层 self-attention 网络只在最底层注入位置信息)。

Sinusoidal Position Encoding(Vaswani 2017)

不同频率的正弦/余弦构造位置向量。第 $i$ 位置的第 $2k$ / $2k+1$ 维:

$$ \boldsymbol{p}_i \;=\; \begin{pmatrix} \sin(i / 10000^{2\cdot 1 / d}) \\ \cos(i / 10000^{2\cdot 1 / d}) \\ \vdots \\ \sin(i / 10000^{2\cdot d/2 / d}) \\ \cos(i / 10000^{2\cdot d/2 / d}) \end{pmatrix} $$
维度(高频→低频) 序列位置 i 0 256 512
图 4.2:Sinusoidal position encoding 的"条纹"模式。低维(底部)周期短、变化快;高维(顶部)周期长、变化慢。每个位置都得到一个独特的 d 维"指纹"。

优点:

缺点:

Learned Absolute Position Encoding

更简单粗暴的做法:直接学一个 $\boldsymbol{p} \in \mathbb{R}^{d \times n_{\max}}$ 矩阵,每列对应一个位置。BERT、GPT-1/2、RoBERTa 都用这种方式。

SinusoidalLearned
参数量0$d \times n_{\max}$(百万级)
外推理论可,实际差不可(超出 $n_{\max}$ 没意义)
灵活性固定模式每位置自由学
原始论文使用Vaswani 2017BERT, GPT-1/2

后续相对位置编码(relative, Shaw et al. 2018)、RoPE(Su et al. 2021,被 LLaMA / GPT-NeoX / Mistral 采用) 和 ALiBi(Press et al. 2022,BLOOM 采用)在外推性与建模能力上更强,详见第六部分。

4.5 前馈网络注入非线性

第二个障碍:self-attention 是线性的。证明:

$$ \boldsymbol{o}_i = \sum_j \alpha_{ij}\,\boldsymbol{v}_j = \sum_j \alpha_{ij}\,(V\boldsymbol{x}_j) = V\Big(\sum_j \alpha_{ij}\,\boldsymbol{x}_j\Big) $$ 即使 $\alpha_{ij}$ 是 softmax 算出来的(非线性),但 $\boldsymbol{o}_i$ 关于 values 是线性组合。 深层堆叠等价于"加权再加权"——多层 self-attention 的表示能力等价于一层(在 attention pattern 固定的意义下)。

修复办法:在每个 self-attention 层之后加一个逐位置(position-wise)的两层 FFN,引入 ReLU 非线性:

Position-wise Feed-Forward Network
$$ \boldsymbol{m}_i = \mathrm{MLP}(\boldsymbol{o}_i) = W_2\,\mathrm{ReLU}(W_1\,\boldsymbol{o}_i + \boldsymbol{b}_1) + \boldsymbol{b}_2 $$ 其中 $W_1 \in \mathbb{R}^{d_{\text{ff}}\times d}$, $W_2 \in \mathbb{R}^{d\times d_{\text{ff}}}$, $d_{\text{ff}} = 4d$ 是标准选择(如 $d=512, d_{\text{ff}}=2048$)。

"逐位置"意思是每个位置独立用同一个 FFN 处理——这两层 MLP 在所有位置间共享权重,但不混合位置信息。 混合信息的工作完全交给 self-attention。Transformer 因此是"attention 混位置,FFN 加非线性"的清晰分工。

FFN 占了多少参数?
对一个 Transformer block:
  • Self-attention 参数:$4 d^2$($Q, K, V, W_O$ 各 $d \times d$)
  • FFN 参数:$2 \times d \times 4d = 8 d^2$
所以 FFN 占了一个 block 中 ~67% 的参数。最近几年的研究(如 Geva et al. 2021 "Transformer FFN as Key-Value Memory")发现 FFN 实际上充当了模型的事实记忆库

4.6 因果掩码:防止偷看未来

第三个障碍:训练自回归语言模型时,预测 $y_t$ 只能看 $y_1, \dots, y_{t-1}$。但 self-attention 默认看整个序列,会"偷看" $y_t, y_{t+1}, \dots$。 直接做法是逐位置改变 keys/queries 的集合 — 但这无法并行(每个位置不同长度)。

聪明的并行解法:把"未来位置"的 attention score 设为 $-\infty$,softmax 之后这些位置的权重自动变成 0:

Causal masking
$$ e_{ij} = \begin{cases} \boldsymbol{q}_i^\top \boldsymbol{k}_j, & j \le i \\ -\infty, & j > i \end{cases} $$
⟨START⟩ The chef who cooks ↑ 可以看的(keys, j) ⟨START⟩ The chef who cooks ← 编码的(queries, i) -∞ -∞ -∞ -∞ -∞ -∞ -∞ -∞ -∞ -∞ 绿色:可注意(j ≤ i) 灰色:屏蔽未来(j > i)
图 4.3:因果掩码矩阵(lower triangular)。第 $i$ 行表示位置 $i$ 的 attention,绿色单元可参与 softmax,灰色单元被强制为 0。

PyTorch 实现一行:

mask = torch.triu(torch.full((n, n), float('-inf')), diagonal=1)
scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)
scores = scores + mask          # 上三角变 -inf
attn = scores.softmax(dim=-1)   # softmax 后上三角自动 = 0

注意三种 attention 的 mask 配置:

模型类型Mask用途
Encoder(双向)BERT, T5 encoder:理解任务
Decoder(因果)下三角GPT 系列:生成任务
Encoder-Decoder Cross-Attention无(query 总能看完整 encoder)翻译、摘要

4.7 障碍-方案总结

Self-attention 三大障碍 + 方案对照(PDF p.53)
障碍方案
无顺序信息加位置编码 $\boldsymbol{p}_i$ 到输入
无非线性逐位置 FFN(两层 + ReLU)
偷看未来上三角设 $-\infty$ 因果掩码

把这三块拼起来,就已经是一个最简版 Transformer block(minimal self-attention building block):

flowchart TB X["输入序列 x₁ … xₙ"] --> P["+ 位置编码 pᵢ"] P --> A["Masked Self-Attention
(Q, K, V)"] A --> F["Position-wise FFN
(W₂ ReLU(W₁ x))"] F --> O["输出表示 m₁ … mₙ"] O --> S["重复 N 次(堆叠)"] S --> L["最终 LM 头:softmax 词表"] style A fill:#fbecec style F fill:#eafaf0 style P fill:#fff5d6

下一部分我们再加 4 个工程优化:multi-head、scaled dot-product、残差、layer-norm。这就完整还原 Vaswani 2017 论文里的 Transformer Decoder。

本节思考题(研究生)

  1. 证明:如果在 self-attention 之后不加 FFN,堆叠两层 self-attention(不带 mask)等价于一层 self-attention 加一组新的投影矩阵。
  2. Sinusoidal 位置编码的"线性外推"性质:证明存在矩阵 $M_k$ 使得 $\boldsymbol{p}_{i+k} = M_k \boldsymbol{p}_i$。这与相对位置如何相关?
  3. 因果掩码的实现里,把 $-\infty$ 改成 $-10^9$ 实际效果一样吗?数值上有什么风险?

第五部分:完整 Transformer 架构

上一节我们得到了一个最简 self-attention building block。这一节加上 4 个让它"真的能 train 得起来"的工程优化: 矩阵化批量计算、multi-head、scaled dot product、残差连接 + Layer Norm。最终拼装成 Vaswani et al. 2017 的标准 Transformer。

5.1 矩阵化 Attention:批量并行

之前我们写的 self-attention 是"逐位置"形式(公式里有 $\sum_j$)。为了 GPU 并行,必须改写成矩阵乘法。

令 $X = [\boldsymbol{x}_1; \cdots; \boldsymbol{x}_n] \in \mathbb{R}^{n \times d}$ 为所有输入向量堆叠。则:

$$ XQ \in \mathbb{R}^{n\times d},\qquad XK \in \mathbb{R}^{n\times d},\qquad XV \in \mathbb{R}^{n\times d} $$

所有两两 query-key 点积一次性算出来:

$$ S = XQ\,(XK)^\top \in \mathbb{R}^{n\times n} $$ $S_{ij} = \boldsymbol{q}_i^\top \boldsymbol{k}_j$,即所有 $(i,j)$ 对的 attention 分数。

沿最后一维 softmax,再乘以 values:

矩阵化 Self-Attention
$$ \mathrm{Attention}(X) = \mathrm{softmax}\!\Big(\,XQ\,(XK)^\top\,\Big)\,XV \;\in\; \mathbb{R}^{n\times d} $$
XQ n × d KᵀXᵀ d × n = XQKᵀXᵀ n × n(所有分数) softmax( XQKᵀXᵀ ) · XV = output ∈ ℝⁿˣᵈ All pairs of attention scores!
图 5.1:矩阵化的 self-attention。两次矩阵乘法 + 一次 softmax 即可完成所有位置的 attention 计算(PDF p.57)。

比起 RNN 的 $n$ 步串行,这里只需两次大矩阵乘。在 GPU 上利用 cuBLAS 可以做到 $> 90\%$ peak FLOPs。

5.2 Multi-Head Attention

单一 self-attention 有一个根本局限:每个 query 只能产生一组权重分布。 但一个词在句子中可能同时和多个东西有不同性质的关系(语法、语义、共指、长距等)。让我们看一个 PDF 中的例子(p.56):

Head 1:关注实体 I went to Stanford CS 224n and Head 2:关注句法 I went to Stanford CS 224n 同一个 query("learned")在不同 head 下学到不同的"看哪里"策略
图 5.2:Multi-Head Attention 的直觉。Head 1 关注实体(Stanford / CS / 224n),Head 2 关注主语 (I, went)。

Vaswani 等人的解决方案:同时训 $h$ 套 Q/K/V 矩阵,每套独立做 attention,最后拼接:

Multi-Head Attention
$$ Q_\ell, K_\ell, V_\ell \in \mathbb{R}^{d \times d/h},\quad \ell = 1, \dots, h $$ $$ \text{head}_\ell = \mathrm{softmax}\!\Big(\frac{X Q_\ell K_\ell^\top X^\top}{\sqrt{d/h}}\Big)\,X V_\ell \;\in\; \mathbb{R}^{n \times d/h} $$ $$ \mathrm{MultiHead}(X) = [\text{head}_1; \cdots; \text{head}_h]\,W_O,\quad W_O \in \mathbb{R}^{d\times d} $$

关键设计:每个 head 的维度是 $d/h$(不是 $d$),所以所有 heads 拼起来正好是 $d$ 维。 计算量与单 head $d$-维 attention 相同,但表达力更强

高效实现要点
虽然概念上是 $h$ 套 Q/K/V,工程实现时通常用一次大投影 + reshape
  1. 用单个矩阵 $W_Q \in \mathbb{R}^{d \times d}$ 把 $X$ 投影到 $\mathbb{R}^{n \times d}$
  2. Reshape 成 $\mathbb{R}^{n \times h \times d/h}$,再 transpose 成 $\mathbb{R}^{h \times n \times d/h}$
  3. 把 head 维度当成 batch 维度处理("为啥不利用 GPU 的批量?")
PyTorch 一句:q = self.W_q(x).view(B, n, h, d_k).transpose(1, 2)

实证上,常用 head 数 $h$ 与维度 $d$:

模型$d$$h$$d/h$
Transformer Base (2017)512864
BERT-Base7681264
BERT-Large10241664
GPT-3 175B1228896128
LLaMA-2 70B819264128

注意 $d/h$ 几乎总是 64–128 —— 这不是巧合,与 GPU tensor core 的 16/32/64 对齐有关。

5.3 Scaled Dot-Product 的数学缘由

为什么要除以 $\sqrt{d_k}$?让我们做一个简单概率推导。

假设 $\boldsymbol{q}, \boldsymbol{k} \in \mathbb{R}^{d_k}$ 的各分量都是 $\mathcal{N}(0,1)$ 独立采样。那么它们的点积:

$$ \mathbb{E}[\boldsymbol{q}^\top \boldsymbol{k}] = \sum_{i=1}^{d_k} \mathbb{E}[q_i k_i] = 0 $$ $$ \mathrm{Var}[\boldsymbol{q}^\top \boldsymbol{k}] = \sum_{i=1}^{d_k} \mathrm{Var}[q_i k_i] = d_k $$

所以 $\boldsymbol{q}^\top \boldsymbol{k}$ 的标准差是 $\sqrt{d_k}$。当 $d_k = 64$ 时,分数典型量级是 $\pm 8$;当 $d_k = 128$ 时,是 $\pm 11.3$。

softmax 对输入尺度极其敏感:输入差 1 时概率比是 $e \approx 2.7$;差 10 时是 $e^{10} \approx 22000$;差 20 时是 $5 \times 10^8$!

大尺度 + softmax = 梯度死亡
当 softmax 输入差距 ≫ 1,输出会极度尖锐(接近 one-hot)。此时:
  • 梯度在非主导位置都接近 0(softmax 进入饱和区)
  • 训练几乎学不到新的 attention pattern
  • 初始训练阶段尤其严重,模型容易"卡死"

解决方案就是缩放

Scaled Dot-Product Attention
$$ \mathrm{Attention}(Q,K,V) = \mathrm{softmax}\!\Big(\frac{QK^\top}{\sqrt{d_k}}\Big)\,V $$

除以 $\sqrt{d_k}$ 后,分数的方差变回 1,softmax 输入处于"合理量级",梯度健康。

5.4 残差连接:梯度高速公路

当我们把 Transformer 堆叠到 6 层、12 层、24 层、48 层、96 层(GPT-3)时,纯堆叠会再次遇到梯度消失(不是 RNN 那种沿时间的消失,而是沿深度的消失)。 He et al. 2016 在 ResNet 中提出的残差连接是普世解:

残差连接 (Residual Connection)
$$ X^{(i)} = X^{(i-1)} + \mathrm{Layer}(X^{(i-1)}) $$

直觉:让恒等映射(identity mapping)成为基线,子层 $\mathrm{Layer}(\cdot)$ 只需要学习"残差"。

无残差:X⁽ⁱ⁾ = Layer(X⁽ⁱ⁻¹⁾) Layer X⁽ⁱ⁻¹⁾ X⁽ⁱ⁾ 梯度需穿过 Layer 的所有非线性 有残差:X⁽ⁱ⁾ = X⁽ⁱ⁻¹⁾ + Layer(X⁽ⁱ⁻¹⁾) Layer + skip connection(恒等) X⁽ⁱ⁻¹⁾ X⁽ⁱ⁾ 梯度有"高速公路"直达浅层(系数=1)
图 5.3:残差连接(右)相对于普通堆叠(左),多出一条"梯度高速公路"。

反向传播时:

$$ \frac{\partial X^{(i)}}{\partial X^{(i-1)}} = I + \frac{\partial\,\mathrm{Layer}(X^{(i-1)})}{\partial X^{(i-1)}} $$

即使 Layer 的 Jacobian 接近 0,梯度仍能通过 $I$(恒等)项无衰减回传。这是训练超深网络(>50 层)的必要条件。 Li et al. 2018 的 loss landscape 可视化证明:有残差的网络 loss landscape 平滑得多,更容易优化。

5.5 Layer Normalization

Layer Normalization (Ba et al. 2016) 解决另一个工程难题:不同层、不同位置的激活值尺度差异巨大,导致训练不稳定。

对每个位置的隐藏向量 $\boldsymbol{x} \in \mathbb{R}^d$,LayerNorm 沿特征维度归一化:

Layer Normalization
$$ \mu = \frac{1}{d}\sum_{j=1}^{d} x_j,\qquad \sigma^2 = \frac{1}{d}\sum_{j=1}^{d}(x_j - \mu)^2 $$ $$ \mathrm{LayerNorm}(\boldsymbol{x}) = \frac{\boldsymbol{x} - \mu}{\sqrt{\sigma^2 + \epsilon}}\,\odot\,\boldsymbol{\gamma} + \boldsymbol{\beta} $$ 其中 $\boldsymbol{\gamma}, \boldsymbol{\beta} \in \mathbb{R}^d$ 是可学习的"放大"与"偏移"参数("gain" and "bias")。
BatchNormLayerNorm
归一化轴沿 batch 维(同一特征在不同样本上)沿特征维(同一样本的所有特征)
需要 batch是(小 batch 不稳)否(单样本也行)
训练/推理一致否(推理用 running stats)
序列长度敏感

LayerNorm 在 Transformer 流行后逐渐成了序列模型的事实标准。最近的 LLaMA / GPT-NeoX 使用 RMSNorm(去掉均值减去)也工作得很好,速度更快。

5.6 Transformer Decoder Block 拼装

把 5.1–5.5 节的所有组件拼起来:

完整 Transformer Decoder Block(一层)
$$ Z = \mathrm{LayerNorm}\!\big(X + \mathrm{MultiHeadAttn}^{\text{causal}}(X)\big) $$ $$ X^{\text{out}} = \mathrm{LayerNorm}\!\big(Z + \mathrm{FFN}(Z)\big) $$

这是"Post-LN"版本(先 add 再 norm),原始 Vaswani 2017 的形式。

Embeddings Decoder Inputs Add Position Embeddings Masked Multi-Head Attention Add & Norm 残差 Feed-Forward (FFN) Add & Norm 重复 N 次(堆叠 Block) Linear Softmax Probabilities
图 5.4:完整 Transformer Decoder。一个 block 由 Masked MHA → Add&Norm → FFN → Add&Norm 组成,重复 N 次后接 Linear + Softmax。

5.7 Transformer Encoder 与 Encoder-Decoder

三种 Transformer 变种:

变种Mask典型代表用途
Decoder-only因果(下三角)GPT 系列, LLaMA, Mistral语言生成、对话
Encoder-only无 mask(双向)BERT, RoBERTa, DeBERTa分类、检索、NER
Encoder-Decoderencoder 无 / decoder 有T5, BART, mT5, NMT翻译、摘要、SeqSeq

Encoder 与 Decoder 唯一区别:去掉因果掩码,让每个位置看整个序列。BERT 因此称为"双向"Transformer。

Encoder-Decoder 则把两者组合:

Encoder Block Embeddings Encoder Inputs Add Position Embeddings Multi-Head Attention (无 mask) Add & Norm Feed-Forward Add & Norm Decoder Block Embeddings Decoder Inputs Add Position Embeddings Masked MHA (self) Add & Norm Cross MHA (encoder K,V) Add & Norm Feed-Forward + Add & Norm Linear + Softmax encoder 输出 → K, V
图 5.5:Encoder-Decoder Transformer。Decoder 在 self-attention 之后多一层 cross-attention,从 encoder 拿到 K, V。

5.8 Cross-Attention 细节

Cross-attention 是 encoder-decoder Transformer 的关键。它让 decoder 在生成时"回看"源句信息。 形式上与 self-attention几乎一模一样,只是 Q 和 K/V 来源不同:

Cross-Attention

令:

  • $\boldsymbol{h}_1, \dots, \boldsymbol{h}_n$:来自 encoder 的输出向量
  • $\boldsymbol{z}_1, \dots, \boldsymbol{z}_m$:来自 decoder 的输入向量(已经过 self-attention 处理)

则:

$$ \boldsymbol{k}_i = K\,\boldsymbol{h}_i,\quad \boldsymbol{v}_i = V\,\boldsymbol{h}_i,\quad \boldsymbol{q}_j = Q\,\boldsymbol{z}_j $$ $$ \mathrm{CrossAttn} = \mathrm{softmax}\!\Big(\frac{QK^\top}{\sqrt{d_k}}\Big)\,V $$
Cross-attention 的两个用法
  • 翻译:decoder 生成英语 token 时回看法语 encoder 输出。
  • 多模态:decoder(文本)回看 vision encoder 的图像 patch tokens(如 BLIP-2、Flamingo)。
Cross-attention 是跨模态信息融合的最经典机制,多模态大模型几乎人人都用。

本节思考题(研究生)

  1. Multi-head 的 $h$ 个 head 维度都是 $d/h$,能否让某些 head 维度更大、某些更小?这就是"Mixture-of-Heads"。读 Karpathy minGPT 后讨论可行性。
  2. 为什么 LayerNorm 用 $\gamma, \beta$ 重新引入缩放和偏移,难道不抵消了归一化的效果吗?给出信息论或优化论解释。
  3. BERT 不用 cross-attention 是不是反而成了缺点?为什么 T5 把 encoder-decoder 重新带回主流?

第六部分:实验结果与现代演化

2017 年那篇论文只是 Transformer 的起点。过去 8 年里,Transformer 经历了大量"非原始"改进, 其中很多是 PDF 课堂上没有展开的研究生必备内容:Pre-LN、RoPE、Flash Attention、KV cache、MQA/GQA。 本部分系统串讲这些演化。

6.1 原始 Vaswani 2017 结果

Transformer 在 WMT 2014 翻译任务上一鸣惊人:

模型EN-DE BLEUEN-FR BLEU训练 FLOPs
ByteNet (CNN)23.75
GNMT + RL24.639.92$2.3 \times 10^{19}$
ConvS2S25.1640.46$9.6 \times 10^{18}$
MoE26.0340.56$2.0 \times 10^{19}$
GNMT + RL Ensemble26.3041.16$1.8 \times 10^{20}$
ConvS2S Ensemble26.3641.29$7.7 \times 10^{19}$
Transformer (base)27.338.1$3.3 \times 10^{18}$
Transformer (big)28.441.8$2.3 \times 10^{19}$

亮点:Transformer base 用少一个数量级的 FLOPs 击败所有 RNN/CNN baselines; Transformer big 在 EN-DE 上比之前最好高 1.5+ BLEU,在 EN-FR 创下新 SOTA。

更震撼的是文档生成(WikiSum, Liu et al. 2018):原来 seq2seq+attn 的 perplexity 是 5.05、ROUGE-L 12.7, Transformer 把 perplexity 拉到 1.90、ROUGE-L 38.8 — "Transformers all the way down"

6.2 Transformer 的痛点清单

原始 Transformer 的几个根本痛点
  1. $O(n^2)$ 复杂度:序列长度翻倍,attention FLOPs 翻 4 倍。$n=50,000$ 时单层 $2.5 \times 10^9$ 次操作。
  2. $O(n^2)$ 显存:attention matrix 要存下 $n \times n$。GPT-3 的 $n=2048$ 已是显存瓶颈。
  3. 位置编码不外推:训练时 $n_{\max}=512$,推理时输入 $n=1024$,模型崩溃。
  4. 推理慢:自回归生成时每生成一个 token 要从头重算所有 attention 矩阵。
  5. 训练不稳:原始 Post-LN 配方需要 warmup learning rate、精心调梯度裁剪,对超参敏感。

下面几节讲针对这些痛点的现代解。

6.3 Pre-LN vs Post-LN 之争

原始 Transformer 是 Post-LN:先 add 残差,再 LayerNorm。 Xiong et al. 2020 ("On Layer Normalization in the Transformer Architecture") 证明,Pre-LN(先 norm 再 attention,残差直接相加)训练更稳定。

两种 LayerNorm 位置
Post-LN(原始): $$ X^{\text{out}} = \mathrm{LN}\!\big(X + \mathrm{SubLayer}(X)\big) $$ Pre-LN(现代): $$ X^{\text{out}} = X + \mathrm{SubLayer}\!\big(\mathrm{LN}(X)\big) $$
Post-LN(Vaswani 2017) Input X SubLayer (Attn/FFN) + LayerNorm Pre-LN(GPT-2 及之后) Input X LayerNorm SubLayer (Attn/FFN) +
图 6.1:Post-LN(左)vs Pre-LN(右)。Pre-LN 让残差直通到加法点,梯度路径更纯净。
Post-LNPre-LN
训练稳定性需 warmup 学习率,容易发散无需 warmup,可用大学习率
深度可堆叠性>12 层易梯度爆炸可堆 100+ 层(GPT-3, LLaMA)
峰值性能调好后略优略弱但稳定
采用模型原始 Transformer, BERTGPT-2/3, LLaMA, PaLM, T5, Mistral

现代几乎所有 LLM 都用 Pre-LN。BERT 是少有的 Post-LN 例外(因为已工业部署,难以改)。

6.4 现代位置编码:RoPE 与 ALiBi

原始 sinusoidal / learned absolute position 有两个缺点:

  1. 外推差:训练时见过 $n_{\max}=512$,推理超过就崩溃。
  2. 建模"相对位置"间接:通常我们关心 "i 比 j 早多少位",而非具体位置。

RoPE: Rotary Position Embedding (Su et al. 2021)

不在输入加位置向量,而是在 attention 内部对 query/key 做"旋转"。把 $\boldsymbol{q}_i \in \mathbb{R}^d$ 拆成 $d/2$ 个 2D 子空间,每个子空间用旋转矩阵 $R_\theta$ 旋转 $i\theta$ 角度:

RoPE 旋转
$$ \mathrm{RoPE}(\boldsymbol{q}_i, i) = R_{\boldsymbol{\Theta}, i}\,\boldsymbol{q}_i,\quad R_{\boldsymbol{\Theta},i} = \begin{pmatrix} \cos(i\theta_1) & -\sin(i\theta_1) & & \\ \sin(i\theta_1) & \cos(i\theta_1) & & \\ & & \ddots & \\ & & & \cos(i\theta_{d/2}) \end{pmatrix} $$ 点积 $\boldsymbol{q}_i^\top \boldsymbol{k}_j$ 自然变成 $\boldsymbol{q}_i^\top R_{\boldsymbol{\Theta}, j-i} \boldsymbol{k}_j$,只依赖相对位置 $j-i$

优点:

LLaMA 1/2/3、Mistral、PaLM、Qwen、GLM-4 等绝大多数主流 LLM 都用 RoPE。

ALiBi: Attention with Linear Biases (Press et al. 2022)

更激进:完全不要位置编码,只在 attention 分数上加一个"距离惩罚"线性偏置:

$$ e_{ij} = \boldsymbol{q}_i^\top \boldsymbol{k}_j - m_h \cdot |i - j| $$ 其中 $m_h$ 是每个 head 不同的斜率(如 $m_h = 2^{-8h/H}$)。

ALiBi 的核心好处是极强外推:训练时 $n=1024$,推理时 $n=16384$ 仍可工作。BLOOM、MPT 等模型采用 ALiBi。

6.5 Flash Attention 与 IO 复杂度

Dao et al. 2022 的 FlashAttention 是工程上最重要的 Transformer 优化之一。 它没有改变数学,只是改了 GPU 内存访问方式

原始 attention 在 GPU 上的瓶颈不是 FLOPs,而是HBM ↔ SRAM 之间的 IO

HBM (High Bandwidth Memory)SRAM (片上 cache)
容量40-80 GB~ 20 MB
带宽~1.5 TB/s~19 TB/s(10×+)
访问代价

原始实现需要:

  1. 读 Q, K 到 SRAM,算 $S = QK^\top \in \mathbb{R}^{n\times n}$,写回 HBM
  2. 读 $S$ 回 SRAM,softmax,写回
  3. 读 softmax 和 V,算 output,写回

显存中实例化的 $n \times n$ 矩阵和反复读写 HBM 是慢的元凶。FlashAttention 用 tiling + online softmax, 把 Q, K, V 分块处理,每个块在 SRAM 里走完整流程,不实例化完整 attention 矩阵

FlashAttention 的实际收益
  • 训练速度:1.7-2.4 倍(A100 上 GPT-2 实测)
  • 显存占用:$O(n^2) \to O(n)$(关于 $n$ 线性!)
  • 实现:torch.nn.functional.scaled_dot_product_attention (PyTorch 2.0+) 默认调用
  • FlashAttention-2(2023)和 FlashAttention-3(2024)进一步优化,是当今 LLM 训练事实标准

注意:FlashAttention 的数学结果与原始 attention 完全相同(在浮点误差内),只是更快。这是工程优化的典范。

6.6 推理优化:KV Cache 与 MQA/GQA

自回归生成一个长度 $n$ 序列时,第 $t$ 步的 attention 需要重新计算 $1, \dots, t-1$ 的 K, V 投影。 但这些 K, V 不依赖当前 query——是纯函数$f(x_{1:t-1})$,可以缓存

KV Cache

KV 缓存策略
生成第 $t$ 个 token 时:
  1. 从缓存读出 $K_{1:t-1}, V_{1:t-1}$
  2. 仅计算新 token 的 $\boldsymbol{k}_t = K\boldsymbol{x}_t$, $\boldsymbol{v}_t = V\boldsymbol{x}_t$(常数代价
  3. append 到缓存:$K_{1:t} = [K_{1:t-1}; \boldsymbol{k}_t]$
  4. 当前 query $\boldsymbol{q}_t$ 与完整 $K_{1:t}, V_{1:t}$ 做 attention

推理时间从 $O(n^2)$ 降到 $O(n)$(生成 $n$ tokens 总共)。代价是显存——KV cache 大小 $= 2 \cdot L \cdot n \cdot d$($L$ 层)。 对于 LLaMA-2 70B 在 $n=4096$ 时 KV cache 已经超过 5GB / 用户,是当代 LLM 服务的关键瓶颈

MQA & GQA:减少 KV cache

多头 attention 每 head 都有独立的 K, V,缓存巨大。Multi-Query Attention(MQA, Shazeer 2019)让 $h$ 个 head 共享同一组 K, V

$$ \text{MHA: } \;K, V \in \mathbb{R}^{h \times n \times d_k} $$ $$ \text{MQA: } \;K, V \in \mathbb{R}^{1 \times n \times d_k} $$ KV cache 缩小 $h$ 倍(如 $h=64$ 缩小到 1/64)!

缺点:质量下降。Grouped-Query Attention(GQA, Ainslie et al. 2023)折中——把 $h$ 个 head 分成 $g$ 组(如 $g=8$),每组共享一组 K/V:

变种$K, V$ 数量缓存大小质量采用
MHA$h$ 组×$h$最佳原始 Transformer
MQA1 组×1(极省)下降PaLM, Falcon
GQA$g$ 组 ($g \ll h$)×$g$(中等)接近 MHALLaMA-2/3, Mistral, Qwen

其他推理优化(速览)

Speculative Decoding
用小模型先猜几个 token,大模型只做验证(一次 forward 多 token)
PagedAttention (vLLM)
把 KV cache 像虚拟内存一样分页管理,提升吞吐量 5-10×
Quantization
权重和 KV 缓存用 INT4/INT8 存储,省 2-4× 显存
Continuous Batching
不同请求长度不同时,动态打包 GPU batch,提升 GPU 利用率

6.7 Transformer 的"终极痛点"思考

课堂 PDF 最后一页(p.71)提了一个有趣问题:"既然 attention 在大模型中只占少部分计算,为什么还要追求线性 attention?"

研究生级的悖论
当 Transformer 长到 GPT-4 量级时:
  • FFN 参数量 ≫ attention 参数量 → 多数 FLOPs 在 FFN 上
  • 但 attention 仍然是 $O(n^2)$,是长上下文(>100k tokens)的瓶颈
  • 所以"线性 attention"(如 Linear Transformer, RWKV, Mamba, Linear Attention)仍有研究价值,但商业模型很少采用——因为质量损失难以接受
2024 年的 SSM (State Space Models, e.g., Mamba) 试图重新挑战 Transformer 在长序列上的霸主地位,可视为一个未结案的开放问题。

本节思考题(研究生)

  1. Pre-LN 的"梯度高速公路"性质:用反向传播证明,在 Pre-LN 下从输出到第 1 层的 Jacobian 包含恒等项 $I$。
  2. RoPE 实际上把 absolute → relative。这与 Shaw et al. 2018 的 relative position embedding 有什么本质差别?
  3. 读 FlashAttention 论文,解释 "online softmax" 怎么在不实例化完整 $S$ 矩阵的情况下做归一化。
  4. MQA 让 KV cache 缩 $h$ 倍但质量下降 — 为什么?从信息论给出一个解释。

附录:PyTorch 完整实现

这一节给出一个完整可运行的 minimal Transformer 实现,约 150 行核心代码。 适合直接复制运行、与 Andrej Karpathy 的 nanoGPT 互相参照。

A.1 Scaled Dot-Product Attention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    q, k, v: (..., n, d_k)
    mask:    (..., n, n) of 0/1,1 表示可见,0 表示屏蔽
    返回:    (..., n, d_k)
    """
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)      # (..., n, n)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)                       # (..., n, n)
    return attn @ v                                        # (..., n, d_v)

A.2 Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k     = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_q, x_kv, mask=None):
        B, n_q, _ = x_q.shape
        _, n_kv, _ = x_kv.shape
        # 1) project & reshape to (B, h, n, d_k)
        q = self.W_q(x_q ).view(B, n_q,  self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x_kv).view(B, n_kv, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x_kv).view(B, n_kv, self.n_heads, self.d_k).transpose(1, 2)

        # 2) scaled dot-product attention
        out = scaled_dot_product_attention(q, k, v, mask)  # (B, h, n_q, d_k)

        # 3) concat heads and project out
        out = out.transpose(1, 2).contiguous().view(B, n_q, self.d_model)
        return self.dropout(self.W_o(out))

注:传 x_q == x_kv 就是 self-attention;传不同就是 cross-attention

A.3 Position-wise FFN

class FFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.relu(self.w1(x))))
        # 现代版用 GELU 或 SwiGLU 替换 ReLU

A.4 Decoder Block(Pre-LN 现代版)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1  = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ln2  = nn.LayerNorm(d_model)
        self.ffn  = FFN(d_model, d_ff, dropout)

    def forward(self, x, mask=None):
        # Pre-LN: 先 norm, 再子层, 残差
        x = x + self.attn(self.ln1(x), self.ln1(x), mask)
        x = x + self.ffn(self.ln2(x))
        return x

A.5 Positional Encoding (sinusoidal)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)   # 不当作参数

    def forward(self, x):
        # x: (B, n, d_model)
        return x + self.pe[:x.size(1)]

A.6 完整 Decoder-only LM (GPT-mini)

class GPTMini(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6,
                 d_ff=2048, max_len=512, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.blocks  = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        # weight tying(GPT-2 trick)
        self.head.weight = self.tok_emb.weight

    def forward(self, idx, targets=None):
        B, n = idx.shape
        x = self.tok_emb(idx)                # (B, n, d_model)
        x = self.pos_emb(x)
        causal = torch.tril(torch.ones(n, n, device=idx.device)).unsqueeze(0).unsqueeze(0)
        for blk in self.blocks:
            x = blk(x, mask=causal)
        x = self.ln_f(x)
        logits = self.head(x)                # (B, n, vocab_size)
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-100
            )
            return logits, loss
        return logits, None

# 训练一个 batch 的 minimal 示例
model = GPTMini(vocab_size=50257)
x = torch.randint(0, 50257, (2, 64))         # B=2, n=64
y = torch.randint(0, 50257, (2, 64))
logits, loss = model(x, y)
loss.backward()
print(f"loss = {loss.item():.4f}, params = {sum(p.numel() for p in model.parameters()):,}")
代码与本课公式的对应
  • 5.1 矩阵化 attentionq @ k.transpose(-2, -1)
  • 5.3 scaled dot product/ math.sqrt(d_k)
  • 5.2 multi-head.view(B, n, h, d_k).transpose(1, 2)
  • 4.6 因果掩码torch.tril(torch.ones(n, n))
  • 5.4 残差x = x + self.attn(...)
  • 5.5 LayerNormnn.LayerNorm(d_model)
  • 4.5 FFNw2(ReLU(w1(x)))
  • 4.4 位置编码PositionalEncoding
  • 6.3 Pre-LNx + self.attn(self.ln1(x), ...)

A.7 进阶练习

  1. 实现 RoPE:在 MultiHeadAttention.forward 中对 $q, k$ 做旋转,不传 absolute position embedding。参考实现
  2. 实现 KV cache:在 forward 增加 past_kv 参数,递增生成时只算新 token 的 attention。
  3. 用 nanoGPT 训一个 char-level Shakespeare 模型:6 层、64 head_dim、约 1 万行代码训完一个 demo,是理解 Transformer 内部最快路径。
  4. 实现 cross-attention:把上面 Block 改成 Encoder-Decoder 形式,加上一层 cross-attn。
  5. 实测 FlashAttention 加速:把 scaled_dot_product_attention 替换为 F.scaled_dot_product_attention(q, k, v, is_causal=True)(PyTorch ≥ 2.0),测 forward+backward 时间。