Skip to content

14. FlashAttention Memory Model | FlashAttention Memory Model

难度: Medium | 标签: Attention, FlashAttention, Memory Model | 目标人群: 准备进入 Chapter 2 的学习者

这一页把 Chapter 1 的 Attention 和显存直觉,接到 Chapter 2 的 FlashAttention 实现上。重点是它如何减少中间计算和显存访问。

前置关系

  • Chapter 1 已经讲过 Attention 为什么会吃显存
  • Chapter 2 会直接进入 Attention 实现、FlashAttention 和 PagedAttention
  • 先把“它到底省了什么”讲清楚

你应该先建立的直觉

1. 标准 Attention 的问题不只是算得多

标准 Attention 的麻烦有两层:

  • 计算本身很重
  • 中间矩阵会占用大量显存

也就是说,问题不只是“算慢”,而是“算完以后还要保存很多东西”。

2. FlashAttention 的核心不是改公式,而是改组织方式

FlashAttention 的直觉可以先记成一句话:

text
把大矩阵切成小块,在更快的片上存储里分块处理,尽量减少不必要的中间显存读写

它不是单纯换了一个数学公式,而是换了处理数据的方式。

3. 显存访问方式会决定实际性能

在 GPU 上,很多时候真正慢的不是算,而是:

  • 反复访问慢内存
  • 中间结果太多
  • 数据搬来搬去

FlashAttention 的价值,就是尽量让计算和存储更贴近硬件实际。

一个最常见的理解路径

你可以先按下面这条线理解 FlashAttention:

text
标准 Attention -> 中间矩阵太大 -> 显存压力高 -> 分块处理 -> 减少中间显存 -> 性能提升

这条线比直接背实现细节更重要。
后面 Chapter 2 的代码实现,基本都建立在这个直觉上。

常见误区

  • 以为 FlashAttention 只是“更快的 softmax”
  • 只关注算法表达,忽略显存访问方式
  • 把 FlashAttention 和 PagedAttention 混为一谈
  • 认为它解决了所有推理显存问题,其实它主要解决的是 Attention 的中间计算和访问模式

这一页学完后,你应该能回答

  • 为什么标准 Attention 会很吃显存
  • FlashAttention 到底省了什么
  • 为什么分块处理能改善性能
  • 为什么 Chapter 2 的 Attention 相关内容要连着学

和后续章节的联系

  • Chapter 2: Attention MHA/GQA
    你会看到注意力计算本身如何组织

  • Chapter 2: FlashAttention
    你会看到分块、在线 softmax 和 memory-friendly 的实现思路

  • Chapter 2: vLLM PagedAttention
    你会看到 Attention 之外,KV cache 的组织方式也会影响推理显存

小结

这一页的作用也很简单:

  • 先让你知道标准 Attention 为什么贵
  • 再让你知道 FlashAttention 为什么有效
  • 最后让你知道 Chapter 2 里为什么要把它单独讲

如果你能把“计算方式”和“显存访问方式”联系起来,这一页的目标基本达成,后面再看 Chapter 2 的 Attention 章节会更顺。

Released under the MIT License.