Multi-NPU Scaling for Transformer Models:機會與限制

January 9, 2026

從 Groq LPU 新聞引發的思考:如何用多個 NPU 來擴展 Transformer 模型?探討 Head Parallelism、Data Parallelism、Pipeline Parallelism 等多種並行策略的機會與限制。

Transformer on Chip

Transformer on Chip

前言

最近看到 Groq LPU 的新聞,開始對 multi-NPU 的架構產生興趣。這是我第一篇關於這個主題的研究筆記,主要探討一些基本概念。

Disclaimer: 這篇文章的內容可能有誤,主要是個人學習筆記。後續會有更深入的研究和修正。歡迎指正!

Reference: Groq’s LPU Architecture

Core Question: Can we scale up to support large models with multi-NPUs?


1. Introduction

Why Multi-Chip Scaling?

Single-chip transformer accelerators face hard limits:

  • On-chip memory capacity (weights, activations, KV cache)
  • Fixed datapath widths (d_model, d_head, sequence length, context window)

Multi-chip parallelism via SPMD (Single-Program Multi-Data) becomes necessary for larger/deeper models.

Two Fundamental Constraints

Capacity: Can the model be decomposed so each chip processes only a subset while staying within single-chip limits?

Compatibility: Does the model’s execution graph conform to the hardware’s primitives and datapath?

2. Typical NPU Hardware Constraints

When designing multi-NPU systems for transformers, we must consider typical hardware limitations:

Common Constraints:

  • On-chip memory capacity - Limited space for weights, activations, and KV cache
  • Maximum sequence length - Upper bound on tokens processable in a single pass
  • Attention window size - Local attention window limits for efficiency
  • Fixed datapath widths - Predetermined dimensions for d_model, d_head, number of heads

Key Architectural Characteristics:

  1. Local windowed attention - Many NPUs support only local attention windows for hardware efficiency
  2. Fixed datapath widths - No dynamic reshape or routing capabilities
  3. Static execution model - Limited or no dynamic control flow
  4. Specialized compute units - Built-in softmax or other operators with fixed input sizes

3. Multi-Chip Scaling Methods

3.1 Head Parallelism

Strategy: Distribute attention heads across chips.

Example (32 heads, 4 chips):
Chip 0: heads 0-7   (all layers)
Chip 1: heads 8-15  (all layers)
Chip 2: heads 16-23 (all layers)
Chip 3: heads 24-31 (all layers)

Communication: 1 all-reduce per layer (output projection)

3.2 Data Parallelism

Strategy: Replicate full model on each chip; process different batch samples.

Example (full model, 4 chips, batch size = 16):
Chip 0: full model (all 32 layers) → processes batch samples 0-3
Chip 1: full model (all 32 layers) → processes batch samples 4-7
Chip 2: full model (all 32 layers) → processes batch samples 8-11
Chip 3: full model (all 32 layers) → processes batch samples 12-15

Model weights: Identical copies on all chips
Data flow: Independent per chip
Communication:
  - Inference: ZERO (each chip independent)
  - Training: All-reduce gradients after backward pass

3.3 Pipeline Parallelism

Strategy: Shard by layers; each chip processes a subset of layers sequentially.

Example (32 layers, 4 chips):
Chip 0: Embedding + layers 0-7     → forward → pass to Chip 1
Chip 1: layers 8-15                → forward → pass to Chip 2
Chip 2: layers 16-23               → forward → pass to Chip 3
Chip 3: layers 24-31 + LM head     → output

3.4 Hidden-Dimension Parallelism

Strategy: Shard d_model across chips; slicing cuts across head boundaries (different from head parallelism).

Example (num_heads = 2, d_head = 1024, d_model = 2048, 4 chips):

Head structure:
  Head 0: dims [0:1023]
  Head 1: dims [1024:2047]

Hidden-dimension slicing (CUTS ACROSS HEADS):
Chip 0: dims [0:511]     ← first half of Head 0  (all 32 layers)
Chip 1: dims [512:1023]  ← second half of Head 0 (all 32 layers)
Chip 2: dims [1024:1535] ← first half of Head 1  (all 32 layers)
Chip 3: dims [1536:2047] ← second half of Head 1 (all 32 layers)

Model weights: Column-wise partitioned (W_q, W_k, W_v, W_o sliced by hidden dim)
Data flow: All chips process same tokens, different hidden dimensions
Communication per layer:
  1. LayerNorm: All-reduce SUM (global mean)
  2. LayerNorm: All-reduce SUM (global variance)
  3. Within each head: All-reduce to combine partial results
  4. Cross heads: All-reduce to combine multi-head outputs
  Total: 3-4 all-reduce ops per layer

3.5 Sequence Parallelism

Strategy: Shard by sequence length; each chip processes a subset of tokens.

Example: Process long sequence in sequential chunks
Chip processes chunk 0: tokens [0:2047]     → output_0
Chip processes chunk 1: tokens [2048:4095]  → output_1
Chip processes chunk 2: tokens [4096:6143]  → output_2
Chip processes chunk 3: tokens [6144:8095]  → output_3

4. Appendix: All-Reduce Operations

What is All-Reduce?

Definition: Each chip has a partial value → All chips end up with the same combined result.

Example (4 chips, SUM operation):

Initial state:
Chip 0: 10
Chip 1: 20
Chip 2: 30
Chip 3: 40

After all-reduce SUM:
All chips: 100  (sum of all values)

Common operations: SUM, MAX, MIN, AVG

LayerNorm All-Reduce Example

When d_model is sharded across chips, LayerNorm requires global statistics:

Problem: d_model = 2048 split across 4 chips (512 dims each)

Step 1: Each chip computes local sum
Step 2: All-reduce SUM → get global_sum
Step 3: Compute global_mean = global_sum / 2048
Step 4: Each chip normalizes its shard using global_mean
(Repeat for variance)

Total: 2 all-reduce ops per LayerNorm (mean + variance)

All-Gather

Definition: Each chip has a different slice of data → All chips end up with the complete concatenated data.

Example (4 chips):

Initial state (each chip has different data slice):
Chip 0: [10, 20]
Chip 1: [30, 40]
Chip 2: [50, 60]
Chip 3: [70, 80]

After all-gather:
Chip 0: [10, 20, 30, 40, 50, 60, 70, 80]  ← everyone gets full data
Chip 1: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 2: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 3: [10, 20, 30, 40, 50, 60, 70, 80]

Use case in transformers: Gathering distributed hidden dimensions

Hidden-dimension parallel (d_model = 2048, 4 chips):
Before all-gather:
Chip 0: hidden dims [0:511]
Chip 1: hidden dims [512:1023]
Chip 2: hidden dims [1024:1535]
Chip 3: hidden dims [1536:2047]

After all-gather:
All chips: complete hidden vector [0:2047]

Why needed? Some operations require the full hidden vector

Reduce-Scatter

Definition: Each chip has full data → Perform reduction → Each chip gets a different slice of the result.

Example (4 chips, SUM operation):

Initial state (each chip has full data):
Chip 0: [10, 20, 30, 40, 50, 60, 70, 80]
Chip 1: [1,  2,  3,  4,  5,  6,  7,  8]
Chip 2: [5,  5,  5,  5,  5,  5,  5,  5]
Chip 3: [2,  3,  4,  5,  6,  7,  8,  9]

Step 1 - Divide into chunks (2 elements per chip):
Positions [0:1] → will go to Chip 0
Positions [2:3] → will go to Chip 1
Positions [4:5] → will go to Chip 2
Positions [6:7] → will go to Chip 3

Step 2 - Reduce each chunk across all chips:
Chunk [0:1]: 10+1+5+2=18, 20+2+5+3=30
Chunk [2:3]: 30+3+5+4=42, 40+4+5+5=54
Chunk [4:5]: 50+5+5+6=66, 60+6+5+7=78
Chunk [6:7]: 70+7+5+8=90, 80+8+5+9=102

After reduce-scatter (SUM):
Chip 0: [18, 30]   ← sum of positions [0:1] from all chips
Chip 1: [42, 54]   ← sum of positions [2:3] from all chips
Chip 2: [66, 78]   ← sum of positions [4:5] from all chips
Chip 3: [90, 102]  ← sum of positions [6:7] from all chips

Use case in transformers: Distributing aggregated gradients

Training with hidden-dim parallelism:
Before reduce-scatter (each chip computed full gradients):
Chip 0: gradient [0:2047]
Chip 1: gradient [0:2047]
Chip 2: gradient [0:2047]
Chip 3: gradient [0:2047]

After reduce-scatter (SUM):
Chip 0: summed gradient [0:511]     ← only its slice
Chip 1: summed gradient [512:1023]  ← only its slice
Chip 2: summed gradient [1024:1535] ← only its slice
Chip 3: summed gradient [1536:2047] ← only its slice

Now each chip updates its portion of weights

Broadcast

Definition: One chip has data → Send copy to all other chips.

Example (4 chips):

Initial state:
Chip 0: [100, 200, 300, 400]  ← has the data
Chip 1: [ ]                    ← empty
Chip 2: [ ]                    ← empty
Chip 3: [ ]                    ← empty

After broadcast (from Chip 0):
Chip 0: [100, 200, 300, 400]
Chip 1: [100, 200, 300, 400]
Chip 2: [100, 200, 300, 400]
Chip 3: [100, 200, 300, 400]

Use case in transformers: Distributing model weights or hyperparameters

Example: Broadcasting updated learning rate from master chip
Chip 0 (master): learning_rate = 0.001

After broadcast:
All chips: learning_rate = 0.001

Comparison Summary

OperationInputOutputDirection
All-reduceEach: partial valueEveryone: same combined resultMany → Many (same)
All-gatherEach: different sliceEveryone: same full dataMany → Many (same)
Reduce-scatterEach: full dataEach: different slice of resultMany → Many (different)
BroadcastOne: dataEveryone: same copyOne → Many

Relationship: All-reduce = Reduce-scatter + All-gather

All-reduce in two steps:
1. Reduce-scatter: Combine data and distribute slices
2. All-gather: Gather slices to get full result on all chips

Tags