"如果你能把"预测下一个词"这件事做到极致,你就解决了几乎所有 NLP 问题。" — 这是 ChatGPT 之后越来越多人意识到的真相。
在 2018 年 OpenAI 那篇 GPT-1 论文以前,"语言模型 (Language Model, LM)" 这个词在 NLP 工程师心中常常只是统计机器翻译里那个用来给候选译文打分的小工具。 但短短七年后,GPT-4 / Claude / Gemini 都是语言模型——更确切地说,是自回归语言模型 (autoregressive LM)。 所有所谓的"会写代码、能解题、能对话"的能力,本质都浓缩在一个数学对象里:
本课(CS224N Lecture 4)正是把这个对象第一次拿到台面上详细讲解。它不仅是后续 attention、Transformer、预训练、RLHF 的共同地基, 也是理解现代 LLM 行为(hallucination、长上下文、推理)必须的入门。
注意几条"演化驱动"的虚线:每一步都不是凭空冒出来的新模型,而是上一步的具体失败模式所逼出来的。这种"沿着失败前进"的学习方式,比单独背诵每个模型更牢固。
最直白的定义:语言建模 = 预测下一个词的概率分布。 给定一段已知文本,例如 "the students opened their ___",模型需要为词表中每一个可能的词输出一个概率: "books" 0.21、"laptops" 0.18、"minds" 0.07、"exams" 0.06、"refrigerator" 0.00003、…,所有概率之和等于 1。
更形式地,给定词序列 $\boldsymbol{x}^{(1)}, \boldsymbol{x}^{(2)}, \ldots, \boldsymbol{x}^{(t)}$,语言模型计算:
这里 $V$ 是词表 (vocabulary),$|V|$ 通常在 $10^4$(小型)到 $10^5$(GPT-2 是 50,257)甚至 $10^6$(多语言模型)之间。 注意几个细节:
有了"下一个词"的条件概率,我们就能用概率乘法链式法则 (chain rule) 计算整段文本的概率:
这个分解之所以叫"自回归 (autoregressive)",是因为每一步预测都依赖前面所有时刻已知的输出,像把一个长问题拆成 $T$ 个串行的小问题。 这也意味着推断(生成文本)天然是顺序的,无法像 BERT 那样并行——这正是现代 LLM 推断慢的根本原因,也是 KV-cache、speculative decoding 等优化技术的出发点。
Diyi Yang 在课程里说:"Language Modeling is the most important concept in this class. It leads to most of modern NLP." 这句话不是夸张,而是基于两个事实:
| 下游任务 | 为什么需要 LM |
|---|---|
| 输入法预测/手机自动补全 | 直接输出 $P(\boldsymbol{x}^{(t+1)}\mid \text{已输入})$,挑前 3 个候选 |
| 语音识别 | 声学模型给出多个候选转录,LM 打分挑出"最像人话"的 |
| 手写识别 / OCR | 同上:多个候选 → LM 重排 |
| 拼写/语法纠错 | "He |
| 统计机器翻译 (SMT) | 翻译模型 + LM 联合打分(IBM 模型时代的经典分解) |
| 文档摘要 / 对话生成 | 生成式输出本身就是从 LM 采样 |
| 作者识别 | 每个候选作者训练一个 LM,看测试文本在哪个 LM 下概率最高 |
LM 的损失(cross-entropy / perplexity)本身就是衡量模型有多懂语言的客观指标,不需要任务特定的标注。 这意味着任何研究团队都能用同样的 corpus(如 WikiText-103、The Pile)对比自己的方法。整个 2010s 的神经 LM 论文都在比较这一个数字。
为了感受"足够强的 LM"能做什么,看几个填空题——每一题都隐含着不同层级的语言学/世界知识:
| 填空 | 所需知识层级 |
|---|---|
| Stanford University is located in ___, California. | 百科事实 Trivia |
| I put ___ fork down on the table. | 语法 Syntax (need "the"/"my") |
| The woman walked across the street, checking for traffic over ___ shoulder. | 指代消解 Coreference (her) |
| I went to the ocean to see the fish, turtles, seals, and ___. | 词汇语义/主题 Lexical semantics |
| Overall, the value I got from the two hours was the sum of popcorn and drink. The movie was ___. | 情感推理 Sentiment |
| Iroh went into the kitchen. Standing next to Iroh, Zuko pondered his destiny. Zuko left the ___. | 常识推理 Reasoning (kitchen) |
| I was thinking about the sequence 1, 1, 2, 3, 5, 8, 13, 21, ___ | 数学 Arithmetic (Fibonacci) |
注意一个事实:所有这些任务,从语法填空到斐波那契数列,都可以被表述成"预测下一个词"。 所以一个理想的 LM——只要它真的能把所有 next-token 概率算准——就必须隐含语法、语义、世界知识、常识推理、甚至数学能力。 这就是 Sutskever 的著名论断:"Just predicting the next token is enough." 当然,"enough" 的代价是巨大的数据和算力。
在 2003 年 Bengio 提出神经语言模型以前,从 1980s 到 2000s 主导 NLP 的 LM 范式叫做 n-gram language model—— 完全基于计数和统计,没有任何"学习"过程(没有梯度,没有参数优化)。理解它对于理解神经 LM 的动机至关重要: 神经 LM 的每一处改进,几乎都是针对 n-gram 的具体缺陷。
n-gram LM 的核心是一个简化假设:下一个词只依赖于前面 $n-1$ 个词,而非全部历史。这就是经典的 $(n-1)$ 阶 Markov 假设。
让我们先约定术语:
| 名称 | 定义 | 例子:"the students opened their" |
|---|---|---|
| unigram | 1 个连续词 | "the", "students", "opened", "their" |
| bigram | 2 个连续词 | "the students", "students opened", "opened their" |
| trigram | 3 个连续词 | "the students opened", "students opened their" |
| 4-gram | 4 个连续词 | "the students opened their" |
应用条件概率定义 $P(A\mid B) = P(A,B)/P(B)$:
分子是 n-gram 的概率,分母是 $(n-1)$-gram 的概率。问题转化为:如何估计这两个联合概率?
最直接的答案:最大似然估计 (Maximum Likelihood Estimation, MLE)——在一个大语料库里数频次!
假设我们的语料是 "as the proctor started the clock, the students opened their ___"。在 4-gram 假设下, 模型只看 "students opened their",前面的 "as the proctor started the clock," 全部丢弃——这是 Markov 假设的代价。
n-gram LM 有两个致命的稀疏性问题:
如果某个 4-gram "students opened their petunias" 在语料中从未出现,则 count = 0,模型直接判定 $P(\text{petunias} \mid \cdot) = 0$。 但这显然不对——"petunias (矮牵牛花)" 罕见但不是不可能。
解决方案:加性平滑 (additive / Laplace smoothing)。给每个词的计数都加一个小常数 $\delta$(典型 $\delta=1$ 或 $\delta=0.01$):
分母里加 $\delta\cdot|V|$ 保证概率仍然归一化(因为我们对 $|V|$ 个词都做了加 $\delta$)。这种朴素方法在小词表上能用,但把概率质量从高频词偷给低频词的力度太粗暴。 现代 n-gram 系统采用更精细的方案:
如果连 "students opened their" 这个 trigram 在语料中都没出现,那分母也是 0,整个表达式没有意义。
解决方案:回退 (backoff)。当 $n$-gram 不可用时,退到 $(n-1)$-gram。比如 4-gram "students opened their _" 不够用,就退到 trigram "opened their _"; 还不够,再退到 bigram "their _"。系统化的实现叫 Stupid Backoff (Brants 2007) 或者上面提到的 KN 的迭代版本。
除了稀疏性,n-gram LM 还需要把语料中出现过的所有 n-gram 计数全部存起来。 对于 trigram + 1.7M 词的 Reuters 语料,模型大小约几百 MB;但如果跨到 5-gram、上万亿词的 Web 语料(Google 2006 的 Web 5-gram 数据集解压后 ~24 GB),存储就成为工程瓶颈。 更大的问题是:语料越大,模型越大;模型大小线性甚至超线性增长——这与神经 LM "数据再多模型大小不变" 形成鲜明对比。
让我们用代码感受 n-gram LM 的能力极限。下面是一个 50 行的 trigram LM,训练在 Reuters 商业新闻语料上:
from collections import defaultdict, Counter
import random
import nltk
from nltk.corpus import reuters
nltk.download('reuters', quiet=True)
nltk.download('punkt', quiet=True)
# 1. 收集所有 trigram 计数
tri_counts = defaultdict(Counter) # context (w1,w2) -> {w3: count}
for sent in reuters.sents():
tokens = ["<s>", "<s>"] + [w.lower() for w in sent] + ["</s>"]
for w1, w2, w3 in zip(tokens, tokens[1:], tokens[2:]):
tri_counts[(w1, w2)][w3] += 1
# 2. 从 trigram 计数得到条件概率(无平滑版)
def next_word_dist(context):
counts = tri_counts[context]
total = sum(counts.values())
return {w: c/total for w, c in counts.items()}
# 3. 自回归采样生成
def generate(max_len=40, seed=("<s>", "<s>")):
out, context = [], seed
for _ in range(max_len):
dist = next_word_dist(context)
if not dist: break
words, probs = zip(*dist.items())
nxt = random.choices(words, weights=probs)[0]
if nxt == "</s>": break
out.append(nxt)
context = (context[1], nxt)
return " ".join(out)
print(generate(seed=("today", "the")))
# 可能输出:
# today the price of gold per ton, while production of shoe
# lasts and shoe industry, the bank intervened just after it
# considered and rejected an imf demand to rebuild depleted
# european stocks, sept 30 end primary 76 cts a share.
总结 n-gram LM 的四宗罪:
第 4 条尤其重要。n-gram LM 把每个词当作原子符号,不知道 "cat" 和 "feline" 在意义上相似。这一缺陷只能由分布式表示 (distributed representations)——也就是词向量——来根本解决。 这正是 Bengio 2003 神经 LM 的核心贡献,下一节详谈。
回想 Lecture 3 讲的 window-based neural classifier(用于命名实体识别): 取一个固定大小的词窗口(比如 5 个词),把每个词的 embedding 拼起来,过一个隐藏层,最后输出类别。
Bengio 等人 (2000, 2003) 的洞见非常简单:把这个分类器的输出层从"实体类别"换成"词表上的概率分布",就得到了第一个神经语言模型。
Bengio et al. (JMLR 2003) 在论文 "A Neural Probabilistic Language Model" 里指出了三个突破:
定窗口神经 LM 解决了 n-gram 的稀疏性,但没有解决上下文长度问题,还引入了新的问题:
RNN 的核心 idea 一句话:在每个时刻应用同一个权重矩阵 $\boldsymbol{W}$,把前一时刻的"记忆"和当前输入合成新的"记忆"。
这就解决了上一节的三个病灶:
RNN-LM 的前向传播由 4 个公式定义:
其中 $\sigma$ 通常是 $\tanh$(取值 $[-1,1]$,零中心)。初始隐藏状态 $\boldsymbol{h}^{(0)}$ 可以设为零向量或者一个可学习参数。
注意:参数集合 $\{\boldsymbol{E}, \boldsymbol{W}_e, \boldsymbol{W}_h, \boldsymbol{U}, \boldsymbol{b}_1, \boldsymbol{b}_2\}$ 不随序列长度 $T$ 增长——这是 RNN 相对于定窗口 LM 的关键胜利。
| 优点 ✓ | 说明 |
|---|---|
| 处理任意长输入 | 循环结构,$T$ 多大都不变模型大小 |
| 原理上可看到所有历史 | $\boldsymbol{h}^{(t)}$ 是 $\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(t)}$ 的函数 |
| 模型大小恒定 | 仅与 $d_h, d_e, |V|$ 有关,与上下文长度无关 |
| 输入处理对称 | 所有时间步用同一个 $\boldsymbol{W}_h$,无位置偏倚 |
| 可端到端训练 | SGD + BPTT,全程可微 |
| 缺点 ✗ | 说明 |
|---|---|
| 计算慢(顺序依赖) | $\boldsymbol{h}^{(t)}$ 依赖 $\boldsymbol{h}^{(t-1)}$,无法跨时间并行——GPU 利用率低 |
| 难以记住远距离信息 | 梯度消失/爆炸,原理上能看到却实际学不到——下一章重点 |
理解了数学,写代码就是直译。下面是一个最简版本(不用 nn.RNN 黑盒):
import torch
import torch.nn as nn
class VanillaRNNLM(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.W_e = nn.Linear(embed_dim, hidden_dim, bias=False)
self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=True) # bias 合并在此
self.U = nn.Linear(hidden_dim, vocab_size, bias=True)
self.tanh = nn.Tanh()
self.hidden_dim = hidden_dim
def forward(self, x, h0=None):
"""
x: (B, T) long tensor, token ids
h0: (B, H) initial hidden state, or None
returns: logits (B, T, V)
"""
B, T = x.shape
h = h0 if h0 is not None else x.new_zeros(B, self.hidden_dim, dtype=torch.float)
e = self.embed(x) # (B, T, d_e)
logits = []
for t in range(T): # ← 顺序循环,无法并行
h = self.tanh(self.W_e(e[:, t]) + self.W_h(h)) # (B, H)
logits.append(self.U(h)) # (B, V)
return torch.stack(logits, dim=1) # (B, T, V)
nn.RNN / nn.LSTM / nn.GRU,底层用 cuDNN 实现的融合 CUDA kernel,比 Python for 循环快 10–50 倍。
手写仅用于教学和需要奇异变体的研究场景。即便如此,仍然是时间串行——这是 RNN 的本质瓶颈。
训练 RNN-LM 的目标:让模型预测出的下一个词分布 $\hat{\boldsymbol{y}}^{(t)}$ 尽可能接近真实下一个词的 one-hot 分布 $\boldsymbol{y}^{(t)}$。 形式化即最小化交叉熵:
由于 $\boldsymbol{y}^{(t)}$ 是 one-hot(只在真实下一词位置为 1),求和坍缩为负对数似然 (NLL):模型给"正确下一词" 分配的概率越高,损失越小。
全序列损失是对所有时刻求平均:
注意上图的一个关键细节:训练时第 $t$ 步输入的是真实的 $\boldsymbol{x}^{(t)}$(即语料中真实出现的词),而不是模型上一步预测的 $\hat{\boldsymbol{x}}^{(t)}$。 这种做法叫做 Teacher Forcing——老师在每一步把"正确答案"塞回模型,强迫它在已知正确历史上学习。
理论上 $J(\theta) = \frac{1}{T}\sum_t J^{(t)}$ 是对整个语料 $T$ 求平均,但语料动辄数十亿 token,一次性算完损失和梯度内存爆炸。 实际做法:
PyTorch 代码骨架:
model = VanillaRNNLM(vocab_size).cuda()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(NUM_EPOCHS):
for x_batch in loader: # x_batch: (B, T)
# 用前 T-1 个 token 预测后 T-1 个 token(teacher forcing)
inp, tgt = x_batch[:, :-1], x_batch[:, 1:]
logits = model(inp.cuda()) # (B, T-1, V)
loss = loss_fn(logits.reshape(-1, vocab_size),
tgt.cuda().reshape(-1))
opt.zero_grad()
loss.backward() # ← BPTT 自动展开
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) # 见第七部分
opt.step()
为了讲清 BPTT,先回顾多变量微积分。给定 $f(x, y)$,且 $x = x(t), y = y(t)$ 都是 $t$ 的函数,则复合函数的导数:
直觉:$t$ 通过 $x$ 影响 $f$ 的一份,加上通过 $y$ 影响 $f$ 的另一份。
对计算图,这条规则的推广是:"梯度在出度分支处相加 (Gradients sum at outward branches)"。 即:如果某个变量被下游多个节点使用,反向传播时来自各下游的梯度要全部加起来。
现在回到 RNN。我们关心一个看似奇怪的问题:损失 $J^{(t)}(\theta)$ 对重复使用的权重 $\boldsymbol{W}_h$ 的梯度是什么?
关键洞察是:把 $\boldsymbol{W}_h$ 看作在每一步 $i=1,\ldots,t$ 都"复制了一份" $\boldsymbol{W}_h\big|_{(i)}$, 然后对每一份的梯度求和,最后再利用 $\frac{\partial \boldsymbol{W}_h|_{(i)}}{\partial \boldsymbol{W}_h} = \boldsymbol{I}$ (即一份份其实就是同一个)。
为了简化,记 $\boldsymbol{z}^{(t)} = \boldsymbol{W}_h \boldsymbol{h}^{(t-1)} + \boldsymbol{W}_e \boldsymbol{e}^{(t)}$,所以 $\boldsymbol{h}^{(t)} = \tanh(\boldsymbol{z}^{(t)})$。
loss.backward() 一行就自动完成。
但理解推导是必要的——因为梯度消失/爆炸正是从步骤 2 的连乘里诞生的,光会调 API 你无法诊断这些问题。
完整 BPTT 要把整个序列的计算图保存在显存里,对于长序列(如一本小说的几千 tokens)会爆显存。 工程实践用 Truncated BPTT (TBPTT):每隔 $k$ 步切断反向梯度。
典型 $k = 20\sim50$。前向继续维护 $\boldsymbol{h}^{(t)}$,但反向只在最近 $k$ 步内传播。
# Truncated BPTT 模式
hidden = None
for chunk in long_sequence_chunks(text, chunk_size=35): # 每 35 tokens 一段
logits, hidden = model(chunk, hidden)
loss = loss_fn(logits, chunk_targets)
loss.backward()
opt.step()
hidden = hidden.detach() # ← 关键:切断对前面 chunk 的反传依赖
opt.zero_grad()
训练完 RNN-LM 后,怎么用它写新文本?流程叫 autoregressive rollout:
<s> 句首符号),喂入模型,得到下一个词的概率分布 $\hat{\boldsymbol{y}}^{(1)}$。</s> 或达到最大长度时停止。"采样"这一步有大学问。给定一个 $|V|$ 维分布 $\hat{\boldsymbol{y}}^{(t)}$,怎么选下一个词?这就是 解码策略 (decoding strategy)。
| 策略 | 公式 | 特点 |
|---|---|---|
| 贪心 (greedy) | $\arg\max_w \hat{y}^{(t)}_w$ | 最确定,但容易重复/无趣 |
| 纯随机采样 | $w \sim \hat{\boldsymbol{y}}^{(t)}$ | 多样性最大,但容易跑偏(罕见词坍塌) |
| 温度采样 | $\hat{y}_w \propto \exp(z_w / \tau)$ | $\tau<1$ 更尖锐(保守),$\tau>1$ 更平坦(创意),$\tau\to 0$ = greedy |
| Top-k | 只在 top-$k$ 词上重新归一化采样 | 截断长尾,避免奇怪词;$k\in[10,100]$ 常见 |
| Top-p (nucleus) | 选累积概率达到 $p$ 的最小词集,重归一化采样 | Holtzman 2019 提出,目前 LLM 默认;$p\in[0.9,0.95]$ 常见 |
| Beam Search | 同时保持 $k$ 条候选,每步扩展 → 取 top-$k$ | 用于翻译/摘要等需要"找最大概率序列"的任务 |
怎么客观比较两个 LM 谁更好?标准答案是 perplexity (PPL, 困惑度)——一个 LM 对测试集的预测有多"惊讶"。
直觉解读:
所以 perplexity 越低越好。直觉上,perplexity = $k$ 意味着"模型在每一步都好像在 $k$ 个等概率词中犹豫"。 $k=1$ 是完美预测(一定知道下一个词是什么),$k=|V|$ 是完全随机猜。
Perplexity 跟 cross-entropy loss 在数学上是同一件事的两个皮肤。
也就是:perplexity 就是 cross-entropy 损失的指数。因为 $\exp$ 是单调递增函数,最小化 cross-entropy 等价于最小化 perplexity。
| 模型 / 时代 | WikiText-103 Test PPL | 备注 |
|---|---|---|
| 5-gram KN smoothing (经典统计) | ~140 | n-gram 的极限 |
| LSTM (Merity 2018) | ~40 | + tied embeddings + dropout |
| Transformer-XL (Dai 2019) | ~18 | RNN 全面退场 |
| GPT-2 1.5B (zero-shot) | ~17 | 大规模预训练范式 |
| GPT-3 175B | ~10 | 规模定律 + RLHF 前夜 |
现在我们抵达本课最重要、也最数学化的一节。RNN 在原理上能记住任意远的历史(隐状态能传递无穷长),但实际训练中极难学到长程依赖。 原因正是 Pascanu et al. (2013) 论文 "On the difficulty of training recurrent neural networks" 揭示的两个数学病理: 梯度消失 (vanishing gradient) 与 梯度爆炸 (exploding gradient)。
我们之前推导过 BPTT 的核心:
$$ \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(i)}} = \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} \cdot \prod_{j=i+1}^{t} \frac{\partial \boldsymbol{h}^{(j)}}{\partial \boldsymbol{h}^{(j-1)}} $$注意中间那个连乘符号 $\prod$。每多一个时间步,就要多乘一个雅可比矩阵 $\partial \boldsymbol{h}^{(j)} / \partial \boldsymbol{h}^{(j-1)}$。 连乘 $T$ 个矩阵的结果,会按矩阵谱半径 (spectral radius) 的 $T$ 次方变化:
我们来精确化"梯度消失"。考虑无激活函数的简化 RNN:$\boldsymbol{h}^{(t)} = \boldsymbol{W}_h \boldsymbol{h}^{(t-1)}$(去掉 $\sigma$)。
那么 $\boldsymbol{h}^{(t)} = \boldsymbol{W}_h^{\,t-i}\,\boldsymbol{h}^{(i)}$,所以
对 $\boldsymbol{W}_h$ 做特征分解 $\boldsymbol{W}_h = \boldsymbol{Q}\boldsymbol{\Lambda}\boldsymbol{Q}^{-1}$(假设可对角化),则 $\boldsymbol{W}_h^{\,t-i} = \boldsymbol{Q}\boldsymbol{\Lambda}^{\,t-i}\boldsymbol{Q}^{-1}$。 对角矩阵 $\boldsymbol{\Lambda}^{\,t-i}$ 的元素是 $\lambda_k^{\,t-i}$。如果某些 $|\lambda_k| < 1$,对应的分量按 $\lambda_k^{\,t-i}$ 指数衰减;如果 $|\lambda_k| > 1$,则指数爆炸。
对于带 $\tanh$ 的真实 RNN,Pascanu 等人证明了如下充分条件:
为什么梯度消失是致命的?看下面这个真实例子:
When she tried to print her tickets, she found that the printer was out of toner. She went to the stationery store to buy more toner. It was very overpriced. After installing the toner into the printer, she finally printed her ___
正确答案是 "tickets"。它在第 7 个 token 处出现,目标位置在第 40+ 个 token——距离 33+ 步。 要学会这种填空,模型必须把第 7 步的 "tickets" 信息保留到第 40 步。但因为梯度消失,反传时第 40 步的 loss 几乎完全不会更新第 7 步周围的权重。 模型学不到这种长程关联,所以测试时也无法预测。
如果谱半径 > 1,梯度会按 $\rho^T$ 指数爆炸。SGD 更新规则 $\theta^{\text{new}} = \theta^{\text{old}} - \alpha\nabla_\theta J$ 中, 一个 $10^{10}$ 量级的梯度乘上学习率 $\alpha=0.01$,就会让参数瞬间跳到 $10^8$ 量级——直接 NaN,训练崩溃。
"You think you've found a hill to climb, but suddenly you're in Iowa." — Diyi Yang 课堂金句
grad_norm = total_norm_of_gradients(grads)
if grad_norm > threshold:
for g in grads:
g *= threshold / grad_norm # 等比例缩小
# 然后 SGD 更新
直觉:方向不变(仍指向下降)只是步子小一点,避免一脚踩进灾难区域。PyTorch 一行实现:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
对付梯度消失的两条治本路径:
vanilla RNN 的隐藏状态 $\boldsymbol{h}^{(t)}$ 每一步都被重写: $\boldsymbol{h}^{(t)} = \sigma(\boldsymbol{W}_h\boldsymbol{h}^{(t-1)} + \cdots)$,前一时刻的信息每次都要过激活函数,自然衰减。
LSTM (Long Short-Term Memory, Hochreiter & Schmidhuber 1997) 的关键 idea: 引入一条独立的 cell state $\boldsymbol{c}^{(t)}$,它的更新是加法而非完全重写:
当 $\boldsymbol{f}^{(t)} \approx 1$ 时,$\boldsymbol{c}^{(t)} \approx \boldsymbol{c}^{(t-1)}$——信息可以原封不动跨越任意步。 反向传播时 $\partial \boldsymbol{c}^{(t)} / \partial \boldsymbol{c}^{(t-1)} \approx \boldsymbol{1}$,不连乘衰减。这就是 LSTM 缓解梯度消失的几何原理。
更激进的方案:让梯度不经过那一长串雅可比矩阵的连乘。这就是 残差连接 (residual connection) 和 attention 的核心: 建立 $O(1)$ 长度的"梯度高速公路",让损失信号可以直接传到任意远的位置。
本课最后一个章节,我们把 RNN-LM 推广到一个具体应用:机器翻译 (Machine Translation, MT)。 这个推广不仅是工程练习,更是 attention 机制和 Transformer 诞生的直接历史路径。
定义:给定源语言句子 $\boldsymbol{x} = x_1, x_2, \ldots, x_n$(如英文 "I like deep learning"), 输出目标语言句子 $\boldsymbol{y} = y_1, y_2, \ldots, y_m$(如中文"我喜欢深度学习")。 注意:$n$ 和 $m$ 通常不相等。
| 时代 | 方法 | 特点 |
|---|---|---|
| 1950s | Rule-based (规则翻译) | 专家手写翻译规则。冷战军用。完全无泛化。 |
| 1990s–2010s | SMT (Statistical MT, IBM 模型) | 词对齐 + 短语翻译 + n-gram LM + 重排序,几百位工程师维护 |
| 2014–2016 | NMT 萌芽 (Seq2Seq) | 一个神经网络端到端学翻译 |
| 2016 | Google 切到 GNMT | NMT 全面取代 SMT |
| 2017+ | Transformer 时代 | WMT 等竞赛上 BLEU 持续刷新 |
Sutskever, Vinyals, Le 在 NeurIPS 2014 论文 "Sequence to Sequence Learning with Neural Networks" 提出 seq2seq。 核心想法极其优雅:用两个 RNN 串联——一个把源句子"压缩"成向量,另一个把这个向量"解压"成目标句子。
关键观察:Decoder 就是一个 RNN-LM,只不过它的 hidden state 一开始被源句"初始化"了。所以 seq2seq 是一个条件语言模型 (Conditional LM):
对比 vanilla LM 的 $P(\boldsymbol{y})$,NMT 的 $P(\boldsymbol{y}\mid \boldsymbol{x})$ 多了一份源句条件——其余数学结构完全一样。 这种"万物皆条件 LM"的视角,是后来 instruction-following、prompt engineering、in-context learning 等概念的基础。
训练 NMT 需要平行语料 (parallel corpus)——同一句话的源语言和目标语言对照(如 WMT、OPUS 数据集)。 对每个 $(\boldsymbol{x}, \boldsymbol{y})$ 对,用 Teacher Forcing 喂入 Decoder,计算每个时刻的交叉熵,求平均:
Sutskever 2014 原版用 4 层 LSTM;Luong 2015 进一步堆叠,发现深度对翻译质量重要。多层架构里,第 $i$ 层 RNN 的隐藏状态作为第 $i+1$ 层的输入:
深层 RNN 能够捕捉层次化的语言结构(底层:词形态;中层:短语;高层:语义),但深度受梯度消失影响,2-4 层通常是 RNN 的上限。 Transformer 把这个上限推到 6, 12, 96, 200+ 层。
看图 8.1 那个橙色高亮的 $\boldsymbol{c}$ 向量——它要承担一项不可能的任务: 把整个源句子的语义压缩到一个 $d_h$ 维向量里(典型 $d_h=512$)。
Attention 的根本想法(Bahdanau 2014, Luong 2015):解码每一步不止看 $\boldsymbol{c}$,直接回头看 Encoder 的所有隐藏状态,加权求和挑出当前最相关的源词。 这就把"一根独木桥"变成了"一座立交桥",瓶颈消失。具体数学和实现,请见下一课。
"""
最小可运行的 RNN-LM 训练脚本
依赖: torch, torchtext, datasets
"""
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
# ---------- 1. 数据准备 ----------
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
# 简化处理:用空格分词 + 构建词表
from collections import Counter
def tokenize(s): return s.split()
counter = Counter()
for ex in ds["train"]: counter.update(tokenize(ex["text"]))
vocab = ["<pad>", "<unk>", "<bos>", "<eos>"] + \
[w for w,c in counter.most_common(20000)]
stoi = {w:i for i,w in enumerate(vocab)}
def encode(s):
return [stoi.get(w, 1) for w in tokenize(s)]
# 拼接所有训练文本,切成长度 35 的块(典型 TBPTT 设置)
BPTT = 35
def make_chunks(split):
ids = []
for ex in ds[split]:
ids += encode(ex["text"])
n_chunks = len(ids) // BPTT
ids = ids[:n_chunks * BPTT]
return torch.tensor(ids).view(-1, BPTT)
train_data = make_chunks("train")
val_data = make_chunks("validation")
# ---------- 2. 模型 ----------
class RNNLM(nn.Module):
def __init__(self, vocab_size, emb=200, hidden=200, n_layers=2, dropout=0.2):
super().__init__()
self.embed = nn.Embedding(vocab_size, emb)
self.drop = nn.Dropout(dropout)
self.rnn = nn.LSTM(emb, hidden, n_layers, dropout=dropout, batch_first=True)
self.fc = nn.Linear(hidden, vocab_size)
def forward(self, x, h=None):
e = self.drop(self.embed(x))
out, h = self.rnn(e, h)
return self.fc(self.drop(out)), h
device = "cuda" if torch.cuda.is_available() else "cpu"
model = RNNLM(len(vocab)).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = nn.CrossEntropyLoss()
def evaluate(data):
model.eval()
total_loss, n = 0., 0
with torch.no_grad():
for i in range(0, len(data)-1, 64): # batch=64 (粗略)
batch = data[i:i+64].to(device)
x, y = batch[:, :-1], batch[:, 1:]
logits, _ = model(x)
loss = crit(logits.reshape(-1, len(vocab)), y.reshape(-1))
total_loss += loss.item() * y.numel()
n += y.numel()
return total_loss / n
# ---------- 3. 训练循环 ----------
for epoch in range(10):
model.train()
hidden = None
for i in range(0, len(train_data)-1, 64):
batch = train_data[i:i+64].to(device)
x, y = batch[:, :-1], batch[:, 1:]
logits, hidden = model(x, hidden)
loss = crit(logits.reshape(-1, len(vocab)), y.reshape(-1))
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
opt.step()
# 关键:detach hidden 以截断 BPTT
hidden = tuple(h.detach() for h in hidden)
val_loss = evaluate(val_data)
print(f"Epoch {epoch}: val_loss={val_loss:.3f}, val_ppl={math.exp(val_loss):.1f}")
# ---------- 4. 文本生成 ----------
@torch.no_grad()
def generate(prompt, max_len=50, temp=1.0):
model.eval()
ids = torch.tensor([encode(prompt)], device=device)
h = None
out = list(ids[0].cpu().numpy())
for _ in range(max_len):
logits, h = model(ids[:, -1:], h)
probs = torch.softmax(logits[0, -1] / temp, dim=-1)
nxt = torch.multinomial(probs, 1).item()
out.append(nxt)
ids = torch.tensor([[nxt]], device=device)
return " ".join(vocab[i] for i in out)
print(generate("the meaning of life is", temp=0.8))