CS336 Assignment 1: Building a Transformer Language Model from Scratch
CS336 作业 1:从头构建 Transformer 语言模型
对实现完整 Transformer 语言模型管道的全面反思——从 BPE 分词器到文本生成——在 TinyStories 和 OpenWebText 上进行训练。
目录
1. 概述
本作业从头实现了一个完整的 Transformer 语言模型管道,未依赖于
torch.nn.Linear 或 torch.nn.Embedding
等高级库。代码库涵盖:
- 字节对编码 (BPE) 分词器,具有并行预分词
- 仅解码器 Transformer,使用 RMSNorm、RoPE、SwiGLU 和因果多头注意力
- 训练基础设施:AdamW 优化器、余弦学习率调度、梯度裁剪、数据加载、检查点
- 自回归文本生成,带温度和核(top-p)采样
- 实验:学习率扫描、批量大小研究、架构消融和 OpenWebText 训练
完整代码已发布在 GitHub: https://github.com/XLOverflow/CS336_Transformer_from_Scratch
项目结构
1 | cs336_basics/ |
2. BPE 分词器
2.1 Unicode 基础
问:Unicode 代码点与 UTF-8 编码之间有什么关系?
Unicode 为每个字符分配一个唯一的
代码点(例如,U+0041 表示 ‘A’)。UTF-8
是一种 可变长度编码,将代码点映射到 1–4 字节:
| 代码点范围 | UTF-8 字节 | 示例 |
|---|---|---|
| U+0000 – U+007F | 1 字节 | ASCII 字符 |
| U+0080 – U+07FF | 2 字节 | 拉丁文、希腊文、斯拉夫文 |
| U+0800 – U+FFFF | 3 字节 | CJK 字符,大多数表情符号 |
| U+10000 – U+10FFFF | 4 字节 | 稀有表情符号、历史脚本 |
UTF-8 向后兼容 ASCII,并且是自同步的:你总是可以判断一个字节是字符的开始还是继续字节。
问:为什么使用字节级分词而不是字符级分词?
字节级分词以 256 个字节值的基本词汇表开始,可以表示 任何 语言中的文本,而无需未知标记。字符级分词需要处理完整的 Unicode 范围(143,000+ 字符)作为基本词汇表。
2.2 BPE 训练算法
核心 BPE 训练过程:
- 初始化词汇表,包含 256 个字节值 +
特殊标记(例如,
<|endoftext|>) - 预分词,使用 GPT-2 正则表达式模式将文本拆分为“单词”
- 迭代合并 最频繁的相邻字节对,将合并的标记添加到词汇表中
- 重复直到达到目标词汇表大小
1 | PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" |
此正则表达式模式处理英语缩写('s、't、'll
等)、可选前导空格的单词、数字、标点符号和空白。
2.3 并行化策略
在大型语料库(例如,OpenWebText)上训练 BPE 分词器计算开销大。我的实现使用 并行预分词 和多进程:
- 查找与
<|endoftext|>标记对齐的块边界,以避免拆分文档 - 将块分配给工作进程,使用
multiprocessing.Pool - 每个工作进程应用正则表达式预分词并返回
频率计数(
Counter) - 在主进程中逐步合并频率计数,以控制内存使用
- BPE 合并顺序进行(因为每次合并依赖于前一次)
关键优化:
- 内存管理:定期进行垃圾回收和每 5000 次合并重建索引,以减少内存碎片
- 增量对更新:在每次合并后,不必从头重新计算所有对的频率,而是维护
pair_to_tuples和pair_freq索引,仅更新受影响的条目 - 批处理:工作进程以 16 的批量处理块,以控制并发内存使用
2.4 分词器实验
在 TinyStories 上的词汇表大小比较:
对于 TinyStories 数据集,我训练了词汇表大小为 10,000 的分词器。分词器成功学习了常见的英语单词和子词模式。例如:
- 常见单词如 “the”、“and”、“once” 成为单个标记
- 不太常见的单词被拆分为学习到的子词单元
<|endoftext|>被处理为一个特殊标记,不参与 BPE 合并
编码:编码器贪婪地应用 BPE 合并——对于每个预分词的单词,它从单个字节开始,并重复合并优先级最高的对(在合并列表中最早的)直到没有更多合并适用。
解码:简单地连接每个标记 ID 的字节值,并将结果解码为 UTF-8。
3. Transformer 架构

3.1 线性层(无偏置)
遵循现代 LLM 实践(PaLM、LLaMA),所有线性层省略偏置项:
\[ y = xW^T \]
初始化:截断正态分布 \(\mathcal{N}(0, \sigma^2)\),其中 \(\sigma = \sqrt{2 / (d_{in} + d_{out})}\),截断在 \([-3\sigma, 3\sigma]\)。
1 | class Linear(nn.Module): |
3.2 令牌嵌入
简单的查找表,将令牌 ID 映射到密集向量:
\[ \text{embed}(x) = E[x] \]
其中 \(E \in \mathbb{R}^{V \times d_{model}}\) 使用截断正态 \(\mathcal{N}(0, 1)\) 初始化。
3.3 RMSNorm
均方根层归一化(Zhang & Sennrich, 2019),在 LLaMA 中使用,而不是 LayerNorm:
\[ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \quad \text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon} \]
关键实现细节:在计算 RMS 之前转换为 float32 以确保数值稳定性,然后再转换回原始数据类型。
1 | def forward(self, x): |
3.4 旋转位置嵌入 (RoPE)
RoPE(Su et al., 2021)通过对查询和键向量应用旋转来编码 相对 位置信息:
\[ \text{RoPE}(x, m) = \begin{pmatrix} x_0 \cos(m\theta_0) - x_1 \sin(m\theta_0) \\ x_0 \sin(m\theta_0) + x_1 \cos(m\theta_0) \\ \vdots \\ x_{d-2} \cos(m\theta_{d/2-1}) - x_{d-1} \sin(m\theta_{d/2-1}) \\ x_{d-2} \sin(m\theta_{d/2-1}) + x_{d-1} \cos(m\theta_{d/2-1}) \end{pmatrix} \]
其中 \(\theta_k = \theta_{\text{base}}^{-2k/d_k}\),对于 \(k = 0, \ldots, d_k/2 - 1\)。
关键特性:
- 无可学习参数:RoPE 完全由位置和频率计算得出
- 仅应用于 Q 和 K(不适用于 V)
- 捕获 相对 位置:\(q_m^T k_n\) 仅依赖于 \(m - n\)
- 在所有层中共享(一个 RoPE 模块实例)
1 | class RotaryPositionalEmbedding(nn.Module): |
3.5 Softmax
使用最大减法技巧的数值稳定 softmax:
\[ \text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \]
3.6 缩放点积注意力
\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]
\(\sqrt{d_k}\) 的缩放防止点积的大小过大,这会将 softmax 推入极小梯度的区域。
实现使用 einops.einsum
提高可读性,并支持任意批量维度:
1 | attn_scores = einsum(q, k, "b ... q d_k, b ... k d_k -> b ... q k") / (d_k ** 0.5) |
3.7 多头自注意力
将模型维度拆分为多个头以实现并行注意力:
\[ \text{MultiHead}(x) = W_O \cdot \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \]
\[ \text{head}_i = \text{Attention}(xW_Q^i, xW_K^i, xW_V^i) \]
处理流程:
- 使用单独的线性层将输入投影到 Q、K、V
- 重塑为单独的头部:
(batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k) - 对 Q 和 K 应用 RoPE
- 使用因果掩码(下三角)应用缩放点积注意力
- 连接头部并重新投影
3.8 前馈网络

SwiGLU(Shazeer, 2020):带有 SiLU 激活的门控 FFN \[ \text{SwiGLU}(x) = W_2 \cdot (\text{SiLU}(W_1 x) \odot W_3 x) \]
其中 \(\text{SiLU}(x) = x \cdot \sigma(x)\),\(\odot\) 是逐元素乘法。
SwiGLU 使用 \(d_{ff} \approx \frac{8}{3} d_{model}\)(四舍五入为 64 的倍数),具有 3 个权重矩阵,总参数量约为 \(\approx 3 \times d_{model} \times \frac{8}{3} d_{model} = 8 d_{model}^2\)。
SiLU FFN(用于消融):标准的 2 层 FFN
\[ \text{SiLUFFN}(x) = W_2 \cdot \text{SiLU}(W_1 x) \]
使用 \(d_{ff} = 4 \times d_{model}\),具有 2 个权重矩阵,总参数量约为 \(\approx 2 \times d_{model} \times 4 d_{model} = 8 d_{model}^2\)。这与 SwiGLU 的参数计数相匹配,以便进行公平的消融比较。
3.9 Transformer 块
预归一化(默认): \[ z = x + \text{Attention}(\text{RMSNorm}(x)) \]
\[ y = z + \text{FFN}(\text{RMSNorm}(z)) \]
后归一化(消融): \[ z = \text{RMSNorm}(x + \text{Attention}(x)) \]
\[ y = \text{RMSNorm}(z + \text{FFN}(z)) \]
预归一化在现代 LLM 中更受欢迎,因为它稳定了训练——残差连接保留了输入的大小,而在子层之前进行归一化可以防止激活无限增长。
3.10 完整的 Transformer LM
完整的仅解码器架构:
- 令牌嵌入:
token_ids → (batch, seq_len, d_model) - N 个 Transformer 块:应用自注意力 + FFN,带有残差连接
- 最终 RMSNorm:对输出进行归一化
- LM 头:线性投影到词汇对数
(batch, seq_len, vocab_size)
3.11 Transformer 计算:参数、内存、FLOPs 和训练时间
设 \(B\) = batch_size,\(T\) = context_length,\(d\) = d_model,\(L\) = num_layers,\(H\) = num_heads,\(V\) = vocab_size,\(d_{ff} = 4d\)
3.11.1 参数计数 \(P\)
| 组件 | 参数 |
|---|---|
| 每层注意力(\(W_Q, W_K, W_V, W_O\)) | \(4d^2\) |
| 每层 FFN(SwiGLU: \(W_1, W_2, W_3\)) | \(3 \times d \times d_{ff} = 12d^2\) |
| 每层 RMSNorm ×2 | \(2d\) |
| 令牌嵌入 | \(Vd\) |
| LM 头 | \(Vd\) |
| 最终 RMSNorm | \(d\) |
\[ \boxed{P = L(16d^2 + 2d) + 2Vd + d} \]
3.11.2 训练内存分析
在训练期间,GPU 内存由四部分组成(float32 = 4 字节):
| 组件 | 公式 | 描述 |
|---|---|---|
| 参数 | \(4P\) | 每个参数 4 字节 |
| 梯度 | \(4P\) | 与参数大小相同 |
| 优化器 (m+v) | \(8P\) | AdamW 存储 2 个与参数形状相同的张量 |
| 激活 | 见下文 | 与 batch_size 成正比 |
每层激活内存(为反向传播保存的中间结果):
| 组件 | 形状 | 元素计数 |
|---|---|---|
| RMSNorm 输入 ×2 | \((B,T,d)\) ×2 | \(2BTd\) |
| Q, K, V | \((B,T,d)\) ×3 | \(3BTd\) |
| Softmax 输出 | \((B,H,T,T)\) | \(BHT^2\) |
| 注意力输出 | \((B,T,d)\) | \(BTd\) |
| W1 输出(用于 SiLU 反向传播) | \((B,T,d_{ff})\) | \(4BTd\) |
| W3 输出 | \((B,T,d_{ff})\) | \(4BTd\) |
| SiLU 输出 | \((B,T,d_{ff})\) | \(4BTd\) |
| Gate⊙Value = W2 输入 | \((B,T,d_{ff})\) | \(4BTd\) |
每层激活 ≈ \(22BTd + BHT^2\)
加上非层组件:嵌入输出 (\(BTd\)) + 对数 (\(BTV\)) + 交叉熵 softmax (\(BTV\)) ≈ \(BTd + 2BTV\)
\[ \text{总激活内存} = 4 \times \left[L(22BTd + BHT^2) + BTd + 2BTV\right] \text{ 字节} \]
\[ \boxed{\text{峰值内存} = 16P + 4BT\left[L(22d + HT) + d + 2V\right]} \]
3.11.3 GPT-2 XL 具体示例
\(d=1600, L=48, H=25, T=1024, V=50257\)
(a) 详细参数计数:
每层参数:
| 组件 | 参数 |
|---|---|
| \(W_Q, W_K, W_V, W_O\) | \(4 \times d^2 = 4 \times 2{,}560{,}000 = 10{,}240{,}000\) |
| \(W_1, W_2, W_3\) (FFN) | \(3 \times d \times d_{ff} = 3 \times 10{,}240{,}000 = 30{,}720{,}000\) |
| 2 × RMSNorm | \(2 \times 1{,}600 = 3{,}200\) |
| 每层总计 | 40,963,200 |
完整模型:
| 组件 | 参数 |
|---|---|
| 48 层 | \(48 \times 40{,}963{,}200 = 1{,}966{,}233{,}600\) |
| 令牌嵌入 (\(V \times d\)) | \(80{,}411{,}200\) |
| LM 头 | \(80{,}411{,}200\) |
| 最终 RMSNorm | \(1{,}600\) |
| 总计 | ≈ 2.13B |
参数内存:\(2.13\text{B} \times 4 \text{ 字节} \approx 8.51 \text{ GB}\)
(b) 内存分析:
模型相关内存(固定):\(16P = 16 \times 2.13 \times 10^9 \approx 34.0 \text{ GB}\)
每个批量元素的激活内存: \[ L(22d + HT) + d + 2V = 48(22 \times 1600 + 25 \times 1024) + 1600 + 2 \times 50257 \]
\[ = 48(35200 + 25600) + 102114 = 48 \times 60800 + 102114 = 2{,}920{,}514 \]
\[ \text{每个批量元素}: 4 \times 1024 \times 2{,}920{,}514 \approx 12.0 \text{ GB} \]
在 80GB A100 上的最大批量大小: \[ \text{总内存}: 34.0 + 12.0 \times B \leq 80 \text{ GB} \]
\[ B \leq (80 - 34) / 12 \approx 3.8 \rightarrow \boxed{B_{\max} = 3} \]
3.11.4 为什么前向传播 ≈ 2 × 参数 FLOPs/标记?
Transformer 计算主要由矩阵乘法主导。对于矩阵乘法 \(Y = X \times W\),其中 \(W\) 的形状为 \((d_{in}, d_{out})\):
- 每个输出元素需要 \(d_{in}\) 次乘法 + \(d_{in}\) 次加法 = \(2d_{in}\) FLOPs
- 每个标记有 \(d_{out}\) 个输出元素
- 总 FLOPs = \(2 \times d_{in} \times d_{out}\) = 2 × 参数计数
每层矩阵乘法细分(×\(L\) 层),使用 GPT-2 XL 数字(\(d=1600, T=1024, H=25, d_k=64, d_{ff}=6400\)):
| 操作 | 维度 | FLOPs 公式 | FLOPs |
|---|---|---|---|
| Q 投影 | \((T,d) \times (d,d)\) | \(2Td^2\) | 5.24B |
| K 投影 | 同上 | \(2Td^2\) | 5.24B |
| V 投影 | 同上 | \(2Td^2\) | 5.24B |
| O 投影 | 同上 | \(2Td^2\) | 5.24B |
| \(QK^T\)(\(H\) 头) | \(H \times (T,d_k) \times (d_k,T)\) | \(2T^2d\) | 3.36B |
| attn_weights × V | \(H \times (T,T) \times (T,d_k)\) | \(2T^2d\) | 3.36B |
| FFN W1 | \((T,d) \times (d,d_{ff})\) | \(2Td \cdot d_{ff}\) | 20.97B |
| FFN W3 | 同上 | \(2Td \cdot d_{ff}\) | 20.97B |
| FFN W2 | \((T,d_{ff}) \times (d_{ff},d)\) | \(2Td \cdot d_{ff}\) | 20.97B |
| 每层总计 | 90.60B |
模型级 FLOPs:
| 组件 | FLOPs |
|---|---|
| 48 层 | 4,348.7B |
| LM 头: \((T,d) \times (d,V)\) | 164.7B |
| 总计 | ≈ 4.51 TFLOPs |
3.11.5 FLOPs 在模型大小之间的细分
(c) 每层,FFN 占 ~69.5%(62.91B / 90.60B),使其成为计算最密集的组件。注意力投影占 23.1%,而注意力分数(\(QK^T\) + attn×V)仅占 7.4%。
| 组件 | 小型 (12L, 768) | 中型 (24L, 1024) | 大型 (36L, 1280) | XL (48L, 1600) |
|---|---|---|---|---|
| 注意力投影 | 16.6% | 20.0% | 21.4% | 22.3% |
| 注意力分数 (\(QK^T\) 等) | 11.1% | 10.0% | 8.6% | 7.1% |
| FFN | 49.7% | 59.9% | 64.2% | 66.9% |
| LM 头 | 22.6% | 10.2% | 5.8% | 3.7% |
(d) 趋势:随着模型变大,FFN 的比例增加(50% → 67%),而 LM 头的比例显著下降(23% → 4%)。这是因为 LM 头在每层的大小是固定的(与 vocab_size 相关),而 FFN 随着层数和 d_model 的增加而增长。
3.11.6 上下文长度缩放:为什么 FlashAttention 重要
(e) 将上下文长度从 1024 增加到 16384:
| 组件 | T=1024 | T=16384 |
|---|---|---|
| 注意力投影 | 22.3% | 10.8% |
| 注意力分数 | 7.1% | 55.2% |
| FFN | 66.9% | 32.3% |
| LM 头 | 3.7% | 1.8% |
| 总 FLOPs | 4.51T | ≈ 149.5T (33×) |
总 FLOPs 增加约 33×(而不是 16×!),因为注意力分数的缩放为 \(O(T^2)\)。当上下文长度增长 16× 时,注意力分数从 7.1% 跃升至 55.2%,成为主导成本。这正是为什么长上下文模型需要 FlashAttention 和其他 IO 感知的注意力优化——二次注意力成本在长序列中压倒了线性 FFN 成本。
3.11.7 为什么反向传播 ≈ 2× 前向?
对于每个矩阵乘法 \(Y = XW\),反向传播需要计算两个梯度:
- \(\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} \times W^T\)(一个矩阵乘法)
- \(\frac{\partial L}{\partial W} = X^T \times \frac{\partial L}{\partial Y}\)(一个矩阵乘法)
这就是反向传播需要 2 次矩阵乘法,而前向传播只需要 1 次。因此:
\[ \boxed{\text{反向} \approx 2 \times \text{前向}} \]
\[ \text{总计 = 前向 + 反向} \approx 3 \times \text{前向} = 6PBT \]
3.11.8 AdamW 每步 FLOPs
每个参数执行的操作:
| 操作 | 公式 | 每个参数的 FLOPs |
|---|---|---|
| 更新 m | \(m = \beta_1 m + (1-\beta_1)g\) | 3(2 次乘法 + 1 次加法) |
| 更新 v | \(v = \beta_2 v + (1-\beta_2)g^2\) | 4(3 次乘法 + 1 次加法) |
| 参数更新 | \(p = \alpha_t \cdot m/(\sqrt{v}+\epsilon)\) | 5(平方根、加法、除法、乘法、减法) |
| 权重衰减 | \(p -= lr \times \lambda \times p\) | 2(乘法 + 减法) |
\[ \text{AdamW FLOPs} = 14P \]
(偏差校正 \(\alpha_t\) 是标量计算,可以忽略。远小于前向/反向 FLOPs。)
3.11.9 GPT-2 XL 训练时间估计
每步 FLOPs:
- 前向 ≈ \(2P \times B \times T\)(每个参数每个标记大约执行 ~2 次操作)
- 反向 ≈ \(2 \times\) 前向
- 总计 ≈ \(3 \times\) 前向 = \(6PBT\)
替换 GPT-2 XL(\(B=1024, T=1024\)): \[ \text{每步} = 6 \times 2.13 \times 10^9 \times 1024 \times 1024 = 1.34 \times 10^{16} \text{ FLOPs/步} \]
总计 400K 步:\(400{,}000 \times 1.34 \times 10^{16} = 5.36 \times 10^{21}\) FLOPs
有效吞吐量:50% × 19.5 TFLOP/s = \(9.75 \times 10^{12}\) FLOP/s
\[ \text{时间} = \frac{5.36 \times 10^{21}}{9.75 \times 10^{12}} \approx 5.5 \times 10^8 \text{ 秒} \approx 6{,}360 \text{ 天} \approx \boxed{17.4 \text{ 年}} \]
这解释了为什么大规模模型训练需要大量 GPU 并行性——在单个 A100 上训练 GPT-2 XL 将需要 17 年!
4. 训练基础设施
4.1 交叉熵损失
使用对数和指数技巧的数值稳定实现:
\[ \ell_i = -\log \text{softmax}(o_i)[x_{i+1}] = \log\left(\sum_j e^{o_j - o_{\max}}\right) - (o_{x_{i+1}} - o_{\max}) \]
1 | def cross_entropy(inputs, targets): |
4.2 AdamW 优化器
实现 AdamW(Loshchilov & Hutter, 2019)与 解耦权重衰减:
\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]
\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]
\[ \hat{\alpha}_t = \alpha \cdot \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t} \]
\[ \theta_t = \theta_{t-1} - \hat{\alpha}_t \cdot \frac{m_t}{\sqrt{v_t} + \epsilon} - \alpha \lambda \theta_{t-1} \]
与 L2 正则化的关键区别:权重衰减作为单独步骤应用,使用 基础学习率 \(\alpha\),而不是偏差校正率。这是 AdamW 的“解耦”部分。
AdamW 内存计算:对于每个参数,AdamW 维护 2 个附加张量(\(m\) 和 \(v\)),因此优化器状态在内存中需要 2× 模型参数。加上模型权重,总计为 3× 模型大小(不包括梯度)。包括梯度后,内存需求为 4× 模型大小,以 float32 计。
4.3 余弦学习率调度与预热
三个阶段(遵循 LLaMA):
- 线性预热(\(t < T_w\)):\(\alpha_t = \frac{t}{T_w} \cdot \alpha_{\max}\)
- 余弦退火(\(T_w \leq t \leq T_c\)):\(\alpha_t = \alpha_{\min} + \frac{1}{2}(1 + \cos(\frac{t - T_w}{T_c - T_w} \cdot \pi)) \cdot (\alpha_{\max} - \alpha_{\min})\)
- 恒定最小值(\(t > T_c\)):\(\alpha_t = \alpha_{\min}\)
预热的目的:在训练的早期阶段,模型参数是随机初始化的,梯度可能非常嘈杂且大。学习率预热可以防止优化器采取过大的步骤,从而可能导致训练不稳定或发散。它为 Adam 的动量估计提供了时间,以便在使用完整学习率之前积累有意义的统计数据。
4.4 梯度裁剪
L2 范数梯度裁剪以确保训练稳定性:
\[ \text{如果 } \|g\|_2 > M: \quad g \leftarrow g \cdot \frac{M}{\|g\|_2 + \epsilon} \]
其中 \(M\) 是最大允许范数(通常为 1.0),\(\epsilon = 10^{-6}\)。
5. 训练循环
5.1 数据加载
对于包含 \(n\) 个标记的数据集,每个批次随机采样 \(B\) 个起始位置并创建:
- 输入:
dataset[i : i + context_length] - 目标:
dataset[i+1 : i+1 + context_length]
数据存储为内存映射的 uint16 numpy 数组,以便高效随机访问,而无需将整个数据集加载到 RAM 中。
5.2 检查点
检查点保存:
model_state_dict:所有模型参数optimizer_state_dict:优化器状态(动量、步数)iteration:当前训练步骤
这使得可以从任何检查点恢复训练,并完全恢复优化器状态。
5.3 训练配置
TinyStories(默认实验):
| 参数 | 值 |
|---|---|
| 词汇表大小 | 10,000 |
| 上下文长度 | 256 |
| d_model | 512 |
| 层数 | 4 |
| 头数 | 16 |
| d_ff | 1,344 |
| 学习率 | 1e-3(变化) |
| 批量大小 | 256(变化) |
| 最大步骤 | 5,000 |
| 预热步骤 | 500 |
| 权重衰减 | 0.1 |
| 梯度裁剪 | 1.0 |
6. 文本生成
6.1 自回归生成
模型一次生成一个标记的文本:
- 将提示编码为令牌 ID
- 通过模型获取下一个标记的对数
- 应用
温度缩放:
logits / temperature - (可选)应用 top-p / 核采样:仅保留累积概率 ≤ p 的标记
- 从结果分布中采样
- 附加采样的标记并重复
温度 控制随机性:
T → 0:贪婪(argmax),确定性但重复T = 1.0:从模型的分布中标准采样T > 1.0:更随机,更多样但可能不连贯
Top-p(核)采样(Holtzman et al., 2019):与从完整分布中采样不同,仅保留累积概率超过 \(p\) 的最小标记集,然后重新归一化。这根据模型的置信度动态调整候选标记的数量。
6.2 生成样本
来自 TinyStories 模型的示例生成(温度=0.8,top_p=0.9):
提示:“从前有一个”
从前有一个小女孩,名叫 Lily。她喜欢在公园外面玩耍。一天,她在地上看到一个大红球。她捡起它,开始弹跳。 “看,妈妈!”她说。“我找到一个球了!” 她的妈妈微笑着说:“那是个好发现,Lily……”
模型成功学习到:
- 连贯的叙事结构,具有开头、中间和结尾
- 正确的语法和对话格式
- 角色一致性(名称、代词)
- 儿童故事的典型故事惯例(道德、简单冲突)
注意:<|endoftext|>
标记有时会出现在生成中。这不是一个错误——它是训练数据中使用的
文档分隔符。模型学习到这个标记标记故事之间的边界,并可能在其后生成新故事。
7. 实验
除非另有说明,所有实验均使用 TinyStories 数据集,配置如第 5.3 节所述。结果通过 Weights & Biases 记录。
7.1 学习率扫描
设置:固定 batch_size=256,max_steps=5000,warmup=500。扫描 lr ∈ {5e-4, 1e-3, 2e-3, 5e-3, 1e-2}。
| 学习率 | 最终验证损失 | 验证困惑度 |
|---|---|---|
| 1e-2 | 1.3004 | 3.671 |
| 5e-3 | 1.3171 | 3.733 |
| 2e-3 | 1.3567 | 3.883 |
| 1e-3 | 1.3974 | 4.045 |
| 5e-4 | 1.4930 | 4.450 |
分析:较高的学习率在 5000 步内始终实现较低的损失。最佳学习率为 lr=1e-2,验证损失为 1.3004,困惑度为 3.671。这有些令人惊讶——人们可能会期望如此高的学习率会导致不稳定,但预热、余弦退火、梯度裁剪和 RMSNorm 的组合提供了足够的正则化。
在这个范围内的趋势是单调的:更高的 LR → 更低的损失。这表明模型仍处于受益于更激进优化的状态,可能是因为 5000 步对于这个模型大小来说相对较少。



7.2 批量大小实验
设置:固定 lr=1e-3,变化批量大小,并相应调整步骤,以保持大致相同的标记更新数量。
| 批量大小 | 步数 | 最终验证损失 | 验证困惑度 |
|---|---|---|---|
| 16 | 80,000 | 1.3264 | 3.768 |
| 64 | 20,000 | 1.3318 | 3.788 |
| 128 | 10,000 | 1.3560 | 3.881 |
| 512 | 2,500 | 1.4805 | 4.395 |
分析:较小的批量大小在相同的总标记数量下实现了更好的最终损失。最佳结果是批量大小为 16,验证损失为 1.3264。
这与“泛化差距”理论一致:较小的批量引入了更多的梯度估计噪声,这作为隐式正则化,可以导致更平坦的最小值,从而实现更好的泛化。然而,较小的批量由于较低的硬件利用率而计算成本更高。
在实践中,批量大小的选择涉及以下权衡:
- 计算效率:较大的批量更好地利用 GPU 并行性
- 泛化:较小的批量通常具有更好的泛化能力
- 收敛速度:较小的批量需要更多的步骤,但看到相同数量的标记


7.3 消融研究
设置:固定 lr=1e-3,batch_size=256,max_steps=5000。每个消融修改基线架构的一个方面。
| 配置 | 最终验证损失 | 验证困惑度 | Δ 损失 |
|---|---|---|---|
| 基线(预归一化 + RMSNorm + RoPE + SwiGLU) | 1.3974 | 4.045 | — |
| 后归一化(而不是预归一化) | 1.4095 | 4.094 | +0.0121 |
| 无 RMSNorm(恒等归一化) | 1.4400 | 4.221 | +0.0426 |
| SiLU FFN(而不是 SwiGLU) | 1.4649 | 4.327 | +0.0675 |
| 无 RoPE(NoPE — 无位置编码) | 1.4712 | 4.354 | +0.0738 |
按组件重要性分析(从最重要到最不重要):
RoPE(Δ = +0.074):影响最大的组件。没有位置编码,模型无法区分标记顺序。值得注意的是,NoPE 仍然可以达到合理的困惑度(4.354),这表明模型可以部分通过语义上下文和因果掩码推断顺序。但位置信息显然提供了显著的提升。
SwiGLU(Δ = +0.068):用 SiLU FFN 替换 SwiGLU(匹配参数计数)使损失增加 0.068。SwiGLU 中的门控机制提供了更细致的信息流控制,从而导致更好的表示学习。
RMSNorm(Δ = +0.043):完全去除归一化会降低性能,确认了归一化对训练稳定性和表示质量的重要性。没有它,激活可能通过残差连接无限增长。
预归一化与后归一化(Δ = +0.012):差异最小。后归一化略微表现不如预归一化,与文献中显示的预归一化更稳定的训练一致。然而,对于这个模型大小和训练时长,差距很小。


7.4 OpenWebText (OWT) 训练
设置:GPT-2 小型架构(117M 参数)在 OpenWebText 上训练。
| 参数 | 值 |
|---|---|
| 配置 | GPT-2 小型 |
| 词汇表大小 | 50,257 |
| 上下文长度 | 1,024 |
| d_model | 768 |
| 层数 | 12 |
| 头数 | 12 |
| d_ff | 2,048 |
| 批量大小 | 8 |
| 最大步骤 | 10,000 |
| 学习率 | 1e-3 |
| 指标 | 值 |
|---|---|
| 最终验证损失 | 3.9364 |
| 最终验证困惑度 | 51.236 |
分析:OWT 训练在 10K 步的训练中实现了约 51 的验证困惑度,这对于 117M 参数模型来说是合理的。作为参考:
- GPT-2(117M)训练 300K 步在 WebText 上实现了约 30 的困惑度
- 我们的模型看到的标记数量远少于此,但在训练过程中显示出明显的学习(损失持续下降)
主要瓶颈是 GPU 内存:batch_size=64 和上下文长度=1024 在单个 80GB A100 上导致 OOM。减少到 batch_size=8 解决了这个问题,但这意味着每步处理的标记更少。可以使用梯度累积来模拟更大的有效批量大小,而无需额外的内存成本。


8. 反思
8.1 我学到了什么
从头实现很重要。 从 nn.Parameter 和
torch.empty
构建每个组件迫使你理解每一步的确切数据流、形状和数值考虑。例如:
RMSNorm 精度:如果不转换为 float32,均方根计算可能在 float16/bfloat16 中溢出,导致 NaN 损失。这是一个微妙的错误,在 float32 的单元测试中无法捕获。
权重初始化:截断正态初始化对训练稳定性有显著影响。过宽 → 梯度爆炸;过窄 → 梯度消失。公式 \(\sigma = \sqrt{2/(d_{in} + d_{out})}\)(类似 Glorot)保持各层之间的方差大致恒定。
因果掩码效率:将因果掩码预计算为缓冲区并在每次前向传递时切片,比每次都新建它要高效得多,尤其是对于长序列。
分词器训练是隐藏的瓶颈。 在大型语料库上进行 BPE 训练需要仔细的内存管理:
- 预分词可能生成数百万个唯一字节序列
- 对频率表可能增长到数十 GB
- 如果没有增量索引更新,每次合并迭代的复杂度将是 O(corpus_size)
- 并行预分词提供了接近线性的加速,但 BPE 合并仍然是顺序的
超参数敏感性因组件而异。 学习率对训练动态的影响最大,而架构选择(预归一化与后归一化)对小模型的影响可能出乎意料地小。这表明,对于快速实验,花时间调整学习率比架构变化更有价值。
8.2 设计决策
共享 RoPE 模块:而不是每个注意力层创建自己的 RoPE,而是所有层共享一个实例。这节省了内存并确保一致的位置编码。
通过构造函数参数支持消融:而不是为每个消融创建单独的模型类,
TransformerBlock和TransformerLM接受配置标志(norm_type、use_post_norm、use_rope、ffn_type)。这保持了代码库的 DRY,同时支持所有实验变体。内存映射数据加载:使用
np.memmap处理训练数据,避免将整个数据集加载到 RAM 中。随机批量采样仅读取所需的切片,使得在比可用内存更大的数据集上进行训练成为可能。并行分词器编码:
encode_parallel方法在<|endoftext|>边界处拆分文本,并行编码块,将中间结果保存到磁盘并合并。这支持恢复并避免在大型文本上出现 OOM。
8.3 我会做的不同的事情
- 学习率预热调整:我在所有实验中使用了固定的 500 预热步骤。根据配置调整这个值(例如,与总步骤成比例)可能会改善结果。
- 梯度累积:对于 OWT 实验,实施梯度累积将允许在保持 GPU 内存限制的情况下使用有效批量大小为 64+,每步为 8。
- 混合精度训练:使用
torch.cuda.amp和 bfloat16 将减少内存使用并提高吞吐量,可能使更大的批量大小或更多步骤成为可能。 - 生成的 KV 缓存:当前的生成实现为每个新标记重新计算所有注意力分数。KV 缓存将存储过去的键值对,从而将生成成本从 O(n²) 降低到 O(n)。
这篇博客文章记录了 CS336 作业 1(2025 年春季)的实现。所有代码均在 PyTorch 中从头编写,实验在匹兹堡超级计算中心(PSC)的 NVIDIA H100 GPU 上运行。




