CS336 作业 1:从头构建 Transformer 语言模型

对实现完整 Transformer 语言模型管道的全面反思——从 BPE 分词器到文本生成——在 TinyStories 和 OpenWebText 上进行训练。


目录

  1. 概述
  2. BPE 分词器
  3. Transformer 架构
  4. 训练基础设施
  5. 训练循环
  6. 文本生成
  7. 实验
  8. 反思

1. 概述

本作业从头实现了一个完整的 Transformer 语言模型管道,未依赖于 torch.nn.Lineartorch.nn.Embedding 等高级库。代码库涵盖:

  • 字节对编码 (BPE) 分词器,具有并行预分词
  • 仅解码器 Transformer,使用 RMSNorm、RoPE、SwiGLU 和因果多头注意力
  • 训练基础设施:AdamW 优化器、余弦学习率调度、梯度裁剪、数据加载、检查点
  • 自回归文本生成,带温度和核(top-p)采样
  • 实验:学习率扫描、批量大小研究、架构消融和 OpenWebText 训练

完整代码已发布在 GitHub: https://github.com/XLOverflow/CS336_Transformer_from_Scratch

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
cs336_basics/
├── model/
│ ├── linear.py # 线性层(无偏置)
│ ├── embedding.py # 令牌嵌入
│ ├── normalization.py # RMSNorm
│ ├── positional_encoding.py # RoPE
│ ├── attention.py # Softmax、缩放点积注意力、多头自注意力
│ ├── feedforward.py # SwiGLU、SiLUFFN
│ ├── transformer_block.py # 预归一化 / 后归一化 Transformer 块
│ ├── transformer_lm.py # 完整的 Transformer LM,带有 generate()
│ └── config.py # 模型配置(TinyStories、GPT-2 系列)
├── tokenizers/
│ ├── bpe.py # 带并行预分词的 BPE 训练器
│ └── tokenizer.py # 带并行编码的 BPE 编码/解码
└── training/
├── cross_entropy.py # 数值稳定的交叉熵
├── adamw.py # AdamW 优化器(从头开始)
├── lr_schedule.py # 余弦退火与线性预热
├── gradient_clipping.py # L2 范数梯度裁剪
├── data_loader.py # 随机批量采样
└── checkpointing.py # 保存/加载模型检查点

2. BPE 分词器

2.1 Unicode 基础

问:Unicode 代码点与 UTF-8 编码之间有什么关系?

Unicode 为每个字符分配一个唯一的 代码点(例如,U+0041 表示 ‘A’)。UTF-8 是一种 可变长度编码,将代码点映射到 1–4 字节:

代码点范围 UTF-8 字节 示例
U+0000 – U+007F 1 字节 ASCII 字符
U+0080 – U+07FF 2 字节 拉丁文、希腊文、斯拉夫文
U+0800 – U+FFFF 3 字节 CJK 字符,大多数表情符号
U+10000 – U+10FFFF 4 字节 稀有表情符号、历史脚本

UTF-8 向后兼容 ASCII,并且是自同步的:你总是可以判断一个字节是字符的开始还是继续字节。

问:为什么使用字节级分词而不是字符级分词?

字节级分词以 256 个字节值的基本词汇表开始,可以表示 任何 语言中的文本,而无需未知标记。字符级分词需要处理完整的 Unicode 范围(143,000+ 字符)作为基本词汇表。

2.2 BPE 训练算法

核心 BPE 训练过程:

  1. 初始化词汇表,包含 256 个字节值 + 特殊标记(例如,<|endoftext|>
  2. 预分词,使用 GPT-2 正则表达式模式将文本拆分为“单词”
  3. 迭代合并 最频繁的相邻字节对,将合并的标记添加到词汇表中
  4. 重复直到达到目标词汇表大小
1
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

此正则表达式模式处理英语缩写('s't'll 等)、可选前导空格的单词、数字、标点符号和空白。

2.3 并行化策略

在大型语料库(例如,OpenWebText)上训练 BPE 分词器计算开销大。我的实现使用 并行预分词 和多进程:

  1. 查找与 <|endoftext|> 标记对齐的块边界,以避免拆分文档
  2. 将块分配给工作进程,使用 multiprocessing.Pool
  3. 每个工作进程应用正则表达式预分词并返回 频率计数Counter
  4. 在主进程中逐步合并频率计数,以控制内存使用
  5. BPE 合并顺序进行(因为每次合并依赖于前一次)

关键优化:

  • 内存管理:定期进行垃圾回收和每 5000 次合并重建索引,以减少内存碎片
  • 增量对更新:在每次合并后,不必从头重新计算所有对的频率,而是维护 pair_to_tuplespair_freq 索引,仅更新受影响的条目
  • 批处理:工作进程以 16 的批量处理块,以控制并发内存使用

2.4 分词器实验

在 TinyStories 上的词汇表大小比较:

对于 TinyStories 数据集,我训练了词汇表大小为 10,000 的分词器。分词器成功学习了常见的英语单词和子词模式。例如:

  • 常见单词如 “the”、“and”、“once” 成为单个标记
  • 不太常见的单词被拆分为学习到的子词单元
  • <|endoftext|> 被处理为一个特殊标记,不参与 BPE 合并

编码:编码器贪婪地应用 BPE 合并——对于每个预分词的单词,它从单个字节开始,并重复合并优先级最高的对(在合并列表中最早的)直到没有更多合并适用。

解码:简单地连接每个标记 ID 的字节值,并将结果解码为 UTF-8。


3. Transformer 架构

image-20260208223722664

3.1 线性层(无偏置)

遵循现代 LLM 实践(PaLM、LLaMA),所有线性层省略偏置项:

\[ y = xW^T \]

初始化:截断正态分布 \(\mathcal{N}(0, \sigma^2)\),其中 \(\sigma = \sqrt{2 / (d_{in} + d_{out})}\),截断在 \([-3\sigma, 3\sigma]\)

1
2
3
4
5
6
7
8
9
class Linear(nn.Module):
def __init__(self, in_features, out_features, device=None, dtype=None):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
std = (2 / (in_features + out_features)) ** 0.5
nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-3*std, b=3*std)

def forward(self, x):
return einsum(x, self.weight, "... i, o i -> ... o")

3.2 令牌嵌入

简单的查找表,将令牌 ID 映射到密集向量:

\[ \text{embed}(x) = E[x] \]

其中 \(E \in \mathbb{R}^{V \times d_{model}}\) 使用截断正态 \(\mathcal{N}(0, 1)\) 初始化。

3.3 RMSNorm

均方根层归一化(Zhang & Sennrich, 2019),在 LLaMA 中使用,而不是 LayerNorm:

\[ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \quad \text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon} \]

关键实现细节:在计算 RMS 之前转换为 float32 以确保数值稳定性,然后再转换回原始数据类型。

1
2
3
4
5
6
def forward(self, x):
original_dtype = x.dtype
x = x.to(torch.float32)
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
normalized = x / rms
return (normalized * self.weight).to(original_dtype)

3.4 旋转位置嵌入 (RoPE)

RoPE(Su et al., 2021)通过对查询和键向量应用旋转来编码 相对 位置信息:

\[ \text{RoPE}(x, m) = \begin{pmatrix} x_0 \cos(m\theta_0) - x_1 \sin(m\theta_0) \\ x_0 \sin(m\theta_0) + x_1 \cos(m\theta_0) \\ \vdots \\ x_{d-2} \cos(m\theta_{d/2-1}) - x_{d-1} \sin(m\theta_{d/2-1}) \\ x_{d-2} \sin(m\theta_{d/2-1}) + x_{d-1} \cos(m\theta_{d/2-1}) \end{pmatrix} \]

其中 \(\theta_k = \theta_{\text{base}}^{-2k/d_k}\),对于 \(k = 0, \ldots, d_k/2 - 1\)

关键特性:

  • 无可学习参数:RoPE 完全由位置和频率计算得出
  • 仅应用于 Q 和 K(不适用于 V)
  • 捕获 相对 位置:\(q_m^T k_n\) 仅依赖于 \(m - n\)
  • 在所有层中共享(一个 RoPE 模块实例)
1
2
3
4
5
6
7
8
9
10
11
12
13
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, theta, d_k, max_seq_len, device=None):
super().__init__()
theta_k = theta ** (-2 * torch.arange(d_k // 2, device=device) / d_k)
positions = torch.arange(max_seq_len, device=device).unsqueeze(1)
angles = positions * theta_k.unsqueeze(0)
self.register_buffer("sin", torch.sin(angles), persistent=False)
self.register_buffer("cos", torch.cos(angles), persistent=False)

def forward(self, x, token_positions):
sin, cos = self.sin[token_positions], self.cos[token_positions]
x1, x2 = x[..., ::2], x[..., 1::2]
return torch.stack((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1).flatten(-2)

3.5 Softmax

使用最大减法技巧的数值稳定 softmax:

\[ \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \]

3.6 缩放点积注意力

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

\(\sqrt{d_k}\) 的缩放防止点积的大小过大,这会将 softmax 推入极小梯度的区域。

实现使用 einops.einsum 提高可读性,并支持任意批量维度:

1
2
3
4
attn_scores = einsum(q, k, "b ... q d_k, b ... k d_k -> b ... q k") / (d_k ** 0.5)
attn_scores = attn_scores.masked_fill(~mask, float("-inf")) # 因果掩码
attn_weights = softmax(attn_scores, dim=-1)
return einsum(attn_weights, v, "b ... q k, b ... k d_v -> b ... q d_v")

3.7 多头自注意力

将模型维度拆分为多个头以实现并行注意力:

\[ \text{MultiHead}(x) = W_O \cdot \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \]

\[ \text{head}_i = \text{Attention}(xW_Q^i, xW_K^i, xW_V^i) \]

处理流程:

  1. 使用单独的线性层将输入投影到 Q、K、V
  2. 重塑为单独的头部:(batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
  3. 对 Q 和 K 应用 RoPE
  4. 使用因果掩码(下三角)应用缩放点积注意力
  5. 连接头部并重新投影

3.8 前馈网络

image-20260208223809903

SwiGLU(Shazeer, 2020):带有 SiLU 激活的门控 FFN \[ \text{SwiGLU}(x) = W_2 \cdot (\text{SiLU}(W_1 x) \odot W_3 x) \]

其中 \(\text{SiLU}(x) = x \cdot \sigma(x)\)\(\odot\) 是逐元素乘法。

SwiGLU 使用 \(d_{ff} \approx \frac{8}{3} d_{model}\)(四舍五入为 64 的倍数),具有 3 个权重矩阵,总参数量约为 \(\approx 3 \times d_{model} \times \frac{8}{3} d_{model} = 8 d_{model}^2\)

SiLU FFN(用于消融):标准的 2 层 FFN

\[ \text{SiLUFFN}(x) = W_2 \cdot \text{SiLU}(W_1 x) \]

使用 \(d_{ff} = 4 \times d_{model}\),具有 2 个权重矩阵,总参数量约为 \(\approx 2 \times d_{model} \times 4 d_{model} = 8 d_{model}^2\)。这与 SwiGLU 的参数计数相匹配,以便进行公平的消融比较。

3.9 Transformer 块

预归一化(默认): \[ z = x + \text{Attention}(\text{RMSNorm}(x)) \]

\[ y = z + \text{FFN}(\text{RMSNorm}(z)) \]

后归一化(消融): \[ z = \text{RMSNorm}(x + \text{Attention}(x)) \]

\[ y = \text{RMSNorm}(z + \text{FFN}(z)) \]

预归一化在现代 LLM 中更受欢迎,因为它稳定了训练——残差连接保留了输入的大小,而在子层之前进行归一化可以防止激活无限增长。

3.10 完整的 Transformer LM

完整的仅解码器架构:

  1. 令牌嵌入token_ids → (batch, seq_len, d_model)
  2. N 个 Transformer 块:应用自注意力 + FFN,带有残差连接
  3. 最终 RMSNorm:对输出进行归一化
  4. LM 头:线性投影到词汇对数 (batch, seq_len, vocab_size)

3.11 Transformer 计算:参数、内存、FLOPs 和训练时间

\(B\) = batch_size,\(T\) = context_length,\(d\) = d_model,\(L\) = num_layers,\(H\) = num_heads,\(V\) = vocab_size,\(d_{ff} = 4d\)

3.11.1 参数计数 \(P\)

组件 参数
每层注意力(\(W_Q, W_K, W_V, W_O\) \(4d^2\)
每层 FFN(SwiGLU: \(W_1, W_2, W_3\) \(3 \times d \times d_{ff} = 12d^2\)
每层 RMSNorm ×2 \(2d\)
令牌嵌入 \(Vd\)
LM 头 \(Vd\)
最终 RMSNorm \(d\)

\[ \boxed{P = L(16d^2 + 2d) + 2Vd + d} \]

3.11.2 训练内存分析

在训练期间,GPU 内存由四部分组成(float32 = 4 字节):

组件 公式 描述
参数 \(4P\) 每个参数 4 字节
梯度 \(4P\) 与参数大小相同
优化器 (m+v) \(8P\) AdamW 存储 2 个与参数形状相同的张量
激活 见下文 与 batch_size 成正比

每层激活内存(为反向传播保存的中间结果):

组件 形状 元素计数
RMSNorm 输入 ×2 \((B,T,d)\) ×2 \(2BTd\)
Q, K, V \((B,T,d)\) ×3 \(3BTd\)
Softmax 输出 \((B,H,T,T)\) \(BHT^2\)
注意力输出 \((B,T,d)\) \(BTd\)
W1 输出(用于 SiLU 反向传播) \((B,T,d_{ff})\) \(4BTd\)
W3 输出 \((B,T,d_{ff})\) \(4BTd\)
SiLU 输出 \((B,T,d_{ff})\) \(4BTd\)
Gate⊙Value = W2 输入 \((B,T,d_{ff})\) \(4BTd\)

每层激活 ≈ \(22BTd + BHT^2\)

加上非层组件:嵌入输出 (\(BTd\)) + 对数 (\(BTV\)) + 交叉熵 softmax (\(BTV\)) ≈ \(BTd + 2BTV\)

\[ \text{总激活内存} = 4 \times \left[L(22BTd + BHT^2) + BTd + 2BTV\right] \text{ 字节} \]

\[ \boxed{\text{峰值内存} = 16P + 4BT\left[L(22d + HT) + d + 2V\right]} \]

3.11.3 GPT-2 XL 具体示例

\(d=1600, L=48, H=25, T=1024, V=50257\)

(a) 详细参数计数:

每层参数:

组件 参数
\(W_Q, W_K, W_V, W_O\) \(4 \times d^2 = 4 \times 2{,}560{,}000 = 10{,}240{,}000\)
\(W_1, W_2, W_3\) (FFN) \(3 \times d \times d_{ff} = 3 \times 10{,}240{,}000 = 30{,}720{,}000\)
2 × RMSNorm \(2 \times 1{,}600 = 3{,}200\)
每层总计 40,963,200

完整模型:

组件 参数
48 层 \(48 \times 40{,}963{,}200 = 1{,}966{,}233{,}600\)
令牌嵌入 (\(V \times d\)) \(80{,}411{,}200\)
LM 头 \(80{,}411{,}200\)
最终 RMSNorm \(1{,}600\)
总计 ≈ 2.13B

参数内存:\(2.13\text{B} \times 4 \text{ 字节} \approx 8.51 \text{ GB}\)

(b) 内存分析:

模型相关内存(固定)\(16P = 16 \times 2.13 \times 10^9 \approx 34.0 \text{ GB}\)

每个批量元素的激活内存\[ L(22d + HT) + d + 2V = 48(22 \times 1600 + 25 \times 1024) + 1600 + 2 \times 50257 \]

\[ = 48(35200 + 25600) + 102114 = 48 \times 60800 + 102114 = 2{,}920{,}514 \]

\[ \text{每个批量元素}: 4 \times 1024 \times 2{,}920{,}514 \approx 12.0 \text{ GB} \]

在 80GB A100 上的最大批量大小\[ \text{总内存}: 34.0 + 12.0 \times B \leq 80 \text{ GB} \]

\[ B \leq (80 - 34) / 12 \approx 3.8 \rightarrow \boxed{B_{\max} = 3} \]

3.11.4 为什么前向传播 ≈ 2 × 参数 FLOPs/标记?

Transformer 计算主要由矩阵乘法主导。对于矩阵乘法 \(Y = X \times W\),其中 \(W\) 的形状为 \((d_{in}, d_{out})\)

  • 每个输出元素需要 \(d_{in}\) 次乘法 + \(d_{in}\) 次加法 = \(2d_{in}\) FLOPs
  • 每个标记有 \(d_{out}\) 个输出元素
  • 总 FLOPs = \(2 \times d_{in} \times d_{out}\) = 2 × 参数计数

每层矩阵乘法细分(×\(L\) 层),使用 GPT-2 XL 数字(\(d=1600, T=1024, H=25, d_k=64, d_{ff}=6400\)):

操作 维度 FLOPs 公式 FLOPs
Q 投影 \((T,d) \times (d,d)\) \(2Td^2\) 5.24B
K 投影 同上 \(2Td^2\) 5.24B
V 投影 同上 \(2Td^2\) 5.24B
O 投影 同上 \(2Td^2\) 5.24B
\(QK^T\)\(H\) 头) \(H \times (T,d_k) \times (d_k,T)\) \(2T^2d\) 3.36B
attn_weights × V \(H \times (T,T) \times (T,d_k)\) \(2T^2d\) 3.36B
FFN W1 \((T,d) \times (d,d_{ff})\) \(2Td \cdot d_{ff}\) 20.97B
FFN W3 同上 \(2Td \cdot d_{ff}\) 20.97B
FFN W2 \((T,d_{ff}) \times (d_{ff},d)\) \(2Td \cdot d_{ff}\) 20.97B
每层总计 90.60B

模型级 FLOPs:

组件 FLOPs
48 层 4,348.7B
LM 头: \((T,d) \times (d,V)\) 164.7B
总计 ≈ 4.51 TFLOPs

3.11.5 FLOPs 在模型大小之间的细分

(c) 每层,FFN 占 ~69.5%(62.91B / 90.60B),使其成为计算最密集的组件。注意力投影占 23.1%,而注意力分数(\(QK^T\) + attn×V)仅占 7.4%。

组件 小型 (12L, 768) 中型 (24L, 1024) 大型 (36L, 1280) XL (48L, 1600)
注意力投影 16.6% 20.0% 21.4% 22.3%
注意力分数 (\(QK^T\) 等) 11.1% 10.0% 8.6% 7.1%
FFN 49.7% 59.9% 64.2% 66.9%
LM 头 22.6% 10.2% 5.8% 3.7%

(d) 趋势:随着模型变大,FFN 的比例增加(50% → 67%),而 LM 头的比例显著下降(23% → 4%)。这是因为 LM 头在每层的大小是固定的(与 vocab_size 相关),而 FFN 随着层数和 d_model 的增加而增长。

3.11.6 上下文长度缩放:为什么 FlashAttention 重要

(e) 将上下文长度从 1024 增加到 16384:

组件 T=1024 T=16384
注意力投影 22.3% 10.8%
注意力分数 7.1% 55.2%
FFN 66.9% 32.3%
LM 头 3.7% 1.8%
总 FLOPs 4.51T ≈ 149.5T (33×)

总 FLOPs 增加约 33×(而不是 16×!),因为注意力分数的缩放为 \(O(T^2)\)。当上下文长度增长 16× 时,注意力分数从 7.1% 跃升至 55.2%,成为主导成本。这正是为什么长上下文模型需要 FlashAttention 和其他 IO 感知的注意力优化——二次注意力成本在长序列中压倒了线性 FFN 成本。

3.11.7 为什么反向传播 ≈ 2× 前向?

对于每个矩阵乘法 \(Y = XW\),反向传播需要计算两个梯度:

  • \(\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} \times W^T\)(一个矩阵乘法)
  • \(\frac{\partial L}{\partial W} = X^T \times \frac{\partial L}{\partial Y}\)(一个矩阵乘法)

这就是反向传播需要 2 次矩阵乘法,而前向传播只需要 1 次。因此:

\[ \boxed{\text{反向} \approx 2 \times \text{前向}} \]

\[ \text{总计 = 前向 + 反向} \approx 3 \times \text{前向} = 6PBT \]

3.11.8 AdamW 每步 FLOPs

每个参数执行的操作:

操作 公式 每个参数的 FLOPs
更新 m \(m = \beta_1 m + (1-\beta_1)g\) 3(2 次乘法 + 1 次加法)
更新 v \(v = \beta_2 v + (1-\beta_2)g^2\) 4(3 次乘法 + 1 次加法)
参数更新 \(p = \alpha_t \cdot m/(\sqrt{v}+\epsilon)\) 5(平方根、加法、除法、乘法、减法)
权重衰减 \(p -= lr \times \lambda \times p\) 2(乘法 + 减法)

\[ \text{AdamW FLOPs} = 14P \]

(偏差校正 \(\alpha_t\) 是标量计算,可以忽略。远小于前向/反向 FLOPs。)

3.11.9 GPT-2 XL 训练时间估计

每步 FLOPs

  • 前向 ≈ \(2P \times B \times T\)(每个参数每个标记大约执行 ~2 次操作)
  • 反向 ≈ \(2 \times\) 前向
  • 总计 ≈ \(3 \times\) 前向 = \(6PBT\)

替换 GPT-2 XL(\(B=1024, T=1024\)): \[ \text{每步} = 6 \times 2.13 \times 10^9 \times 1024 \times 1024 = 1.34 \times 10^{16} \text{ FLOPs/步} \]

总计 400K 步\(400{,}000 \times 1.34 \times 10^{16} = 5.36 \times 10^{21}\) FLOPs

有效吞吐量:50% × 19.5 TFLOP/s = \(9.75 \times 10^{12}\) FLOP/s

\[ \text{时间} = \frac{5.36 \times 10^{21}}{9.75 \times 10^{12}} \approx 5.5 \times 10^8 \text{ 秒} \approx 6{,}360 \text{ 天} \approx \boxed{17.4 \text{ 年}} \]

这解释了为什么大规模模型训练需要大量 GPU 并行性——在单个 A100 上训练 GPT-2 XL 将需要 17 年!


4. 训练基础设施

4.1 交叉熵损失

使用对数和指数技巧的数值稳定实现:

\[ \ell_i = -\log \text{softmax}(o_i)[x_{i+1}] = \log\left(\sum_j e^{o_j - o_{\max}}\right) - (o_{x_{i+1}} - o_{\max}) \]

1
2
3
4
5
def cross_entropy(inputs, targets):
shifted = inputs - inputs.max(dim=-1, keepdim=True).values
log_sum_exp = torch.log(torch.sum(torch.exp(shifted), dim=-1))
target_logits = shifted.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
return (log_sum_exp - target_logits).mean()

4.2 AdamW 优化器

实现 AdamW(Loshchilov & Hutter, 2019)与 解耦权重衰减

\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]

\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]

\[ \hat{\alpha}_t = \alpha \cdot \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t} \]

\[ \theta_t = \theta_{t-1} - \hat{\alpha}_t \cdot \frac{m_t}{\sqrt{v_t} + \epsilon} - \alpha \lambda \theta_{t-1} \]

与 L2 正则化的关键区别:权重衰减作为单独步骤应用,使用 基础学习率 \(\alpha\),而不是偏差校正率。这是 AdamW 的“解耦”部分。

AdamW 内存计算:对于每个参数,AdamW 维护 2 个附加张量(\(m\)\(v\)),因此优化器状态在内存中需要 2× 模型参数。加上模型权重,总计为 3× 模型大小(不包括梯度)。包括梯度后,内存需求为 4× 模型大小,以 float32 计。

4.3 余弦学习率调度与预热

三个阶段(遵循 LLaMA):

  1. 线性预热\(t < T_w\)):\(\alpha_t = \frac{t}{T_w} \cdot \alpha_{\max}\)
  2. 余弦退火\(T_w \leq t \leq T_c\)):\(\alpha_t = \alpha_{\min} + \frac{1}{2}(1 + \cos(\frac{t - T_w}{T_c - T_w} \cdot \pi)) \cdot (\alpha_{\max} - \alpha_{\min})\)
  3. 恒定最小值\(t > T_c\)):\(\alpha_t = \alpha_{\min}\)

预热的目的:在训练的早期阶段,模型参数是随机初始化的,梯度可能非常嘈杂且大。学习率预热可以防止优化器采取过大的步骤,从而可能导致训练不稳定或发散。它为 Adam 的动量估计提供了时间,以便在使用完整学习率之前积累有意义的统计数据。

4.4 梯度裁剪

L2 范数梯度裁剪以确保训练稳定性:

\[ \text{如果 } \|g\|_2 > M: \quad g \leftarrow g \cdot \frac{M}{\|g\|_2 + \epsilon} \]

其中 \(M\) 是最大允许范数(通常为 1.0),\(\epsilon = 10^{-6}\)


5. 训练循环

5.1 数据加载

对于包含 \(n\) 个标记的数据集,每个批次随机采样 \(B\) 个起始位置并创建:

  • 输入dataset[i : i + context_length]
  • 目标dataset[i+1 : i+1 + context_length]

数据存储为内存映射的 uint16 numpy 数组,以便高效随机访问,而无需将整个数据集加载到 RAM 中。

5.2 检查点

检查点保存:

  • model_state_dict:所有模型参数
  • optimizer_state_dict:优化器状态(动量、步数)
  • iteration:当前训练步骤

这使得可以从任何检查点恢复训练,并完全恢复优化器状态。

5.3 训练配置

TinyStories(默认实验):

参数
词汇表大小 10,000
上下文长度 256
d_model 512
层数 4
头数 16
d_ff 1,344
学习率 1e-3(变化)
批量大小 256(变化)
最大步骤 5,000
预热步骤 500
权重衰减 0.1
梯度裁剪 1.0

6. 文本生成

6.1 自回归生成

模型一次生成一个标记的文本:

  1. 将提示编码为令牌 ID
  2. 通过模型获取下一个标记的对数
  3. 应用 温度缩放logits / temperature
  4. (可选)应用 top-p / 核采样:仅保留累积概率 ≤ p 的标记
  5. 从结果分布中采样
  6. 附加采样的标记并重复

温度 控制随机性:

  • T → 0:贪婪(argmax),确定性但重复
  • T = 1.0:从模型的分布中标准采样
  • T > 1.0:更随机,更多样但可能不连贯

Top-p(核)采样(Holtzman et al., 2019):与从完整分布中采样不同,仅保留累积概率超过 \(p\) 的最小标记集,然后重新归一化。这根据模型的置信度动态调整候选标记的数量。

6.2 生成样本

来自 TinyStories 模型的示例生成(温度=0.8,top_p=0.9):

提示:“从前有一个”

从前有一个小女孩,名叫 Lily。她喜欢在公园外面玩耍。一天,她在地上看到一个大红球。她捡起它,开始弹跳。 “看,妈妈!”她说。“我找到一个球了!” 她的妈妈微笑着说:“那是个好发现,Lily……”

模型成功学习到:

  • 连贯的叙事结构,具有开头、中间和结尾
  • 正确的语法和对话格式
  • 角色一致性(名称、代词)
  • 儿童故事的典型故事惯例(道德、简单冲突)

注意:<|endoftext|> 标记有时会出现在生成中。这不是一个错误——它是训练数据中使用的 文档分隔符。模型学习到这个标记标记故事之间的边界,并可能在其后生成新故事。


7. 实验

除非另有说明,所有实验均使用 TinyStories 数据集,配置如第 5.3 节所述。结果通过 Weights & Biases 记录。

7.1 学习率扫描

设置:固定 batch_size=256,max_steps=5000,warmup=500。扫描 lr ∈ {5e-4, 1e-3, 2e-3, 5e-3, 1e-2}。

学习率 最终验证损失 验证困惑度
1e-2 1.3004 3.671
5e-3 1.3171 3.733
2e-3 1.3567 3.883
1e-3 1.3974 4.045
5e-4 1.4930 4.450

分析:较高的学习率在 5000 步内始终实现较低的损失。最佳学习率为 lr=1e-2,验证损失为 1.3004,困惑度为 3.671。这有些令人惊讶——人们可能会期望如此高的学习率会导致不稳定,但预热、余弦退火、梯度裁剪和 RMSNorm 的组合提供了足够的正则化。

在这个范围内的趋势是单调的:更高的 LR → 更低的损失。这表明模型仍处于受益于更激进优化的状态,可能是因为 5000 步对于这个模型大小来说相对较少。

image-20260208225124153

image-20260208224958480

image-20260208225043193

7.2 批量大小实验

设置:固定 lr=1e-3,变化批量大小,并相应调整步骤,以保持大致相同的标记更新数量。

批量大小 步数 最终验证损失 验证困惑度
16 80,000 1.3264 3.768
64 20,000 1.3318 3.788
128 10,000 1.3560 3.881
512 2,500 1.4805 4.395

分析:较小的批量大小在相同的总标记数量下实现了更好的最终损失。最佳结果是批量大小为 16,验证损失为 1.3264。

这与“泛化差距”理论一致:较小的批量引入了更多的梯度估计噪声,这作为隐式正则化,可以导致更平坦的最小值,从而实现更好的泛化。然而,较小的批量由于较低的硬件利用率而计算成本更高。

在实践中,批量大小的选择涉及以下权衡:

  • 计算效率:较大的批量更好地利用 GPU 并行性
  • 泛化:较小的批量通常具有更好的泛化能力
  • 收敛速度:较小的批量需要更多的步骤,但看到相同数量的标记

image-20260208225239934

image-20260208225353926

7.3 消融研究

设置:固定 lr=1e-3,batch_size=256,max_steps=5000。每个消融修改基线架构的一个方面。

配置 最终验证损失 验证困惑度 Δ 损失
基线(预归一化 + RMSNorm + RoPE + SwiGLU) 1.3974 4.045
后归一化(而不是预归一化) 1.4095 4.094 +0.0121
无 RMSNorm(恒等归一化) 1.4400 4.221 +0.0426
SiLU FFN(而不是 SwiGLU) 1.4649 4.327 +0.0675
无 RoPE(NoPE — 无位置编码) 1.4712 4.354 +0.0738

按组件重要性分析(从最重要到最不重要):

  1. RoPE(Δ = +0.074):影响最大的组件。没有位置编码,模型无法区分标记顺序。值得注意的是,NoPE 仍然可以达到合理的困惑度(4.354),这表明模型可以部分通过语义上下文和因果掩码推断顺序。但位置信息显然提供了显著的提升。

  2. SwiGLU(Δ = +0.068):用 SiLU FFN 替换 SwiGLU(匹配参数计数)使损失增加 0.068。SwiGLU 中的门控机制提供了更细致的信息流控制,从而导致更好的表示学习。

  3. RMSNorm(Δ = +0.043):完全去除归一化会降低性能,确认了归一化对训练稳定性和表示质量的重要性。没有它,激活可能通过残差连接无限增长。

  4. 预归一化与后归一化(Δ = +0.012):差异最小。后归一化略微表现不如预归一化,与文献中显示的预归一化更稳定的训练一致。然而,对于这个模型大小和训练时长,差距很小。

image-20260208230125952

image-20260208230146528

7.4 OpenWebText (OWT) 训练

设置:GPT-2 小型架构(117M 参数)在 OpenWebText 上训练。

参数
配置 GPT-2 小型
词汇表大小 50,257
上下文长度 1,024
d_model 768
层数 12
头数 12
d_ff 2,048
批量大小 8
最大步骤 10,000
学习率 1e-3
指标
最终验证损失 3.9364
最终验证困惑度 51.236

分析:OWT 训练在 10K 步的训练中实现了约 51 的验证困惑度,这对于 117M 参数模型来说是合理的。作为参考:

  • GPT-2(117M)训练 300K 步在 WebText 上实现了约 30 的困惑度
  • 我们的模型看到的标记数量远少于此,但在训练过程中显示出明显的学习(损失持续下降)

主要瓶颈是 GPU 内存:batch_size=64 和上下文长度=1024 在单个 80GB A100 上导致 OOM。减少到 batch_size=8 解决了这个问题,但这意味着每步处理的标记更少。可以使用梯度累积来模拟更大的有效批量大小,而无需额外的内存成本。

image-20260208230235210

image-20260208230218497


8. 反思

8.1 我学到了什么

从头实现很重要。nn.Parametertorch.empty 构建每个组件迫使你理解每一步的确切数据流、形状和数值考虑。例如:

  • RMSNorm 精度:如果不转换为 float32,均方根计算可能在 float16/bfloat16 中溢出,导致 NaN 损失。这是一个微妙的错误,在 float32 的单元测试中无法捕获。

  • 权重初始化:截断正态初始化对训练稳定性有显著影响。过宽 → 梯度爆炸;过窄 → 梯度消失。公式 \(\sigma = \sqrt{2/(d_{in} + d_{out})}\)(类似 Glorot)保持各层之间的方差大致恒定。

  • 因果掩码效率:将因果掩码预计算为缓冲区并在每次前向传递时切片,比每次都新建它要高效得多,尤其是对于长序列。

分词器训练是隐藏的瓶颈。 在大型语料库上进行 BPE 训练需要仔细的内存管理:

  • 预分词可能生成数百万个唯一字节序列
  • 对频率表可能增长到数十 GB
  • 如果没有增量索引更新,每次合并迭代的复杂度将是 O(corpus_size)
  • 并行预分词提供了接近线性的加速,但 BPE 合并仍然是顺序的

超参数敏感性因组件而异。 学习率对训练动态的影响最大,而架构选择(预归一化与后归一化)对小模型的影响可能出乎意料地小。这表明,对于快速实验,花时间调整学习率比架构变化更有价值。

8.2 设计决策

  1. 共享 RoPE 模块:而不是每个注意力层创建自己的 RoPE,而是所有层共享一个实例。这节省了内存并确保一致的位置编码。

  2. 通过构造函数参数支持消融:而不是为每个消融创建单独的模型类,TransformerBlockTransformerLM 接受配置标志(norm_typeuse_post_normuse_ropeffn_type)。这保持了代码库的 DRY,同时支持所有实验变体。

  3. 内存映射数据加载:使用 np.memmap 处理训练数据,避免将整个数据集加载到 RAM 中。随机批量采样仅读取所需的切片,使得在比可用内存更大的数据集上进行训练成为可能。

  4. 并行分词器编码encode_parallel 方法在 <|endoftext|> 边界处拆分文本,并行编码块,将中间结果保存到磁盘并合并。这支持恢复并避免在大型文本上出现 OOM。

8.3 我会做的不同的事情

  • 学习率预热调整:我在所有实验中使用了固定的 500 预热步骤。根据配置调整这个值(例如,与总步骤成比例)可能会改善结果。
  • 梯度累积:对于 OWT 实验,实施梯度累积将允许在保持 GPU 内存限制的情况下使用有效批量大小为 64+,每步为 8。
  • 混合精度训练:使用 torch.cuda.amp 和 bfloat16 将减少内存使用并提高吞吐量,可能使更大的批量大小或更多步骤成为可能。
  • 生成的 KV 缓存:当前的生成实现为每个新标记重新计算所有注意力分数。KV 缓存将存储过去的键值对,从而将生成成本从 O(n²) 降低到 O(n)。

这篇博客文章记录了 CS336 作业 1(2025 年春季)的实现。所有代码均在 PyTorch 中从头编写,实验在匹兹堡超级计算中心(PSC)的 NVIDIA H100 GPU 上运行。