CS336-Lec2 PyTorch & Resource accounting
Lec 2: PyTorch & Resource accounting
数值格式(Float Pointing Formats)
深入计算之前,我们需要知道我们在计算的是什么。这其中涉及到很多的对运算效率和数值精度的Tradeoffs。
任何浮点数在计算机中都由3部分的bit位组成:符号位(Sign)、指数位(Exponent)和尾数位(Fraction)。其计算公式为 \[ (-1)^S \times 1.M \times 2^{E-Bias} \]
- S决定正负
- M决定精度,小数点后的部分,位数越多,数字越密
- E决定动态的范围,位数越多,表示的范围越大
- Bias偏置,固定常数,让指数可以是负数,平衡正负范围
为什么不全用FP32
标准的FP32(Single Precision)拥有1个符号、8个指数、23位位数
- 优点:精度高,范围大,训练stable
- 缺点:慢,占用显存大
- 现代大模型的训练中,FP32主要用于优化器更新时的备份,而非核心计算
FP16和BF16
为了寻求更快的计算效率,业界主要是16-bit的格式,这里主要是由FP16和BF16组成。
FP16(Half Precision):
结构:5位指数+10位尾数
缺陷:指数位太少,意味着动态范围很小,如果数值太小直接会Underflow,比如\(1e^-8\)就会有这种情况
BF16(Brain Float 16):
- 结构:8位指数+7位尾数
- 核心优势:指数位和FP32一样,意味着拥有和FP32同样的动态范围,精度稍微差点。
- 结论:深度学习对”范围”敏感,对“精度”不敏感,神经网络本身就是模糊近似,因此BF16是目前大模型训练的标准格式。
算力的度量:FLOPs
矩阵乘法(MatMul)
深度学习的核心是矩阵乘法\(Y = X \times W\),计算代价的公式如下: \[ FLOPs = 2 \cdot B \cdot D \cdot K \]
- B: Batch Size
- D: Input Dimension
- K: Output Dimension
- Origin of 2: 计算机底层的一次FMA(Fused Multiply-Add)指令通常是2个FLOPs,1次乘法和1次加法
训练成本定律:1:2:6
这是本节课最重要的Rule of Thumb,为什么训练比推理贵那么多?
整个训练过程分为前向和反向传播,分别计算代价
- Forward Pass: 计算\(H = XW\)。消耗1个单位的算力
- Backward Pass: 需要计算2份梯度,第一个是权重梯度\(dW\),输入梯度\(dX\)。前者用于更新参数,后者用于传播给上一层。
对于一个包含P个参数的模型,是用N个Token进行训练,总浮点数的运算量约为: \[ C \approx 6 \cdot N \cdot P \] 前向2个FLOPs,反向4个FLOPs,所以总共是6个,由此得到了上面的比率。
我们在这里所讨论的参数量(Parameter Count),永远指的是物理上独立存在(不包括共享参数)、需要被优化器更新的变量总数。
对于共享参数而言,其的确能够减少参数量,但计算层面其实计算量FLOPs并没有减少,因为你还是需要穿过这些层的网络结构执行对应的计算。
之所以在这里使用了\(\approx\)是因为我们忽略了两样东西:
- Embedding层和Softmax层:这两层的参数量P并不少,词表很大,但是并不完全遵循2P的计算逻辑,不过相对于中间几十层的Tranformer Block,误差可以接受。
- Attention的\(N^2\)计算:我们此前只假设算力和参数量P有关,但是Attention己之力,Query和Key的相乘计算量是\(N^2\),这部分的计算不消耗参数,但只要N(序列长度)远小于P,这一项就能忽略不计。
这里可能会有一个疑问Attention为什么是\(N^2\)?
其实这里更像是时间复杂度的简化,精确的计算公式是:
\[ 12N^2LH = \underbrace{L}_{\text{层数}} \times \underbrace{(4N^2H)}_{\text{前向传播}} \times \underbrace{3}_{\text{训练系数}} \]
- L是Transformer的层数,模型是由L个完全一样的Block堆叠起来的
- 前向传播的部分是QKV矩阵的乘法运算
- 训练系数遵循前向传播1,反向传播2,总共是3
我们在比较的是:“巨大的参数量 P” vs “序列长度 N”。在目前的标准模型配置下,前者比后者大太多了。
硬件:H100与Tensor Cores
GPU中有通用核心和专用核心:
- CUDA Cores: 通用灵活但慢,处理FP32/FP64
- Tensor Cores: 专为矩阵乘法设计,H100的Tensor Core每个时钟周期能执行规模\(4 \times 4\) 矩阵运算,支持混合精度,输入可以是FP16/BF16,内部累加用FP32,兼顾速度和数值的稳定性。
稀疏性(Sparity):
NVIDIA宣传H100拥有1979TFLOPs的FP16算力。但真相是这是开启(Structured Sparsity(2:4))后的理论值,要求每4个元素必须有2个是0才能跳过0计算。
现实中,大部分训练都是Dense训练,所以算力直接减半。
工程实践:代码和维度管理
维度地狱 (Dimensionality Hell)
在 Transformer 代码中,手动处理维度(如 view,
transpose)极易出错且难以调试:
1 | x = x.view(B, N, H, D).permute(0, 2, 1, 3) # 这里的 0,2,1,3 到底是谁? |
Einops库
课程强烈推荐使用 einops
库,它让维度变换变得“声明式”和“自文档化”:
1 | from einops import rearrange |
原则:代码不仅是给机器跑的,更是给人看的。显式优于隐式。
Optimizer 优化器
AdamW来源
我们现在主流的优化器是AdamW,但他是如何进化来的?
- SGD:朴素的梯度下降,根据直接的梯度来走
- Momentum:加上“惯性”,物理上的动量,用于在陡坡上累计的动量,冲过平坦的区域,因为平坦的地方可能是鞍点,局部最优,这时候普通的SGD可能直接停下来了
- RMSProp:加上自适应的步长(二阶动量),路陡步子小,路平步子大。
- Adam:Momentum + RMSProp。集成前两者优点。
什么是动量
在深度学习的优化器中,主要分为一阶和二阶动量:
- 一阶动量(Momentum)是梯度的平均值,保留正负号,作用是往前冲还是往后退。考虑基于历史速度和现在加速度(梯度)得出来的新速度。
- 二阶动量(Second Moment)是梯度平方的平均值,因为有平方,所以符号消失了,证明它只在乎大小幅度而不在乎方向,衡量的是梯度到底多陡。
RMSProp(Root Mean Square Propagation): \[ 参数更新量 = \frac{\text{学习率}}{\sqrt{\text{二阶动量}}} \times \text{梯度} \] 他可以做到如果在梯度很大的情况下利用分母的二阶动量让步长变小(路陡步子小),小的时候又可以放大(路平步子大),解决了峡谷的震荡问题以及鞍点的收敛问题。
Adam (Adaptive Moment Estimation)
\[ \begin{aligned} \text{1. 计算梯度:} & \quad g_t = \nabla_\theta J(\theta_{t-1}) \\ \text{2. 更新一阶动量 (Momentum):} & \quad m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \\ \text{3. 更新二阶动量 (RMSProp):} & \quad v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ \text{4. 偏差修正 (Bias Correction):} & \quad \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{5. 参数更新 (Final Update):} & \quad \theta_t = \theta_{t-1} - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \end{aligned} \]
\[ \theta_t = \theta_{t-1} - \eta \cdot \frac{\overbrace{\left( \frac{\beta_1 m_{t-1} + (1-\beta_1)g_t}{1-\beta_1^t} \right)}^{\text{修正后的一阶动量 (方向 + 惯性)}}}{\underbrace{\sqrt{\frac{\beta_2 v_{t-1} + (1-\beta_2)g_t^2}{1-\beta_2^t}}}_{\text{修正后的二阶动量 (自适应步长)}} + \epsilon} \] 简单分析一下,分子是惯性,分母是阻力,或者说叫归一化。其实这个公式的本质就是代码Momentum的惯性去冲,但是同时有RMSProp的自适应刹车。
关键总结 (Takeaways)
- 算力去哪了? 绝大部分算力(>95%)都在 Forward 和 Backward 的矩阵乘法上。
- 显存去哪了?
- 静态:模型参数 (Weights) + 优化器状态 (Optimizer States)。
- 动态:中间激活值 (Activations)。Batch Size 越大,Context Length 越长,激活值显存占用越恐怖。
- 瓶颈在哪?
- 矩阵乘法层通常是 Compute Bound (卡在算力上)。
- LayerNorm, Softmax, CrossEntropy 通常是 Memory Bound (卡在显存带宽上)。




