Skip to content

注意力机制

注意力机制是 Transformer 的核心,允许模型在处理每个位置时关注输入序列的不同部分。

Scaled Dot-Product Attention

基本原理

python
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

计算流程

  1. 计算相似度: Q 和 K 的点积
  2. 缩放: 除以 √d_k 防止梯度消失
  3. 掩码: 可选的掩码操作
  4. 归一化: softmax 归一化
  5. 加权求和: 乘以 V 得到输出

Multi-Head Attention

多头注意力机制

python
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性变换
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 注意力计算
        output, _ = scaled_dot_product_attention(Q, K, V, mask)
        
        # 拼接多头输出
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)
        
        return output

多头的优势

  • 不同子空间: 每个头学习不同的注意力模式
  • 丰富表示: 多头可以捕捉多种类型的依赖关系
  • 并行计算: 多个头可以并行计算

注意力变体

自注意力

python
# 自注意力:Q=K=V
output = self_attn(x, x, x)

交叉注意力

python
# 交叉注意力:Q来自解码器,K和V来自编码器
output = cross_attn(decoder_output, encoder_output, encoder_output)

因果注意力

python
# 因果注意力:防止未来位置的信息泄露
mask = torch.tril(torch.ones(n, n))
output = masked_attn(x, x, x, mask)

注意力可视化

可视化注意力权重

python
import matplotlib.pyplot as plt

# 获取注意力权重
_, attn_weights = scaled_dot_product_attention(Q, K, V)

# 可视化
plt.figure(figsize=(10, 10))
plt.imshow(attn_weights[0, 0].detach().cpu().numpy(), cmap='viridis')
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.title('Attention Weights')
plt.colorbar()
plt.show()

高效注意力机制

Flash Attention

Flash Attention 通过分块计算和重计算来优化内存使用。

python
from flash_attn import flash_attn_qkvpacked_func

# 使用 Flash Attention
output = flash_attn_qkvpacked_func(qkv, causal=True)

Linear Attention

Linear Attention 将复杂度从 O(n²) 降低到 O(n)。

python
def linear_attention(Q, K, V):
    # 简化版本
    K = K.softmax(dim=-2)
    output = torch.matmul(Q, K.transpose(-2, -1))
    output = torch.matmul(output, V)
    return output

应用场景

机器翻译

  • 编码器处理源语言
  • 解码器生成目标语言
  • 交叉注意力连接两者

文本摘要

  • 编码器处理输入文本
  • 解码器生成摘要
  • 注意力机制选择重要信息

问答系统

  • 编码器处理上下文和问题
  • 注意力机制定位答案位置