检索增强
检索增强是 RAG 技术的核心,通过从外部知识库检索相关信息来增强模型回答。
概述
检索增强生成(RAG)通过检索外部知识来提高模型回答的准确性和可靠性。
RAG 流程
用户查询
↓
检索器从知识库中查找相关文档
↓
将检索到的文档作为上下文
↓
语言模型基于上下文生成回答
↓
输出最终回答实现示例
使用 LangChain
python
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
# 加载文档
loader = TextLoader("knowledge_base.txt")
documents = loader.load()
# 创建向量数据库
embeddings = OpenAIEmbeddings()
db = Chroma.from_documents(documents, embeddings)
# 创建检索问答链
qa_chain = RetrievalQA.from_chain_type(
llm=OpenAI(),
chain_type="stuff",
retriever=db.as_retriever()
)
# 查询
result = qa_chain.run("什么是 RAG 技术?")
print(result)检索策略
关键词检索
python
# 基于关键词的检索
def keyword_search(query, documents):
results = []
for doc in documents:
if any(keyword.lower() in doc.lower() for keyword in query.split()):
results.append(doc)
return results[:5]语义检索
python
# 基于语义相似度的检索
def semantic_search(query, documents, model):
query_embedding = model.encode(query)
similarities = []
for i, doc in enumerate(documents):
doc_embedding = model.encode(doc)
similarity = cosine_similarity(query_embedding, doc_embedding)
similarities.append((i, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return [documents[i] for i, _ in similarities[:5]]混合检索
python
# 结合关键词和语义检索
def hybrid_search(query, documents, model):
keyword_results = keyword_search(query, documents)
semantic_results = semantic_search(query, documents, model)
# 合并结果,去重
combined = list(set(keyword_results + semantic_results))
return combined[:5]上下文构建
拼接方式
python
# 简单拼接
def build_context(documents):
context = "\n\n".join(documents)
return context
# 带来源标注
def build_context_with_sources(documents):
context = ""
for i, doc in enumerate(documents):
context += f"来源{i+1}:\n{doc}\n\n"
return context示例
python
# 构建提示词
prompt = f"""
请根据以下上下文回答问题:
{context}
问题:{query}
回答:
"""评估指标
检索效果
python
# 计算召回率
def calculate_recall(retrieved_docs, relevant_docs):
retrieved_set = set(retrieved_docs)
relevant_set = set(relevant_docs)
intersection = retrieved_set & relevant_set
return len(intersection) / len(relevant_set)
# 计算精确率
def calculate_precision(retrieved_docs, relevant_docs):
retrieved_set = set(retrieved_docs)
relevant_set = set(relevant_docs)
intersection = retrieved_set & relevant_set
return len(intersection) / len(retrieved_set)回答质量
python
# 使用 ROUGE 评估
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
scores = scorer.score(reference, prediction)
print(f"ROUGE-1: {scores['rouge1'].fmeasure}")
print(f"ROUGE-L: {scores['rougeL'].fmeasure}")优化策略
文档分段
python
# 将长文档分段
def split_document(document, chunk_size=500):
chunks = []
words = document.split()
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i+chunk_size])
chunks.append(chunk)
return chunks重排序
python
# 使用 Cross-Encoder 重排序
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(query, documents):
pairs = [[query, doc] for doc in documents]
scores = model.predict(pairs)
# 按分数排序
sorted_docs = [doc for _, doc in sorted(zip(scores, documents), reverse=True)]
return sorted_docs