Automatic Differentiation & Deep Learning Frameworks


1. The Learning Problem & Gradient Descent

Given a training set D = {(xₙ, yₙ)}, the goal is to find model parameters θ that minimize a loss function: \[ \min_{\theta} \mathcal{L}(\theta), \quad \mathcal{L} : \mathbb{R}^d \to \mathbb{R} \] For classification, the standard loss is cross-entropy. PyTorch’s CrossEntropyLoss applies negative log-likelihood directly on raw logits (before softmax), which is numerically stable.

Why Gradient Descent?

Using a first-order Taylor expansion around the current parameters xₜ: \[ f(x_t + \Delta x) \approx f(x_t) + \Delta x^T \nabla f \big|_{x_t} \] To make f decrease, we want the second term to be as negative as possible. By Cauchy–Schwarz, the optimal direction is opposite to the gradient, giving us the SGD update rule: \[ x_{t+1} = x_t - \eta \cdot \nabla f \big|_{x_t} \] where η is the learning rate hyperparameter.

SGD Algorithm

1
2
3
4
5
6
7
8
9
10
set learning rate η
1. initialize θ ← θ₀
2. for epoch = 1 to maxEpoch:
3. for each mini-batch in data:
4. total_g = 0
5. for each (x, y) in batch:
6. compute error: err(f(x; θ), y)
7. compute gradient: g = ∂err/∂θ
8. total_g += g
9. update: θ = θ - η * total_g / N

Key Question: How do we efficiently compute ∂ℒ/∂θ for every parameter in an arbitrarily deep network? This is exactly what automatic differentiation solves.


2. Computation Graphs

A computation graph is a directed acyclic graph (DAG) where:

  • Nodes represent variables (inputs, parameters, intermediates) or operations.
  • Directed edges indicate data flow — the inputs to each operation.

Forward Evaluation via Topological Sort

To evaluate the graph:

  1. Put all nodes in an unprocessed queue.
  2. Repeatedly find a node whose inputs are all already computed.
  3. Evaluate it, move it to the processed set.
  4. Repeat until all nodes are evaluated.

This guarantees every node is computed exactly once in a valid dependency order.


3. Differentiation Methods: A Taxonomy

Method How It Works Cost for f: ℝⁿ → ℝ Practical Use
Numerical Evaluate f at perturbed points; finite differences O(n) forward passes Gradient checking only — too slow & unstable for training
Symbolic Algebraically apply sum/product/chain rules Can blow up exponentially CAS tools (Mathematica). Expression swell is impractical for large models
Forward-mode AD Propagate derivative “tangents” alongside values in one forward pass O(n) forward passes Good when n is small (e.g., Hessian-vector products)
Reverse-mode AD Run forward, then propagate “adjoints” backward O(1) backward passes The workhorse of deep learning — one pass gives all ∂ℒ/∂θ

Key Insight: For f: ℝⁿ → ℝ (scalar loss, many parameters), reverse-mode AD computes the full gradient in time proportional to a single forward pass. Forward mode would need n passes.


4. Forward-Mode Automatic Differentiation

Forward-mode AD computes derivatives by carrying a tangent (derivative value) alongside each intermediate value, propagating in the same direction as the forward computation.

Definition

For each intermediate value vᵢ, define: \[ \dot{v}_i = \frac{\partial v_i}{\partial x_1} \]

Worked Example

y = ln(x₁) + x₁·x₂ − sin(x₂), computing ∂y/∂x₁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Forward evaluation trace (values)
v1 = x1 # = 2
v2 = x2 # = 5
v3 = ln(v1) # = 0.693
v4 = v1 * v2 # = 10
v5 = sin(v2) # = −0.959
v6 = v3 + v4 # = 10.693
v7 = v6 − v5 # = 11.652 → y

# Forward AD trace (tangents, seeded for ∂/∂x₁)
v1_dot = 1 # seed
v2_dot = 0 # not differentiating w.r.t. x₂
v3_dot = v1_dot / v1 # = 0.5 (chain rule on ln)
v4_dot = v1_dot*v2 + v2_dot*v1 # = 5.0 (product rule)
v5_dot = v2_dot * cos(v2) # = 0.0
v6_dot = v3_dot + v4_dot # = 5.5
v7_dot = v6_dot − v5_dot # = 5.5 → ∂y/∂x₁

Limitation

To get ∂y/∂x₂, we need to re-run the entire trace with a different seed (v̇₁=0, v̇₂=1). For a network with millions of parameters, this is prohibitively expensive — hence reverse mode.


5. Reverse-Mode Automatic Differentiation

Reverse-mode AD is the engine behind loss.backward() in PyTorch. It computes gradients w.r.t. all inputs in a single backward pass.

Definition

Define the adjoint of each node vᵢ:

\[\bar{v}_i = \frac{\partial y}{\partial v_i}\]

Seed v̄_output = 1, then traverse the graph in reverse topological order.

Worked Example (same function)

1
2
3
4
5
6
7
8
9
10
11
12
# Reverse AD trace (adjoints)
v7_bar = 1 # seed: ∂y/∂y = 1
v6_bar = v7_bar * 1 # ∂v7/∂v6 = 1 (subtraction)
v5_bar = v7_bar * (-1) # ∂v7/∂v5 = −1
v4_bar = v6_bar * 1 # ∂v6/∂v4 = 1 (addition)
v3_bar = v6_bar * 1 # ∂v6/∂v3 = 1

# Nodes with MULTIPLE outgoing edges: sum contributions
v2_bar = v5_bar * cos(v2) + v4_bar * v1 # ≈ 1.716
v1_bar = v4_bar * v2 + v3_bar / v1 # = 5.5

# Result: ∂y/∂x₁ = 5.5, ∂y/∂x₂ ≈ 1.716 (both in one pass!)

Multiple Pathway Rule

When a node vᵢ feeds into multiple downstream nodes, its adjoint is the sum of partial contributions:

\[\bar{v}*i = \sum*{j \in \text{next}(i)} \bar{v}_j \cdot \frac{\partial v_j}{\partial v_i}\]

Reverse AD Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
def gradient(out):
node_to_grad = {out: [1]} # seed

for i in reverse_topo_order(out):
v_bar_i = sum(node_to_grad[i]) # sum partial adjoints

for k in inputs(i):
# compute partial adjoint: v̄ᵢ · ∂vᵢ/∂vₖ
v_k_to_i = v_bar_i * local_grad(i, k)
node_to_grad[k].append(v_k_to_i) # accumulate

return adjoint of input

6. Vector-Jacobian Products & Tensor Adjoints

In practice, nodes hold tensors (matrices, vectors), not scalars. We generalize using the Jacobian and VJP.

Jacobian

For y ∈ ℝᵐ, x ∈ ℝⁿ: \[ J = \frac{\partial y}{\partial x} = \begin{bmatrix} \partial y_1/\partial x_1 & \partial y_1/\partial x_2 \ \partial y_2/\partial x_1 & \partial y_2/\partial x_2 \end{bmatrix} \] We never explicitly form J (it can be enormous). Instead, we compute the Vector-Jacobian Product (VJP): \[ \bar{x} = J^T \bar{y} \]

Example: Linear Layer

For y = Wx:

Gradient Formula
w.r.t. input x x̄ = Wᵀ ȳ
w.r.t. weight W W̄ = ȳ xᵀ

Implementation Pattern: Every primitive op ships with both a forward() and a vjp() function. The framework only needs to know these primitives; everything else is handled by graph traversal.

7. Backprop vs. Reverse-Mode AD: Graph Extension

There are two implementation strategies, and the distinction matters for higher-order derivatives.

Backpropagation (1st generation)

  • Runs backward operations directly on the forward graph.
  • Computes adjoint values by hand, node by node.
  • Used in early frameworks: Caffe, cuda-convnet.
  • Cannot naturally support gradient-of-gradient.

Reverse-Mode AD by Graph Extension (modern)

  • Constructs new graph nodes for each adjoint computation.
  • The backward pass itself becomes a computation graph.
  • You can run another backward pass on top → higher-order derivatives.
  • Used by PyTorch, JAX, modern TensorFlow.
1
2
3
4
5
6
7
8
9
# The gradient function returns a computation graph, not just a value.
# We can compose and differentiate again:

grad_fn = grad(loss_fn) # first-order gradient (a graph)
hessian_fn = grad(grad_fn) # second-order — differentiate the gradient!

# JAX makes this particularly clean:
import jax
hessian = jax.grad(jax.grad(loss_fn))

Key Takeaway: Reverse-mode AD by graph extension is strictly more powerful than classic backprop. It’s the reason modern frameworks can compute Hessians, higher-order gradients, and meta-learning objectives.


8. Putting It Together: A Deep Learning Framework

A DL framework needs to be: expressive (any network architecture), productive (hide CUDA, auto-differentiate), and efficient (scale to large models, auto hardware acceleration).

Design Principles

  1. Define the program as a symbolic dataflow graph with placeholders, variables, and operations.
  2. Execute an optimized version of that graph on available devices.

Basic Components (TensorFlow-style)

Component Role Example
Placeholder Input data fed at runtime tf.placeholder(tf.float32, (1, 784))
Variable Stateful node for parameters; persists across executions tf.Variable(tf.zeros((100,)))
Constant Static data tf.constant([[1, 2], [3, 4]])
Operation Math ops; must define forward + backward tf.nn.relu(...), tf.matmul(...)
Session Execution context binding graph to a device (CPU/GPU) tf.Session()

Implementing an Operation

1
2
3
4
5
6
7
8
9
10
11
class AddOperation(Operation):
"""Define the Add operation: output = a + b"""
def __init__(self, a, b):
super().__init__([a, b]) # a, b are input nodes

def forward(self, a, b):
return a + b

def backward(self, upstream_grad):
# ∂(a+b)/∂a = 1, ∂(a+b)/∂b = 1
return upstream_grad, upstream_grad

Full Training Loop (using AutoGrad)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Define training objective
def objective(params, iter):
idx = batch_indices(iter)
return -log_posterior(params, train_images[idx], train_labels[idx], L2_reg)

# Get gradient via autograd
objective_grad = grad(objective)

# Neural network forward pass
def neural_net_predict(params, inputs):
"""params: list of (W, b) tuples. inputs: (N x D) matrix."""
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
return outputs - logsumexp(outputs, axis=1, keepdims=True)

# Log posterior = log prior + log likelihood
def log_posterior(params, inputs, targets, L2_reg):
log_prior = -L2_reg * l2_norm(params)
log_lik = np.sum(neural_net_predict(params, inputs) * targets)
return log_prior + log_lik

Implementing Session (Execution)

1
2
3
4
5
6
7
8
9
10
11
def run_session(end_node, feed_dict):
"""Execute the computation graph."""
for node in topological_sort(end_node):
if isinstance(node, Placeholder):
node.value = feed_dict[node]
elif isinstance(node, (Variable, Constant)):
pass # value already set
elif isinstance(node, Operation):
inputs = [n.value for n in node.input_nodes]
node.value = node.forward(*inputs)
return end_node.value

9. Framework Comparison: PyTorch, TensorFlow, JAX

Aspect PyTorch TensorFlow JAX NumPy
Paradigm Dynamic (eager) Static graph / Eager Functional transformations Procedural
Autograd Dynamic comp graph Static comp graph Functional (grad/jit) None
Hardware CPU, GPU, TPU CPU, GPU, TPU CPU, GPU, TPU CPU only
Ease of Use Pythonic Steeper learning curve Pythonic + functional Very easy
Parallelism DataParallel / DDP tf.distribute pmap None
Ecosystem Lightning, TorchVision TensorBoard, TF Extended Integrates with NumPy

Dynamic vs. Static Graphs

  • PyTorch (dynamic/eager): The computation graph is built on-the-fly as you write Python code. Easy to debug with standard print statements. The graph can change every iteration (e.g., different sequence lengths in an RNN).
  • TensorFlow v1 (static): You first define the graph symbolically, then execute it in a session. Enables more aggressive optimization but harder to debug.
  • JAX (functional): No explicit graph object. You write pure functions and use transformations (jax.grad, jax.jit, jax.vmap) to get gradients, JIT compilation, and vectorization.

10. Gradient Checking & Debugging

When implementing custom backward passes, numerical gradient checking is your best friend.

Centered Finite Differences

\[\frac{\partial f(x_1, x_2)}{\partial x_1} \approx \frac{f(x_1 + h,, x_2) - f(x_1 - h,, x_2)}{2h}\]

This is more accurate than the one-sided formula (error is O(h²) vs. O(h)).

Best Practices

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Gradient checking recipe:
# 1. Use double precision (float64) — float32 errors dominate
# 2. Pick a small h (e.g., 1e-6)
# 3. Compute forward difference through the graph twice
# 4. Compare with your analytical gradient

h = 1e-6

def numerical_grad(f, x, i):
"""Compute ∂f/∂xᵢ via centered differences."""
x_plus = x.copy(); x_plus[i] += h
x_minus = x.copy(); x_minus[i] -= h
return (f(x_plus) - f(x_minus)) / (2 * h)

# More general: pick a random direction δ, check directional derivative
# δᵀ ∇f(θ) ≈ (f(θ + εδ) − f(θ − εδ)) / 2ε

Tip: If your analytical gradient and numerical gradient disagree by more than ~1e-5 (relative error), there’s likely a bug in your VJP implementation. Check edge cases like zero inputs and boundary conditions.


Summary

Concept What It Does
Computation Graph DAG representing the program; enables both forward eval and backward differentiation
Topological Sort Defines a valid execution order for the graph
Forward-mode AD Propagates tangents forward; O(n) passes for n inputs
Reverse-mode AD Propagates adjoints backward; O(1) passes for scalar output — the key to efficient training
VJP The primitive operation of reverse AD on tensors; avoids forming the full Jacobian
Graph Extension Modern technique: backward pass builds a new graph, enabling higher-order derivatives
DL Framework Combines: symbolic graph definition → automatic differentiation → optimized execution on hardware