14. FlashAttention Memory Model | FlashAttention Memory Model
难度: Medium | 标签: Attention, FlashAttention, Memory Model | 目标人群: 准备进入 Chapter 2 的学习者
这一页把 Chapter 1 的 Attention 和显存直觉,接到 Chapter 2 的 FlashAttention 实现上。重点是它如何减少中间计算和显存访问。
前置关系
- Chapter 1 已经讲过 Attention 为什么会吃显存
- Chapter 2 会直接进入 Attention 实现、FlashAttention 和 PagedAttention
- 先把“它到底省了什么”讲清楚
你应该先建立的直觉
1. 标准 Attention 的问题不只是算得多
标准 Attention 的麻烦有两层:
- 计算本身很重
- 中间矩阵会占用大量显存
也就是说,问题不只是“算慢”,而是“算完以后还要保存很多东西”。
2. FlashAttention 的核心不是改公式,而是改组织方式
FlashAttention 的直觉可以先记成一句话:
text
把大矩阵切成小块,在更快的片上存储里分块处理,尽量减少不必要的中间显存读写它不是单纯换了一个数学公式,而是换了处理数据的方式。
3. 显存访问方式会决定实际性能
在 GPU 上,很多时候真正慢的不是算,而是:
- 反复访问慢内存
- 中间结果太多
- 数据搬来搬去
FlashAttention 的价值,就是尽量让计算和存储更贴近硬件实际。
一个最常见的理解路径
你可以先按下面这条线理解 FlashAttention:
text
标准 Attention -> 中间矩阵太大 -> 显存压力高 -> 分块处理 -> 减少中间显存 -> 性能提升这条线比直接背实现细节更重要。
后面 Chapter 2 的代码实现,基本都建立在这个直觉上。
常见误区
- 以为 FlashAttention 只是“更快的 softmax”
- 只关注算法表达,忽略显存访问方式
- 把 FlashAttention 和 PagedAttention 混为一谈
- 认为它解决了所有推理显存问题,其实它主要解决的是 Attention 的中间计算和访问模式
这一页学完后,你应该能回答
- 为什么标准 Attention 会很吃显存
- FlashAttention 到底省了什么
- 为什么分块处理能改善性能
- 为什么 Chapter 2 的 Attention 相关内容要连着学
和后续章节的联系
Chapter 2: Attention MHA/GQA
你会看到注意力计算本身如何组织Chapter 2: FlashAttention
你会看到分块、在线 softmax 和 memory-friendly 的实现思路Chapter 2: vLLM PagedAttention
你会看到 Attention 之外,KV cache 的组织方式也会影响推理显存
小结
这一页的作用也很简单:
- 先让你知道标准 Attention 为什么贵
- 再让你知道 FlashAttention 为什么有效
- 最后让你知道 Chapter 2 里为什么要把它单独讲
如果你能把“计算方式”和“显存访问方式”联系起来,这一页的目标基本达成,后面再看 Chapter 2 的 Attention 章节会更顺。
