注意力机制
注意力机制是 Transformer 的核心,允许模型在处理每个位置时关注输入序列的不同部分。
Scaled Dot-Product Attention
基本原理
python
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights计算流程
- 计算相似度: Q 和 K 的点积
- 缩放: 除以 √d_k 防止梯度消失
- 掩码: 可选的掩码操作
- 归一化: softmax 归一化
- 加权求和: 乘以 V 得到输出
Multi-Head Attention
多头注意力机制
python
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 注意力计算
output, _ = scaled_dot_product_attention(Q, K, V, mask)
# 拼接多头输出
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(output)
return output多头的优势
- 不同子空间: 每个头学习不同的注意力模式
- 丰富表示: 多头可以捕捉多种类型的依赖关系
- 并行计算: 多个头可以并行计算
注意力变体
自注意力
python
# 自注意力:Q=K=V
output = self_attn(x, x, x)交叉注意力
python
# 交叉注意力:Q来自解码器,K和V来自编码器
output = cross_attn(decoder_output, encoder_output, encoder_output)因果注意力
python
# 因果注意力:防止未来位置的信息泄露
mask = torch.tril(torch.ones(n, n))
output = masked_attn(x, x, x, mask)注意力可视化
可视化注意力权重
python
import matplotlib.pyplot as plt
# 获取注意力权重
_, attn_weights = scaled_dot_product_attention(Q, K, V)
# 可视化
plt.figure(figsize=(10, 10))
plt.imshow(attn_weights[0, 0].detach().cpu().numpy(), cmap='viridis')
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.title('Attention Weights')
plt.colorbar()
plt.show()高效注意力机制
Flash Attention
Flash Attention 通过分块计算和重计算来优化内存使用。
python
from flash_attn import flash_attn_qkvpacked_func
# 使用 Flash Attention
output = flash_attn_qkvpacked_func(qkv, causal=True)Linear Attention
Linear Attention 将复杂度从 O(n²) 降低到 O(n)。
python
def linear_attention(Q, K, V):
# 简化版本
K = K.softmax(dim=-2)
output = torch.matmul(Q, K.transpose(-2, -1))
output = torch.matmul(output, V)
return output应用场景
机器翻译
- 编码器处理源语言
- 解码器生成目标语言
- 交叉注意力连接两者
文本摘要
- 编码器处理输入文本
- 解码器生成摘要
- 注意力机制选择重要信息
问答系统
- 编码器处理上下文和问题
- 注意力机制定位答案位置