Encoder-Decoder
Encoder-Decoder 架构是 Transformer 的核心组成部分,用于处理序列到序列的任务。
架构组成
Encoder
Encoder 负责处理输入序列,提取特征表示:
python
class Encoder(torch.nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.layers = torch.nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
def forward(self, x, mask=None):
# 嵌入层
x = self.embedding(x) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
# 位置编码
x = self.positional_encoding(x)
# 编码器层
for layer in self.layers:
x = layer(x, mask)
return xDecoder
Decoder 负责生成输出序列:
python
class Decoder(torch.nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.layers = torch.nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.fc = torch.nn.Linear(d_model, vocab_size)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# 嵌入层
x = self.embedding(x) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
# 位置编码
x = self.positional_encoding(x)
# 解码器层
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
# 输出层
output = self.fc(x)
return output完整模型
python
class Transformer(torch.nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff, num_layers)
self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff, num_layers)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_output = self.encoder(src, src_mask)
dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
return dec_output训练过程
数据准备
python
# 输入序列
src = torch.tensor([[1, 2, 3, 4, 5]]) # 源语言
tgt = torch.tensor([[1, 6, 7, 8, 9]]) # 目标语言
# 创建掩码
src_mask = None
tgt_mask = torch.tril(torch.ones(tgt.size(1), tgt.size(1)))损失计算
python
# 前向传播
output = model(src, tgt, src_mask, tgt_mask)
# 计算损失
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(output.reshape(-1, vocab_size), tgt.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()推理过程
自回归生成
python
def generate(model, src, max_length):
# 编码输入
enc_output = model.encoder(src)
# 初始化输出
tgt = torch.tensor([[1]]) # 起始符号
for _ in range(max_length):
# 创建掩码
tgt_mask = torch.tril(torch.ones(tgt.size(1), tgt.size(1)))
# 解码
output = model.decoder(tgt, enc_output, tgt_mask=tgt_mask)
# 获取最后一个token
next_token = output.argmax(dim=-1)[:, -1].unsqueeze(1)
# 拼接
tgt = torch.cat([tgt, next_token], dim=1)
# 检查结束符号
if next_token.item() == 2: # 结束符号
break
return tgt应用场景
机器翻译
python
# 输入:英语句子
src = tokenizer.encode("Hello, how are you?")
# 输出:法语翻译
tgt = generate(model, src, max_length=50)
print(tokenizer.decode(tgt)) # "Bonjour, comment ça va ?"文本摘要
python
# 输入:长文本
src = tokenizer.encode("这是一篇很长的文章...")
# 输出:摘要
tgt = generate(model, src, max_length=100)
print(tokenizer.decode(tgt)) # "文章摘要..."对话系统
python
# 输入:用户问题
src = tokenizer.encode("什么是人工智能?")
# 输出:回答
tgt = generate(model, src, max_length=200)
print(tokenizer.decode(tgt)) # "人工智能是..."优化策略
教师强制
在训练时使用真实的目标序列作为输入。
标签平滑
python
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)学习率调度
python
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)