Spaces:
Paused
Paused
lanny xu
commited on
Commit
·
9cce495
1
Parent(s):
55a0955
add react
Browse files- config.py +2 -2
- main.py +56 -6
- routers_and_graders.py +94 -7
- workflow_nodes.py +160 -18
config.py
CHANGED
|
@@ -51,8 +51,8 @@ LOCAL_LLM = "mistral" # 在Kaggle中可改为 "phi" 或 "tinyllama"
|
|
| 51 |
# 知识库URL配置
|
| 52 |
KNOWLEDGE_BASE_URLS = [
|
| 53 |
"https://lilianweng.github.io/posts/2023-06-23-agent/",
|
| 54 |
-
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
|
| 55 |
-
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
|
| 56 |
]
|
| 57 |
|
| 58 |
# 文档分块配置
|
|
|
|
| 51 |
# 知识库URL配置
|
| 52 |
KNOWLEDGE_BASE_URLS = [
|
| 53 |
"https://lilianweng.github.io/posts/2023-06-23-agent/",
|
| 54 |
+
# "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
|
| 55 |
+
# "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
|
| 56 |
]
|
| 57 |
|
| 58 |
# 文档分块配置
|
main.py
CHANGED
|
@@ -3,13 +3,20 @@
|
|
| 3 |
集成所有模块,构建工作流并运行自适应RAG系统
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
from langgraph.graph import END, StateGraph, START
|
| 7 |
from pprint import pprint
|
| 8 |
|
| 9 |
-
from config import setup_environment, validate_api_keys
|
| 10 |
from document_processor import initialize_document_processor
|
| 11 |
from routers_and_graders import initialize_graders_and_router
|
| 12 |
from workflow_nodes import WorkflowNodes, GraphState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class AdaptiveRAGSystem:
|
|
@@ -54,6 +61,23 @@ class AdaptiveRAGSystem:
|
|
| 54 |
print("初始化评分器和路由器...")
|
| 55 |
self.graders = initialize_graders_and_router()
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# 初始化工作流节点
|
| 58 |
print("设置工作流节点...")
|
| 59 |
# WorkflowNodes 将在 _build_workflow 中初始化
|
|
@@ -91,6 +115,8 @@ class AdaptiveRAGSystem:
|
|
| 91 |
workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
|
| 92 |
workflow.add_node("generate", self.workflow_nodes.generate)
|
| 93 |
workflow.add_node("transform_query", self.workflow_nodes.transform_query)
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# 构建图
|
| 96 |
workflow.add_conditional_edges(
|
|
@@ -98,20 +124,23 @@ class AdaptiveRAGSystem:
|
|
| 98 |
self.workflow_nodes.route_question,
|
| 99 |
{
|
| 100 |
"web_search": "web_search",
|
| 101 |
-
"vectorstore": "
|
| 102 |
},
|
| 103 |
)
|
| 104 |
workflow.add_edge("web_search", "generate")
|
|
|
|
| 105 |
workflow.add_edge("retrieve", "grade_documents")
|
| 106 |
workflow.add_conditional_edges(
|
| 107 |
"grade_documents",
|
| 108 |
self.workflow_nodes.decide_to_generate,
|
| 109 |
{
|
| 110 |
"transform_query": "transform_query",
|
|
|
|
| 111 |
"generate": "generate",
|
| 112 |
},
|
| 113 |
)
|
| 114 |
workflow.add_edge("transform_query", "retrieve")
|
|
|
|
| 115 |
workflow.add_conditional_edges(
|
| 116 |
"generate",
|
| 117 |
self.workflow_nodes.grade_generation_v_documents_and_question,
|
|
@@ -151,10 +180,16 @@ class AdaptiveRAGSystem:
|
|
| 151 |
# 设置配置,增加递归限制
|
| 152 |
config = {"recursion_limit": 50} # 增加到 50,默认是 25
|
| 153 |
|
|
|
|
| 154 |
for output in self.app.stream(inputs, config=config):
|
| 155 |
for key, value in output.items():
|
| 156 |
if verbose:
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
# 可选:在每个节点打印完整状态
|
| 159 |
# pprint(value, indent=2, width=80, depth=None)
|
| 160 |
final_generation = value.get("generation", final_generation)
|
|
@@ -162,11 +197,25 @@ class AdaptiveRAGSystem:
|
|
| 162 |
if "retrieval_metrics" in value:
|
| 163 |
retrieval_metrics = value["retrieval_metrics"]
|
| 164 |
if verbose:
|
| 165 |
-
pprint("\n---\n")
|
|
|
|
| 166 |
|
|
|
|
| 167 |
print("🎯 最终答案:")
|
| 168 |
print("-" * 30)
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
print("=" * 50)
|
| 171 |
|
| 172 |
# 返回包含答案和评估指标的字典
|
|
@@ -220,7 +269,8 @@ def main():
|
|
| 220 |
rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
|
| 221 |
|
| 222 |
# 测试查询
|
| 223 |
-
test_question = "AlphaCodium论文讲的是什么?"
|
|
|
|
| 224 |
# test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
|
| 225 |
result = rag_system.query(test_question)
|
| 226 |
|
|
|
|
| 3 |
集成所有模块,构建工作流并运行自适应RAG系统
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import time
|
| 7 |
from langgraph.graph import END, StateGraph, START
|
| 8 |
from pprint import pprint
|
| 9 |
|
| 10 |
+
from config import setup_environment, validate_api_keys, ENABLE_GRAPHRAG
|
| 11 |
from document_processor import initialize_document_processor
|
| 12 |
from routers_and_graders import initialize_graders_and_router
|
| 13 |
from workflow_nodes import WorkflowNodes, GraphState
|
| 14 |
+
try:
|
| 15 |
+
from knowledge_graph import initialize_knowledge_graph, initialize_community_summarizer
|
| 16 |
+
from graph_retriever import initialize_graph_retriever
|
| 17 |
+
except ImportError:
|
| 18 |
+
print("⚠️ 无法导入知识图谱模块,GraphRAG功能将不可用")
|
| 19 |
+
ENABLE_GRAPHRAG = False
|
| 20 |
|
| 21 |
|
| 22 |
class AdaptiveRAGSystem:
|
|
|
|
| 61 |
print("初始化评分器和路由器...")
|
| 62 |
self.graders = initialize_graders_and_router()
|
| 63 |
|
| 64 |
+
# 初始化知识图谱 (如果启用)
|
| 65 |
+
self.graph_retriever = None
|
| 66 |
+
if ENABLE_GRAPHRAG:
|
| 67 |
+
print("初始化 GraphRAG...")
|
| 68 |
+
try:
|
| 69 |
+
kg = initialize_knowledge_graph()
|
| 70 |
+
# 尝试加载已有的图谱数据
|
| 71 |
+
try:
|
| 72 |
+
kg.load_from_file("knowledge_graph.json")
|
| 73 |
+
except FileNotFoundError:
|
| 74 |
+
print(" 未找到 existing knowledge_graph.json, 将使用空图谱")
|
| 75 |
+
|
| 76 |
+
self.graph_retriever = initialize_graph_retriever(kg)
|
| 77 |
+
print("✅ GraphRAG 初始化成功")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"⚠️ GraphRAG 初始化失败: {e}")
|
| 80 |
+
|
| 81 |
# 初始化工作流节点
|
| 82 |
print("设置工作流节点...")
|
| 83 |
# WorkflowNodes 将在 _build_workflow 中初始化
|
|
|
|
| 115 |
workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
|
| 116 |
workflow.add_node("generate", self.workflow_nodes.generate)
|
| 117 |
workflow.add_node("transform_query", self.workflow_nodes.transform_query)
|
| 118 |
+
workflow.add_node("decompose_query", self.workflow_nodes.decompose_query)
|
| 119 |
+
workflow.add_node("prepare_next_query", self.workflow_nodes.prepare_next_query)
|
| 120 |
|
| 121 |
# 构建图
|
| 122 |
workflow.add_conditional_edges(
|
|
|
|
| 124 |
self.workflow_nodes.route_question,
|
| 125 |
{
|
| 126 |
"web_search": "web_search",
|
| 127 |
+
"vectorstore": "decompose_query", # 向量检索前先进行查询分解
|
| 128 |
},
|
| 129 |
)
|
| 130 |
workflow.add_edge("web_search", "generate")
|
| 131 |
+
workflow.add_edge("decompose_query", "retrieve")
|
| 132 |
workflow.add_edge("retrieve", "grade_documents")
|
| 133 |
workflow.add_conditional_edges(
|
| 134 |
"grade_documents",
|
| 135 |
self.workflow_nodes.decide_to_generate,
|
| 136 |
{
|
| 137 |
"transform_query": "transform_query",
|
| 138 |
+
"prepare_next_query": "prepare_next_query",
|
| 139 |
"generate": "generate",
|
| 140 |
},
|
| 141 |
)
|
| 142 |
workflow.add_edge("transform_query", "retrieve")
|
| 143 |
+
workflow.add_edge("prepare_next_query", "retrieve")
|
| 144 |
workflow.add_conditional_edges(
|
| 145 |
"generate",
|
| 146 |
self.workflow_nodes.grade_generation_v_documents_and_question,
|
|
|
|
| 180 |
# 设置配置,增加递归限制
|
| 181 |
config = {"recursion_limit": 50} # 增加到 50,默认是 25
|
| 182 |
|
| 183 |
+
print("\n🤖 思考过程:")
|
| 184 |
for output in self.app.stream(inputs, config=config):
|
| 185 |
for key, value in output.items():
|
| 186 |
if verbose:
|
| 187 |
+
# 简单的节点执行提示,模拟流式感
|
| 188 |
+
print(f" ↳ 执行节点: {key}...", end="\r")
|
| 189 |
+
time.sleep(0.1) # 视觉暂停
|
| 190 |
+
print(f" ✅ 完成节点: {key} ")
|
| 191 |
+
|
| 192 |
+
# pprint(f"节点 '{key}':")
|
| 193 |
# 可选:在每个节点打印完整状态
|
| 194 |
# pprint(value, indent=2, width=80, depth=None)
|
| 195 |
final_generation = value.get("generation", final_generation)
|
|
|
|
| 197 |
if "retrieval_metrics" in value:
|
| 198 |
retrieval_metrics = value["retrieval_metrics"]
|
| 199 |
if verbose:
|
| 200 |
+
# pprint("\n---\n")
|
| 201 |
+
pass
|
| 202 |
|
| 203 |
+
print("\n" + "=" * 50)
|
| 204 |
print("🎯 最终答案:")
|
| 205 |
print("-" * 30)
|
| 206 |
+
|
| 207 |
+
# 模拟流式输出效果 (打字机效果)
|
| 208 |
+
if final_generation:
|
| 209 |
+
import sys
|
| 210 |
+
import time
|
| 211 |
+
for char in final_generation:
|
| 212 |
+
sys.stdout.write(char)
|
| 213 |
+
sys.stdout.flush()
|
| 214 |
+
time.sleep(0.01) # 控制打字速度
|
| 215 |
+
print() # 换行
|
| 216 |
+
else:
|
| 217 |
+
print("未生成答案")
|
| 218 |
+
|
| 219 |
print("=" * 50)
|
| 220 |
|
| 221 |
# 返回包含答案和评估指标的字典
|
|
|
|
| 269 |
rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
|
| 270 |
|
| 271 |
# 测试查询
|
| 272 |
+
# test_question = "AlphaCodium论文讲的是什么?"
|
| 273 |
+
test_question = "LangGraph的作者目前在哪家公司工作?"
|
| 274 |
# test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
|
| 275 |
result = rag_system.query(test_question)
|
| 276 |
|
routers_and_graders.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
try:
|
| 7 |
from langchain_core.prompts import PromptTemplate
|
| 8 |
except ImportError:
|
|
@@ -152,23 +153,105 @@ class HallucinationGrader:
|
|
| 152 |
return result.get("score", "no")
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
class QueryRewriter:
|
| 156 |
"""查询重写器,优化查询以获得更好的检索结果"""
|
| 157 |
|
| 158 |
def __init__(self):
|
| 159 |
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
| 160 |
self.prompt = PromptTemplate(
|
| 161 |
-
template="""
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
self.rewriter = self.prompt | self.llm | StrOutputParser()
|
| 167 |
|
| 168 |
-
def rewrite(self, question: str) -> str:
|
| 169 |
"""重写查询以获得更好的检索效果"""
|
| 170 |
print(f"---原始查询: {question}---")
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
| 172 |
print(f"---重写查询: {rewritten_query}---")
|
| 173 |
return rewritten_query
|
| 174 |
|
|
@@ -187,11 +270,15 @@ def initialize_graders_and_router():
|
|
| 187 |
answer_grader = AnswerGrader()
|
| 188 |
hallucination_grader = HallucinationGrader(method=detection_method)
|
| 189 |
query_rewriter = QueryRewriter()
|
|
|
|
|
|
|
| 190 |
|
| 191 |
return {
|
| 192 |
"query_router": query_router,
|
| 193 |
"document_grader": document_grader,
|
| 194 |
"answer_grader": answer_grader,
|
| 195 |
"hallucination_grader": hallucination_grader,
|
| 196 |
-
"query_rewriter": query_rewriter
|
|
|
|
|
|
|
| 197 |
}
|
|
|
|
| 3 |
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from typing import List
|
| 7 |
try:
|
| 8 |
from langchain_core.prompts import PromptTemplate
|
| 9 |
except ImportError:
|
|
|
|
| 153 |
return result.get("score", "no")
|
| 154 |
|
| 155 |
|
| 156 |
+
class QueryDecomposer:
|
| 157 |
+
"""查询分解器,将复杂的多跳问题分解为子问题序列"""
|
| 158 |
+
|
| 159 |
+
def __init__(self):
|
| 160 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 161 |
+
self.prompt = PromptTemplate(
|
| 162 |
+
template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。
|
| 163 |
+
|
| 164 |
+
分解规则:
|
| 165 |
+
1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。
|
| 166 |
+
2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤:
|
| 167 |
+
- 步骤1: "谁是A的作者?"
|
| 168 |
+
- 步骤2: "该作者在哪个大学?"
|
| 169 |
+
3. 保持子问题简洁明了。
|
| 170 |
+
4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。
|
| 171 |
+
|
| 172 |
+
输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。
|
| 173 |
+
不要输出任何前言或解释。
|
| 174 |
+
|
| 175 |
+
复杂问题: {question}""",
|
| 176 |
+
input_variables=["question"],
|
| 177 |
+
)
|
| 178 |
+
self.decomposer = self.prompt | self.llm | JsonOutputParser()
|
| 179 |
+
|
| 180 |
+
def decompose(self, question: str) -> List[str]:
|
| 181 |
+
"""分解问题"""
|
| 182 |
+
print(f"---分解问题: {question}---")
|
| 183 |
+
try:
|
| 184 |
+
result = self.decomposer.invoke({"question": question})
|
| 185 |
+
sub_queries = result.get("sub_queries", [question])
|
| 186 |
+
# 确保至少包含原始问题
|
| 187 |
+
if not sub_queries:
|
| 188 |
+
sub_queries = [question]
|
| 189 |
+
print(f"---子问题: {sub_queries}---")
|
| 190 |
+
return sub_queries
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"⚠️ 分解失败: {e},使用原始问题")
|
| 193 |
+
return [question]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class AnswerabilityGrader:
|
| 197 |
+
"""答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题"""
|
| 198 |
+
|
| 199 |
+
def __init__(self):
|
| 200 |
+
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
|
| 201 |
+
self.prompt = PromptTemplate(
|
| 202 |
+
template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。
|
| 203 |
+
|
| 204 |
+
原始问题: {question}
|
| 205 |
+
|
| 206 |
+
目前检索到的文档集合:
|
| 207 |
+
{documents}
|
| 208 |
+
|
| 209 |
+
任务:
|
| 210 |
+
判断上述文档是否已经包含了回答原始问题所需的全部关键信息。
|
| 211 |
+
- 如果信息充足,可以终止进一步的检索,返回 'yes'。
|
| 212 |
+
- 如果信息缺失,需要继续检索更多信息,返回 'no'。
|
| 213 |
+
|
| 214 |
+
输出格式:
|
| 215 |
+
返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。
|
| 216 |
+
不要输出任何前言或解释。""",
|
| 217 |
+
input_variables=["question", "documents"],
|
| 218 |
+
)
|
| 219 |
+
self.grader = self.prompt | self.llm | JsonOutputParser()
|
| 220 |
+
|
| 221 |
+
def grade(self, question: str, documents: str) -> str:
|
| 222 |
+
"""评估文档是否足以回答问题"""
|
| 223 |
+
result = self.grader.invoke({"question": question, "documents": documents})
|
| 224 |
+
return result.get("score", "no")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
class QueryRewriter:
|
| 228 |
"""查询重写器,优化查询以获得更好的检索结果"""
|
| 229 |
|
| 230 |
def __init__(self):
|
| 231 |
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
| 232 |
self.prompt = PromptTemplate(
|
| 233 |
+
template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。
|
| 234 |
+
|
| 235 |
+
你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。
|
| 236 |
+
如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。
|
| 237 |
+
|
| 238 |
+
初始问题: {question}
|
| 239 |
+
|
| 240 |
+
之前的上下文(如果有):
|
| 241 |
+
{context}
|
| 242 |
+
|
| 243 |
+
改进的问题(只输出问题,无前言):""",
|
| 244 |
+
input_variables=["question", "context"],
|
| 245 |
)
|
| 246 |
self.rewriter = self.prompt | self.llm | StrOutputParser()
|
| 247 |
|
| 248 |
+
def rewrite(self, question: str, context: str = "") -> str:
|
| 249 |
"""重写查询以获得更好的检索效果"""
|
| 250 |
print(f"---原始查询: {question}---")
|
| 251 |
+
if context:
|
| 252 |
+
print(f"---参考上下文长度: {len(context)} 字符---")
|
| 253 |
+
|
| 254 |
+
rewritten_query = self.rewriter.invoke({"question": question, "context": context})
|
| 255 |
print(f"---重写查询: {rewritten_query}---")
|
| 256 |
return rewritten_query
|
| 257 |
|
|
|
|
| 270 |
answer_grader = AnswerGrader()
|
| 271 |
hallucination_grader = HallucinationGrader(method=detection_method)
|
| 272 |
query_rewriter = QueryRewriter()
|
| 273 |
+
query_decomposer = QueryDecomposer()
|
| 274 |
+
answerability_grader = AnswerabilityGrader()
|
| 275 |
|
| 276 |
return {
|
| 277 |
"query_router": query_router,
|
| 278 |
"document_grader": document_grader,
|
| 279 |
"answer_grader": answer_grader,
|
| 280 |
"hallucination_grader": hallucination_grader,
|
| 281 |
+
"query_rewriter": query_rewriter,
|
| 282 |
+
"query_decomposer": query_decomposer,
|
| 283 |
+
"answerability_grader": answerability_grader
|
| 284 |
}
|
workflow_nodes.py
CHANGED
|
@@ -49,6 +49,9 @@ class GraphState(TypedDict):
|
|
| 49 |
documents: List[str]
|
| 50 |
retry_count: int
|
| 51 |
retrieval_metrics: dict # 添加检索评估指标
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
class WorkflowNodes:
|
|
@@ -64,14 +67,18 @@ class WorkflowNodes:
|
|
| 64 |
|
| 65 |
# 设置RAG链 - 使用本地提示模板
|
| 66 |
rag_prompt_template = PromptTemplate(
|
| 67 |
-
template="""
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
input_variables=["question", "context"]
|
| 76 |
)
|
| 77 |
llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
|
@@ -80,6 +87,37 @@ class WorkflowNodes:
|
|
| 80 |
# 设置网络搜索
|
| 81 |
self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def retrieve(self, state):
|
| 84 |
"""
|
| 85 |
检索文档
|
|
@@ -111,7 +149,7 @@ class WorkflowNodes:
|
|
| 111 |
|
| 112 |
# 记录使用的检索方法
|
| 113 |
if ENABLE_HYBRID_SEARCH:
|
| 114 |
-
print("
|
| 115 |
if ENABLE_QUERY_EXPANSION:
|
| 116 |
print("---使用查询扩展---")
|
| 117 |
if image_paths and ENABLE_MULTIMODAL:
|
|
@@ -136,6 +174,26 @@ class WorkflowNodes:
|
|
| 136 |
print(f"❌ 回退检索也失败: {fallback_e}")
|
| 137 |
documents = []
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# 计算检索时间
|
| 140 |
retrieval_time = time.time() - retrieval_start_time
|
| 141 |
|
|
@@ -161,10 +219,12 @@ class WorkflowNodes:
|
|
| 161 |
"""
|
| 162 |
print("---生成---")
|
| 163 |
question = state["question"]
|
|
|
|
| 164 |
documents = state["documents"]
|
| 165 |
|
| 166 |
-
# RAG生成
|
| 167 |
-
|
|
|
|
| 168 |
return {"documents": documents, "question": question, "generation": generation}
|
| 169 |
|
| 170 |
def grade_documents(self, state):
|
|
@@ -211,8 +271,18 @@ class WorkflowNodes:
|
|
| 211 |
|
| 212 |
print(f" 重试次数: {retry_count}")
|
| 213 |
|
| 214 |
-
#
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
return {"documents": documents, "question": better_question, "retry_count": retry_count}
|
| 217 |
|
| 218 |
def web_search(self, state):
|
|
@@ -257,18 +327,65 @@ class WorkflowNodes:
|
|
| 257 |
print("---将问题路由到RAG---")
|
| 258 |
return "vectorstore"
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
def decide_to_generate(self, state):
|
| 261 |
"""
|
| 262 |
-
|
| 263 |
|
| 264 |
Args:
|
| 265 |
state (dict): 当前图状态
|
| 266 |
|
| 267 |
Returns:
|
| 268 |
-
str:
|
| 269 |
"""
|
| 270 |
print("---评估已评分的文档---")
|
| 271 |
filtered_documents = state["documents"]
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if not filtered_documents:
|
| 274 |
# 所有文档都被过滤掉了
|
|
@@ -276,9 +393,34 @@ class WorkflowNodes:
|
|
| 276 |
print("---决策:所有文档都与问题不相关,转换查询---")
|
| 277 |
return "transform_query"
|
| 278 |
else:
|
| 279 |
-
#
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
def grade_generation_v_documents_and_question(self, state):
|
| 284 |
"""
|
|
|
|
| 49 |
documents: List[str]
|
| 50 |
retry_count: int
|
| 51 |
retrieval_metrics: dict # 添加检索评估指标
|
| 52 |
+
sub_queries: List[str] # 分解后的子问题列表
|
| 53 |
+
current_query_index: int # 当前处理的子问题索引
|
| 54 |
+
original_question: str # 原始问题,用于早期终止检查
|
| 55 |
|
| 56 |
|
| 57 |
class WorkflowNodes:
|
|
|
|
| 67 |
|
| 68 |
# 设置RAG链 - 使用本地提示模板
|
| 69 |
rag_prompt_template = PromptTemplate(
|
| 70 |
+
template="""你是一个智能问答助手。使用以下检索到的上下文来回答问题。
|
| 71 |
+
|
| 72 |
+
规则:
|
| 73 |
+
1. 如果你不知道答案,就说你不知道。
|
| 74 |
+
2. 如果用户请求特定格式(如Markdown、列表、代码块等),请严格遵守。
|
| 75 |
+
3. 如果没有特定格式要求,保持答案简洁。
|
| 76 |
+
|
| 77 |
+
问题: {question}
|
| 78 |
+
|
| 79 |
+
上下文: {context}
|
| 80 |
+
|
| 81 |
+
答案:""",
|
| 82 |
input_variables=["question", "context"]
|
| 83 |
)
|
| 84 |
llm = ChatOllama(model=LOCAL_LLM, temperature=0)
|
|
|
|
| 87 |
# 设置网络搜索
|
| 88 |
self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
|
| 89 |
|
| 90 |
+
def decompose_query(self, state):
|
| 91 |
+
"""
|
| 92 |
+
将初始查询分解为子查询
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
state (dict): 当前图状态
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
state (dict): 更新sub_queries和current_query_index
|
| 99 |
+
"""
|
| 100 |
+
print("---查询分解---")
|
| 101 |
+
question = state["question"]
|
| 102 |
+
|
| 103 |
+
# 使用分解器
|
| 104 |
+
sub_queries = self.graders["query_decomposer"].decompose(question)
|
| 105 |
+
|
| 106 |
+
# 如果分解器返回空或只有一个问题,我们仍然将其视为列表
|
| 107 |
+
if not sub_queries:
|
| 108 |
+
sub_queries = [question]
|
| 109 |
+
|
| 110 |
+
print(f" 生成了 {len(sub_queries)} 个子查询")
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"sub_queries": sub_queries,
|
| 114 |
+
"current_query_index": 0,
|
| 115 |
+
"question": sub_queries[0], # 将当前问题设置为第一个子查询
|
| 116 |
+
"original_question": question, # 保存原始问题
|
| 117 |
+
"documents": [], # 清空文档,准备开始新的检索
|
| 118 |
+
"retry_count": 0
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
def retrieve(self, state):
|
| 122 |
"""
|
| 123 |
检索文档
|
|
|
|
| 149 |
|
| 150 |
# 记录使用的检索方法
|
| 151 |
if ENABLE_HYBRID_SEARCH:
|
| 152 |
+
print("---使用混合检索(向量+关键词)---")
|
| 153 |
if ENABLE_QUERY_EXPANSION:
|
| 154 |
print("---使用查询扩展---")
|
| 155 |
if image_paths and ENABLE_MULTIMODAL:
|
|
|
|
| 174 |
print(f"❌ 回退检索也失败: {fallback_e}")
|
| 175 |
documents = []
|
| 176 |
|
| 177 |
+
# === 向量多跳检索支持:合并上下文 ===
|
| 178 |
+
# 如果这不是第一次检索(即重试次数 > 0 或 正在处理后续子查询),说明之前的检索结果可能不完整或问题被重写了
|
| 179 |
+
# 我们应该保留之前的有价值文档,实现 "累积式上下文" (Accumulated Context)
|
| 180 |
+
current_query_index = state.get("current_query_index", 0)
|
| 181 |
+
if (retry_count > 0 or current_query_index > 0) and "documents" in state and state["documents"]:
|
| 182 |
+
print(f"---多跳上下文合并 (轮次 {retry_count}, 子查询 {current_query_index})---")
|
| 183 |
+
previous_docs = state["documents"]
|
| 184 |
+
if previous_docs:
|
| 185 |
+
# 简单的去重合并(基于内容)
|
| 186 |
+
current_content = {d.page_content for d in documents}
|
| 187 |
+
merged_count = 0
|
| 188 |
+
for prev_doc in previous_docs:
|
| 189 |
+
# 只有当内容不重复时才添加
|
| 190 |
+
if prev_doc.page_content not in current_content:
|
| 191 |
+
documents.append(prev_doc)
|
| 192 |
+
current_content.add(prev_doc.page_content)
|
| 193 |
+
merged_count += 1
|
| 194 |
+
print(f" 合并了 {merged_count} 个来自上一轮/上一跳的文档,当前总文档数: {len(documents)}")
|
| 195 |
+
# =================================
|
| 196 |
+
|
| 197 |
# 计算检索时间
|
| 198 |
retrieval_time = time.time() - retrieval_start_time
|
| 199 |
|
|
|
|
| 219 |
"""
|
| 220 |
print("---生成---")
|
| 221 |
question = state["question"]
|
| 222 |
+
original_question = state.get("original_question", question) # 优先使用原始问题
|
| 223 |
documents = state["documents"]
|
| 224 |
|
| 225 |
+
# RAG生成 - 使用原始问题以确保回答用户的初始意图
|
| 226 |
+
# 如果用户有特定的格式要求(如Markdown),通常包含在original_question中
|
| 227 |
+
generation = self.rag_chain.invoke({"context": documents, "question": original_question})
|
| 228 |
return {"documents": documents, "question": question, "generation": generation}
|
| 229 |
|
| 230 |
def grade_documents(self, state):
|
|
|
|
| 271 |
|
| 272 |
print(f" 重试次数: {retry_count}")
|
| 273 |
|
| 274 |
+
# 提取当前上下文摘要,帮助重写器理解缺失信息
|
| 275 |
+
context_summary = ""
|
| 276 |
+
if documents:
|
| 277 |
+
# 只提取前两个文档的摘要,避免上下文过长
|
| 278 |
+
docs_content = [d.page_content for d in documents[:2]]
|
| 279 |
+
context_summary = "\n---\n".join(docs_content)
|
| 280 |
+
# 截断以防止过长
|
| 281 |
+
if len(context_summary) > 2000:
|
| 282 |
+
context_summary = context_summary[:2000] + "...(截断)"
|
| 283 |
+
|
| 284 |
+
# 重写问题,传入上下文
|
| 285 |
+
better_question = self.graders["query_rewriter"].rewrite(question, context=context_summary)
|
| 286 |
return {"documents": documents, "question": better_question, "retry_count": retry_count}
|
| 287 |
|
| 288 |
def web_search(self, state):
|
|
|
|
| 327 |
print("---将问题路由到RAG---")
|
| 328 |
return "vectorstore"
|
| 329 |
|
| 330 |
+
def prepare_next_query(self, state):
|
| 331 |
+
"""
|
| 332 |
+
准备下一个子查询:提取桥接实体并重写查询
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
state (dict): 当前图状态
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
state (dict): 更新question, current_query_index, retry_count
|
| 339 |
+
"""
|
| 340 |
+
print("---准备下一个子查询---")
|
| 341 |
+
current_query_index = state.get("current_query_index", 0)
|
| 342 |
+
sub_queries = state.get("sub_queries", [])
|
| 343 |
+
documents = state["documents"]
|
| 344 |
+
|
| 345 |
+
# 移动到下一个索引
|
| 346 |
+
next_index = current_query_index + 1
|
| 347 |
+
next_query_raw = sub_queries[next_index]
|
| 348 |
+
|
| 349 |
+
print(f" 原始下一个子查询: {next_query_raw}")
|
| 350 |
+
|
| 351 |
+
# 提取上下文摘要用于重写(桥接实体提取)
|
| 352 |
+
context_summary = ""
|
| 353 |
+
if documents:
|
| 354 |
+
# 使用所有相关文档作为上下文
|
| 355 |
+
docs_content = [d.page_content for d in documents]
|
| 356 |
+
context_summary = "\n---\n".join(docs_content)
|
| 357 |
+
# 截断
|
| 358 |
+
if len(context_summary) > 3000:
|
| 359 |
+
context_summary = context_summary[:3000] + "...(截断)"
|
| 360 |
+
|
| 361 |
+
# 使用重写器将上下文(包含桥接实体)注入到下一个查询中
|
| 362 |
+
# 例如:Q1结果是"作者是J.K. Rowling",Q2是"她出生在哪里?" -> "J.K. Rowling出生在哪里?"
|
| 363 |
+
better_next_query = self.graders["query_rewriter"].rewrite(next_query_raw, context=context_summary)
|
| 364 |
+
|
| 365 |
+
print(f" 优化后的下一个子查询: {better_next_query}")
|
| 366 |
+
|
| 367 |
+
return {
|
| 368 |
+
"question": better_next_query,
|
| 369 |
+
"current_query_index": next_index,
|
| 370 |
+
"retry_count": 0, # 重置重试计数
|
| 371 |
+
"documents": documents # 保留文档作为上下文
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
def decide_to_generate(self, state):
|
| 375 |
"""
|
| 376 |
+
确定是生成答案、继续下一个子查询还是重新生成问题
|
| 377 |
|
| 378 |
Args:
|
| 379 |
state (dict): 当前图状态
|
| 380 |
|
| 381 |
Returns:
|
| 382 |
+
str: 要调用的下一个节点的决策
|
| 383 |
"""
|
| 384 |
print("---评估已评分的文档---")
|
| 385 |
filtered_documents = state["documents"]
|
| 386 |
+
current_query_index = state.get("current_query_index", 0)
|
| 387 |
+
sub_queries = state.get("sub_queries", [])
|
| 388 |
+
original_question = state.get("original_question", "")
|
| 389 |
|
| 390 |
if not filtered_documents:
|
| 391 |
# 所有文档都被过滤掉了
|
|
|
|
| 393 |
print("---决策:所有文档都与问题不相关,转换查询---")
|
| 394 |
return "transform_query"
|
| 395 |
else:
|
| 396 |
+
# 我们有相关文档
|
| 397 |
+
# 检查是否有更多子查询
|
| 398 |
+
if sub_queries and current_query_index < len(sub_queries) - 1:
|
| 399 |
+
# === 早期终止检查 ===
|
| 400 |
+
# 检查当前累积的文档是否已经足以回答原始问题
|
| 401 |
+
if original_question:
|
| 402 |
+
print("---检查是否已获取足够信息 (早期终止)---")
|
| 403 |
+
|
| 404 |
+
# 准备文档上下文
|
| 405 |
+
docs_content = [d.page_content for d in filtered_documents]
|
| 406 |
+
context_summary = "\n---\n".join(docs_content)
|
| 407 |
+
if len(context_summary) > 5000: # 限制上下文长度
|
| 408 |
+
context_summary = context_summary[:5000]
|
| 409 |
+
|
| 410 |
+
score = self.graders["answerability_grader"].grade(original_question, context_summary)
|
| 411 |
+
|
| 412 |
+
if score == "yes":
|
| 413 |
+
print(f"---决策:当前信息已足够回答原始问题,跳过剩余 {len(sub_queries) - 1 - current_query_index} 个子查询---")
|
| 414 |
+
return "generate"
|
| 415 |
+
else:
|
| 416 |
+
print("---决策:信息尚不完整,继续下一个子查询---")
|
| 417 |
+
|
| 418 |
+
print(f"---决策:当前子查询 ({current_query_index + 1}/{len(sub_queries)}) 完成,准备下一个---")
|
| 419 |
+
return "prepare_next_query"
|
| 420 |
+
else:
|
| 421 |
+
# 所有子查询都完成(或没有子查询),生成答案
|
| 422 |
+
print("---决策:所有子查询完成,生成---")
|
| 423 |
+
return "generate"
|
| 424 |
|
| 425 |
def grade_generation_v_documents_and_question(self, state):
|
| 426 |
"""
|