lanny xu commited on
Commit
9cce495
·
1 Parent(s): 55a0955
Files changed (4) hide show
  1. config.py +2 -2
  2. main.py +56 -6
  3. routers_and_graders.py +94 -7
  4. workflow_nodes.py +160 -18
config.py CHANGED
@@ -51,8 +51,8 @@ LOCAL_LLM = "mistral" # 在Kaggle中可改为 "phi" 或 "tinyllama"
51
  # 知识库URL配置
52
  KNOWLEDGE_BASE_URLS = [
53
  "https://lilianweng.github.io/posts/2023-06-23-agent/",
54
- "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
55
- "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
56
  ]
57
 
58
  # 文档分块配置
 
51
  # 知识库URL配置
52
  KNOWLEDGE_BASE_URLS = [
53
  "https://lilianweng.github.io/posts/2023-06-23-agent/",
54
+ # "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
55
+ # "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
56
  ]
57
 
58
  # 文档分块配置
main.py CHANGED
@@ -3,13 +3,20 @@
3
  集成所有模块,构建工作流并运行自适应RAG系统
4
  """
5
 
 
6
  from langgraph.graph import END, StateGraph, START
7
  from pprint import pprint
8
 
9
- from config import setup_environment, validate_api_keys
10
  from document_processor import initialize_document_processor
11
  from routers_and_graders import initialize_graders_and_router
12
  from workflow_nodes import WorkflowNodes, GraphState
 
 
 
 
 
 
13
 
14
 
15
  class AdaptiveRAGSystem:
@@ -54,6 +61,23 @@ class AdaptiveRAGSystem:
54
  print("初始化评分器和路由器...")
55
  self.graders = initialize_graders_and_router()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # 初始化工作流节点
58
  print("设置工作流节点...")
59
  # WorkflowNodes 将在 _build_workflow 中初始化
@@ -91,6 +115,8 @@ class AdaptiveRAGSystem:
91
  workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
92
  workflow.add_node("generate", self.workflow_nodes.generate)
93
  workflow.add_node("transform_query", self.workflow_nodes.transform_query)
 
 
94
 
95
  # 构建图
96
  workflow.add_conditional_edges(
@@ -98,20 +124,23 @@ class AdaptiveRAGSystem:
98
  self.workflow_nodes.route_question,
99
  {
100
  "web_search": "web_search",
101
- "vectorstore": "retrieve",
102
  },
103
  )
104
  workflow.add_edge("web_search", "generate")
 
105
  workflow.add_edge("retrieve", "grade_documents")
106
  workflow.add_conditional_edges(
107
  "grade_documents",
108
  self.workflow_nodes.decide_to_generate,
109
  {
110
  "transform_query": "transform_query",
 
111
  "generate": "generate",
112
  },
113
  )
114
  workflow.add_edge("transform_query", "retrieve")
 
115
  workflow.add_conditional_edges(
116
  "generate",
117
  self.workflow_nodes.grade_generation_v_documents_and_question,
@@ -151,10 +180,16 @@ class AdaptiveRAGSystem:
151
  # 设置配置,增加递归限制
152
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
153
 
 
154
  for output in self.app.stream(inputs, config=config):
155
  for key, value in output.items():
156
  if verbose:
157
- pprint(f"节点 '{key}':")
 
 
 
 
 
158
  # 可选:在每个节点打印完整状态
159
  # pprint(value, indent=2, width=80, depth=None)
160
  final_generation = value.get("generation", final_generation)
@@ -162,11 +197,25 @@ class AdaptiveRAGSystem:
162
  if "retrieval_metrics" in value:
163
  retrieval_metrics = value["retrieval_metrics"]
164
  if verbose:
165
- pprint("\n---\n")
 
166
 
 
167
  print("🎯 最终答案:")
168
  print("-" * 30)
169
- print(final_generation)
 
 
 
 
 
 
 
 
 
 
 
 
170
  print("=" * 50)
171
 
172
  # 返回包含答案和评估指标的字典
@@ -220,7 +269,8 @@ def main():
220
  rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
221
 
222
  # 测试查询
223
- test_question = "AlphaCodium论文讲的是什么?"
 
224
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
225
  result = rag_system.query(test_question)
226
 
 
3
  集成所有模块,构建工作流并运行自适应RAG系统
4
  """
5
 
6
+ import time
7
  from langgraph.graph import END, StateGraph, START
8
  from pprint import pprint
9
 
10
+ from config import setup_environment, validate_api_keys, ENABLE_GRAPHRAG
11
  from document_processor import initialize_document_processor
12
  from routers_and_graders import initialize_graders_and_router
13
  from workflow_nodes import WorkflowNodes, GraphState
14
+ try:
15
+ from knowledge_graph import initialize_knowledge_graph, initialize_community_summarizer
16
+ from graph_retriever import initialize_graph_retriever
17
+ except ImportError:
18
+ print("⚠️ 无法导入知识图谱模块,GraphRAG功能将不可用")
19
+ ENABLE_GRAPHRAG = False
20
 
21
 
22
  class AdaptiveRAGSystem:
 
61
  print("初始化评分器和路由器...")
62
  self.graders = initialize_graders_and_router()
63
 
64
+ # 初始化知识图谱 (如果启用)
65
+ self.graph_retriever = None
66
+ if ENABLE_GRAPHRAG:
67
+ print("初始化 GraphRAG...")
68
+ try:
69
+ kg = initialize_knowledge_graph()
70
+ # 尝试加载已有的图谱数据
71
+ try:
72
+ kg.load_from_file("knowledge_graph.json")
73
+ except FileNotFoundError:
74
+ print(" 未找到 existing knowledge_graph.json, 将使用空图谱")
75
+
76
+ self.graph_retriever = initialize_graph_retriever(kg)
77
+ print("✅ GraphRAG 初始化成功")
78
+ except Exception as e:
79
+ print(f"⚠️ GraphRAG 初始化失败: {e}")
80
+
81
  # 初始化工作流节点
82
  print("设置工作流节点...")
83
  # WorkflowNodes 将在 _build_workflow 中初始化
 
115
  workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
116
  workflow.add_node("generate", self.workflow_nodes.generate)
117
  workflow.add_node("transform_query", self.workflow_nodes.transform_query)
118
+ workflow.add_node("decompose_query", self.workflow_nodes.decompose_query)
119
+ workflow.add_node("prepare_next_query", self.workflow_nodes.prepare_next_query)
120
 
121
  # 构建图
122
  workflow.add_conditional_edges(
 
124
  self.workflow_nodes.route_question,
125
  {
126
  "web_search": "web_search",
127
+ "vectorstore": "decompose_query", # 向量检索前先进行查询分解
128
  },
129
  )
130
  workflow.add_edge("web_search", "generate")
131
+ workflow.add_edge("decompose_query", "retrieve")
132
  workflow.add_edge("retrieve", "grade_documents")
133
  workflow.add_conditional_edges(
134
  "grade_documents",
135
  self.workflow_nodes.decide_to_generate,
136
  {
137
  "transform_query": "transform_query",
138
+ "prepare_next_query": "prepare_next_query",
139
  "generate": "generate",
140
  },
141
  )
142
  workflow.add_edge("transform_query", "retrieve")
143
+ workflow.add_edge("prepare_next_query", "retrieve")
144
  workflow.add_conditional_edges(
145
  "generate",
146
  self.workflow_nodes.grade_generation_v_documents_and_question,
 
180
  # 设置配置,增加递归限制
181
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
182
 
183
+ print("\n🤖 思考过程:")
184
  for output in self.app.stream(inputs, config=config):
185
  for key, value in output.items():
186
  if verbose:
187
+ # 简单的节点执行提示,模拟流式感
188
+ print(f" ↳ 执行节点: {key}...", end="\r")
189
+ time.sleep(0.1) # 视觉暂停
190
+ print(f" ✅ 完成节点: {key} ")
191
+
192
+ # pprint(f"节点 '{key}':")
193
  # 可选:在每个节点打印完整状态
194
  # pprint(value, indent=2, width=80, depth=None)
195
  final_generation = value.get("generation", final_generation)
 
197
  if "retrieval_metrics" in value:
198
  retrieval_metrics = value["retrieval_metrics"]
199
  if verbose:
200
+ # pprint("\n---\n")
201
+ pass
202
 
203
+ print("\n" + "=" * 50)
204
  print("🎯 最终答案:")
205
  print("-" * 30)
206
+
207
+ # 模拟流式输出效果 (打字机效果)
208
+ if final_generation:
209
+ import sys
210
+ import time
211
+ for char in final_generation:
212
+ sys.stdout.write(char)
213
+ sys.stdout.flush()
214
+ time.sleep(0.01) # 控制打字速度
215
+ print() # 换行
216
+ else:
217
+ print("未生成答案")
218
+
219
  print("=" * 50)
220
 
221
  # 返回包含答案和评估指标的字典
 
269
  rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
270
 
271
  # 测试查询
272
+ # test_question = "AlphaCodium论文讲的是什么?"
273
+ test_question = "LangGraph的作者目前在哪家公司工作?"
274
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
275
  result = rag_system.query(test_question)
276
 
routers_and_graders.py CHANGED
@@ -3,6 +3,7 @@
3
  包含查询路由、文档相关性评分、答案质量评分和幻觉检测
4
  """
5
 
 
6
  try:
7
  from langchain_core.prompts import PromptTemplate
8
  except ImportError:
@@ -152,23 +153,105 @@ class HallucinationGrader:
152
  return result.get("score", "no")
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  class QueryRewriter:
156
  """查询重写器,优化查询以获得更好的检索结果"""
157
 
158
  def __init__(self):
159
  self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
160
  self.prompt = PromptTemplate(
161
- template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。
162
- 查看初始问题并制定一个改进的问题。
163
- 这里是初始问题:\n\n {question}。改进的问题(无前言):\n """,
164
- input_variables=["question"],
 
 
 
 
 
 
 
 
165
  )
166
  self.rewriter = self.prompt | self.llm | StrOutputParser()
167
 
168
- def rewrite(self, question: str) -> str:
169
  """重写查询以获得更好的检索效果"""
170
  print(f"---原始查询: {question}---")
171
- rewritten_query = self.rewriter.invoke({"question": question})
 
 
 
172
  print(f"---重写查询: {rewritten_query}---")
173
  return rewritten_query
174
 
@@ -187,11 +270,15 @@ def initialize_graders_and_router():
187
  answer_grader = AnswerGrader()
188
  hallucination_grader = HallucinationGrader(method=detection_method)
189
  query_rewriter = QueryRewriter()
 
 
190
 
191
  return {
192
  "query_router": query_router,
193
  "document_grader": document_grader,
194
  "answer_grader": answer_grader,
195
  "hallucination_grader": hallucination_grader,
196
- "query_rewriter": query_rewriter
 
 
197
  }
 
3
  包含查询路由、文档相关性评分、答案质量评分和幻觉检测
4
  """
5
 
6
+ from typing import List
7
  try:
8
  from langchain_core.prompts import PromptTemplate
9
  except ImportError:
 
153
  return result.get("score", "no")
154
 
155
 
156
+ class QueryDecomposer:
157
+ """查询分解器,将复杂的多跳问题分解为子问题序列"""
158
+
159
+ def __init__(self):
160
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
161
+ self.prompt = PromptTemplate(
162
+ template="""你是一个查询分解专家。你的任务是将一个复杂的多跳问题分解为一系列简单的子问题,这些子问题可以按顺序检索来回答原始问题。
163
+
164
+ 分解规则:
165
+ 1. 如果问题很简单,不需要分解,返回只包含原始问题的列表。
166
+ 2. 如果问题需要多步推理(例如"A的作者的大学在哪里"),分解为逻辑步骤:
167
+ - 步骤1: "谁是A的作者?"
168
+ - 步骤2: "该作者在哪个大学?"
169
+ 3. 保持子问题简洁明了。
170
+ 4. 即使返回单个问题,也必须包装在JSON的 sub_queries 列表中。
171
+
172
+ 输出格式:返回一个包含 'sub_queries' 键的 JSON,其值为字符串列表。
173
+ 不要输出任何前言或解释。
174
+
175
+ 复杂问题: {question}""",
176
+ input_variables=["question"],
177
+ )
178
+ self.decomposer = self.prompt | self.llm | JsonOutputParser()
179
+
180
+ def decompose(self, question: str) -> List[str]:
181
+ """分解问题"""
182
+ print(f"---分解问题: {question}---")
183
+ try:
184
+ result = self.decomposer.invoke({"question": question})
185
+ sub_queries = result.get("sub_queries", [question])
186
+ # 确保至少包含原始问题
187
+ if not sub_queries:
188
+ sub_queries = [question]
189
+ print(f"---子问题: {sub_queries}---")
190
+ return sub_queries
191
+ except Exception as e:
192
+ print(f"⚠️ 分解失败: {e},使用原始问题")
193
+ return [question]
194
+
195
+
196
+ class AnswerabilityGrader:
197
+ """答案可回答性评分器,用于判断当前检索到的文档是否足够回答原始问题"""
198
+
199
+ def __init__(self):
200
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
201
+ self.prompt = PromptTemplate(
202
+ template="""你是一个专家评分员,负责评估检索到的文档是否包含足够的信息来回答用户的问题。
203
+
204
+ 原始问题: {question}
205
+
206
+ 目前检索到的文档集合:
207
+ {documents}
208
+
209
+ 任务:
210
+ 判断上述文档是否已经包含了回答原始问题所需的全部关键信息。
211
+ - 如果信息充足,可以终止进一步的检索,返回 'yes'。
212
+ - 如果信息缺失,需要继续检索更多信息,返回 'no'。
213
+
214
+ 输出格式:
215
+ 返回一个只包含 'score' 键的 JSON,值为 'yes' 或 'no'。
216
+ 不要输出任何前言或解释。""",
217
+ input_variables=["question", "documents"],
218
+ )
219
+ self.grader = self.prompt | self.llm | JsonOutputParser()
220
+
221
+ def grade(self, question: str, documents: str) -> str:
222
+ """评估文档是否足以回答问题"""
223
+ result = self.grader.invoke({"question": question, "documents": documents})
224
+ return result.get("score", "no")
225
+
226
+
227
  class QueryRewriter:
228
  """查询重写器,优化查询以获得更好的检索结果"""
229
 
230
  def __init__(self):
231
  self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
232
  self.prompt = PromptTemplate(
233
+ template="""你是一个问题重写器,负责将输入问题转换为更适合向量存储检索的更好版本。
234
+
235
+ 你的目标是根据原始问题和(可选的)之前的检索上下文,生成一个新的查询,以便检索到回答问题所需的缺失信息。
236
+ 如果提供了之前的上下文,请分析其中缺少什么信息,并针对缺失的信息构建查询。
237
+
238
+ 初始问题: {question}
239
+
240
+ 之前的上下文(如果有):
241
+ {context}
242
+
243
+ 改进的问题(只输出问题,无前言):""",
244
+ input_variables=["question", "context"],
245
  )
246
  self.rewriter = self.prompt | self.llm | StrOutputParser()
247
 
248
+ def rewrite(self, question: str, context: str = "") -> str:
249
  """重写查询以获得更好的检索效果"""
250
  print(f"---原始查询: {question}---")
251
+ if context:
252
+ print(f"---参考上下文长度: {len(context)} 字符---")
253
+
254
+ rewritten_query = self.rewriter.invoke({"question": question, "context": context})
255
  print(f"---重写查询: {rewritten_query}---")
256
  return rewritten_query
257
 
 
270
  answer_grader = AnswerGrader()
271
  hallucination_grader = HallucinationGrader(method=detection_method)
272
  query_rewriter = QueryRewriter()
273
+ query_decomposer = QueryDecomposer()
274
+ answerability_grader = AnswerabilityGrader()
275
 
276
  return {
277
  "query_router": query_router,
278
  "document_grader": document_grader,
279
  "answer_grader": answer_grader,
280
  "hallucination_grader": hallucination_grader,
281
+ "query_rewriter": query_rewriter,
282
+ "query_decomposer": query_decomposer,
283
+ "answerability_grader": answerability_grader
284
  }
workflow_nodes.py CHANGED
@@ -49,6 +49,9 @@ class GraphState(TypedDict):
49
  documents: List[str]
50
  retry_count: int
51
  retrieval_metrics: dict # 添加检索评估指标
 
 
 
52
 
53
 
54
  class WorkflowNodes:
@@ -64,14 +67,18 @@ class WorkflowNodes:
64
 
65
  # 设置RAG链 - 使用本地提示模板
66
  rag_prompt_template = PromptTemplate(
67
- template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
68
- 如果你不知道答案,就说你不知道。最多使用三句话并保持答案简洁。
69
-
70
- 问题: {question}
71
-
72
- 上下文: {context}
73
-
74
- 答案:""",
 
 
 
 
75
  input_variables=["question", "context"]
76
  )
77
  llm = ChatOllama(model=LOCAL_LLM, temperature=0)
@@ -80,6 +87,37 @@ class WorkflowNodes:
80
  # 设置网络搜索
81
  self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def retrieve(self, state):
84
  """
85
  检索文档
@@ -111,7 +149,7 @@ class WorkflowNodes:
111
 
112
  # 记录使用的检索方法
113
  if ENABLE_HYBRID_SEARCH:
114
- print("---使用混合检索---")
115
  if ENABLE_QUERY_EXPANSION:
116
  print("---使用查询扩展---")
117
  if image_paths and ENABLE_MULTIMODAL:
@@ -136,6 +174,26 @@ class WorkflowNodes:
136
  print(f"❌ 回退检索也失败: {fallback_e}")
137
  documents = []
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # 计算检索时间
140
  retrieval_time = time.time() - retrieval_start_time
141
 
@@ -161,10 +219,12 @@ class WorkflowNodes:
161
  """
162
  print("---生成---")
163
  question = state["question"]
 
164
  documents = state["documents"]
165
 
166
- # RAG生成
167
- generation = self.rag_chain.invoke({"context": documents, "question": question})
 
168
  return {"documents": documents, "question": question, "generation": generation}
169
 
170
  def grade_documents(self, state):
@@ -211,8 +271,18 @@ class WorkflowNodes:
211
 
212
  print(f" 重试次数: {retry_count}")
213
 
214
- # 重写问题
215
- better_question = self.graders["query_rewriter"].rewrite(question)
 
 
 
 
 
 
 
 
 
 
216
  return {"documents": documents, "question": better_question, "retry_count": retry_count}
217
 
218
  def web_search(self, state):
@@ -257,18 +327,65 @@ class WorkflowNodes:
257
  print("---将问题路由到RAG---")
258
  return "vectorstore"
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def decide_to_generate(self, state):
261
  """
262
- 确定是生成答案还是重新生成问题
263
 
264
  Args:
265
  state (dict): 当前图状态
266
 
267
  Returns:
268
- str: 要调用的下一个节点的二进制决策
269
  """
270
  print("---评估已评分的文档---")
271
  filtered_documents = state["documents"]
 
 
 
272
 
273
  if not filtered_documents:
274
  # 所有文档都被过滤掉了
@@ -276,9 +393,34 @@ class WorkflowNodes:
276
  print("---决策:所有文档都与问题不相关,转换查询---")
277
  return "transform_query"
278
  else:
279
- # 我们有相关文档,所以生成答案
280
- print("---决策:生成---")
281
- return "generate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  def grade_generation_v_documents_and_question(self, state):
284
  """
 
49
  documents: List[str]
50
  retry_count: int
51
  retrieval_metrics: dict # 添加检索评估指标
52
+ sub_queries: List[str] # 分解后的子问题列表
53
+ current_query_index: int # 当前处理的子问题索引
54
+ original_question: str # 原始问题,用于早期终止检查
55
 
56
 
57
  class WorkflowNodes:
 
67
 
68
  # 设置RAG链 - 使用本地提示模板
69
  rag_prompt_template = PromptTemplate(
70
+ template="""你是一个智能问答助手。使用以下检索到的上下文来回答问题。
71
+
72
+ 规则:
73
+ 1. 如果你不知道答案,就说你不知道。
74
+ 2. 如果用户请求特定格式(如Markdown、列表、代码块等),请严格遵守。
75
+ 3. 如果没有特定格式要求,保持答案简洁。
76
+
77
+ 问题: {question}
78
+
79
+ 上下文: {context}
80
+
81
+ 答案:""",
82
  input_variables=["question", "context"]
83
  )
84
  llm = ChatOllama(model=LOCAL_LLM, temperature=0)
 
87
  # 设置网络搜索
88
  self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
89
 
90
+ def decompose_query(self, state):
91
+ """
92
+ 将初始查询分解为子查询
93
+
94
+ Args:
95
+ state (dict): 当前图状态
96
+
97
+ Returns:
98
+ state (dict): 更新sub_queries和current_query_index
99
+ """
100
+ print("---查询分解---")
101
+ question = state["question"]
102
+
103
+ # 使用分解器
104
+ sub_queries = self.graders["query_decomposer"].decompose(question)
105
+
106
+ # 如果分解器返回空或只有一个问题,我们仍然将其视为列表
107
+ if not sub_queries:
108
+ sub_queries = [question]
109
+
110
+ print(f" 生成了 {len(sub_queries)} 个子查询")
111
+
112
+ return {
113
+ "sub_queries": sub_queries,
114
+ "current_query_index": 0,
115
+ "question": sub_queries[0], # 将当前问题设置为第一个子查询
116
+ "original_question": question, # 保存原始问题
117
+ "documents": [], # 清空文档,准备开始新的检索
118
+ "retry_count": 0
119
+ }
120
+
121
  def retrieve(self, state):
122
  """
123
  检索文档
 
149
 
150
  # 记录使用的检索方法
151
  if ENABLE_HYBRID_SEARCH:
152
+ print("---使用混合检索(向量+关键词)---")
153
  if ENABLE_QUERY_EXPANSION:
154
  print("---使用查询扩展---")
155
  if image_paths and ENABLE_MULTIMODAL:
 
174
  print(f"❌ 回退检索也失败: {fallback_e}")
175
  documents = []
176
 
177
+ # === 向量多跳检索支持:合并上下文 ===
178
+ # 如果这不是第一次检索(即重试次数 > 0 或 正在处理后续子查询),说明之前的检索结果可能不完整或问题被重写了
179
+ # 我们应该保留之前的有价值文档,实现 "累积式上下文" (Accumulated Context)
180
+ current_query_index = state.get("current_query_index", 0)
181
+ if (retry_count > 0 or current_query_index > 0) and "documents" in state and state["documents"]:
182
+ print(f"---多跳上下文合并 (轮次 {retry_count}, 子查询 {current_query_index})---")
183
+ previous_docs = state["documents"]
184
+ if previous_docs:
185
+ # 简单的去重合并(基于内容)
186
+ current_content = {d.page_content for d in documents}
187
+ merged_count = 0
188
+ for prev_doc in previous_docs:
189
+ # 只有当内容不重复时才添加
190
+ if prev_doc.page_content not in current_content:
191
+ documents.append(prev_doc)
192
+ current_content.add(prev_doc.page_content)
193
+ merged_count += 1
194
+ print(f" 合并了 {merged_count} 个来自上一轮/上一跳的文档,当前总文档数: {len(documents)}")
195
+ # =================================
196
+
197
  # 计算检索时间
198
  retrieval_time = time.time() - retrieval_start_time
199
 
 
219
  """
220
  print("---生成---")
221
  question = state["question"]
222
+ original_question = state.get("original_question", question) # 优先使用原始问题
223
  documents = state["documents"]
224
 
225
+ # RAG生成 - 使用原始问题以确保回答用户的初始意图
226
+ # 如果用户有特定的格式要求(如Markdown),通常包含在original_question中
227
+ generation = self.rag_chain.invoke({"context": documents, "question": original_question})
228
  return {"documents": documents, "question": question, "generation": generation}
229
 
230
  def grade_documents(self, state):
 
271
 
272
  print(f" 重试次数: {retry_count}")
273
 
274
+ # 提取当前上下文摘要,帮助重写器理解缺失信息
275
+ context_summary = ""
276
+ if documents:
277
+ # 只提取前两个文档的摘要,避免上下文过长
278
+ docs_content = [d.page_content for d in documents[:2]]
279
+ context_summary = "\n---\n".join(docs_content)
280
+ # 截断以防止过长
281
+ if len(context_summary) > 2000:
282
+ context_summary = context_summary[:2000] + "...(截断)"
283
+
284
+ # 重写问题,传入上下文
285
+ better_question = self.graders["query_rewriter"].rewrite(question, context=context_summary)
286
  return {"documents": documents, "question": better_question, "retry_count": retry_count}
287
 
288
  def web_search(self, state):
 
327
  print("---将问题路由到RAG---")
328
  return "vectorstore"
329
 
330
+ def prepare_next_query(self, state):
331
+ """
332
+ 准备下一个子查询:提取桥接实体并重写查询
333
+
334
+ Args:
335
+ state (dict): 当前图状态
336
+
337
+ Returns:
338
+ state (dict): 更新question, current_query_index, retry_count
339
+ """
340
+ print("---准备下一个子查询---")
341
+ current_query_index = state.get("current_query_index", 0)
342
+ sub_queries = state.get("sub_queries", [])
343
+ documents = state["documents"]
344
+
345
+ # 移动到下一个索引
346
+ next_index = current_query_index + 1
347
+ next_query_raw = sub_queries[next_index]
348
+
349
+ print(f" 原始下一个子查询: {next_query_raw}")
350
+
351
+ # 提取上下文摘要用于重写(桥接实体提取)
352
+ context_summary = ""
353
+ if documents:
354
+ # 使用所有相关文档作为上下文
355
+ docs_content = [d.page_content for d in documents]
356
+ context_summary = "\n---\n".join(docs_content)
357
+ # 截断
358
+ if len(context_summary) > 3000:
359
+ context_summary = context_summary[:3000] + "...(截断)"
360
+
361
+ # 使用重写器将上下文(包含桥接实体)注入到下一个查询中
362
+ # 例如:Q1结果是"作者是J.K. Rowling",Q2是"她出生在哪里?" -> "J.K. Rowling出生在哪里?"
363
+ better_next_query = self.graders["query_rewriter"].rewrite(next_query_raw, context=context_summary)
364
+
365
+ print(f" 优化后的下一个子查询: {better_next_query}")
366
+
367
+ return {
368
+ "question": better_next_query,
369
+ "current_query_index": next_index,
370
+ "retry_count": 0, # 重置重试计数
371
+ "documents": documents # 保留文档作为上下文
372
+ }
373
+
374
  def decide_to_generate(self, state):
375
  """
376
+ 确定是生成答案、继续下一个子查询还是重新生成问题
377
 
378
  Args:
379
  state (dict): 当前图状态
380
 
381
  Returns:
382
+ str: 要调用的下一个节点的决策
383
  """
384
  print("---评估已评分的文档---")
385
  filtered_documents = state["documents"]
386
+ current_query_index = state.get("current_query_index", 0)
387
+ sub_queries = state.get("sub_queries", [])
388
+ original_question = state.get("original_question", "")
389
 
390
  if not filtered_documents:
391
  # 所有文档都被过滤掉了
 
393
  print("---决策:所有文档都与问题不相关,转换查询---")
394
  return "transform_query"
395
  else:
396
+ # 我们有相关文档
397
+ # 检查是否有更多子查询
398
+ if sub_queries and current_query_index < len(sub_queries) - 1:
399
+ # === 早期终止检查 ===
400
+ # 检查当前累积的文档是否已经足以回答原始问题
401
+ if original_question:
402
+ print("---检查是否已获取足够信息 (早期终止)---")
403
+
404
+ # 准备文档上下文
405
+ docs_content = [d.page_content for d in filtered_documents]
406
+ context_summary = "\n---\n".join(docs_content)
407
+ if len(context_summary) > 5000: # 限制上下文长度
408
+ context_summary = context_summary[:5000]
409
+
410
+ score = self.graders["answerability_grader"].grade(original_question, context_summary)
411
+
412
+ if score == "yes":
413
+ print(f"---决策:当前信息已足够回答原始问题,跳过剩余 {len(sub_queries) - 1 - current_query_index} 个子查询---")
414
+ return "generate"
415
+ else:
416
+ print("---决策:信息尚不完整,继续下一个子查询---")
417
+
418
+ print(f"---决策:当前子查询 ({current_query_index + 1}/{len(sub_queries)}) 完成,准备下一个---")
419
+ return "prepare_next_query"
420
+ else:
421
+ # 所有子查询都完成(或没有子查询),生成答案
422
+ print("---决策:所有子查询完成,生成---")
423
+ return "generate"
424
 
425
  def grade_generation_v_documents_and_question(self, state):
426
  """