Lec16 Systems for MoE Models

This lecture addresses the systems challenges behind Mixture-of-Experts (MoE) models: how sparse expert architectures scale model capacity without proportionally increasing compute, what communication patterns expert parallelism introduces, and how systems like GShard, DeepSpeed-MoE, and DeepSeek V3 solve routing, load balancing, and inference efficiency at scale.

Motivation: Why Sparse Models?

Dense models are expensive to scale. Training a 175B OPT model takes 992 A100 GPUs for 56 days; PaLM at 540B needs 6144 TPU v4 chips for 57 days. The compute cost grows roughly linearly with model size because every parameter participates in every forward pass.

Model Model Size Hardware Days to Train
Megatron-LM GPT-2 8.3B 512 V100 GPU 9.2 days
OPT 175B 992 A100 GPU 56 days
MT-NLG 530B 2200 A100 GPU 60 days
PaLM 540B 6144 TPU v4 57 days

Sparse models offer a different scaling path: increase total parameters while keeping per-token compute constant. Mixture-of-Experts is the dominant sparse architecture for LLMs. MoE pretraining is much faster than dense models of equivalent quality, and MoE inference is faster than a dense model with the same total parameter count because only a subset of parameters activates per token.

Transformer Mixture-of-Experts Architecture

The core idea of Transformer MoE is simple: replace the single dense FFN in each Transformer block with multiple small expert FFNs and a gating network (router) that selects which expert(s) to activate for each input token.

This is distinct from the classical Mixture-of-Experts learning algorithm, which learns a weighted average of predictor models. In Transformer MoE, the experts are sub-networks within the model itself, and the router makes a hard or soft selection per token.

Switch Transformer

The Switch Transformer (Fedus et al., JMLR 2022) is the canonical MoE architecture. Each token is routed to exactly one selected FFN expert based on the router’s output:

  • The gating network computes a probability distribution over experts:

\[G_\sigma(x) = \text{Softmax}(x \cdot W_g)\]

  • The output is a weighted combination of expert outputs:

\[y = \sum_{i=1}^{n} G(x)_i E_i(x)\]

  • Top-k gating keeps only the top-\(k\) expert scores and sets the rest to \(-\infty\):

\[\text{KeepTopK}(v, k)_i = \begin{cases} v_i & \text{if } v_i \text{ is in the top } k \text{ elements of } v \\ -\infty & \text{otherwise} \end{cases}\]

\[G(x) = \text{Softmax}(\text{KeepTopK}(H(x), k))\]

In the Switch Transformer specifically, \(k=1\): each token goes to exactly one expert. This simplifies routing and reduces communication overhead compared to top-2 designs.

Shared vs. Routed Experts

A refinement introduced in DeepSpeed-MoE and later adopted by DeepSeek is the shared-routed expert design:

Expert Type Behavior Purpose
Shared Expert Always activated for every token Captures common knowledge shared across all inputs
Routed Expert Conditionally activated by the router Captures token-specific, specialized knowledge

The shared expert acts as a fixed FFN that every token passes through, while the routed experts provide conditional capacity. The final output is the sum of the shared expert output and the weighted routed expert outputs:

\[\text{output} = \text{SharedFFN}(h) + \sum_{i \in \text{TopK}} g_i \cdot \text{Expert}_i(h)\]

where the routing weights are computed as \(\text{Softmax}(\text{TopK}(h_t \cdot W))\).

What Experts Learn

Encoder experts tend to specialize in token groups or shallow concepts: punctuation, proper nouns, sentinel tokens, verbs, visual descriptions, counting and numbers. Decoder experts exhibit less specialization. In multilingual setups, experts do not specialize by language — load balancing and token routing prevent that kind of clustering.

Activated Experts Differ Across Layers

Different tokens activate different experts at each layer. Token “red” might route to Expert 3 at layer 1 and Expert 1 at layer 2, while “fox” routes to Expert 2 and Expert N at the same layers. The routing pattern is dynamic and learned end-to-end.

Parameters of MoE

MoE models have fewer total parameters than the naive multiplication suggests. Mixtral 8x7B has 47B parameters, not 56B, because only the FFN layers are replicated as experts — attention layers, embeddings, and layer norms are shared across all experts.

Overfitting Risk

MoE models are prone to overfitting on small datasets. With many more parameters than a comparable dense model, the sparse model can memorize small training sets while the dense counterpart generalizes better. On larger tasks, MoE performs well. This is a key consideration when fine-tuning MoE models on downstream tasks.

Training and Inference for MoE

Expert Parallelism (GShard)

Expert parallelism (Lepikhin et al., ICLR 2021) is the standard approach to distributing MoE models across devices:

  • Each expert lives on one worker device — experts are split across GPUs
  • All other components (attention, embedding, layer norm) are replicated on every device
  • All-to-all communication dispatches tokens to the correct expert device and collects results

The forward pass follows four steps:

  1. Route: each device runs its local router to determine which expert each token should go to
  2. All-to-All Dispatch: tokens are sent to the device hosting their assigned expert
  3. Expert Compute: each device runs its local expert FFN on the received tokens
  4. All-to-All Combine: expert outputs are sent back to the originating device

GShard also introduces interleaving: MoE layers alternate with standard dense FFN layers (every other layer uses MoE). This balances the communication overhead with the capacity benefits of sparse routing.

Token Computation Path in MoE

During inference with expert parallelism, each token’s computation path crosses device boundaries at the MoE layers. The token starts on its “home” device for embedding and attention, gets dispatched to a remote device for expert computation, then returns. Each device maintains its own KV-caches for the attention layers, but expert computation requires cross-device communication at every MoE layer.

Load Balancing in MoE Training

Without intervention, the router can collapse into sending all tokens to a few popular experts, leaving most experts undertrained. The load balancing loss prevents this:

\[L_{ExpBal} = \alpha_1 M \sum_{i=1}^{\text{\#experts}} f_i P_i\]

where: - \(f_i = \frac{\text{\#tokens to expert } i}{\text{\#tokens}}\) is the fraction of tokens routed to expert \(i\) - \(P_i = \frac{1}{\text{\#tokens}} \sum_{t=1}^{\text{\#tokens}} s_{i,t}\) is the average routing weight for expert \(i\) - \(M\) is the number of experts - \(\alpha_1\) is a hyperparameter controlling the balance loss strength

The product \(f_i \cdot P_i\) penalizes experts that receive both a high fraction of tokens and high routing probability. This loss is differentiable through \(P_i\) (the routing weights), encouraging the router to spread tokens more evenly.

MoE Inference Challenges

MoE inference performance depends on three factors:

  • Overall model size — determines total memory footprint
  • Number of activated experts — determines per-token compute
  • Overall memory bandwidth — since only a subset of experts activate, inference is often memory-bandwidth-bound

The default implementation keeps all experts in GPU memory, which requires large memory capacity even though only a few experts are active per token.

Optimizing MoE Inference

The system design goal for MoE inference is to minimize the critical data path per device and maximize aggregate memory bandwidth. Key optimization strategies:

  • Group and route tokens with the same critical data path together to reduce data access per device and achieve maximum aggregate bandwidth
  • Optimize communication scheduling with parallelism coordination
  • Optimize transformer and MoE-related kernels to improve per-device performance

Expert Parallelism and Tensor Parallelism (DeepSpeed-MoE)

DeepSpeed-MoE (Rajbhandari et al., 2022) combines multiple parallelism strategies for MoE models:

  • Expert Parallelism / Expert Slicing: group all tokens assigned to the same experts under the same critical data path, and parallelize processing of token groups with different critical paths across different devices
  • Tensor Parallelism / Tensor Slicing: partition the non-expert parameters (attention) across devices, usually within a node
  • Data Parallelism: replicate the above setup across additional device groups

This multi-dimensional parallelism approach handles expert and non-expert parameters differently, applying the most appropriate parallelism strategy to each.

Optimizing MoE Kernels

MoE-specific kernel optimizations include:

  • Fusing the gating function into a single kernel — avoids multiple kernel launches for the softmax, top-k, and dispatch operations
  • Dense token-to-expert mapping table — enables efficient batched computation

These optimizations achieve over 6x reduction in MoE kernel-related latency.

Optimizing All-to-All Communication

Expert parallelism requires all-to-all communication between all expert parallel devices, and the latency increases linearly with the number of devices. Two key optimizations:

  • Hierarchical all-to-all communication: reduces communication hops by first communicating within a node, then across nodes
  • Parallelism-coordinated communication: schedules communications based on the model’s parallelism strategy to minimize overhead

MoE Training and Inference Results

DeepSpeed-MoE demonstrates that a 1.3B+MoE-128 model (52B total parameters) achieves the same validation loss as a 6.7B dense model while training at 5x higher throughput (372 vs. 70 samples/sec on 128 A100 GPUs).

For inference, DeepSpeed achieves significantly lower latency and higher throughput compared to PyTorch baseline across all GPU counts (8, 16, 32, 64 GPUs), with throughput scaling nearly linearly with GPU count.

DeepSeek MoE

Fine-Grained Expert Design

DeepSeek MoE (Dai et al., 2024) introduces fine-grained experts: each original FFN is split into \(k\) smaller experts, yielding \(kN\) total experts (where \(N\) is the original number of experts). This finer granularity allows more flexible routing — the model activates \(kM\) experts per token (where \(M\) is the base activation count), each smaller than in a standard MoE design.

The architecture combines shared experts with routed experts, and uses top-k weighted averaging of the activated routing experts.

DeepSeek V3 MoE (670B)

DeepSeek V3 is a 670B-parameter MoE model with the following configuration:

Parameter Value
Vocabulary 129,280
Hidden dimension 7,168
Number of layers 61
Dense layers (lowest) 3
Number of attention heads 128
FFN intermediate dimension 18,432
MoE dimension 2,048
Shared experts 1
Routed experts 256
Activated experts per token 8
Expert groups 8
Limited groups 4

The architecture uses Multi-head Latent Attention (MLA) instead of standard multi-head attention. Each MoE layer has 1 shared expert FFN that processes every token, plus 256 routed experts of which 8 are activated per token. The routing uses \(\text{TopK}(\text{Softmax}(h_t \cdot W), 8)\).

Each expert uses the SwiGLU activation:

\[\text{FFN}_{SwiGLU}(x) = (\text{Swish}(x \cdot W_1) \odot (x \cdot W_2)) \cdot W_3\]

The first 3 layers are dense (standard FFN), and the remaining 58 layers use MoE. This design lets the lowest layers build general representations before introducing sparse routing.

Load Balancing in DeepSeek MoE

DeepSeek V3 uses two levels of load balancing loss:

Expert-Level Balance Loss (to avoid routing collapse):

\[L_{ExpBal} = \alpha_1 \sum_{i=1}^{\text{\#experts}} f_i P_i\]

where \(f_i = \frac{\text{\#experts}}{\text{\#activated\_experts}} \cdot \frac{\text{\#tokens to expert } i}{\text{\#tokens}}\) and \(P_i = \frac{1}{\text{\#tokens}} \sum_{t=1}^{\text{\#tokens}} s_{i,t}\) (the average routing weight).

Device-Level Balance Loss (to balance computation across devices):

\[L_{DevBal} = \alpha_2 \sum_{j=1}^{\text{\#groups}} f_j P_j\]

where \(f_j\) is the average \(f\) across experts in group \(j\), and \(P_j\) is the sum of \(P\) across experts in group \(j\).

The two-level design ensures both individual expert utilization and cross-device computational balance.

DeepSeek Libraries for MoE

DeepSeek has open-sourced two libraries to accelerate MoE:

  • DeepEP: a communication library tailored for MoE and expert parallelism, optimizing the all-to-all dispatch and combine operations
  • EPLB (Expert Parallelism Load Balancer): handles dynamic load balancing across expert parallel devices

DeepSpeed MoE Code Example

DeepSpeed provides a simple API for adding MoE layers:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import deepspeed
import deepspeed.utils.groups as groups
from deepspeed.moe.layer import MoE

WORLD_SIZE = 4
EP_WORLD_SIZE = 2
EXPERTS = 8

fc3 = torch.nn.Linear(84, 84)
fc3 = MoE(hidden_size=84, expert=self.fc3,
num_experts=EXPERTS, ep_size=EP_WORLD_SIZE, k=1)
fc4 = torch.nn.Linear(84, 10)

The MoE class wraps an existing nn.Module (the expert template) and handles replication, gating, and expert-parallel communication. Key parameters include:

  • hidden_size: model hidden dimension (input and output dimension)
  • expert: the base expert module (e.g., an MLP or torch.nn.Linear)
  • num_experts: total number of experts per layer
  • ep_size: number of ranks in the expert parallel group
  • k: top-k gating value (1 or 2)
  • capacity_factor: controls expert capacity at training time
  • drop_tokens: whether to drop tokens exceeding expert capacity

Internally, the Experts class deep-copies the expert module for each local expert, sets param.allreduce = False on expert parameters (since they use expert parallelism, not data parallelism), and chunks input for per-expert processing:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Experts(nn.Module):
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super().__init__()
self.deepspeed_experts = nn.ModuleList(
[copy.deepcopy(expert) for _ in range(num_local_experts)])
for expert in self.deepspeed_experts:
for param in expert.parameters():
param.allreduce = False
param.group_name = expert_group_name

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
expert_outputs += [expert(chunk)]
return torch.cat(expert_outputs, dim=1)

Lec16 Takeaways

  • MoE scales capacity without scaling compute: by activating only a subset of experts per token, MoE models can have trillions of parameters while maintaining the per-token FLOPS of a much smaller dense model
  • Expert parallelism is the natural distribution strategy: split experts across devices, replicate everything else, and use all-to-all communication to dispatch tokens — but all-to-all becomes the bottleneck at scale
  • Load balancing is critical: without auxiliary losses, routers collapse to a few popular experts, wasting capacity and creating compute imbalance across devices
  • DeepSeek V3 pushes fine-grained design: 256 routed experts with only 8 activated, shared experts for common knowledge, two-level balance losses, and MLA attention in a 670B model
  • Multi-dimensional parallelism (EP + TP + DP) is necessary at scale, with each strategy applied to the part of the model it fits best

Final Summary

Topic Key Idea
Sparse vs. Dense MoE increases model capacity without proportionally increasing per-token compute
Switch Transformer routes each token to one expert via top-1 gating with softmax over expert scores
Shared vs. Routed Experts shared experts capture common knowledge; routed experts specialize per token
Expert Parallelism (GShard) split experts across devices, replicate non-expert params, use all-to-all communication
Load Balancing Loss auxiliary loss penalizing uneven token distribution across experts to prevent routing collapse
MoE Inference memory-bandwidth-bound; optimize by grouping tokens, fusing kernels, hierarchical all-to-all
DeepSpeed-MoE combines expert parallelism, tensor parallelism, and data parallelism with optimized kernels
DeepSeek V3 (670B) fine-grained experts (256 routed, 8 activated), shared expert, two-level balance loss, SwiGLU
DeepEP / EPLB DeepSeek libraries for optimized MoE communication and dynamic expert load balancing

Key takeaway: MoE is the dominant approach to scaling LLMs beyond the compute budget of dense models. The systems challenge shifts from parallelizing a uniform computation to orchestrating dynamic, token-dependent routing across devices — requiring expert parallelism with all-to-all communication, load balancing to prevent routing collapse, and multi-dimensional parallelism to handle both expert and non-expert parameters efficiently.

References

  1. Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” JMLR 2022.
  2. Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” ICLR 2021.
  3. Rajbhandari et al. “DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale.” 2022.
  4. Dai, D. et al. “DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models.” 2024.
  5. DeepSeek-V3 inference code: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
  6. DeepEP: https://github.com/deepseek-ai/DeepEP
  7. EPLB: https://github.com/deepseek-ai/EPLB

This post is based on lecture materials from CMU 11-868 LLM Systems by Lei Li (System for MOE Models, Lec16).