Close

Presentation

Batch Tiling on Attention: Efficient Mixture of Experts Training on Wafer-Scale Processors
DescriptionMixture of Experts (MoE) models face computational bottlenecks on wafer-scale processors due to conflicting batch size requirements: attention mechanisms need smaller batches for memory constraints, while routable MLP layers require larger batches for optimal compute density.
We introduce \textbf{Batch Tiling on Attention (BTA)}, which decouples batch processing across MoE computation stages by applying dynamic tiling on attention's batch dimension. Our method processes attention operations at reduced batch size $B$ through tiled computation, then concatenates outputs to form larger batch size $\widetilde{B} = G \cdot B$ for MLP operations, where $G$ is a positive integer. This addresses attention memory limitations while maximizing hardware utilization in expert layers.
We demonstrate BTA's effectiveness on Cerebras wafer-scale engines using Qwen3-like models, achieving up to 5$\times$ performance improvements at higher sparsity levels compared to conventional uniform batching. Unlike existing GPU-focused solutions like FlashAttention and expert parallelism, BTA specifically targets wafer-scale processors' unique computational characteristics.