Skip to content

Transformer

下面只考虑 Decoder-only 的 Transformer。

计算过程

输入 n 个 Token,通过 Text 和 Position Embedding 后,得到一个尺寸为 (n,dmodel) 的矩阵,其中 dmodel 为 Embedding 向量的长度。

接着,经过 nlayer 个层,每层中要进行如下的计算:

  1. 从上一层的输出 (n,dmodel),每个 head 上乘以三个 (dmodel,dhead) 的矩阵,得到 Q,K 和 V,尺寸为 (n,dhead)
  2. 按照 attention 激活公式,得到 Z,尺寸为 (n,dhead)
  3. 把所有 head 的 Z 矩阵拼起来,保证 dmodel=nheaddhead,那么所有的 Z 拼起来以后得到的矩阵的尺寸为 (n,dhead)
  4. 乘以一个尺寸为 (dmodel,dmodel) 的 Projection 矩阵,得到新矩阵 (n,dmodel)
  5. 经过一个 MLP,MLP 第一层是 (dmodel,4dmodel),第二层是 (4dmodel,dmodel)
  6. MLP 输出的矩阵尺寸为 (n,dmodel)

KV cache

在推理的时候,是在已有的 context 的基础上,生成一个新 token,再把 token 加到 context,继续生成下一个 token。计算的时候,由于 Attention 会带 Mask,旧 token 不会依赖新 token,因此旧 token 的部分不会变,可以只考虑新引入的 token 带来的变化。

那么,计算新的 token 的 Q K V 以后,在进行 Attention 计算的时候,会发现先前的 token 的 K 和 V 部分不变,先前 token 的 Q 不影响当前 token 的结果。因此可以把之前 token 的 K 和 V 保存下来,不用重新计算,这就是 KV cache。

参数量和浮点计算量

考虑每层的参数量和浮点计算量:

  1. 计算 Q,K 和 V:参数是三个 (dmodel,dmodel) 的矩阵,计算量是 23ndmodel2
  2. 乘以 Projection 矩阵:参数是 (dmodel,dmodel) 的矩阵,计算量是 2ndmodel2
  3. 乘以 MLP 第一层:参数是 (dmodel,4dmodel) 的矩阵,计算量是 24ndmodel2
  4. 乘以 MLP 第二层:参数是 (4dmodel,dmodel) 的矩阵,计算量是 24ndmodel2

其余部分的参数量和计算量可以忽略不计。

因此总参数量(不考虑 Embedding)为:

12nlayerdmodel2

每个 Token 的浮点计算量为:

24nlayerdmodel2FLOP

Llama 2 7B

以 Llama 2 7B 为例,下面分析 Transformer 推理的计算过程,它的参数如下:

  • hidden size: 4096
  • intermediate size(MLP 的中间层的维度): 11008
  • hidden layers: 32
  • attention heads: 32
  • key value heads: 32
  • head dim: 4096 / 32 = 128
  • vocab size: 32000

参考 HuggingFace 源码

Llama2 的主要计算过程是 32 层 LlamaDecoderLayer,每个 LlamaDecoderLayer 包括:

  1. hidden_states = self.input_layernorm(hidden_states): 见 LlamaRMSNorm
  2. hidden_states, self_attn_weights, present_key_value = self.self_attn():见 LlamaAttention
  3. hidden_states = residual + hidden_states: aten::add([1, 1, 4096], [1, 1, 4096]) = [1, 1, 4096]
  4. hidden_states = self.post_attention_layernorm(hidden_states): 见 LlamaRMSNorm
  5. hidden_states = self.mlp(hidden_states):见 LlamaMLP
  6. hidden_states = residual + hidden_states: aten::add([1, 1, 4096], [1, 1, 4096]) = [1, 1, 4096]

hidden_states 的规模是 [1, 1, 4096]

LlamaRMSNorm 包括:

  1. hidden_states = hidden_states.to(torch.float32): aten::to([1, 1, 4096]) = [1, 1, 4096]
  2. v1 = hidden_states.pow(2): aten::pow([1, 1, 4096]) = [1, 1, 4096]
  3. variance = v2.mean(-1, keepdim=True): aten::mean([1, 1, 4096]) = [1, 1, 1]
  4. v3 = variance + self.variance_epsilon: aten::add([1, 1, 1]) = [1, 1, 1], 1 FLOP
  5. v4 = torch.rsqrt(v3): aten::rsqrt([1, 1, 1]) = [1, 1, 1]
  6. hidden_states = hidden_states * v4: aten::mul([1, 1, 4096], [1, 1, 1]) = [1, 1, 4096], 4096 FLOP
  7. v5 = hidden_states.to(input_dtype): aten::to([1, 1, 4096]) = [1, 1, 4096]
  8. return self.weight * v5: aten::mul([4096], [1, 1, 4096]), 4096 FLOP

LlamaAttention 包括:

  1. query_states = self.q_proj(hidden_states): aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096], 33554432 FLOP
  2. key_states = self.k_proj(hidden_states): aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096], 33554432 FLOP
  3. value_states = self.v_proj(hidden_states): aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096], 33554432 FLOP
  4. v1 = query_states.view(bsz, q_len, self.num_heads, self.head_dim): aten::view([1, 1, 4096]) = [1, 1, 32, 128]
  5. query_states = v1.transpose(1, 2): aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
  6. v2 = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim): aten::view([1, 1, 4096]) = [1, 1, 32, 128]
  7. key_states = v2.transpose(1, 2): aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
  8. v3 = value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim): aten::view([1, 1, 4096]) = [1, 1, 32, 128]
  9. value_states = v3.transpose(1, 2): aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
  10. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len): 见 LlamaRotaryEmebdding
  11. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids): 见 apply_rotary_pos_emb
  12. key_states = torch.cat([past_key_value[0], key_states], dim=2): aten::cat([1, 32, C-1, 128], [1, 32, 1, 128]) = [1, 32, C, 128]
  13. value_states = torch.cat([past_key_value[1], value_states], dim=2): aten::cat([1, 32, C-1, 128], [1, 32, 1, 128]) = [1, 32, C, 128]
  14. v4 = key_states.transpose(2, 3): aten::transpose([1, 32, C, 128]) = [1, 32, 128, C]
  15. v5 = torch.matmul(query_states, v4): aten::matmul([1, 32, 1, 128], [1, 32, 128, C]) = [1, 32, 1, C], 8192*C FLOP
  16. attn_weights = v1 / math.sqrt(self.head_dim): aten::div([1, 32, 1, C]) = [1, 32, 1, C]
  17. attn_weights = attn_weights + attention_mask: aten::add([1, 32, 1, C], [1, 1, 1, C]) = [1, 32, 1, C], 32*C FLOP
  18. v6 = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype): aten::softmax([1, 32, 1, C]) = [1, 32, 1, C]
  19. attn_weights = v6.to(query_states.dtype): aten::to([1, 32, 1, C]) = [1, 32, 1, C]
  20. attn_output = torch.matmul(attn_weights, value_states): aten::matmul([1, 32, 1, C], [1, 32, C, 128]) = [1, 32, 1, 128]: 8192*C FLOP
  21. attn_output = attn_output.transpose(1, 2).contiguous(): aten::transpose([1, 32, 1, 128]) = [1, 1, 32, 128]
  22. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size): aten::reshape([1, 1, 32, 128]) = [1, 1, 4096]
  23. attn_output = self.o_proj(attn_output): aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096], 33554432 FLOP

LlamaRotaryEmbedding 包括:

  1. self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),: aten::slice([1, 1, 4096, 128])
  2. self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),: aten::slice([1, 1, 4096, 128])

apply_rotary_pos_emb 包括:

  1. v1 = cos.squeeze(1): aten::squeeze([1, 1, 10, 128])
  2. cos = v1.squeeze(0): aten::squeeze([1, 10, 128])
  3. v2 = cos.squeeze(1): aten::squeeze([1, 1, 10, 128])
  4. sin = v2.squeeze(0): aten::squeeze([1, 10, 128])
  5. v3 = cos[position_ids]: aten::index([10, 128])
  6. cos = v3.unsqueeze(1): aten::unsqueeze([1, 1, 128]) = [1, 1, 1, 128]
  7. v4 = sin[position_ids]: aten::index([10, 128])
  8. sin = v4.unsqueeze(1): aten::unsqueeze([1, 1, 128]) = [1, 1, 1, 128]
  9. v5 = q * cos: aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128], 4096 FLOP
  10. v6 = rotate_half(q): 见 rotate_half
  11. v7 = v6 * sin: aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128], 4096 FLOP
  12. q_embed = v5 + v7: aten::add([1, 32, 1, 128], [1, 32, 1, 128]) = [1, 32, 1, 128], 4096 FLOP
  13. v8 = k * cos: aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128], 4096 FLOP
  14. v9 = rotate_half(k): 见 rotate_half
  15. v10 = v9 * sin: aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128], 4096 FLOP
  16. k_embed = v8 + v10: aten::add([1, 32, 1, 128], [1, 32, 1, 128]) = [1, 32, 1, 128], 4096 FLOP

rotate_half 包括:

  1. x1 = x[..., : x.shape[-1] // 2]: aten::slice([1, 3, 1, 128])
  2. x2 = x[..., x.shape[-1] // 2 :]: aten::slice([1, 3, 1, 128])
  3. v1 = -x2: aten::neg([1, 3, 1, 64])
  4. return torch.cat((v1, x1), dim=-1): aten::cat([1, 3, 1, 64], [1, 3, 1, 64]) = [1, 3, 1, 128]

LlamaMLP 包括:

  1. v1 = self.gate_proj(x): aten::linear([1, 1, 4096], [11008, 4096]) = [1, 1, 11008], 90177536 FLOP
  2. v2 = self.act_fn(v1): aten::silu([1, 1, 11008]) = [1, 1, 11008]
  3. v3 = self.up_proj(x): aten::linear([1, 1, 4096], [11008, 4096]) = [1, 1, 11008], 90177536 FLOP
  4. v4 = v2 * v3: aten::mul([1, 1, 11008], [1, 1, 11008]) = [1, 1, 11008], 11008 FLOP
  5. v5 = self.down_proj(v4): aten::linear([1, 1, 11008], [4096, 11008]) = [1, 1, 4096], 90177536 FLOP

最后一个 LlamaDecoderLayer 的输出会经过 lm_head 得到 logits:

  1. logits = self.lm_head(hidden_states): aten::linear([1, 1, 4096], [32000, 4096]) = [1, 1, 32000], 262144000 FLOP

所有 aten 算子:

  • aten::add
  • aten::cat
  • aten::div
  • aten::index
  • aten::linear
  • aten::matmul
  • aten::mean
  • aten::mul
  • aten::neg
  • aten::pow
  • aten::reshape
  • aten::rsqrt
  • aten::sequeeze
  • aten::silu
  • aten::slice
  • aten::softmax
  • aten::to
  • aten::transpose
  • aten::unsqueeze
  • aten::view

FlashAttention

论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

代码:Dao-AILab/flash-attention

推荐阅读:FlashAttention 核心逻辑以及 V1 V2 差异总结

主要针对 Transformer 训练中的 Attention 计算进行优化。Attention 计算的是:

S=QKT,P=softmax(S),O=PV

其中 Q、K 和 V 矩阵规模都是 Nd,S 和 P 矩阵规模是 NN,O 矩阵规模是 Nd,N 是序列长度,d 是 head dimension。实际上 S 矩阵还会除以一个系数,但是这个优化起来比较简单,就忽略了。朴素算法就是这三个步骤分别算,但是就比较慢。核心问题是 softmax 操作。softmax 的定义是:

softmax(x)=exiΣjexj

分子是对应项的 exp,分母是所有项的 exp 之和。所以 softmax 的按定义算的方法就是:所有元素求 exp,然后求 exp 的和,再集体做除法。

但是这么算会有一个问题:元素的 exp 可能会很大,导致精度比较差。因此实际计算的时候,会先计算出元素的最大值 xm,然后分子分母同时除以最大值的 exp:

softmax(x)=exixmΣjexjxm

这样计算的精度会比较好,因为 exp 的值都在 0 和 1 之间。这也意味着 softmax 的计算需要先求 max,然后每个元素减去 max 后,求 exp,再求元素的 exp 之和,最后每个元素再除以 exp 之和。准确地说,因为这里的输入是矩阵,所以是对矩阵的每一行分别算 softmax。

但是这样的求法意味着需要先把完整的 softmax 输入求出来,使得 tiling 变得困难。为了 tiling,FlashAttention 采用了一种分块的 softmax 计算方法:首先对每一块分别做 softmax,但是按照定义,max 和 sum 都应该是完整向量的 max 和 sum,而如果分块去计算 softmax,此时的 max 和 sum 是块内的,因此需要进行后处理,把分块的 softmax 纠正成正确的 softmax:

假如有两个块分别做了 softmax,结果是两个向量 x1x2,计算时的 max 记为 m1m2,分子记为 f(x1)f(x2),分母记为 l(x1)l(x2)。那么合并后的向量的 max 就是 m=max(m1,m2),接下来计算出完整的 softmax:分子来自于两个向量,但是需要把 m1 m2m 的差异补上,那就是 em1mf(x1)em2mf(x2) 两部分拼接起来;分母则是两个向量的部分,把 max 差异纠正后再求和:em1ml(x1)+em2ml(x2)。论文中是这么表示的:

分块 softmax(来源:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness)

因此计算 softmax 的时候可以先分块进行,最后再把结果纠正过来。能分块了以后,就可以做 kernel fusion,和矩阵乘法、Masking、Dropout 合并起来,让中间结果在 GPU 内部完成,而不是先写到显存里再读回来。

论文的 Algorithm 1 和 2 给出了融合后的 Kernel 的伪代码。思路是:

  1. 对 Q、K 和 V 分块
  2. 对于每个块,计算当前块的 QKT 结果,然后计算 rowmax(每行计算出一个最大值),每个元素减去 rowmax 再 exp,得到 softmax 函数的分子;再求和,得到 softmax 函数的分母
  3. 把刚算出来的 softmax 结果和之前计算的 softmax 结果合并,得到一组新的 softmax 系数
  4. 更新当前的输出矩阵:把旧的分母乘回来,就得到旧的分子,把旧的分子乘上 exp 系数,再加上新的分子,结果再除以新的分子;这一步和迭代更新平均值很像:旧的平均值乘以旧的元素个数,加上新的元素再除以新的元素个数

然后在中间穿插 dropout 和 masking 等细节,就得到了最终的实现。对这个过程讲的比较清楚的是下面这个图:

Flash Attention(来源:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning)

右下角 Rescaling to correct denominator 就是上面说的,把旧的分母乘回来,再计算出新的值。标量乘矩阵在上面的式子里表示成了对角矩阵和矩阵的乘法。

不过,这份伪代码距离实际的 CUDA 实现,还有很多细节上的优化。此外,论文还对梯度反向传播进行了优化,毕竟是服务于训练:思路是,前向计算的时候,因为 kernel fusion 的原因,跳过了中间结果的运算,所以反向传播的时候,就重新计算一下 attention,再求梯度。

FlashAttention 2

论文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

代码:Dao-AILab/flash-attention

FlashAttention 2 相比 FlashAttention 的主要区别是并行的方式。第一个版本是在 batch 和 head 维度上进行并行,也就是说,每个 CUDA thread block 对应一个 batch size 和一个 attention head,一共有 batch size 乘以 head 个数那么多个 thread block。

而第二个版本在 sequence length 维度上也引入了并行,使得 GPU 的利用率可以继续提升。此外,在计算局部 softmax 的时候,也做了修改:不着急计算局部的 softmax,而是分别维护分子和分母,到算完了以后,再算分子除以坟墓。

FlashDecoding

博客:Flash-Decoding for long-context inference

Flash Decoding 是针对长上下文场景下的推理:KV cache 的读取变成了一个瓶颈。所以 Flash-Decoding 的思路就是,并行读取 KV-cache,并且在 sequence length 维度上并行 attention 计算。当然了,并行了以后,就要拆成多块分别求 softmax,也需要 Flash Attention 的合并方法来保证最终 softmax 结果的正确性。

FlashDecoding++

论文:FLASHDECODING++: FASTER LARGE LANGUAGE MODEL INFERENCE ON GPUS

这篇论文也是针对 transformer 推理的优化,主要的优化点:

  1. Flash Attention 论文解决了 softmax 的分块计算问题,但是每次需要计算一个 max;FlashDecoding++ 选择根据数据的分布去估计一个 max,之后再纠正,可以减少一些同步
  2. 针对 batch 维度小的矩阵乘法进行优化

参考文献

Comments