11868 LLM Sys & 15642 ML Sys: DL Frameworks and Auto Differentiation
Automatic Differentiation & Deep Learning Frameworks
1. The Learning Problem & Gradient Descent
Given a training set D = {(xₙ, yₙ)}, the goal is to
find model parameters θ that minimize a loss function:
\[
\min_{\theta} \mathcal{L}(\theta), \quad \mathcal{L} : \mathbb{R}^d \to
\mathbb{R}
\] For classification, the standard loss is
cross-entropy. PyTorch’s CrossEntropyLoss
applies negative log-likelihood directly on raw logits (before softmax),
which is numerically stable.
Why Gradient Descent?
Using a first-order Taylor expansion around the current parameters xₜ: \[ f(x_t + \Delta x) \approx f(x_t) + \Delta x^T \nabla f \big|_{x_t} \] To make f decrease, we want the second term to be as negative as possible. By Cauchy–Schwarz, the optimal direction is opposite to the gradient, giving us the SGD update rule: \[ x_{t+1} = x_t - \eta \cdot \nabla f \big|_{x_t} \] where η is the learning rate hyperparameter.
SGD Algorithm
1 | set learning rate η |
Key Question: How do we efficiently compute ∂ℒ/∂θ for every parameter in an arbitrarily deep network? This is exactly what automatic differentiation solves.
2. Computation Graphs
A computation graph is a directed acyclic graph (DAG) where:
- Nodes represent variables (inputs, parameters, intermediates) or operations.
- Directed edges indicate data flow — the inputs to each operation.
Forward Evaluation via Topological Sort
To evaluate the graph:
- Put all nodes in an unprocessed queue.
- Repeatedly find a node whose inputs are all already computed.
- Evaluate it, move it to the processed set.
- Repeat until all nodes are evaluated.
This guarantees every node is computed exactly once in a valid dependency order.
3. Differentiation Methods: A Taxonomy
| Method | How It Works | Cost for f: ℝⁿ → ℝ | Practical Use |
|---|---|---|---|
| Numerical | Evaluate f at perturbed points; finite differences | O(n) forward passes | Gradient checking only — too slow & unstable for training |
| Symbolic | Algebraically apply sum/product/chain rules | Can blow up exponentially | CAS tools (Mathematica). Expression swell is impractical for large models |
| Forward-mode AD | Propagate derivative “tangents” alongside values in one forward pass | O(n) forward passes | Good when n is small (e.g., Hessian-vector products) |
| Reverse-mode AD | Run forward, then propagate “adjoints” backward | O(1) backward passes | The workhorse of deep learning — one pass gives all ∂ℒ/∂θ |
Key Insight: For f: ℝⁿ → ℝ (scalar loss, many parameters), reverse-mode AD computes the full gradient in time proportional to a single forward pass. Forward mode would need n passes.
4. Forward-Mode Automatic Differentiation
Forward-mode AD computes derivatives by carrying a tangent (derivative value) alongside each intermediate value, propagating in the same direction as the forward computation.
Definition
For each intermediate value vᵢ, define: \[ \dot{v}_i = \frac{\partial v_i}{\partial x_1} \]
Worked Example
y = ln(x₁) + x₁·x₂ − sin(x₂), computing ∂y/∂x₁:
1 | # Forward evaluation trace (values) |
Limitation
To get ∂y/∂x₂, we need to re-run the entire trace with a different seed (v̇₁=0, v̇₂=1). For a network with millions of parameters, this is prohibitively expensive — hence reverse mode.
5. Reverse-Mode Automatic Differentiation
Reverse-mode AD is the engine behind loss.backward() in
PyTorch. It computes gradients w.r.t. all inputs in a
single backward pass.
Definition
Define the adjoint of each node vᵢ:
\[\bar{v}_i = \frac{\partial y}{\partial v_i}\]
Seed v̄_output = 1, then traverse the graph in reverse topological order.
Worked Example (same function)
1 | # Reverse AD trace (adjoints) |
Multiple Pathway Rule
When a node vᵢ feeds into multiple downstream nodes, its adjoint is the sum of partial contributions:
\[\bar{v}*i = \sum*{j \in \text{next}(i)} \bar{v}_j \cdot \frac{\partial v_j}{\partial v_i}\]
Reverse AD Algorithm
1 | def gradient(out): |
6. Vector-Jacobian Products & Tensor Adjoints
In practice, nodes hold tensors (matrices, vectors), not scalars. We generalize using the Jacobian and VJP.
Jacobian
For y ∈ ℝᵐ, x ∈ ℝⁿ: \[ J = \frac{\partial y}{\partial x} = \begin{bmatrix} \partial y_1/\partial x_1 & \partial y_1/\partial x_2 \ \partial y_2/\partial x_1 & \partial y_2/\partial x_2 \end{bmatrix} \] We never explicitly form J (it can be enormous). Instead, we compute the Vector-Jacobian Product (VJP): \[ \bar{x} = J^T \bar{y} \]
Example: Linear Layer
For y = Wx:
| Gradient | Formula |
|---|---|
| w.r.t. input x | x̄ = Wᵀ ȳ |
| w.r.t. weight W | W̄ = ȳ xᵀ |
Implementation Pattern: Every primitive op ships with both a
forward()and avjp()function. The framework only needs to know these primitives; everything else is handled by graph traversal.
7. Backprop vs. Reverse-Mode AD: Graph Extension
There are two implementation strategies, and the distinction matters for higher-order derivatives.
Backpropagation (1st generation)
- Runs backward operations directly on the forward graph.
- Computes adjoint values by hand, node by node.
- Used in early frameworks: Caffe, cuda-convnet.
- Cannot naturally support gradient-of-gradient.
Reverse-Mode AD by Graph Extension (modern)
- Constructs new graph nodes for each adjoint computation.
- The backward pass itself becomes a computation graph.
- You can run another backward pass on top → higher-order derivatives.
- Used by PyTorch, JAX, modern TensorFlow.
1 | # The gradient function returns a computation graph, not just a value. |
Key Takeaway: Reverse-mode AD by graph extension is strictly more powerful than classic backprop. It’s the reason modern frameworks can compute Hessians, higher-order gradients, and meta-learning objectives.
8. Putting It Together: A Deep Learning Framework
A DL framework needs to be: expressive (any network architecture), productive (hide CUDA, auto-differentiate), and efficient (scale to large models, auto hardware acceleration).
Design Principles
- Define the program as a symbolic dataflow graph with placeholders, variables, and operations.
- Execute an optimized version of that graph on available devices.
Basic Components (TensorFlow-style)
| Component | Role | Example |
|---|---|---|
| Placeholder | Input data fed at runtime | tf.placeholder(tf.float32, (1, 784)) |
| Variable | Stateful node for parameters; persists across executions | tf.Variable(tf.zeros((100,))) |
| Constant | Static data | tf.constant([[1, 2], [3, 4]]) |
| Operation | Math ops; must define forward + backward | tf.nn.relu(...), tf.matmul(...) |
| Session | Execution context binding graph to a device (CPU/GPU) | tf.Session() |
Implementing an Operation
1 | class AddOperation(Operation): |
Full Training Loop (using AutoGrad)
1 | # Define training objective |
Implementing Session (Execution)
1 | def run_session(end_node, feed_dict): |
9. Framework Comparison: PyTorch, TensorFlow, JAX
| Aspect | PyTorch | TensorFlow | JAX | NumPy |
|---|---|---|---|---|
| Paradigm | Dynamic (eager) | Static graph / Eager | Functional transformations | Procedural |
| Autograd | Dynamic comp graph | Static comp graph | Functional (grad/jit) |
None |
| Hardware | CPU, GPU, TPU | CPU, GPU, TPU | CPU, GPU, TPU | CPU only |
| Ease of Use | Pythonic | Steeper learning curve | Pythonic + functional | Very easy |
| Parallelism | DataParallel / DDP |
tf.distribute |
pmap |
None |
| Ecosystem | Lightning, TorchVision | TensorBoard, TF Extended | Integrates with NumPy | — |
Dynamic vs. Static Graphs
- PyTorch (dynamic/eager): The computation graph is built on-the-fly as you write Python code. Easy to debug with standard print statements. The graph can change every iteration (e.g., different sequence lengths in an RNN).
- TensorFlow v1 (static): You first define the graph symbolically, then execute it in a session. Enables more aggressive optimization but harder to debug.
- JAX (functional): No explicit graph object. You
write pure functions and use transformations (
jax.grad,jax.jit,jax.vmap) to get gradients, JIT compilation, and vectorization.
10. Gradient Checking & Debugging
When implementing custom backward passes, numerical gradient checking is your best friend.
Centered Finite Differences
\[\frac{\partial f(x_1, x_2)}{\partial x_1} \approx \frac{f(x_1 + h,, x_2) - f(x_1 - h,, x_2)}{2h}\]
This is more accurate than the one-sided formula (error is O(h²) vs. O(h)).
Best Practices
1 | # Gradient checking recipe: |
Tip: If your analytical gradient and numerical gradient disagree by more than ~1e-5 (relative error), there’s likely a bug in your VJP implementation. Check edge cases like zero inputs and boundary conditions.
Summary
| Concept | What It Does |
|---|---|
| Computation Graph | DAG representing the program; enables both forward eval and backward differentiation |
| Topological Sort | Defines a valid execution order for the graph |
| Forward-mode AD | Propagates tangents forward; O(n) passes for n inputs |
| Reverse-mode AD | Propagates adjoints backward; O(1) passes for scalar output — the key to efficient training |
| VJP | The primitive operation of reverse AD on tensors; avoids forming the full Jacobian |
| Graph Extension | Modern technique: backward pass builds a new graph, enabling higher-order derivatives |
| DL Framework | Combines: symbolic graph definition → automatic differentiation → optimized execution on hardware |




