The standard attention mechanism computes similarity between all pairs of tokens in a sequence. This involves three matrices: Query, Key, and Value. The attention formula applies softmax to the scaled dot product of Query and Key matrices, then multiplies by Values. However, this creates a fundamental bottleneck: the attention matrix has quadratic size N-squared, where N is the sequence length. For long sequences, this quadratic memory requirement becomes prohibitive, limiting the scalability of transformer models.
GPU memory hierarchy creates significant challenges for attention computation. Graphics processors have a two-tier memory system: large but slow HBM memory with 600 gigabytes per second bandwidth, and small but fast SRAM memory with 19 terabytes per second bandwidth. Standard attention computation requires loading the entire attention matrix from HBM to SRAM multiple times, creating a memory wall bottleneck. The frequent data transfers between these memory levels severely limit performance, as the GPU cores spend more time waiting for data than actually computing.
The tiling strategy is the foundation of Flash Attention's efficiency. Instead of processing the entire attention matrix at once, tiling divides large matrices into smaller blocks that fit within the GPU's fast SRAM memory. Each block can be processed independently, dramatically reducing the number of expensive memory transfers between HBM and SRAM. This block-wise computation approach reduces memory access complexity from O of N-squared to O of N-squared over M, where M is the SRAM block size. The key insight is that attention can be computed incrementally across blocks while maintaining numerical accuracy.
The online softmax algorithm is the mathematical innovation that makes Flash Attention possible. Traditional softmax requires computing the maximum and sum over all elements simultaneously. The online version maintains running statistics: a maximum value m and a normalizing sum l, updating them incrementally as new blocks are processed. When a new block arrives, we update the global maximum, rescale previous computations, and incorporate new values. This approach maintains numerical stability by subtracting the running maximum before computing exponentials, preventing overflow. The key insight is that softmax can be computed incrementally without storing intermediate results, enabling block-wise processing while maintaining mathematical correctness.
The complete Flash Attention algorithm synthesizes tiling and online softmax into an efficient implementation. The algorithm divides Query, Key, and Value matrices into blocks that fit in SRAM memory. For each Query block, it iterates through all Key-Value block pairs, loading them into fast SRAM memory. Local attention scores are computed using standard matrix operations, then the online softmax algorithm updates global statistics and accumulates results. This approach reduces memory complexity from O of N-squared to O of N-squared over M, where M is the SRAM capacity. The algorithm maintains mathematical equivalence to standard attention while achieving 2 to 4 times speedup and 10 to 20 times memory reduction, enabling efficient processing of much longer sequences.