Multi-NPU Scaling for Transformer Models:機會與限制
January 9, 2026
從 Groq LPU 新聞引發的思考:如何用多個 NPU 來擴展 Transformer 模型?探討 Head Parallelism、Data Parallelism、Pipeline Parallelism 等多種並行策略的機會與限制。
Transformer on Chip
前言
最近看到 Groq LPU 的新聞,開始對 multi-NPU 的架構產生興趣。這是我第一篇關於這個主題的研究筆記,主要探討一些基本概念。
Disclaimer: 這篇文章的內容可能有誤,主要是個人學習筆記。後續會有更深入的研究和修正。歡迎指正!
Reference: Groq’s LPU Architecture
Core Question: Can we scale up to support large models with multi-NPUs?
1. Introduction
Why Multi-Chip Scaling?
Single-chip transformer accelerators face hard limits:
- On-chip memory capacity (weights, activations, KV cache)
- Fixed datapath widths (
d_model,d_head,sequence length,context window)
Multi-chip parallelism via SPMD (Single-Program Multi-Data) becomes necessary for larger/deeper models.
Two Fundamental Constraints
Capacity: Can the model be decomposed so each chip processes only a subset while staying within single-chip limits?
Compatibility: Does the model’s execution graph conform to the hardware’s primitives and datapath?
2. Typical NPU Hardware Constraints
When designing multi-NPU systems for transformers, we must consider typical hardware limitations:
Common Constraints:
- On-chip memory capacity - Limited space for weights, activations, and KV cache
- Maximum sequence length - Upper bound on tokens processable in a single pass
- Attention window size - Local attention window limits for efficiency
- Fixed datapath widths - Predetermined dimensions for
d_model,d_head, number of heads
Key Architectural Characteristics:
- Local windowed attention - Many NPUs support only local attention windows for hardware efficiency
- Fixed datapath widths - No dynamic reshape or routing capabilities
- Static execution model - Limited or no dynamic control flow
- Specialized compute units - Built-in softmax or other operators with fixed input sizes
3. Multi-Chip Scaling Methods
3.1 Head Parallelism
Strategy: Distribute attention heads across chips.
Example (32 heads, 4 chips):
Chip 0: heads 0-7 (all layers)
Chip 1: heads 8-15 (all layers)
Chip 2: heads 16-23 (all layers)
Chip 3: heads 24-31 (all layers)
Communication: 1 all-reduce per layer (output projection)
3.2 Data Parallelism
Strategy: Replicate full model on each chip; process different batch samples.
Example (full model, 4 chips, batch size = 16):
Chip 0: full model (all 32 layers) → processes batch samples 0-3
Chip 1: full model (all 32 layers) → processes batch samples 4-7
Chip 2: full model (all 32 layers) → processes batch samples 8-11
Chip 3: full model (all 32 layers) → processes batch samples 12-15
Model weights: Identical copies on all chips
Data flow: Independent per chip
Communication:
- Inference: ZERO (each chip independent)
- Training: All-reduce gradients after backward pass
3.3 Pipeline Parallelism
Strategy: Shard by layers; each chip processes a subset of layers sequentially.
Example (32 layers, 4 chips):
Chip 0: Embedding + layers 0-7 → forward → pass to Chip 1
Chip 1: layers 8-15 → forward → pass to Chip 2
Chip 2: layers 16-23 → forward → pass to Chip 3
Chip 3: layers 24-31 + LM head → output
3.4 Hidden-Dimension Parallelism
Strategy: Shard d_model across chips; slicing cuts across head boundaries (different from head parallelism).
Example (num_heads = 2, d_head = 1024, d_model = 2048, 4 chips):
Head structure:
Head 0: dims [0:1023]
Head 1: dims [1024:2047]
Hidden-dimension slicing (CUTS ACROSS HEADS):
Chip 0: dims [0:511] ← first half of Head 0 (all 32 layers)
Chip 1: dims [512:1023] ← second half of Head 0 (all 32 layers)
Chip 2: dims [1024:1535] ← first half of Head 1 (all 32 layers)
Chip 3: dims [1536:2047] ← second half of Head 1 (all 32 layers)
Model weights: Column-wise partitioned (W_q, W_k, W_v, W_o sliced by hidden dim)
Data flow: All chips process same tokens, different hidden dimensions
Communication per layer:
1. LayerNorm: All-reduce SUM (global mean)
2. LayerNorm: All-reduce SUM (global variance)
3. Within each head: All-reduce to combine partial results
4. Cross heads: All-reduce to combine multi-head outputs
Total: 3-4 all-reduce ops per layer
3.5 Sequence Parallelism
Strategy: Shard by sequence length; each chip processes a subset of tokens.
Example: Process long sequence in sequential chunks
Chip processes chunk 0: tokens [0:2047] → output_0
Chip processes chunk 1: tokens [2048:4095] → output_1
Chip processes chunk 2: tokens [4096:6143] → output_2
Chip processes chunk 3: tokens [6144:8095] → output_3
4. Appendix: All-Reduce Operations
What is All-Reduce?
Definition: Each chip has a partial value → All chips end up with the same combined result.
Example (4 chips, SUM operation):
Initial state:
Chip 0: 10
Chip 1: 20
Chip 2: 30
Chip 3: 40
After all-reduce SUM:
All chips: 100 (sum of all values)
Common operations: SUM, MAX, MIN, AVG
LayerNorm All-Reduce Example
When d_model is sharded across chips, LayerNorm requires global statistics:
Problem: d_model = 2048 split across 4 chips (512 dims each)
Step 1: Each chip computes local sum
Step 2: All-reduce SUM → get global_sum
Step 3: Compute global_mean = global_sum / 2048
Step 4: Each chip normalizes its shard using global_mean
(Repeat for variance)
Total: 2 all-reduce ops per LayerNorm (mean + variance)
All-Gather
Definition: Each chip has a different slice of data → All chips end up with the complete concatenated data.
Example (4 chips):
Initial state (each chip has different data slice):
Chip 0: [10, 20]
Chip 1: [30, 40]
Chip 2: [50, 60]
Chip 3: [70, 80]
After all-gather:
Chip 0: [10, 20, 30, 40, 50, 60, 70, 80] ← everyone gets full data
Chip 1: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 2: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 3: [10, 20, 30, 40, 50, 60, 70, 80]
Use case in transformers: Gathering distributed hidden dimensions
Hidden-dimension parallel (d_model = 2048, 4 chips):
Before all-gather:
Chip 0: hidden dims [0:511]
Chip 1: hidden dims [512:1023]
Chip 2: hidden dims [1024:1535]
Chip 3: hidden dims [1536:2047]
After all-gather:
All chips: complete hidden vector [0:2047]
Why needed? Some operations require the full hidden vector
Reduce-Scatter
Definition: Each chip has full data → Perform reduction → Each chip gets a different slice of the result.
Example (4 chips, SUM operation):
Initial state (each chip has full data):
Chip 0: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 1: [1, 2, 3, 4, 5, 6, 7, 8]
Chip 2: [5, 5, 5, 5, 5, 5, 5, 5]
Chip 3: [2, 3, 4, 5, 6, 7, 8, 9]
Step 1 - Divide into chunks (2 elements per chip):
Positions [0:1] → will go to Chip 0
Positions [2:3] → will go to Chip 1
Positions [4:5] → will go to Chip 2
Positions [6:7] → will go to Chip 3
Step 2 - Reduce each chunk across all chips:
Chunk [0:1]: 10+1+5+2=18, 20+2+5+3=30
Chunk [2:3]: 30+3+5+4=42, 40+4+5+5=54
Chunk [4:5]: 50+5+5+6=66, 60+6+5+7=78
Chunk [6:7]: 70+7+5+8=90, 80+8+5+9=102
After reduce-scatter (SUM):
Chip 0: [18, 30] ← sum of positions [0:1] from all chips
Chip 1: [42, 54] ← sum of positions [2:3] from all chips
Chip 2: [66, 78] ← sum of positions [4:5] from all chips
Chip 3: [90, 102] ← sum of positions [6:7] from all chips
Use case in transformers: Distributing aggregated gradients
Training with hidden-dim parallelism:
Before reduce-scatter (each chip computed full gradients):
Chip 0: gradient [0:2047]
Chip 1: gradient [0:2047]
Chip 2: gradient [0:2047]
Chip 3: gradient [0:2047]
After reduce-scatter (SUM):
Chip 0: summed gradient [0:511] ← only its slice
Chip 1: summed gradient [512:1023] ← only its slice
Chip 2: summed gradient [1024:1535] ← only its slice
Chip 3: summed gradient [1536:2047] ← only its slice
Now each chip updates its portion of weights
Broadcast
Definition: One chip has data → Send copy to all other chips.
Example (4 chips):
Initial state:
Chip 0: [100, 200, 300, 400] ← has the data
Chip 1: [ ] ← empty
Chip 2: [ ] ← empty
Chip 3: [ ] ← empty
After broadcast (from Chip 0):
Chip 0: [100, 200, 300, 400]
Chip 1: [100, 200, 300, 400]
Chip 2: [100, 200, 300, 400]
Chip 3: [100, 200, 300, 400]
Use case in transformers: Distributing model weights or hyperparameters
Example: Broadcasting updated learning rate from master chip
Chip 0 (master): learning_rate = 0.001
After broadcast:
All chips: learning_rate = 0.001
Comparison Summary
| Operation | Input | Output | Direction |
|---|---|---|---|
| All-reduce | Each: partial value | Everyone: same combined result | Many → Many (same) |
| All-gather | Each: different slice | Everyone: same full data | Many → Many (same) |
| Reduce-scatter | Each: full data | Each: different slice of result | Many → Many (different) |
| Broadcast | One: data | Everyone: same copy | One → Many |
Relationship: All-reduce = Reduce-scatter + All-gather
All-reduce in two steps:
1. Reduce-scatter: Combine data and distribute slices
2. All-gather: Gather slices to get full result on all chips