MoE Emergence

Checkpoints from a research project studying expert specialization in Mixture-of-Experts models. I fine-tuned GPT-2 small on three domains -- code, math, and prose -- to see whether experts naturally specialize by domain when given the right routing incentives.

Short answer: they do. MoE beats the dense baseline by 3.6% overall and 14% on math, with zero expert collapse across 10,000 training steps. Two ablation runs confirmed that load balancing loss is essential (without it, one expert captures 73.6% of tokens by step 500) and that top-2 routing provides negligible improvement over top-1.


1. Results

Main comparison

Metric Dense Baseline MoE (top-1) Delta
eval/loss 2.157 2.080 -3.6%
loss_code 1.554 1.521 -2.1%
loss_math 2.023 1.740 -14.0%
loss_prose 3.485 3.541 +1.6%
perplexity 8.64 7.91 -8.4%

Math benefits the most from expert routing. Prose is the one domain where dense wins; diverse web text doesn't lend itself to clean expert specialization. The MoE model crossed the dense baseline at step ~3,600 (36% of training).

Ablations

Run What it tests Result
No-LB ablation Remove load balancing loss (lb_coef=0.0) Expert collapse at step 500. Single expert handles 73.6% of tokens. Z-loss alone doesn't prevent it.
Top-2 directional Route to 2 experts instead of 1 eval/loss=2.077 vs top-1's 2.080, which is a 0.14% difference. Not worth 2x expert compute.

2. Files

dense-baseline/
β”œβ”€β”€ final-model.safetensors    # 622 MB -- dense GPT-2, 124M params
β”œβ”€β”€ final-model.json           # metadata sidecar (config, metrics)
β”œβ”€β”€ ckpt-step-4999.pt          # 1.4 GB -- full resume checkpoint
└── metrics.jsonl              # per-step training + eval metrics

moe-main/
β”œβ”€β”€ final-model.safetensors    # 1.1 GB -- MoE GPT-2, 257M params (8 experts Γ— 4 layers)
β”œβ”€β”€ final-model.json           # metadata sidecar
β”œβ”€β”€ ckpt-step-9999.pt          # 2.9 GB -- full resume checkpoint
└── metrics.jsonl

no-lb-ablation/
β”œβ”€β”€ final-model.safetensors    # 1.1 GB -- collapsed MoE model at step 500
β”œβ”€β”€ best-model.safetensors     # 1.1 GB -- best eval loss (step 400, pre-collapse)
β”œβ”€β”€ ckpt-step-500.pt           # 2.9 GB -- full resume checkpoint
β”œβ”€β”€ config.json, run_summary.json
└── metrics.jsonl

top2-main-10k/
β”œβ”€β”€ final-model.safetensors    # 1.2 GB -- top-2 MoE model at step 9999
β”œβ”€β”€ best-model.safetensors     # 1.2 GB -- best eval loss (step 8000)
β”œβ”€β”€ ckpt-step-9999.pt          # 2.9 GB -- full resume checkpoint
β”œβ”€β”€ config.json, run_summary.json
└── metrics.jsonl

The .safetensors files are the trained model weights. The .pt files contain the full training state for resuming runs (optimizer, LR scheduler, RNG states). The .json sidecars store architecture config and final eval metrics.


3. Usage

Clone the source repo and install dependencies:

git clone https://github.com/sumitdotml/moe-emergence.git
cd moe-emergence
uv sync

Run inference with a trained checkpoint:

# MoE model
uv run python moe_emergence/gpt2_inference.py \
  --checkpoint checkpoints/moe-main/final-model \
  --prompt "def fibonacci(n):" \
  --sample --temperature 0.8

# Dense baseline
uv run python moe_emergence/gpt2_inference.py \
  --checkpoint checkpoints/dense-baseline/final-model \
  --prompt "The meaning of life is"

The inference script reads the .json sidecar to detect mode (dense vs MoE) and architecture config automatically.

To resume training from a checkpoint:

uv run python -m moe_emergence.train \
  --preset moe-main --run-name moe-main \
  --device cuda \
  --resume checkpoints/moe-main/ckpt-step-9999.pt

4. Architecture

The dense baseline is standard GPT-2 small (124M parameters, 12 transformer layers).

The MoE model takes GPT-2 small and replaces layers 8-11 with MoE layers. Each MoE layer has 8 experts -- deep copies of the original GPT-2 MLP, warm-started from pretrained weights -- and a learned router with top-1 routing. Total: 257M parameters.

Routing uses the Straight-Through Estimator. Forward pass routes to one expert with weight 1.0, backward pass flows gradients through the soft probability from the router.

Component Detail
Base model GPT-2 small (124M)
MoE layers 8, 9, 10, 11
Experts per layer 8
Routing Top-1, STE
Expert init deepcopy(original_mlp) + tiny noise
Load balance loss 0.01 Γ— n_experts Γ— Ξ£(f_i Γ— P_i)
Z-loss 0.001 Γ— mean(logsumexp(logits)Β²)

5. Training

All models trained on ~6.6M tokens across three domains, balanced to equal token counts:

Domain Source Size
Code CodeParrot-clean (Python) 10 MB
Math MathQA (allenai) 10 MB
Prose C4 English (allenai) 10 MB

Training config:

Parameter Dense MoE (top-1) MoE (top-2) No-LB
Max steps 5,000 10,000 10,000 2,000 (early-stopped at 500)
Batch size 8 8 8 8
Block size 512 512 512 512
Learning rate 5e-5 5e-5 5e-5 5e-5
lb_coef β€” 0.01 0.01 0.0
noise_std β€” 0.1 0.1 0.0
Hardware 1Γ— RTX 4090 1Γ— RTX 4090 1Γ— RTX 4090 1Γ— RTX 4090
Wall time ~30 min ~85 min ~48 min ~5 min

Total GPU cost for all 4 runs: ~$2.79 (including setup/idle overhead).


6. W&B

Training curves are on Weights & Biases:


7. Links


License

MIT. See the source repo for details. Third-party dataset licenses are documented in THIRD-PARTY-NOTICES.md.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for sumitdotml/moe-emergence

Finetuned
(2081)
this model

Datasets used to train sumitdotml/moe-emergence