Build a Large Language Model (From Scratch) · Chapter 3 中文译文

编码注意力机制

译自 Build a Large Language Model (From Scratch) MEAP V08。本文保留章节结构、代码、图、表与公式标记;图像从原 PDF 页段抽取并嵌入。

第 3 章 编写注意力机制代码

本章涵盖以下内容

  • 探索在神经网络中使用注意力机制的原因
  • 介绍一个基础的自注意力框架,并逐步推进到增强版自注意力机制
  • 实现一个因果注意力模块,使 LLM 能够一次生成一个词元
  • 用 dropout 随机屏蔽选定的注意力权重,以减少过拟合
  • 把多个因果注意力模块堆叠成一个多头注意力模块

在上一章中,你学习了如何为训练 LLM 准备输入文本。这包括把文本拆分成单个单词词元和子词词元,并把它们编码成向量表示,也就是 LLM 所用的嵌入。在本章中,我们将考察 LLM 架构本身的一个组成部分:注意力机制,如图 3.1 所示。

图 3.1
图 3.1 关于编写 LLM、在通用文本数据集上预训练 LLM、以及在带标签数据集上微调 LLM 这三个主要阶段的思维模型。本章重点介绍注意力机制,它是 LLM 架构的组成部分。

注意力机制是一个内容丰富的主题,因此我们用整整一章来讲解它。我们将在很大程度上单独考察这些注意力机制,并聚焦在其机制层面。下一章中,我们会编写 LLM 中围绕自注意力机制的其余部分,看看它如何工作,并创建一个用于生成文本的模型。在本章中,我们将实现四种不同的注意力机制变体,如图 3.2 所示。

图 3.2
图 3.2 该图展示了本章将编写的不同注意力机制:先从一个简化版自注意力开始,再添加可训练权重。因果注意力机制会给自注意力添加掩码,使 LLM 能够一次生成一个词。最后,多头注意力把注意力机制组织成多个头,使模型能够并行捕捉输入数据的不同方面。

图 3.2 所示的这些不同注意力变体彼此递进。本章末尾的目标是得到一个紧凑且高效的多头注意力实现,然后可以把它接入我们将在下一章编写的 LLM 架构中。

3.1 对长序列建模的问题

在本章后面深入介绍 LLM 核心的自注意力机制之前,先看一个问题:早于 LLM 且没有注意力机制的架构有什么问题?假设我们想开发一个语言翻译模型,把文本从一种语言翻译成另一种语言。如图 3.3 所示,由于源语言和目标语言的语法结构不同,我们不能简单地逐词翻译文本。

图 3.3
图 3.3 将文本从一种语言翻译成另一种语言时,例如从德语翻译成英语,不能只是逐词翻译。相反,翻译过程需要上下文理解和语法对齐。

为了解决不能逐词翻译文本的问题,常见做法是使用一个包含两个子模块的深度神经网络,即所谓的编码器和解码器。编码器的任务是先读入并处理整段文本,然后由解码器生成翻译后的文本。

我们在第 1 章介绍 Transformer 架构时,已经简要讨论过编码器-解码器网络(第 1.4 节“将 LLM 用于不同任务”)。在 Transformer 出现之前,循环神经网络(RNN)是语言翻译中最流行的编码器-解码器架构。

RNN 是一种神经网络:前面步骤的输出会作为当前步骤的输入,因此非常适合文本这类序列数据。如果你不熟悉 RNN,不用担心;你不需要了解 RNN 的详细工作原理也能跟上这里的讨论。我们在这里更关注编码器-解码器设置的一般概念。

在编码器-解码器 RNN 中,输入文本被送入编码器并按顺序处理。编码器在每一步都会更新其隐藏状态(隐藏层中的内部值),试图在最终隐藏状态中捕捉整个输入句子的含义,如图 3.4 所示。随后,解码器使用这个最终隐藏状态开始逐词生成翻译后的句子。它也会在每一步更新自己的隐藏状态,而这个隐藏状态应当携带下一词预测所需的上下文。

图 3.4
图 3.4 在 Transformer 模型出现之前,编码器-解码器 RNN 是机器翻译的常见选择。编码器接收源语言中的词元序列作为输入,编码器的一个隐藏状态(中间神经网络层)会编码整个输入序列的压缩表示。然后,解码器使用其当前隐藏状态开始逐个词元地进行翻译。

虽然我们不需要了解这些编码器-解码器 RNN 的内部工作方式,但这里的关键思想是:编码器部分把整段输入文本处理成一个隐藏状态(记忆单元)。随后,解码器接收这个隐藏状态来生成输出。你可以把这个隐藏状态看作一个嵌入向量,这是我们在第 2 章讨论过的概念。

编码器-解码器 RNN 的大问题和局限在于,在解码阶段,RNN 无法直接访问编码器中更早的隐藏状态。因此,它只能依赖当前隐藏状态,而该状态封装了所有相关信息。这可能导致上下文丢失,尤其是在复杂句子中,依赖关系可能跨越很长距离。

对于不熟悉 RNN 的读者来说,理解或学习这种架构并不是必需的,因为本书不会使用它。本节的要点是,编码器-解码器 RNN 存在一个缺陷,而这个缺陷推动了注意力机制的设计。

3.2 用注意力机制捕捉数据依赖关系

如前所述,在 Transformer LLM 出现之前,人们通常使用 RNN 来处理语言建模任务,例如语言翻译。RNN 在翻译短句时效果不错,但对较长文本效果不佳,因为它们无法直接访问输入中的先前词语。

这种方法的一个主要缺点是,RNN 必须在把信息传递给解码器之前,把整个编码后的输入都记在单个隐藏状态中,如上一节图 3.4 所示。

因此,研究人员在 2014 年提出了所谓的 Bahdanau 注意力机制(以相应论文第一作者命名)用于 RNN。它修改了编码器-解码器 RNN,使解码器在每个解码步骤都能选择性地访问输入序列的不同部分,如图 3.5 所示。

图 3.5
图 3.5 使用注意力机制后,网络中负责生成文本的解码器部分可以选择性地访问所有输入词元。这意味着,在生成某个给定输出词元时,一些输入词元比其他词元更重要。重要性由所谓的注意力权重决定,我们稍后会计算这些权重。注意,本图展示的是注意力背后的一般思想,并没有描绘 Bahdanau 机制的精确实现;Bahdanau 机制是一种 RNN 方法,超出了本书范围。

有趣的是,仅仅三年后,研究人员发现,构建用于自然语言处理的深度神经网络并不需要 RNN 架构,于是提出了原始 Transformer 架构(第 1 章已讨论),其中包含受 Bahdanau 注意力机制启发的自注意力机制。

自注意力是一种机制:在计算一个序列的表示时,它允许输入序列中的每个位置关注同一序列中的所有位置。自注意力是当代基于 Transformer 架构的 LLM(例如 GPT 系列)的关键组成部分。

本章聚焦于编写并理解 GPT 类模型中使用的这种自注意力机制,如图 3.6 所示。在下一章中,我们会继续编写 LLM 的其余部分。

图 3.6
图 3.6 自注意力是 Transformer 中的一种机制,它通过允许序列中的每个位置与同一序列内所有其他位置交互并衡量其重要性,来计算更高效的输入表示。在下一章编写 GPT 类 LLM 的其余部分之前,本章会从零开始编写这个自注意力机制。

3.3 用自注意力关注输入的不同部分

现在我们将介绍自注意力机制的内部工作方式,并学习如何从零开始编写它。自注意力是每个基于 Transformer 架构的 LLM 的基石。值得注意的是,这个主题可能需要大量专注和注意力(不是双关);但一旦掌握其基础,你就攻克了本书以及实现 LLM 过程中最困难的部分之一。

由于自注意力可能显得复杂,尤其是你第一次接触它时,我们会在下一小节先介绍一个简化版自注意力。之后,在第 3.4 节中,我们再实现 LLM 中使用的、带可训练权重的自注意力机制。

3.3.1 不带可训练权重的简单自注意力机制

在本节中,我们实现一种简化的自注意力变体,它不包含任何可训练权重,如图 3.7 所概括。本节的目标是在下一节 3.4 添加可训练权重之前,说明自注意力中的几个关键概念。

图 3.7
图 3.7 自注意力的目标是为每个输入元素计算一个上下文向量,该向量组合了所有其他输入元素的信息。在图中示例里,我们计算上下文向量 \(z^{(2)}\)。每个输入元素对计算 \(z^{(2)}\) 的重要性或贡献由注意力权重 \(\alpha_{21}\) 到 \(\alpha_{2T}\) 决定。计算 \(z^{(2)}\) 时,注意力权重是相对于输入元素 \(x^{(2)}\) 和所有其他输入计算出来的。这些注意力权重的精确计算方式会在本节稍后讨论。

图 3.7 展示了一个输入序列,记为 \(x\),它由 \(T\) 个元素组成,表示为 \(x^{(1)}\) 到 \(x^{(T)}\)。该序列通常表示文本,例如一个句子,并且已经转换成词元嵌入,如第 2 章所解释。

例如,考虑输入文本 “Your journey starts with one step.” 在这种情况下,序列中的每个元素(例如 \(x^{(1)}\))都对应一个 \(d\) 维嵌入向量,表示一个特定词元,例如 “Your”。在图 3.7 中,这些输入向量显示为 \(3\) 维嵌入。

在自注意力中,我们的目标是为输入序列中的每个元素 \(x^{(i)}\) 计算上下文向量 \(z^{(i)}\)。上下文向量可以解释为一种增强的嵌入向量。

为了说明这个概念,我们聚焦于第二个输入元素 \(x^{(2)}\)(对应词元 “journey”)的嵌入向量,以及图 3.7 底部所示的对应上下文向量 \(z^{(2)}\)。这个增强后的上下文向量 \(z^{(2)}\) 是一个嵌入,其中包含关于 \(x^{(2)}\) 以及所有其他输入元素 \(x^{(1)}\) 到 \(x^{(T)}\) 的信息。

在自注意力中,上下文向量起着关键作用。它们的用途是通过纳入序列中所有其他元素的信息,为输入序列(例如一个句子)中的每个元素创建增强表示,如图 3.7 所示。这对 LLM 至关重要,因为 LLM 需要理解句子中词与词之间的关系和相关性。稍后,我们会添加可训练权重,帮助 LLM 学习构造这些上下文向量,使其与 LLM 生成下一个词元的任务相关。

在本节中,我们逐步实现一个简化的自注意力机制,用来计算这些权重以及由此得到的上下文向量。

考虑下面这个输入句子,它已经按照第 2 章所讨论的方式嵌入为 \(3\) 维向量。出于说明目的,我们选择较小的嵌入维度,以确保它无需换行就能放在页面上:

import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your       (x^1)
    [0.55, 0.87, 0.66], # journey     (x^2)
    [0.57, 0.85, 0.64], # starts      (x^3)
    [0.22, 0.58, 0.33], # with        (x^4)
    [0.77, 0.25, 0.10], # one         (x^5)
    [0.05, 0.80, 0.55]] # step        (x^6)
)

实现自注意力的第一步是计算中间值 \(\omega\),它们称为注意力分数,如图 3.8 所示。(请注意,图 3.8 以截断形式显示了前面 inputs 张量中的值;例如,由于空间限制,0.87 被截断为 0.8。在这个截断版本中,单词 “journey” 和 “starts” 的嵌入可能会因为随机性而显得相似。)

图 3.8
图 3.8 本节的总体目标是展示如何使用第二个输入元素 \(x^{(2)}\) 作为查询来计算上下文向量 \(z^{(2)}\)。该图展示了第一个中间步骤:把查询 \(x^{(2)}\) 与所有其他输入元素之间的注意力分数 \(\omega\) 计算为点积。(注意,为减少视觉杂乱,图中的数字被截断为小数点后一位。)

图 3.8 说明了我们如何计算查询词元与每个输入词元之间的中间注意力分数。我们通过计算查询 \(x^{(2)}\) 与每个其他输入词元的点积来确定这些分数:

query = inputs[1]                                                                   #A
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
     attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
#A
第二个输入词元作为查询。

计算得到的注意力分数如下:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

下一步,如图 3.9 所示,我们对先前计算出的每个注意力分数进行归一化。

图 3.9
图 3.9 在相对于输入查询 \(x^{(2)}\) 计算出注意力分数 \(\omega_{21}\) 到 \(\omega_{2T}\) 之后,下一步是通过归一化这些注意力分数,得到注意力权重 \(\alpha_{21}\) 到 \(\alpha_{2T}\)。

图 3.9 所示归一化背后的主要目标,是得到总和为 \(1\) 的注意力权重。这种归一化是一种约定,有助于解释结果,也有助于维持 LLM 训练的稳定性。下面是一种实现这个归一化步骤的直接方法:

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

输出显示,注意力权重现在总和为 \(1\):

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)

在实践中,更常见也更推荐使用 softmax 函数进行归一化。这种方法更善于处理极端值,并且在训练期间提供更有利的梯度性质。下面是一个用于归一化注意力分数的基础 softmax 函数实现:

def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

输出显示,softmax 函数同样达到了目标,把注意力权重归一化到总和为 \(1\):

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

此外,softmax 函数确保注意力权重始终为正。这使得输出可以解释为概率或相对重要性,其中更高的权重表示更高的重要性。

注意,这个朴素的 softmax 实现(softmax_naive)在处理很大或很小的输入值时,可能遇到数值不稳定问题,例如上溢和下溢。因此,在实践中,建议使用 PyTorch 的 softmax 实现,它已经针对性能进行了大量优化:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

在这个例子中,我们可以看到它产生的结果与前面的 softmax_naive 函数相同:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

现在已经计算出归一化后的注意力权重,我们就可以进入图 3.10 所示的最后一步:通过把嵌入后的输入词元 \(x^{(i)}\) 与对应的注意力权重相乘,然后把得到的向量求和,计算上下文向量 \(z^{(2)}\)。

图 3.10
图 3.10 在计算并归一化注意力分数以得到查询 \(x^{(2)}\) 的注意力权重之后,最后一步是计算上下文向量 \(z^{(2)}\)。这个上下文向量是所有输入向量 \(x^{(1)}\) 到 \(x^{(T)}\) 按注意力权重加权后的组合。

图 3.10 中描绘的上下文向量 \(z^{(2)}\) 是所有输入向量的加权和。它涉及把每个输入向量乘以对应的注意力权重:

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
     context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

该计算的结果如下:

tensor([0.4419, 0.6515, 0.5683])

在下一节中,我们会把这个计算上下文向量的过程推广为同时计算所有上下文向量。

3.3.2 计算所有输入词元的注意力权重

上一节中,我们计算了输入 2 的注意力权重和上下文向量,如图 3.11 中高亮的行所示。现在,我们将扩展这个计算,以计算所有输入的注意力权重和上下文向量。

图 3.11
图 3.11 高亮行展示了以第二个输入元素作为查询时的注意力权重,也就是上一节中我们计算出的结果。本节会推广该计算,以得到所有其他注意力权重。

我们遵循与之前相同的三个步骤,如图 3.12 所概括,只是在代码中做少量修改,以计算所有上下文向量,而不只是第二个上下文向量 \(z^{(2)}\)。

图 3.12
图 3.12 在自注意力中,我们先计算注意力分数,然后把它们归一化,得到总和为 \(1\) 的注意力权重。这些注意力权重用于把输入加权求和,从而计算上下文向量。

首先,在图 3.12 所示的步骤 1 中,我们添加一个额外的 for 循环,计算所有输入对之间的点积。

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
         attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

得到的注意力分数如下:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
         [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
         [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
         [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
         [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
         [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

前面张量中的每个元素都表示一对输入之间的注意力分数,如图 3.11 所示。注意,图 3.11 中的值是归一化后的,因此它们不同于前面张量中的未归一化注意力分数。我们稍后会处理归一化。

在计算前面的注意力分数张量时,我们使用了 Python 中的 for 循环。然而,for 循环通常较慢,我们可以用矩阵乘法得到相同结果:

attn_scores = inputs @ inputs.T
print(attn_scores)

我们可以直观确认结果与之前相同:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
         [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
         [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
         [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
         [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
         [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

在图 3.12 所示的步骤 2 中,我们现在对每一行进行归一化,使每行中的值总和为 \(1\):

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

这会返回下面的注意力权重张量,与图 3.10 中显示的值相匹配:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

在使用 PyTorch 的语境中,torch.softmax 等函数里的 dim 参数指定函数将沿输入张量的哪个维度计算。通过设置 dim=-1,我们是在指示 softmax 函数沿 attn_scores 张量的最后一个维度应用归一化。如果 attn_scores 是一个 2D 张量(例如形状为 [rows, columns]),那么 dim=-1 会沿列方向归一化,使每一行中的值(沿列维度求和)总和为 \(1\)。

在进入图 3.12 所示的步骤 3,也就是最后一步之前,我们先简要验证这些行确实都总和为 \(1\):

row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

结果如下:

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

在第三步也是最后一步中,我们现在使用这些注意力权重,通过矩阵乘法计算所有上下文向量:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

在得到的输出张量中,每一行都包含一个 \(3\) 维上下文向量:

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

我们可以把第 2 行与上一节 3.3.1 中先前计算的上下文向量 \(z^{(2)}\) 进行比较,以再次确认代码正确:

print("Previous 2nd context vector:", context_vec_2)

根据结果可以看到,先前计算的 context_vec_2 与前一个张量中的第二行完全匹配:

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])

至此,简单自注意力机制的代码讲解结束。在下一节中,我们将添加可训练权重,使 LLM 能够从数据中学习,并提升其在特定任务上的性能。

3.4 实现带可训练权重的自注意力

在本节中,我们将实现原始 transformer 架构、GPT 模型以及大多数其他流行 LLM 所使用的自注意力机制。这个自注意力机制也称为缩放点积注意力。图 3.13 提供了一个心智模型,说明这个自注意力机制如何融入实现 LLM 的更大背景。

图 3.13
图 3.13 一个心智模型,说明本节编写的自注意力机制如何融入本书和本章的更大背景。在上一节中,我们编写了一个简化的注意力机制,用来理解注意力机制背后的基本机制。在本节中,我们为该注意力机制加入可训练权重。在接下来的小节中,我们还会通过加入因果掩码和多个头来扩展这个自注意力机制。

如图 3.13 所示,带可训练权重的自注意力机制建立在前面的概念之上:我们希望针对某个输入元素,把上下文向量计算为输入向量的加权和。你将看到,与我们之前在 3.3 节编写的基本自注意力机制相比,这里只有细微差别。

最显著的差别是引入了会在模型训练期间更新的权重矩阵。这些可训练权重矩阵至关重要,因为它们让模型(更具体地说,是模型内部的注意力模块)能够学习生成“好的”上下文向量。(注意,我们会在第 5 章训练 LLM。)

我们将在两个小节中处理这个自注意力机制。首先,我们会像之前一样逐步编写代码。其次,我们会把代码组织成一个紧凑的 Python 类,这个类可以导入到 LLM 架构中;该架构会在第 4 章编写。

3.4.1 逐步计算注意力权重

我们将通过引入三个可训练权重矩阵 \(W_q\)、\(W_k\) 和 \(W_v\),逐步实现自注意力机制。这三个矩阵用于把嵌入后的输入词元 \(x^{(i)}\) 投影为查询、键和值向量,如图 3.14 所示。

图 3.14
图 3.14 在带可训练权重矩阵的自注意力机制的第一步中,我们为输入元素 \(x\) 计算查询(\(q\))、键(\(k\))和值(\(v\))向量。与前几节类似,我们把第二个输入 \(x^{(2)}\) 指定为查询输入。查询向量 \(q^{(2)}\) 通过输入 \(x^{(2)}\) 与权重矩阵 \(W_q\) 的矩阵乘法得到。类似地,我们通过涉及权重矩阵 \(W_k\) 和 \(W_v\) 的矩阵乘法得到键向量和值向量。

前面在 3.3.1 节中,当我们计算简化的注意力权重以得到上下文向量 \(z^{(2)}\) 时,把第二个输入元素 \(x^{(2)}\) 定义为查询。随后,在 3.3.2 节中,我们把它推广为对六词输入句子 “Your journey starts with one step.” 计算所有上下文向量 \(z^{(1)} \ldots z^{(T)}\)。

类似地,为了说明,我们先只计算一个上下文向量 \(z^{(2)}\)。在下一节中,我们会修改这段代码,以计算所有上下文向量。

先定义几个变量:

x_2 = inputs[1]                                                                       #A
d_in = inputs.shape[1]                                                                #B
d_out = 2                                                                             #C

#A The second input element
#B The input embedding size, d=3
#C The output embedding size, d_out=2

注意,在类似 GPT 的模型中,输入维度和输出维度通常相同;但出于说明目的,为了更容易跟随计算过程,我们在这里选择不同的输入维度(\(d_{\mathrm{in}}=3\))和输出维度(\(d_{\mathrm{out}}=2\))。

接下来,我们初始化图 3.14 中显示的三个权重矩阵 \(W_q\)、\(W_k\) 和 \(W_v\):

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

注意,出于说明目的,我们把 requires_grad=False 设置为 false,以减少输出中的杂乱内容;但如果要把这些权重矩阵用于模型训练,就会设置 requires_grad=True,从而在模型训练期间更新这些矩阵。

接下来,我们按照前面图 3.14 所示计算查询、键和值向量:

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

从查询的输出可以看到,由于我们通过 d_out 把相应权重矩阵的列数设为 \(2\),结果是一个二维向量:

tensor([0.4306, 1.4551])

尽管我们的临时目标只是计算一个上下文向量 \(z^{(2)}\),但仍然需要所有输入元素的键向量和值向量,因为它们参与计算相对于查询 \(q^{(2)}\) 的注意力权重,如图 3.14 所示。

我们可以通过矩阵乘法得到所有键和值:

keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

从输出可以看出,我们成功地把 6 个输入词元从三维嵌入空间投影到了二维嵌入空间:

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])

第二步是计算注意力分数,如图 3.15 所示。

图 3.15
图 3.15 注意力分数计算是一个点积计算,类似于我们在 3.3 节的简化自注意力机制中使用过的计算。这里的新内容是,我们不是直接计算输入元素之间的点积,而是使用通过相应权重矩阵变换输入后得到的查询和键。

首先,计算注意力分数 \(\omega_{22}\):

keys_2 = keys[1]                                                                  #A
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

#A Remember that Python starts indexing at 0

这得到以下未归一化的注意力分数:

tensor(1.8524)

同样,我们可以通过矩阵乘法把该计算推广到所有注意力分数:

attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

可以看到,作为快速检查,输出中的第二个元素与我们之前计算的 attn_score_22 一致:

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

第三步是从注意力分数得到注意力权重,如图 3.16 所示。

图 3.16
图 3.16 在计算注意力分数 \(\omega\) 之后,下一步是使用 softmax 函数对这些分数进行归一化,从而得到注意力权重 \(\alpha\)。

接下来,如图 3.16 所示,我们通过缩放注意力分数并使用前面用过的 softmax 函数来计算注意力权重。与之前的差别在于,现在我们会把注意力分数除以键的嵌入维度的平方根来缩放它们(注意,取平方根在数学上等同于取 \(0.5\) 次幂):

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

得到的注意力权重如下:

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

现在,最后一步是计算上下文向量,如图 3.17 所示。

图 3.17
图 3.17 在自注意力计算的最后一步中,我们通过注意力权重组合所有值向量来计算上下文向量。

类似于 3.3 节中把上下文向量计算为输入向量的加权和,这里我们把上下文向量计算为值向量的加权和。在这里,注意力权重作为加权因子,用来衡量每个值向量各自的重要性。与 3.3 节类似,我们可以使用矩阵乘法一步得到输出:

context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

所得向量的内容如下:

tensor([0.3061, 0.8210])

到目前为止,我们只计算了一个上下文向量 \(z^{(2)}\)。在下一节中,我们会把代码推广为计算输入序列中的所有上下文向量,即从 \(z^{(1)}\) 到 \(z^{(T)}\)。

3.4.2 实现紧凑的自注意力 Python 类

在前几节中,我们经历了许多步骤来计算自注意力输出。这样做主要是出于说明目的,使我们能够一次只处理一个步骤。在实践中,考虑到下一章中的 LLM 实现,把这段代码组织成如下 Python 类会很有帮助:

清单 3.1 一个紧凑的自注意力类
import torch.nn as nn
class SelfAttention_v1(nn.Module):
     def __init__(self, d_in, d_out):
          super().__init__()
          self.d_out = d_out
          self.W_query = nn.Parameter(torch.rand(d_in, d_out))
          self.W_key      = nn.Parameter(torch.rand(d_in, d_out))
          self.W_value = nn.Parameter(torch.rand(d_in, d_out))


     def forward(self, x):
          keys = x @ self.W_key
          queries = x @ self.W_query
          values = x @ self.W_value
          attn_scores = queries @ keys.T # omega
          attn_weights = torch.softmax(
               attn_scores / keys.shape[-1]**0.5, dim=-1)
          context_vec = attn_weights @ values
          return context_vec

在这段 PyTorch 代码中,SelfAttention_v1 是一个派生自 nn.Module 的类;nn.Module 是 PyTorch 模型的基本构建块,提供创建和管理模型层所需的功能。

__init__ 方法会为查询、键和值初始化可训练权重矩阵(W_queryW_keyW_value),每个矩阵都把输入维度 d_in 变换到输出维度 d_out

在使用 forward 方法进行前向传播时,我们通过查询与键相乘来计算注意力分数(attn_scores),并用 softmax 对这些分数进行归一化。最后,我们用这些归一化的注意力分数对值进行加权,从而创建上下文向量。

我们可以像下面这样使用这个类:

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

由于 inputs 包含六个嵌入向量,因此结果是一个存储六个上下文向量的矩阵:

tensor([[0.2996, 0.8053],
          [0.3061, 0.8210],
          [0.3058, 0.8203],
          [0.2948, 0.7939],
          [0.2927, 0.7891],
          [0.2990, 0.8040]], grad_fn=<MmBackward0>)

作为快速检查,请注意第二行([0.3061, 0.8210])与上一节中 context_vec_2 的内容一致。

图 3.18 总结了我们刚刚实现的自注意力机制。

图 3.18
图 3.18 在自注意力中,我们使用三个权重矩阵 \(W_q\)、\(W_k\) 和 \(W_v\) 变换输入矩阵 \(X\) 中的输入向量。然后,我们基于所得查询(\(Q\))和键(\(K\))计算注意力权重矩阵。接着,使用注意力权重和值(\(V\))计算上下文向量(\(Z\))。(为便于视觉呈现,图中关注的是一个包含 \(n\) 个词元的单个输入文本,而不是由多个输入组成的批次。因此,在这个语境中,\(3D\) 输入张量被简化为 \(2D\) 矩阵。这种做法让相关过程的可视化和理解更加直接。另外,为了与后续图保持一致,注意力矩阵中的数值并不表示真实的注意力权重。)

如图 3.18 所示,自注意力涉及可训练权重矩阵 \(W_q\)、\(W_k\) 和 \(W_v\)。这些矩阵把输入数据变换为查询、键和值;它们是注意力机制的关键组成部分。随着模型在训练期间接触到更多数据,它会调整这些可训练权重,正如我们会在后续章节中看到的那样。

我们还可以利用 PyTorch 的 nn.Linear 层进一步改进 SelfAttention_v1 的实现;在禁用偏置单元时,这些层实际上会执行矩阵乘法。此外,与手动实现 nn.Parameter(torch.rand(...)) 相比,使用 nn.Linear 的一个重要优势是,nn.Linear 具有优化过的权重初始化方案,有助于实现更稳定、更有效的模型训练。

清单 3.2 使用 PyTorch 的 Linear 层的自注意力类
class SelfAttention_v2(nn.Module):
     def __init__(self, d_in, d_out, qkv_bias=False):
          super().__init__()
          self.d_out = d_out
          self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_key       = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)


     def forward(self, x):
          keys = self.W_key(x)
          queries = self.W_query(x)
          values = self.W_value(x)
          attn_scores = queries @ keys.T
          attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
          context_vec = attn_weights @ values
          return context_vec

你可以像使用 SelfAttention_v1 一样使用 SelfAttention_v2

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

输出为:

tensor([[-0.0739,   0.0713],
        [-0.0748,   0.0703],
        [-0.0749,   0.0702],
        [-0.0760,   0.0685],
        [-0.0763,   0.0679],
        [-0.0754,   0.0693]], grad_fn=<MmBackward0>)

注意,SelfAttention_v1SelfAttention_v2 会给出不同输出,因为它们对权重矩阵使用了不同的初始权重;这是因为 nn.Linear 使用了更复杂的权重初始化方案。

在下一节中,我们将增强自注意力机制,重点纳入因果和多头元素。因果方面涉及修改注意力机制,防止模型访问序列中的未来信息;这对于语言建模等任务至关重要,因为每次词预测都应该只依赖先前的词。

多头组件涉及把注意力机制拆分成多个“头”。每个头学习数据的不同方面,使模型能够同时关注不同位置上来自不同表示子空间的信息。这会提升模型在复杂任务中的性能。

3.5 用因果注意力隐藏未来词

在本节中,我们将修改标准的自注意力机制,创建一种因果注意力机制;这对于后续章节开发 LLM 是必不可少的。

因果注意力也称为掩码注意力,是自注意力的一种特殊形式。它限制模型在处理序列中的任意给定 token 时,只能考虑此前的输入以及当前输入。这与标准自注意力机制形成对比,后者允许模型一次性访问整个输入序列。

因此,在计算注意力分数时,因果注意力机制会确保模型只把序列中出现在当前 token 位置或其之前的 token 纳入考虑。

为了在类似 GPT 的 LLM 中实现这一点,对于每个被处理的 token,我们会把输入文本中位于当前 token 之后的未来 token 掩蔽掉,如图 3.19 所示。

图 3.19
图 3.19 在因果注意力中,我们将对角线上方的注意力权重掩蔽掉,使得对于给定输入,LLM 在使用注意力权重计算上下文向量时无法访问未来 token。例如,对于第二行中的词 "journey",我们只保留它之前的词("Your")以及当前位置的词("journey")对应的注意力权重。

如图 3.19 所示,我们将对角线上方的注意力权重掩蔽掉,并对未被掩蔽的注意力权重进行归一化,使每一行中的注意力权重之和为 1。在下一节中,我们将用代码实现这种掩蔽和归一化过程。

3.5.1 应用因果注意力掩码

在本节中,我们用代码实现因果注意力掩码。我们先从图 3.20 总结的流程开始。

图 3.20
图 3.20 在因果注意力中,获得带掩码的注意力权重矩阵的一种方式是:先对注意力分数应用 softmax 函数,再将对角线上方的元素置零,并对得到的矩阵进行归一化。

为了实现图 3.20 总结的步骤,即应用因果注意力掩码以获得带掩码的注意力权重,我们继续使用上一节中的注意力分数和权重来编写因果注意力机制的代码。

在图 3.20 所示的第一步中,我们像前几节一样使用 softmax 函数计算注意力权重:

queries = sa_v2.W_query(inputs)                                                    #A
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)
  • #A 为方便起见,复用上一节中 SelfAttention_v2 对象的查询和键权重矩阵

这会得到如下注意力权重:

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
          [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
          [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
          [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
          [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
          [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
          grad_fn=<SoftmaxBackward0>)

我们可以使用 PyTorch 的 tril 函数来实现图 3.20 中的第二步,创建一个对角线上方取值为零的掩码:

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

得到的掩码如下:

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

现在,我们可以将这个掩码与注意力权重相乘,把对角线上方的值置零:

masked_simple = attn_weights*mask_simple
print(masked_simple)

可以看到,对角线上方的元素已被成功置零:

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

图 3.20 中的第三步,是重新归一化注意力权重,使每一行的和再次为 1。我们可以通过将每一行中的每个元素除以该行的总和来实现:

row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

结果是一个注意力权重矩阵,其中对角线上方的注意力权重被置零,并且每一行的和为 1:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
          [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
          [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
          [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
        grad_fn=<DivBackward0>)

虽然从技术上说,到这里我们已经可以算完成了因果注意力的实现,但我们还可以利用 softmax 函数的一个数学性质,用更少步骤更高效地计算带掩码的注意力权重,如图 3.21 所示。

图 3.21
图 3.21 在因果注意力中,获得带掩码的注意力权重矩阵的一种更高效方式是:在应用 softmax 函数之前,先用负无穷值掩蔽注意力分数。

softmax 函数会把输入转换为一个概率分布。当某一行中存在负无穷值(\(-\infty\))时,softmax 函数会将它们视为零概率。(从数学上说,这是因为 \(e^{-\infty}\) 趋近于 0。)

我们可以通过创建一个对角线上方为 1 的掩码,然后将这些 1 替换为负无穷(-inf)值,来实现这种更高效的掩蔽“技巧”:

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

这会得到如下掩码结果:

tensor([[0.2899,    -inf,    -inf,    -inf,    -inf,     -inf],
        [0.4656, 0.1723,     -inf,    -inf,    -inf,     -inf],
        [0.4594, 0.1703, 0.1731,      -inf,    -inf,     -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,       -inf,     -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,         -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

现在,我们只需要对这些带掩码的结果应用 softmax 函数,就完成了:

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

从输出可以看到,每一行中的值之和为 1,因此不需要再做额外的归一化:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

现在,我们可以像第 3.4 节那样,通过 context_vec = attn_weights @ values 使用修改后的注意力权重来计算上下文向量。不过,在下一节中,我们先介绍因果注意力机制的另一个小调整,它有助于在训练 LLM 时降低过拟合。

3.5.2 使用 dropout 掩蔽额外的注意力权重

深度学习中的 dropout 是一种技术:在训练期间随机选择一些隐藏层单元并忽略它们,也就是有效地将它们“丢弃”出去。这种方法通过确保模型不会过度依赖任何一组特定的隐藏层单元,来帮助防止过拟合。需要强调的是,dropout 只在训练期间使用,之后会被禁用。

在 transformer 架构中,包括 GPT 这样的模型,注意力机制中的 dropout 通常应用在两个特定位置:计算注意力分数之后,或者将注意力权重应用到值向量之后。

这里,我们会在计算注意力权重之后应用 dropout 掩码,如图 3.22 所示,因为这是实践中更常见的变体。

图 3.22
图 3.22 使用因果注意力掩码(左上)之后,我们再应用一个额外的 dropout 掩码(右上),将更多注意力权重置零,以减少训练期间的过拟合。

在下面的代码示例中,我们使用 50% 的 dropout 率,这意味着掩蔽掉一半的注意力权重。(当我们在后续章节训练 GPT 模型时,会使用更低的 dropout 率,例如 0.1 或 0.2。)

在下面的代码中,为了便于演示,我们先将 PyTorch 的 dropout 实现应用到一个由 1 组成的 \(6 \times 6\) 张量上:

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)                                                  #A
example = torch.ones(6, 6)                                                       #B
print(dropout(example))
  • #A 我们选择 50% 的 dropout 率
  • #B 这里,我们创建一个由 1 组成的矩阵

可以看到,大约一半的值被置零:

tensor([[2., 2., 0., 2., 2., 0.],
          [0., 0., 0., 2., 0., 2.],
          [2., 2., 2., 2., 0., 2.],
          [0., 2., 2., 0., 0., 2.],
          [0., 2., 0., 2., 0., 2.],
          [0., 2., 2., 2., 2., 0.]])

当以 50% 的比例对注意力权重矩阵应用 dropout 时,矩阵中一半的元素会被随机设置为零。为了补偿活跃元素数量的减少,矩阵中剩余元素的值会按 \(1/0.5 = 2\) 的因子放大。这种缩放对于保持注意力权重的整体平衡至关重要,它能确保注意力机制在训练和推理阶段的平均影响保持一致。

现在,让我们将 dropout 应用到注意力权重矩阵本身:

torch.manual_seed(123)
print(dropout(attn_weights))

得到的注意力权重矩阵现在有更多元素被置零,其余元素则被重新缩放:

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
         [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
        grad_fn=<MulBackward0>

请注意,根据你的操作系统不同,得到的 dropout 输出可能看起来不一样;你可以在 PyTorch 的 issue tracker 上的 https://github.com/pytorch/pytorch/issues/121595 阅读更多关于这种不一致的信息。

在理解了因果注意力和 dropout 掩蔽之后,我们将在下一节开发一个简洁的 Python 类。这个类旨在便于高效应用这两项技术。

3.5.3 实现一个紧凑的因果注意力类

在本节中,我们现在会把因果注意力和 dropout 修改整合进第 3.4 节开发的 SelfAttention Python 类。随后,这个类会作为开发下一节多头注意力的模板;多头注意力是我们在本章中实现的最后一个注意力类。

不过,在开始之前,还有一件事需要确保:代码能够处理由多个输入组成的批次,这样 CausalAttention 类才能支持我们在第 2 章实现的数据加载器所产生的批次输出。

为简单起见,为了模拟这样的批次输入,我们复制输入文本示例:

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)                                                            #A
  • #A 2 个输入,每个输入有 6 个 token,并且每个 token 的嵌入维度为 3

这会得到一个三维张量,其中包含 2 段输入文本,每段文本有 6 个 token,每个 token 是一个 3 维嵌入向量:

torch.Size([2, 6, 3])

下面的 CausalAttention 类与我们之前实现的 SelfAttention 类类似,只是现在加入了 dropout 和因果掩码组件,如以下代码中突出标注的部分所示:

清单 3.3 一个紧凑的因果注意力类

class CausalAttention(nn.Module):
     def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
          super().__init__()
          self.d_out = d_out
          self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_key       = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.dropout = nn.Dropout(dropout)                                           #A
          self.register_buffer(
              'mask',
              torch.triu(torch.ones(context_length, context_length),
              diagonal=1)
          )                                                                            #B


     def forward(self, x):
          b, num_tokens, d_in = x.shape                                                #C
          keys = self.W_key(x)
          queries = self.W_query(x)
          values = self.W_value(x)


          attn_scores = queries @ keys.transpose(1, 2)                                 #C
          attn_scores.masked_fill_(                                                    #D
               self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
          attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
          attn_weights = self.dropout(attn_weights)


          context_vec = attn_weights @ values
          return context_vec
  • #A 与之前的 SelfAttention_v1 类相比,我们添加了一个 dropout 层
  • #B register_buffer 调用也是一个新增内容(更多信息见下文)
  • #C 我们转置维度 1 和 2,同时让批次维度保持在第一个位置(0)
  • #D 在 PyTorch 中,带有尾随下划线的操作会原地执行,从而避免不必要的内存复制

虽然所有新增代码行都应该已经可以从前几节中找到对应概念,但我们现在在 __init__ 方法中添加了一个 self.register_buffer() 调用。在 PyTorch 中使用 register_buffer 并非对所有使用场景都严格必要,但在这里有几个优点。例如,当我们在 LLM 中使用 CausalAttention 类时,buffer 会随模型一起自动移动到合适的设备(CPU 或 GPU)上;这在未来章节训练 LLM 时会很有用。这意味着我们不需要手动确保这些张量与模型参数位于同一设备上,从而避免设备不匹配错误。

我们可以像之前使用 SelfAttention 那样使用 CausalAttention 类:

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

得到的上下文向量是一个三维张量,其中每个 token 现在由一个二维嵌入表示:

context_vecs.shape: torch.Size([2, 6, 2])

图 3.23 提供了一个心智模型,总结了到目前为止我们已经完成的内容。

图 3.23
图 3.23 一个心智模型,总结了我们在本章中编写的四种不同注意力模块。我们从简化的注意力机制开始,加入可训练权重,然后加入因果注意力掩码。在本章剩余部分中,我们将扩展因果注意力机制并编写多头注意力;它是我们将在下一章的 LLM 实现中使用的最终模块。

如图 3.23 所示,在本节中,我们重点介绍了神经网络中因果注意力的概念和实现。在下一节中,我们将扩展这一概念,并实现一个多头注意力模块,它会并行实现多个这样的因果注意力机制。

3.6 将单头注意力扩展为多头注意力

在本章最后一节中,我们将把前面实现的因果注意力类扩展到多个头。这也称为多头注意力。

“多头”一词指的是把注意力机制划分为多个“头”,每个头独立运行。在这个语境中,一个单独的因果注意力模块可以看作单头注意力,其中只有一组注意力权重按顺序处理输入。

在下面的小节中,我们将处理从因果注意力到多头注意力的这一扩展。第一小节会为了说明而通过堆叠多个 CausalAttention 模块,直观地构建一个多头注意力模块。第二小节随后会用一种更复杂但计算效率更高的方式实现同一个多头注意力模块。

3.6.1 堆叠多个单头注意力层

从实践角度看,实现多头注意力意味着创建自注意力机制的多个实例(前面在第 3.4.1 节图 3.18 中展示过),每个实例都有自己的权重,然后把它们的输出组合起来。使用自注意力机制的多个实例在计算上可能很昂贵,但这对于 transformer 类 LLM 所擅长的复杂模式识别至关重要。

图 3.24 展示了一个多头注意力模块的结构,它由多个单头注意力模块组成;这些单头注意力模块如前面的图 3.18 所示,被彼此堆叠起来。

图 3.24
图 3.24 本图中的多头注意力模块展示了两个单头注意力模块上下堆叠。因此,在有两个头的多头注意力模块中,我们不再只使用一个矩阵 \(W_v\) 来计算值矩阵,而是有两个值权重矩阵:\(W_{v1}\) 和 \(W_{v2}\)。其他权重矩阵 \(W_q\) 和 \(W_k\) 也是如此。我们得到两组上下文向量 \(Z_1\) 和 \(Z_2\),可以将它们合并为一个上下文向量矩阵 \(Z\)。

如前所述,多头注意力背后的主要思想是用不同的、学习得到的线性投影多次(并行)运行注意力机制。这些线性投影是将输入数据(例如注意力机制中的查询、键和值向量)与权重矩阵相乘得到的结果。

在代码中,我们可以通过实现一个简单的 MultiHeadAttentionWrapper 类来做到这一点,它会堆叠多个我们之前实现的 CausalAttention 模块实例:

代码清单 3.4 用于实现多头注意力的包装类
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                     dropout, num_heads, qkv_bias=False):
          super().__init__()
          self.heads = nn.ModuleList(
              [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                for _ in range(num_heads)]
          )


    def forward(self, x):
          return torch.cat([head(x) for head in self.heads], dim=-1)

例如,如果我们使用这个 MultiHeadAttentionWrapper 类,并设置两个注意力头(通过 num_heads=2)以及 CausalAttention 的输出维度 d_out=2,就会得到 4 维上下文向量(\(d_{\mathrm{out}} \times \mathrm{num\_heads} = 4\)),如图 3.25 所示。

图 3.25
图 3.25 使用 MultiHeadAttentionWrapper 时,我们指定了注意力头的数量(num_heads)。如果像本图这样设置 num_heads=2,就会得到一个包含两组上下文向量矩阵的张量。在每个上下文向量矩阵中,行表示与 token 对应的上下文向量,列对应通过 d_out=4 指定的嵌入维度。我们沿列维度拼接这些上下文向量矩阵。由于有 2 个注意力头,并且嵌入维度为 2,最终嵌入维度为 \(2 \times 2 = 4\)。

为了用一个具体例子进一步说明图 3.25,我们可以像之前使用 CausalAttention 类那样使用 MultiHeadAttentionWrapper 类:

torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)


print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

这会得到下面表示上下文向量的张量:

tensor([[[-0.4519,   0.2216,   0.4772,    0.1063],
         [-0.5874,   0.0058,   0.5891,    0.3257],
         [-0.6300, -0.0632,    0.6202,    0.3860],
         [-0.5675, -0.0843,    0.5478,    0.3589],
         [-0.5526, -0.0981,    0.5321,    0.3428],
         [-0.5299, -0.1081,    0.5077,    0.3493]],


        [[-0.4519,   0.2216,   0.4772,    0.1063],
         [-0.5874,   0.0058,   0.5891,    0.3257],
         [-0.6300, -0.0632,    0.6202,    0.3860],
         [-0.5675, -0.0843,    0.5478,    0.3589],
         [-0.5526, -0.0981,    0.5321,    0.3428],
         [-0.5299, -0.1081,    0.5077,    0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])

得到的 context_vecs 张量的第一维是 2,因为我们有两个输入文本(这些输入文本是重复的,所以它们的上下文向量完全相同)。第二维指的是每个输入中的 6 个 token。第三维指的是每个 token 的 4 维嵌入。

在本节中,我们实现了一个 MultiHeadAttentionWrapper,它组合了多个单头注意力模块。不过请注意,在 forward 方法中,这些模块是通过 [head(x) for head in self.heads] 顺序处理的。我们可以通过并行处理这些头来改进这个实现。实现这一点的一种方式是通过矩阵乘法同时计算所有注意力头的输出,我们将在下一节中探讨。

3.6.2 使用权重拆分实现多头注意力

在上一节中,我们创建了一个 MultiHeadAttentionWrapper,通过堆叠多个单头注意力模块来实现多头注意力。这是通过实例化并组合多个 CausalAttention 对象完成的。

我们不必维护 MultiHeadAttentionWrapperCausalAttention 这两个独立的类,而是可以把这两个概念合并到一个 MultiHeadAttention 类中。此外,除了把 MultiHeadAttentionWrapperCausalAttention 的代码合并之外,我们还会做一些其他修改,以更高效地实现多头注意力。

MultiHeadAttentionWrapper 中,多头是通过创建一个 CausalAttention 对象列表(self.heads)来实现的,每个对象表示一个独立的注意力头。CausalAttention 类独立执行注意力机制,并把各个头的结果拼接起来。相比之下,下面的 MultiHeadAttention 类把多头功能集成在一个类内部。它通过重塑投影后的查询、键和值张量,把输入拆分为多个头,然后在计算注意力之后合并这些头的结果。

在进一步讨论之前,先来看 MultiHeadAttention 类:

代码清单 3.5 一个高效的多头注意力类
class MultiHeadAttention(nn.Module):
     def __init__(self, d_in, d_out,
                     context_length, dropout, num_heads, qkv_bias=False):
          super().__init__()
          assert d_out % num_heads == 0, "d_out must be divisible by num_heads"


          self.d_out = d_out
          self.num_heads = num_heads
          self.head_dim = d_out // num_heads                                         #A
          self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.out_proj = nn.Linear(d_out, d_out)                                    #B
          self.dropout = nn.Dropout(dropout)
          self.register_buffer(
               'mask',
                torch.triu(torch.ones(context_length, context_length), diagonal=1)
          )


     def forward(self, x):
          b, num_tokens, d_in = x.shape
          keys = self.W_key(x)                                                       #C
          queries = self.W_query(x)                                                  #C
          values = self.W_value(x)                                                   #C


          keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #D
          values = values.view(b, num_tokens, self.num_heads, self.head_dim) #D
          queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)#D


          keys = keys.transpose(1, 2)                                               #E
          queries = queries.transpose(1, 2)                                         #E
          values = values.transpose(1, 2)                                           #E


          attn_scores = queries @ keys.transpose(2, 3)              #F
          mask_bool = self.mask.bool()[:num_tokens, :num_tokens]                    #G


          attn_scores.masked_fill_(mask_bool, -torch.inf)                           #H


          attn_weights = torch.softmax(
               attn_scores / keys.shape[-1]**0.5, dim=-1)
          attn_weights = self.dropout(attn_weights)


          context_vec = (attn_weights @ values).transpose(1, 2) #I
                                                                                 #J
          context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
          context_vec = self.out_proj(context_vec)                                  #K
          return context_vec
  • #A 将投影维度缩小,以匹配期望的输出维度。
  • #B 使用一个 Linear 层来合并各个头的输出。
  • #C 张量形状:\((b, \mathrm{num\_tokens}, d_{\mathrm{out}})\)。
  • #D 我们通过添加一个 num_heads 维度来隐式拆分矩阵。然后展开最后一维:\((b, \mathrm{num\_tokens}, d_{\mathrm{out}}) \rightarrow (b, \mathrm{num\_tokens}, \mathrm{num\_heads}, \mathrm{head\_dim})\)。
  • #E 从形状 \((b, \mathrm{num\_tokens}, \mathrm{num\_heads}, \mathrm{head\_dim})\) 转置为 \((b, \mathrm{num\_heads}, \mathrm{num\_tokens}, \mathrm{head\_dim})\)。
  • #F 为每个头计算点积。
  • #G 将掩码截断到 token 的数量。
  • #H 使用掩码填充注意力分数。
  • #I 张量形状:\((b, \mathrm{num\_tokens}, n_{\mathrm{heads}}, \mathrm{head\_dim})\)。
  • #J 合并各个头,其中 self.d_out = self.num_heads * self.head_dim
  • #K 添加一个可选的线性投影。

尽管 MultiHeadAttention 类内部对张量进行重塑(.view)和转置(.transpose)看起来非常复杂,但从数学上说,MultiHeadAttention 类实现的概念与前面的 MultiHeadAttentionWrapper 相同。

从宏观层面看,在前面的 MultiHeadAttentionWrapper 中,我们堆叠了多个单头注意力层,并把它们组合成一个多头注意力层。MultiHeadAttention 类采用一种集成式做法。它从一个多头层开始,然后在内部把这个层拆分为各个独立的注意力头,如图 3.26 所示。

图 3.26
图 3.26 在有两个注意力头的 MultiheadAttentionWrapper 类中,我们初始化两个权重矩阵 \(W_{q1}\) 和 \(W_{q2}\),并如本图上半部分所示计算两个查询矩阵 \(Q_1\) 和 \(Q_2\)。在 MultiheadAttention 类中,我们初始化一个更大的权重矩阵 \(W_q\),只对输入执行一次矩阵乘法以得到查询矩阵 \(Q\),然后如本图下半部分所示把查询矩阵拆分为 \(Q_1\) 和 \(Q_2\)。对于键和值,我们也执行同样的操作;图中未显示它们,以减少视觉杂乱。

如图 3.26 所示,查询、键和值张量的拆分是通过使用 PyTorch 的 .view.transpose 方法进行张量重塑和转置操作实现的。输入首先被转换(通过用于查询、键和值的线性层),然后被重塑为表示多个头的形式。

关键操作是把 d_out 维度拆分为 num_headshead_dim,其中 \(\mathrm{head\_dim} = d_{\mathrm{out}} / \mathrm{num\_heads}\)。随后使用 .view 方法完成这种拆分:一个维度为 \((b, \mathrm{num\_tokens}, d_{\mathrm{out}})\) 的张量被重塑为维度 \((b, \mathrm{num\_tokens}, \mathrm{num\_heads}, \mathrm{head\_dim})\)。

然后,这些张量会被转置,把 num_heads 维度移到 num_tokens 维度之前,从而得到形状 \((b, \mathrm{num\_heads}, \mathrm{num\_tokens}, \mathrm{head\_dim})\)。这种转置对于在不同头之间正确对齐查询、键和值,并高效执行批量矩阵乘法至关重要。

为了说明这种批量矩阵乘法,假设我们有如下示例张量:

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],                                #A
                         [0.8993, 0.0390, 0.9268, 0.7388],
                         [0.7179, 0.7058, 0.9156, 0.4340]],


                        [[0.0772, 0.3565, 0.1479, 0.5331],
                         [0.4066, 0.2318, 0.4545, 0.9737],
                         [0.4606, 0.5159, 0.4220, 0.5786]]]])
  • #A 这个张量的形状是 \((b, \mathrm{num\_heads}, \mathrm{num\_tokens}, \mathrm{head\_dim}) = (1, 2, 3, 4)\)。

现在,我们在这个张量本身与该张量的一个视图之间执行批量矩阵乘法;在这个视图中,我们转置了最后两个维度 num_tokenshead_dim

print(a @ a.transpose(2, 3))

结果如下:

tensor([[[[1.3208, 1.1631, 1.2879],
            [1.1631, 2.2150, 1.8424],
            [1.2879, 1.8424, 2.0402]],


           [[0.4391, 0.7003, 0.5903],
            [0.7003, 1.3737, 1.0620],
            [0.5903, 1.0620, 0.9912]]]])

在这种情况下,PyTorch 中的矩阵乘法实现会处理这个 4 维输入张量,使矩阵乘法在最后两个维度(num_tokenshead_dim)之间执行,然后对各个头重复这一操作。

例如,上面的写法就变成了一种更紧凑的方式,用于分别为每个头计算矩阵乘法:

first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)


second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

这些结果与前面使用批量矩阵乘法 print(a @ a.transpose(2, 3)) 得到的结果完全相同:

First head:
 tensor([[1.3208, 1.1631, 1.2879],
         [1.1631, 2.2150, 1.8424],
         [1.2879, 1.8424, 2.0402]])


Second head:
 tensor([[0.4391, 0.7003, 0.5903],
         [0.7003, 1.3737, 1.0620],
         [0.5903, 1.0620, 0.9912]])

继续回到 MultiHeadAttention:在计算注意力权重和上下文向量之后,来自所有头的上下文向量会被转置回形状 \((b, \mathrm{num\_tokens}, \mathrm{num\_heads}, \mathrm{head\_dim})\)。然后,这些向量会被重塑(展平)为形状 \((b, \mathrm{num\_tokens}, d_{\mathrm{out}})\),从而有效地合并所有头的输出。

此外,在合并各个头之后,我们为 MultiHeadAttention 添加了一个所谓的输出投影层(self.out_proj),这在 CausalAttention 类中并不存在。这个输出投影层并非严格必要(更多细节见附录 B 的参考文献部分),但它在许多 LLM 架构中很常用,因此我们在这里为了完整性而加入它。

尽管由于额外的张量重塑和转置,MultiHeadAttention 类看起来比 MultiHeadAttentionWrapper 更复杂,但它效率更高。原因是,例如在计算键时,我们只需要一次矩阵乘法 keys = self.W_key(x)(查询和值也是如此)。在 MultiHeadAttentionWrapper 中,我们需要为每个注意力头重复这种矩阵乘法,而这在计算上是最昂贵的步骤之一。

MultiHeadAttention 类的使用方式可以与我们之前实现的 SelfAttentionCausalAttention 类类似:

torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

从结果可以看到,输出维度直接由 d_out 参数控制:

tensor([[[0.3190, 0.4858],
            [0.2943, 0.3897],
            [0.2856, 0.3593],
            [0.2693, 0.3873],
            [0.2639, 0.3928],
            [0.2575, 0.4028]],


        [[0.3190, 0.4858],
            [0.2943, 0.3897],
            [0.2856, 0.3593],
            [0.2693, 0.3873],
            [0.2639, 0.3928],
            [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])

在本节中,我们实现了 MultiHeadAttention 类;在后续章节实现并训练 LLM 本身时会用到它。请注意,虽然代码完全可运行,但为了让输出保持可读,我们使用了相对较小的嵌入大小和注意力头数量。

作为对比,最小的 GPT-2 模型(1.17 亿参数)有 12 个注意力头,上下文向量嵌入大小为 768。最大的 GPT-2 模型(15 亿参数)有 25 个注意力头,上下文向量嵌入大小为 1600。请注意,在 GPT 模型中,token 输入的嵌入大小与上下文嵌入大小相同(\(d_{\mathrm{in}} = d_{\mathrm{out}}\))。

3.7 小结

  • 注意力机制会把输入元素转换为增强的上下文向量表示,这些表示包含关于所有输入的信息。
  • 自注意力机制会把上下文向量表示计算为输入上的加权和。
  • 在简化的注意力机制中,注意力权重通过点积计算。
  • 点积只是将两个向量逐元素相乘、然后对乘积求和的一种简洁方式。
  • 矩阵乘法虽然并非严格必需,但有助于用它替代嵌套的 for 循环,从而更高效、更紧凑地实现计算。
  • 在 LLM 使用的自注意力机制中,也就是所谓的缩放点积注意力中,我们加入可训练的权重矩阵,用于计算输入的中间变换:查询、值和键。
  • 在处理从左到右读取并生成文本的 LLM 时,我们会添加因果注意力掩码,以防止 LLM 访问未来 token。
  • 除了用于将注意力权重清零的因果注意力掩码之外,我们还可以添加 dropout 掩码,以减少 LLM 中的过拟合。
  • 基于 transformer 的 LLM 中的注意力模块包含多个因果注意力实例,这称为多头注意力。
  • 我们可以通过堆叠多个因果注意力模块实例来创建多头注意力模块。
  • 创建多头注意力模块的一种更高效方式涉及批量矩阵乘法。

章节范围:原 PDF 物理第 64-112 页。图片从同一页段抽取并嵌入本文。