Transformer
下面只考虑 Decoder-only 的 Transformer。
计算过程
输入 \(n\) 个 Token,通过 Text 和 Position Embedding 后,得到一个尺寸为 \((n, d_{model})\) 的矩阵,其中 \(d_{model}\) 为 Embedding 向量的长度。
接着,经过 \(n_{layer}\) 个层,每层中要进行如下的计算:
- 从上一层的输出 \((n, d_{model})\),每个 head 上乘以三个 \((d_{model}, d_{head})\) 的矩阵,得到 Q,K 和 V,尺寸为 \((n, d_{head})\)
- 按照 attention 激活公式,得到 Z,尺寸为 \((n, d_{head})\)
- 把所有 head 的 Z 矩阵拼起来,保证 \(d_{model} = n_{head} * d_{head}\),那么所有的 Z 拼起来以后得到的矩阵的尺寸为 \((n, d_{head})\)
- 乘以一个尺寸为 \((d_{model}, d_{model})\) 的 Projection 矩阵,得到新矩阵 \((n, d_{model})\)
- 经过一个 MLP,MLP 第一层是 \((d_{model}, 4*d_{model})\),第二层是 \((4*d_{model}, d_{model})\)
- MLP 输出的矩阵尺寸为 \((n, d_{model})\)
KV cache
在推理的时候,是在已有的 context 的基础上,生成一个新 token,再把 token 加到 context,继续生成下一个 token。计算的时候,由于 Attention 会带 Mask,旧 token 不会依赖新 token,因此旧 token 的部分不会变,可以只考虑新引入的 token 带来的变化。
那么,计算新的 token 的 Q K V 以后,在进行 Attention 计算的时候,会发现先前的 token 的 K 和 V 部分不变,先前 token 的 Q 不影响当前 token 的结果。因此可以把之前 token 的 K 和 V 保存下来,不用重新计算,这就是 KV cache。
参数量和浮点计算量
考虑每层的参数量和浮点计算量:
- 计算 Q,K 和 V:参数是三个 \((d_{model}, d_{model})\) 的矩阵,计算量是 \(2 * 3 * n * d_{model}^2\)
- 乘以 Projection 矩阵:参数是 \((d_{model}, d_{model})\) 的矩阵,计算量是 \(2 * n * d_{model}^2\)
- 乘以 MLP 第一层:参数是 \((d_{model}, 4*d_{model})\) 的矩阵,计算量是 \(2 * 4 * n * d_{model}^2\)
- 乘以 MLP 第二层:参数是 \((4*d_{model}, d_{model})\) 的矩阵,计算量是 \(2 * 4 * n * d_{model}^2\)
其余部分的参数量和计算量可以忽略不计。
因此总参数量(不考虑 Embedding)为:
每个 Token 的浮点计算量为:
Llama 2 7B
以 Llama 2 7B 为例,下面分析 Transformer 推理的计算过程,它的参数如下:
- hidden size: 4096
- intermediate size(MLP 的中间层的维度): 11008
- hidden layers: 32
- attention heads: 32
- key value heads: 32
- head dim: 4096 / 32 = 128
- vocab size: 32000
参考 HuggingFace 源码。
Llama2 的主要计算过程是 32 层 LlamaDecoderLayer,每个 LlamaDecoderLayer 包括:
hidden_states = self.input_layernorm(hidden_states)
: 见 LlamaRMSNormhidden_states, self_attn_weights, present_key_value = self.self_attn()
:见 LlamaAttentionhidden_states = residual + hidden_states
:aten::add([1, 1, 4096], [1, 1, 4096]) = [1, 1, 4096]
hidden_states = self.post_attention_layernorm(hidden_states)
: 见 LlamaRMSNormhidden_states = self.mlp(hidden_states)
:见 LlamaMLPhidden_states = residual + hidden_states
:aten::add([1, 1, 4096], [1, 1, 4096]) = [1, 1, 4096]
hidden_states 的规模是 [1, 1, 4096]
。
LlamaRMSNorm 包括:
hidden_states = hidden_states.to(torch.float32)
:aten::to([1, 1, 4096]) = [1, 1, 4096]
v1 = hidden_states.pow(2)
:aten::pow([1, 1, 4096]) = [1, 1, 4096]
variance = v2.mean(-1, keepdim=True)
:aten::mean([1, 1, 4096]) = [1, 1, 1]
v3 = variance + self.variance_epsilon
:aten::add([1, 1, 1]) = [1, 1, 1]
, 1 FLOPv4 = torch.rsqrt(v3)
:aten::rsqrt([1, 1, 1]) = [1, 1, 1]
hidden_states = hidden_states * v4
:aten::mul([1, 1, 4096], [1, 1, 1]) = [1, 1, 4096]
, 4096 FLOPv5 = hidden_states.to(input_dtype)
:aten::to([1, 1, 4096]) = [1, 1, 4096]
return self.weight * v5
:aten::mul([4096], [1, 1, 4096])
, 4096 FLOP
LlamaAttention 包括:
query_states = self.q_proj(hidden_states)
:aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096]
, 33554432 FLOPkey_states = self.k_proj(hidden_states)
:aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096]
, 33554432 FLOPvalue_states = self.v_proj(hidden_states)
:aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096]
, 33554432 FLOPv1 = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
:aten::view([1, 1, 4096]) = [1, 1, 32, 128]
query_states = v1.transpose(1, 2)
:aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
v2 = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
:aten::view([1, 1, 4096]) = [1, 1, 32, 128]
key_states = v2.transpose(1, 2)
:aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
v3 = value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
:aten::view([1, 1, 4096]) = [1, 1, 32, 128]
value_states = v3.transpose(1, 2)
:aten::transpose([1, 1, 32, 128]) = [1, 32, 1, 128]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
: 见 LlamaRotaryEmebddingquery_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
: 见 apply_rotary_pos_embkey_states = torch.cat([past_key_value[0], key_states], dim=2)
:aten::cat([1, 32, C-1, 128], [1, 32, 1, 128]) = [1, 32, C, 128]
value_states = torch.cat([past_key_value[1], value_states], dim=2)
:aten::cat([1, 32, C-1, 128], [1, 32, 1, 128]) = [1, 32, C, 128]
v4 = key_states.transpose(2, 3)
:aten::transpose([1, 32, C, 128]) = [1, 32, 128, C]
v5 = torch.matmul(query_states, v4)
:aten::matmul([1, 32, 1, 128], [1, 32, 128, C]) = [1, 32, 1, C]
, 8192*C FLOPattn_weights = v1 / math.sqrt(self.head_dim)
:aten::div([1, 32, 1, C]) = [1, 32, 1, C]
attn_weights = attn_weights + attention_mask
:aten::add([1, 32, 1, C], [1, 1, 1, C]) = [1, 32, 1, C]
, 32*C FLOPv6 = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
:aten::softmax([1, 32, 1, C]) = [1, 32, 1, C]
attn_weights = v6.to(query_states.dtype)
:aten::to([1, 32, 1, C]) = [1, 32, 1, C]
attn_output = torch.matmul(attn_weights, value_states)
:aten::matmul([1, 32, 1, C], [1, 32, C, 128]) = [1, 32, 1, 128]
: 8192*C FLOPattn_output = attn_output.transpose(1, 2).contiguous()
:aten::transpose([1, 32, 1, 128]) = [1, 1, 32, 128]
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
:aten::reshape([1, 1, 32, 128]) = [1, 1, 4096]
attn_output = self.o_proj(attn_output)
:aten::linear([1, 1, 4096], [4096, 4096]) = [1, 1, 4096]
, 33554432 FLOP
LlamaRotaryEmbedding 包括:
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
:aten::slice([1, 1, 4096, 128])
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
:aten::slice([1, 1, 4096, 128])
apply_rotary_pos_emb 包括:
v1 = cos.squeeze(1)
:aten::squeeze([1, 1, 10, 128])
cos = v1.squeeze(0)
:aten::squeeze([1, 10, 128])
v2 = cos.squeeze(1)
:aten::squeeze([1, 1, 10, 128])
sin = v2.squeeze(0)
:aten::squeeze([1, 10, 128])
v3 = cos[position_ids]
:aten::index([10, 128])
cos = v3.unsqueeze(1)
:aten::unsqueeze([1, 1, 128]) = [1, 1, 1, 128]
v4 = sin[position_ids]
:aten::index([10, 128])
sin = v4.unsqueeze(1)
:aten::unsqueeze([1, 1, 128]) = [1, 1, 1, 128]
v5 = q * cos
:aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOPv6 = rotate_half(q)
: 见 rotate_halfv7 = v6 * sin
:aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOPq_embed = v5 + v7
:aten::add([1, 32, 1, 128], [1, 32, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOPv8 = k * cos
:aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOPv9 = rotate_half(k)
: 见 rotate_halfv10 = v9 * sin
:aten::mul([1, 32, 1, 128], [1, 1, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOPk_embed = v8 + v10
:aten::add([1, 32, 1, 128], [1, 32, 1, 128]) = [1, 32, 1, 128]
, 4096 FLOP
rotate_half 包括:
x1 = x[..., : x.shape[-1] // 2]
:aten::slice([1, 3, 1, 128])
x2 = x[..., x.shape[-1] // 2 :]
:aten::slice([1, 3, 1, 128])
v1 = -x2
:aten::neg([1, 3, 1, 64])
return torch.cat((v1, x1), dim=-1)
:aten::cat([1, 3, 1, 64], [1, 3, 1, 64]) = [1, 3, 1, 128]
LlamaMLP 包括:
v1 = self.gate_proj(x)
:aten::linear([1, 1, 4096], [11008, 4096]) = [1, 1, 11008]
, 90177536 FLOPv2 = self.act_fn(v1)
:aten::silu([1, 1, 11008]) = [1, 1, 11008]
v3 = self.up_proj(x)
:aten::linear([1, 1, 4096], [11008, 4096]) = [1, 1, 11008]
, 90177536 FLOPv4 = v2 * v3
:aten::mul([1, 1, 11008], [1, 1, 11008]) = [1, 1, 11008]
, 11008 FLOPv5 = self.down_proj(v4)
:aten::linear([1, 1, 11008], [4096, 11008]) = [1, 1, 4096]
, 90177536 FLOP
最后一个 LlamaDecoderLayer 的输出会经过 lm_head 得到 logits:
logits = self.lm_head(hidden_states)
:aten::linear([1, 1, 4096], [32000, 4096]) = [1, 1, 32000]
, 262144000 FLOP
所有 aten 算子:
- aten::add
- aten::cat
- aten::div
- aten::index
- aten::linear
- aten::matmul
- aten::mean
- aten::mul
- aten::neg
- aten::pow
- aten::reshape
- aten::rsqrt
- aten::sequeeze
- aten::silu
- aten::slice
- aten::softmax
- aten::to
- aten::transpose
- aten::unsqueeze
- aten::view
FlashAttention
论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
推荐阅读:FlashAttention 核心逻辑以及 V1 V2 差异总结
主要针对 Transformer 训练中的 Attention 计算进行优化。Attention 计算的是:
\(S = QK^T, P = \mathrm{softmax}(S), O = PV\)
其中 Q、K 和 V 矩阵规模都是 \(N*d\),S 和 P 矩阵规模是 \(N*N\),O 矩阵规模是 \(N*d\),N 是序列长度,d 是 head dimension。实际上 \(S\) 矩阵还会除以一个系数,但是这个优化起来比较简单,就忽略了。朴素算法就是这三个步骤分别算,但是就比较慢。核心问题是 softmax 操作。softmax 的定义是:
\(\mathrm{softmax}(x) = \frac{e^{x_i}}{\Sigma_j e^{x_j}}\)
分子是对应项的 exp,分母是所有项的 exp 之和。所以 softmax 的按定义算的方法就是:所有元素求 exp,然后求 exp 的和,再集体做除法。
但是这么算会有一个问题:元素的 exp 可能会很大,导致精度比较差。因此实际计算的时候,会先计算出元素的最大值 \(x_m\),然后分子分母同时除以最大值的 exp:
\(\mathrm{softmax}(x) = \frac{e^{x_i-x_m}}{\Sigma_j e^{x_j-x_m}}\)
这样计算的精度会比较好,因为 exp 的值都在 0 和 1 之间。这也意味着 softmax 的计算需要先求 max,然后每个元素减去 max 后,求 exp,再求元素的 exp 之和,最后每个元素再除以 exp 之和。准确地说,因为这里的输入是矩阵,所以是对矩阵的每一行分别算 softmax。
但是这样的求法意味着需要先把完整的 softmax 输入求出来,使得 tiling 变得困难。为了 tiling,FlashAttention 采用了一种分块的 softmax 计算方法:首先对每一块分别做 softmax,但是按照定义,max 和 sum 都应该是完整向量的 max 和 sum,而如果分块去计算 softmax,此时的 max 和 sum 是块内的,因此需要进行后处理,把分块的 softmax 纠正成正确的 softmax:
假如有两个块分别做了 softmax,结果是两个向量 \(x_1\) 和 \(x_2\),计算时的 max 记为 \(m_1\) 和 \(m_2\),分子记为 \(f(x_1)\) 和 \(f(x_2)\),分母记为 \(l(x_1)\) 和 \(l(x_2)\)。那么合并后的向量的 max 就是 \(m=\max(m_1, m_2)\),接下来计算出完整的 softmax:分子来自于两个向量,但是需要把 \(m_1\) \(m_2\) 和 \(m\) 的差异补上,那就是 \(e^{m_1-m}f(x_1)\) 和 \(e^{m_2-m}f(x_2)\) 两部分拼接起来;分母则是两个向量的部分,把 max 差异纠正后再求和:\(e^{m_1-m}l(x_1)+e^{m_2-m}l(x_2)\)。论文中是这么表示的:
因此计算 softmax 的时候可以先分块进行,最后再把结果纠正过来。能分块了以后,就可以做 kernel fusion,和矩阵乘法、Masking、Dropout 合并起来,让中间结果在 GPU 内部完成,而不是先写到显存里再读回来。
论文的 Algorithm 1 和 2 给出了融合后的 Kernel 的伪代码。思路是:
- 对 Q、K 和 V 分块
- 对于每个块,计算当前块的 \(QK^T\) 结果,然后计算 rowmax(每行计算出一个最大值),每个元素减去 rowmax 再 exp,得到 softmax 函数的分子;再求和,得到 softmax 函数的分母
- 把刚算出来的 softmax 结果和之前计算的 softmax 结果合并,得到一组新的 softmax 系数
- 更新当前的输出矩阵:把旧的分母乘回来,就得到旧的分子,把旧的分子乘上 exp 系数,再加上新的分子,结果再除以新的分子;这一步和迭代更新平均值很像:旧的平均值乘以旧的元素个数,加上新的元素再除以新的元素个数
然后在中间穿插 dropout 和 masking 等细节,就得到了最终的实现。对这个过程讲的比较清楚的是下面这个图:
右下角 Rescaling to correct denominator 就是上面说的,把旧的分母乘回来,再计算出新的值。标量乘矩阵在上面的式子里表示成了对角矩阵和矩阵的乘法。
不过,这份伪代码距离实际的 CUDA 实现,还有很多细节上的优化。此外,论文还对梯度反向传播进行了优化,毕竟是服务于训练:思路是,前向计算的时候,因为 kernel fusion 的原因,跳过了中间结果的运算,所以反向传播的时候,就重新计算一下 attention,再求梯度。
FlashAttention 2
论文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention 2 相比 FlashAttention 的主要区别是并行的方式。第一个版本是在 batch 和 head 维度上进行并行,也就是说,每个 CUDA thread block 对应一个 batch size 和一个 attention head,一共有 batch size 乘以 head 个数那么多个 thread block。
而第二个版本在 sequence length 维度上也引入了并行,使得 GPU 的利用率可以继续提升。此外,在计算局部 softmax 的时候,也做了修改:不着急计算局部的 softmax,而是分别维护分子和分母,到算完了以后,再算分子除以坟墓。
FlashDecoding
博客:Flash-Decoding for long-context inference
Flash Decoding 是针对长上下文场景下的推理:KV cache 的读取变成了一个瓶颈。所以 Flash-Decoding 的思路就是,并行读取 KV-cache,并且在 sequence length 维度上并行 attention 计算。当然了,并行了以后,就要拆成多块分别求 softmax,也需要 Flash Attention 的合并方法来保证最终 softmax 结果的正确性。
FlashDecoding++
论文:FLASHDECODING++: FASTER LARGE LANGUAGE MODEL INFERENCE ON GPUS
这篇论文也是针对 transformer 推理的优化,主要的优化点:
- Flash Attention 论文解决了 softmax 的分块计算问题,但是每次需要计算一个 max;FlashDecoding++ 选择根据数据的分布去估计一个 max,之后再纠正,可以减少一些同步
- 针对 batch 维度小的矩阵乘法进行优化
参考文献
- Speeding up the GPT - KV cache
- 最新最全 GPT-3 模型网络结构详细解析
- Transformer Inference Arithmetic
- Understanding FLOPs-per-token estimates from OpenAI’s scaling laws
- 拆 Transformer 系列二:Multi-Head Attention 机制详解
- Past_key_values - why not past_key_values_queries?