Skip to content

Encoder-Decoder

Encoder-Decoder 架构是 Transformer 的核心组成部分,用于处理序列到序列的任务。

架构组成

Encoder

Encoder 负责处理输入序列,提取特征表示:

python
class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.layers = torch.nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
    
    def forward(self, x, mask=None):
        # 嵌入层
        x = self.embedding(x) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        
        # 位置编码
        x = self.positional_encoding(x)
        
        # 编码器层
        for layer in self.layers:
            x = layer(x, mask)
        
        return x

Decoder

Decoder 负责生成输出序列:

python
class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.layers = torch.nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        self.fc = torch.nn.Linear(d_model, vocab_size)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # 嵌入层
        x = self.embedding(x) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        
        # 位置编码
        x = self.positional_encoding(x)
        
        # 解码器层
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        
        # 输出层
        output = self.fc(x)
        
        return output

完整模型

python
class Transformer(torch.nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff, num_layers)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff, num_layers)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        return dec_output

训练过程

数据准备

python
# 输入序列
src = torch.tensor([[1, 2, 3, 4, 5]])  # 源语言
tgt = torch.tensor([[1, 6, 7, 8, 9]])  # 目标语言

# 创建掩码
src_mask = None
tgt_mask = torch.tril(torch.ones(tgt.size(1), tgt.size(1)))

损失计算

python
# 前向传播
output = model(src, tgt, src_mask, tgt_mask)

# 计算损失
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(output.reshape(-1, vocab_size), tgt.reshape(-1))

# 反向传播
loss.backward()
optimizer.step()

推理过程

自回归生成

python
def generate(model, src, max_length):
    # 编码输入
    enc_output = model.encoder(src)
    
    # 初始化输出
    tgt = torch.tensor([[1]])  # 起始符号
    
    for _ in range(max_length):
        # 创建掩码
        tgt_mask = torch.tril(torch.ones(tgt.size(1), tgt.size(1)))
        
        # 解码
        output = model.decoder(tgt, enc_output, tgt_mask=tgt_mask)
        
        # 获取最后一个token
        next_token = output.argmax(dim=-1)[:, -1].unsqueeze(1)
        
        # 拼接
        tgt = torch.cat([tgt, next_token], dim=1)
        
        # 检查结束符号
        if next_token.item() == 2:  # 结束符号
            break
    
    return tgt

应用场景

机器翻译

python
# 输入:英语句子
src = tokenizer.encode("Hello, how are you?")

# 输出:法语翻译
tgt = generate(model, src, max_length=50)
print(tokenizer.decode(tgt))  # "Bonjour, comment ça va ?"

文本摘要

python
# 输入:长文本
src = tokenizer.encode("这是一篇很长的文章...")

# 输出:摘要
tgt = generate(model, src, max_length=100)
print(tokenizer.decode(tgt))  # "文章摘要..."

对话系统

python
# 输入:用户问题
src = tokenizer.encode("什么是人工智能?")

# 输出:回答
tgt = generate(model, src, max_length=200)
print(tokenizer.decode(tgt))  # "人工智能是..."

优化策略

教师强制

在训练时使用真实的目标序列作为输入。

标签平滑

python
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

学习率调度

python
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)