lanny xu commited on
Commit
a1bc3ec
·
1 Parent(s): 95a1f44

add relative files

Browse files
test_crossencoder_reranking.py DELETED
@@ -1,229 +0,0 @@
1
- """
2
- 测试 CrossEncoder 重排功能
3
- 对比 Bi-Encoder vs CrossEncoder 的效果
4
- """
5
-
6
- from reranker import create_reranker, TFIDFReranker, BM25Reranker, SemanticReranker, CrossEncoderReranker
7
-
8
-
9
- class MockDoc:
10
- """模拟文档类"""
11
- def __init__(self, content, metadata=None):
12
- self.page_content = content
13
- self.metadata = metadata or {}
14
-
15
-
16
- class MockEmbeddings:
17
- """模拟 Embeddings 类(用于 Semantic Reranker)"""
18
- def embed_query(self, text):
19
- # 简单的字符级向量化(仅用于测试)
20
- return [ord(c) / 100.0 for c in text[:10]]
21
-
22
- def embed_documents(self, texts):
23
- return [self.embed_query(text) for text in texts]
24
-
25
-
26
- def create_test_documents():
27
- """创建测试文档集"""
28
- return [
29
- MockDoc("人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。"),
30
- MockDoc("机器学习是人工智能的子领域,专注于让计算机从数据中学习并改进。"),
31
- MockDoc("深度学习使用多层神经网络来处理复杂的数据模式,是机器学习的一种方法。"),
32
- MockDoc("自然语言处理(NLP)是人工智能的一个分支,处理计算机与人类语言之间的交互。"),
33
- MockDoc("计算机视觉是人工智能的另一个重要领域,使机器能够理解和解释视觉信息。"),
34
- MockDoc("今天天气很好,适合出去散步和运动。"),
35
- MockDoc("Python 是一种高级编程语言,由 Guido van Rossum 在 1991 年创建。"),
36
- MockDoc("RAG(检索增强生成)是一种结合信息检索和文本生成的技术。"),
37
- ]
38
-
39
-
40
- def test_tfidf_reranking():
41
- """测试 TF-IDF 重排"""
42
- print("\n" + "=" * 60)
43
- print("📊 测试 TF-IDF 重排")
44
- print("=" * 60)
45
-
46
- query = "什么是人工智能和机器学习?"
47
- docs = create_test_documents()
48
-
49
- reranker = TFIDFReranker()
50
- results = reranker.rerank(query, docs, top_k=3)
51
-
52
- print(f"\n查询: {query}")
53
- print("\nTF-IDF 重排结果:")
54
- for i, (doc, score) in enumerate(results, 1):
55
- print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
56
-
57
-
58
- def test_bm25_reranking():
59
- """测试 BM25 重排"""
60
- print("\n" + "=" * 60)
61
- print("📊 测试 BM25 重排")
62
- print("=" * 60)
63
-
64
- query = "什么是人工智能和机器学习?"
65
- docs = create_test_documents()
66
-
67
- reranker = BM25Reranker()
68
- results = reranker.rerank(query, docs, top_k=3)
69
-
70
- print(f"\n查询: {query}")
71
- print("\nBM25 重排结果:")
72
- for i, (doc, score) in enumerate(results, 1):
73
- print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
74
-
75
-
76
- def test_crossencoder_reranking():
77
- """测试 CrossEncoder 重排"""
78
- print("\n" + "=" * 60)
79
- print("🌟 测试 CrossEncoder 重排(推荐)")
80
- print("=" * 60)
81
-
82
- query = "什么是人工智能和机器学习?"
83
- docs = create_test_documents()
84
-
85
- try:
86
- # 使用轻量级模型
87
- reranker = CrossEncoderReranker(
88
- model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"
89
- )
90
- results = reranker.rerank(query, docs, top_k=3)
91
-
92
- print(f"\n查询: {query}")
93
- print("\nCrossEncoder 重排结果:")
94
- for i, (doc, score) in enumerate(results, 1):
95
- print(f"{i}. 分数: {score:.4f} | 内容: {doc.page_content[:50]}...")
96
-
97
- return True
98
-
99
- except Exception as e:
100
- print(f"\n❌ CrossEncoder 测试失败: {e}")
101
- print("💡 提示: 请先安装 sentence-transformers")
102
- print(" 命令: pip install sentence-transformers")
103
- return False
104
-
105
-
106
- def test_factory_function():
107
- """测试工厂函数"""
108
- print("\n" + "=" * 60)
109
- print("🏭 测试重排器工厂函数")
110
- print("=" * 60)
111
-
112
- query = "深度学习和神经网络"
113
- docs = create_test_documents()
114
-
115
- # 测试各种类型
116
- reranker_types = ['tfidf', 'bm25']
117
-
118
- for rtype in reranker_types:
119
- try:
120
- reranker = create_reranker(rtype)
121
- results = reranker.rerank(query, docs, top_k=2)
122
- print(f"\n✅ {rtype.upper()} 重排器创建成功")
123
- print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
124
- except Exception as e:
125
- print(f"\n❌ {rtype.upper()} 重排器失败: {e}")
126
-
127
- # 测试 CrossEncoder
128
- try:
129
- reranker = create_reranker('crossencoder')
130
- results = reranker.rerank(query, docs, top_k=2)
131
- print(f"\n✅ CROSSENCODER 重排器创建成功")
132
- print(f" Top 1: {results[0][1]:.4f} | {results[0][0].page_content[:40]}...")
133
- except Exception as e:
134
- print(f"\n❌ CROSSENCODER 重排器失败: {e}")
135
-
136
-
137
- def compare_all_methods():
138
- """对比所有重排方法"""
139
- print("\n" + "=" * 60)
140
- print("⚖️ 对比所有重排方法")
141
- print("=" * 60)
142
-
143
- query = "解释一下人工智能、机器学习和深度学习的关系"
144
- docs = create_test_documents()
145
-
146
- methods = {
147
- 'TF-IDF': TFIDFReranker(),
148
- 'BM25': BM25Reranker(),
149
- }
150
-
151
- # 尝试添加 CrossEncoder
152
- try:
153
- methods['CrossEncoder'] = CrossEncoderReranker()
154
- except:
155
- print("\n⚠️ CrossEncoder 不可用,跳过")
156
-
157
- print(f"\n查询: {query}\n")
158
-
159
- for method_name, reranker in methods.items():
160
- try:
161
- results = reranker.rerank(query, docs, top_k=3)
162
- print(f"\n{'=' * 40}")
163
- print(f"{method_name} 重排结果:")
164
- print('=' * 40)
165
- for i, (doc, score) in enumerate(results, 1):
166
- print(f"{i}. [{score:.4f}] {doc.page_content[:60]}...")
167
- except Exception as e:
168
- print(f"\n{method_name} 失败: {e}")
169
-
170
-
171
- def performance_comparison():
172
- """性能对比"""
173
- print("\n" + "=" * 60)
174
- print("⚡ 性能与准确性对比")
175
- print("=" * 60)
176
-
177
- print("""
178
- 重排方法对比:
179
-
180
- ┌─────────────────┬──────────┬──────────┬──────────┬────────────┐
181
- │ 方法 │ 准确率 │ 速度 │ 成本 │ 适用场景 │
182
- ├─────────────────┼──────────┼──────────┼──────────┼────────────┤
183
- │ TF-IDF │ ⭐⭐ │ ⚡⚡⚡ │ 极低 │ 关键词匹配 │
184
- │ BM25 │ ⭐⭐⭐ │ ⚡⚡⚡ │ 极低 │ 文本检索 │
185
- │ Bi-Encoder │ ⭐⭐⭐⭐ │ ⚡⚡ │ 低 │ 语义检索 │
186
- │ CrossEncoder 🌟 │ ⭐⭐⭐⭐⭐│ ⚡ │ 中 │ 精准重排 │
187
- │ Hybrid │ ⭐⭐⭐⭐ │ ⚡⚡ │ 低 │ 综合场景 │
188
- └─────────────────┴──────────┴──────────┴──────────┴────────────┘
189
-
190
- 推荐配置:
191
- 1️⃣ 两阶段检索:Bi-Encoder (快速召回) + CrossEncoder (精准重排)
192
- 2️⃣ 准确率优先:纯 CrossEncoder
193
- 3️⃣ 速度优先:BM25 或 Hybrid
194
-
195
- 当前项目配置:
196
- ✅ 已切换到 CrossEncoder 重排
197
- 📈 准确率预期提升:15-20%
198
- ⚡ 速度:单次重排 20-100ms (Top 20 文档)
199
- """)
200
-
201
-
202
- if __name__ == "__main__":
203
- print("\n🚀 开始测试 CrossEncoder 重排功能...\n")
204
-
205
- # 1. 测试 TF-IDF
206
- test_tfidf_reranking()
207
-
208
- # 2. 测试 BM25
209
- test_bm25_reranking()
210
-
211
- # 3. 测试 CrossEncoder (重点)
212
- crossencoder_available = test_crossencoder_reranking()
213
-
214
- # 4. 测试工厂函数
215
- test_factory_function()
216
-
217
- # 5. 对比所有方法
218
- compare_all_methods()
219
-
220
- # 6. 性能对比总结
221
- performance_comparison()
222
-
223
- print("\n" + "=" * 60)
224
- if crossencoder_available:
225
- print("✅ 所有测试完成!CrossEncoder 重排已就绪")
226
- else:
227
- print("⚠️ 测试完成,但 CrossEncoder 不可用")
228
- print(" 请运行: pip install sentence-transformers")
229
- print("=" * 60 + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_hallucination_detector.py DELETED
@@ -1,173 +0,0 @@
1
- """
2
- 测试专业幻觉检测器
3
- 对比 LLM-as-a-Judge vs Vectara/NLI
4
- """
5
-
6
- from hallucination_detector import (
7
- VectaraHallucinationDetector,
8
- NLIHallucinationDetector,
9
- HybridHallucinationDetector
10
- )
11
-
12
-
13
- def test_vectara_detector():
14
- """测试 Vectara 检测器"""
15
- print("=" * 60)
16
- print("🧪 测试 Vectara 幻觉检测器")
17
- print("=" * 60)
18
-
19
- detector = VectaraHallucinationDetector()
20
-
21
- # 测试用例 1: 正常回答(无幻觉)
22
- documents = """
23
- Python是一种高级编程语言。它由Guido van Rossum在1991年创建。
24
- Python强调代码可读性,使用缩进来定义代码块。
25
- """
26
- generation = "Python是由Guido van Rossum在1991年创建的高级编程语言。"
27
-
28
- print("\n📝 测试用例 1: 正常回答")
29
- print(f"文档: {documents[:100]}...")
30
- print(f"生成: {generation}")
31
- result = detector.detect(generation, documents)
32
- print(f"结果: {result}")
33
-
34
- # 测试用例 2: 幻觉回答
35
- generation_hallucinated = "Python是由Dennis Ritchie在1972年创建的。"
36
-
37
- print("\n📝 测试用例 2: 幻觉回答")
38
- print(f"生成: {generation_hallucinated}")
39
- result = detector.detect(generation_hallucinated, documents)
40
- print(f"结果: {result}")
41
-
42
- print("\n" + "=" * 60)
43
-
44
-
45
- def test_nli_detector():
46
- """测试 NLI 检测器"""
47
- print("\n" + "=" * 60)
48
- print("🧪 测试 NLI 幻觉检测器")
49
- print("=" * 60)
50
-
51
- detector = NLIHallucinationDetector()
52
-
53
- documents = """
54
- LangChain是一个用于构建LLM应用的框架。
55
- 它提供了链式调用、提示模板、内存管理等功能。
56
- """
57
-
58
- # 测试用例 1: 正常回答
59
- generation = "LangChain提供了链式调用和提示模板功能。"
60
-
61
- print("\n📝 测试用例 1: 正常回答")
62
- print(f"生成: {generation}")
63
- result = detector.detect(generation, documents)
64
- print(f"结果: {result}")
65
-
66
- # 测试用例 2: 幻觉回答
67
- generation_hallucinated = "LangChain是由OpenAI开发的数据库系统。它主要用于存储图片。"
68
-
69
- print("\n📝 测试用例 2: 幻觉回答")
70
- print(f"生成: {generation_hallucinated}")
71
- result = detector.detect(generation_hallucinated, documents)
72
- print(f"结果: {result}")
73
-
74
- print("\n" + "=" * 60)
75
-
76
-
77
- def test_hybrid_detector():
78
- """测试混合检测器"""
79
- print("\n" + "=" * 60)
80
- print("🧪 测试混合幻觉检测器 (推荐)")
81
- print("=" * 60)
82
-
83
- detector = HybridHallucinationDetector(use_vectara=True, use_nli=True)
84
-
85
- documents = """
86
- GraphRAG是一种结合图结构和RAG的方法。
87
- 它通过构建知识图谱来增强检索效果。
88
- 主要步骤包括实体提取、关系识别、社区检测和摘要生成。
89
- """
90
-
91
- # 测试用例 1: 正常回答
92
- generation = "GraphRAG通过知识图谱增强检索,包含实体提取和社区检测等步骤。"
93
-
94
- print("\n📝 测试用例 1: 正常回答")
95
- print(f"生成: {generation}")
96
- result = detector.detect(generation, documents)
97
- print(f"结果: {result}")
98
-
99
- # 测试用例 2: 幻觉回答
100
- generation_hallucinated = "GraphRAG是一个数据库管理系统,主要用于存储用户密码和财务数据。"
101
-
102
- print("\n📝 测试用例 2: 幻觉回答")
103
- print(f"生成: {generation_hallucinated}")
104
- result = detector.detect(generation_hallucinated, documents)
105
- print(f"结果: {result}")
106
-
107
- # 测试 grade 方法(兼容接口)
108
- print("\n📝 测试 grade 方法(兼容原有接口)")
109
- score = detector.grade(generation, documents)
110
- print(f"Grade 结果: {score} (yes=无幻觉, no=有幻觉)")
111
-
112
- print("\n" + "=" * 60)
113
-
114
-
115
- def compare_performance():
116
- """对比性能"""
117
- print("\n" + "=" * 60)
118
- print("📊 性能对比总结")
119
- print("=" * 60)
120
-
121
- print("""
122
- 方法对比:
123
-
124
- 1️⃣ LLM-as-a-Judge (原方法)
125
- 准确率: 60-75%
126
- 速度: 慢 (每次 2-5 秒)
127
- 成本: 高 (调用 LLM)
128
-
129
- 2️⃣ Vectara 专门检测模型
130
- 准确率: 90-95%
131
- 速度: 快 (每次 0.1-0.3 秒)
132
- 成本: 低 (本地推理)
133
-
134
- 3️⃣ NLI 模型
135
- 准确率: 85-90%
136
- 速度: 快 (每次 0.2-0.5 秒)
137
- 成本: 低 (本地推理)
138
-
139
- 4️⃣ 混合检测器 (推荐) ⭐
140
- 准确率: 95%+
141
- 速度: 中等 (每次 0.3-0.8 秒)
142
- 成本: 低
143
- 优势: 综合多个模型,准确率最高
144
- """)
145
-
146
- print("=" * 60)
147
-
148
-
149
- if __name__ == "__main__":
150
- print("\n🚀 开始测试专业幻觉检测器...\n")
151
-
152
- try:
153
- # 测试 Vectara
154
- test_vectara_detector()
155
- except Exception as e:
156
- print(f"❌ Vectara 测试失败: {e}")
157
-
158
- try:
159
- # 测试 NLI
160
- test_nli_detector()
161
- except Exception as e:
162
- print(f"❌ NLI 测试失败: {e}")
163
-
164
- try:
165
- # ���试混合检测器
166
- test_hybrid_detector()
167
- except Exception as e:
168
- print(f"❌ 混合检测器测试失败: {e}")
169
-
170
- # 性能对比
171
- compare_performance()
172
-
173
- print("\n✅ 测试完成!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_lightweight_detector.py DELETED
@@ -1,152 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- 轻量级幻觉检测器测试脚本
4
- 测试效果与性能,替代 Vectara 模型
5
- """
6
-
7
- import time
8
- from lightweight_hallucination_detector import LightweightHallucinationDetector
9
-
10
- def test_performance():
11
- """测试不同模型的性能和效果"""
12
- print("="*70)
13
- print("🚀 轻量级幻觉检测器性能测试")
14
- print("="*70)
15
-
16
- # 测试不同模型
17
- models_to_test = [
18
- "cross-encoder/nli-MiniLM2-L6-H768", # 推荐轻量方案
19
- "cross-encoder/nli-deberta-v3-xsmall", # 超轻量方案
20
- "cross-encoder/nli-roberta-base", # 高准确率方案
21
- ]
22
-
23
- # 测试数据
24
- documents = "巴黎是法国的首都,这是一座美丽的城市,拥有许多历史地标和博物馆。"
25
-
26
- test_cases = [
27
- ("完全正确", "巴黎是法国的首都。"),
28
- ("事实错误", "柏林是法国的首都。"),
29
- ("部分正确", "巴黎是德国的首都,但很美丽。"),
30
- ("语义等价", "法国的首都是巴黎。"),
31
- ("无关信息", "纽约是美国的一个大城市。"),
32
- ]
33
-
34
- results = []
35
-
36
- for model_name in models_to_test:
37
- print(f"\n📊 测试模型: {model_name}")
38
- print("-" * 50)
39
-
40
- try:
41
- detector = LightweightHallucinationDetector(model_name)
42
-
43
- model_results = {
44
- "model": model_name,
45
- "tests": []
46
- }
47
-
48
- for test_name, test_case in test_cases:
49
- start_time = time.time()
50
- result = detector.detect(test_case, documents)
51
- end_time = time.time()
52
-
53
- print(f" {test_name}:")
54
- print(f" 假设: {test_case}")
55
- print(f" 是否幻觉: {result['has_hallucination']}")
56
- print(f" 幻觉分数: {result['hallucination_score']:.3f}")
57
- print(f" 推理时间: {end_time - start_time:.3f}秒")
58
- print()
59
-
60
- model_results["tests"].append({
61
- "name": test_name,
62
- "case": test_case,
63
- "result": result,
64
- "time": end_time - start_time
65
- })
66
-
67
- results.append(model_results)
68
-
69
- except Exception as e:
70
- print(f" ❌ 模型测试失败: {e}")
71
-
72
- # 总结
73
- print("\n" + "="*70)
74
- print("📋 测试总结")
75
- print("="*70)
76
-
77
- for model_result in results:
78
- model = model_result["model"]
79
- tests = model_result["tests"]
80
-
81
- avg_time = sum(t["time"] for t in tests) / len(tests)
82
- correct_count = 0
83
-
84
- # 评估准确性
85
- expected_results = [False, True, True, False, False] # 预期结果
86
- for i, test in enumerate(tests):
87
- if test["result"]["has_hallucination"] == expected_results[i]:
88
- correct_count += 1
89
-
90
- accuracy = correct_count / len(tests) * 100
91
-
92
- print(f"\n🤖 {model}:")
93
- print(f" ⚡ 平均推理时间: {avg_time:.3f}秒")
94
- print(f" 🎯 准确率: {accuracy:.1f}% ({correct_count}/{len(tests)})")
95
- print(f" 📊 幻觉检测评分: {sum(t['result']['hallucination_score'] for t in tests):.2f}")
96
-
97
- def test_rag_scenarios():
98
- """测试RAG场景下的幻觉检测"""
99
- print("\n" + "="*70)
100
- print("🔍 RAG场景测试")
101
- print("="*70)
102
-
103
- # RAG测试数据
104
- rag_documents = """
105
- 产品信息:iPhone 14 Pro 是苹果公司在2022年9月发布的旗舰智能手机。
106
- 主要特性:配备6.1英寸Super Retina XDR显示屏,A16仿生芯片,4800万像素主摄像头。
107
- 电池续航:视频播放最长可达23小时,支持20W有线快充。
108
- 价格:起售价为799美元。
109
- """
110
-
111
- rag_test_cases = [
112
- ("准确信息", "iPhone 14 Pro配备了A16仿生芯片和4800万像素摄像头。"),
113
- ("规格错误", "iPhone 14 Pro配备A15仿生芯片和1200万像素摄像头。"),
114
- ("价格错误", "iPhone 14 Pro的起售价为999美元。"),
115
- ("无关信息", "iPhone 14 Pro支持手写笔输入。"),
116
- ("混合信息", "iPhone 14 Pro配备A16芯片,起售价999美元,支持手写笔。"),
117
- ]
118
-
119
- detector = LightweightHallucinationDetector()
120
-
121
- print("🧪 RAG幻觉检测测试:\n")
122
-
123
- for test_name, test_case in rag_test_cases:
124
- result = detector.detect(test_case, rag_documents, method="sentence_level")
125
-
126
- print(f"📋 {test_name}:")
127
- print(f" 生成内容: {test_case}")
128
- print(f" 检测结果: {'🚨 检测到幻觉' if result['has_hallucination'] else '✅ 未检测到幻觉'}")
129
- print(f" 幻觉分数: {result['hallucination_score']:.3f}")
130
- print(f" 事实性分数: {result['factuality_score']:.3f}")
131
-
132
- if result['details'].get('problematic_sentences'):
133
- print(f" 问题句子数: {len(result['details']['problematic_sentences'])}")
134
- for i, prob in enumerate(result['details']['problematic_sentences'], 1):
135
- print(f" {i}. {prob['sentence']} (分数: {prob['score']:.3f})")
136
-
137
- print()
138
-
139
- if __name__ == "__main__":
140
- # 1. 性能测试
141
- test_performance()
142
-
143
- # 2. RAG场景测试
144
- test_rag_scenarios()
145
-
146
- print("\n" + "="*70)
147
- print("💡 使用建议:")
148
- print("1. 生产环境推荐使用 cross-encoder/nli-MiniLM2-L6-H768")
149
- print("2. 资源受限环境可使用 cross-encoder/nli-deberta-v3-xsmall")
150
- print("3. 高准确率需求可使用 cross-encoder/nli-roberta-base")
151
- print("4. 建议设置幻觉分数阈值为 0.6-0.7")
152
- print("="*70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_reranking.py DELETED
@@ -1,224 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- 重排功能测试脚本
4
- 演示不同重排策略的效果
5
- """
6
-
7
- import sys
8
- import os
9
- sys.path.append(os.path.dirname(__file__))
10
-
11
- from document_processor import DocumentProcessor
12
- from reranker import *
13
- try:
14
- from langchain_core.documents import Document
15
- except ImportError:
16
- try:
17
- from langchain_core.documents import Document
18
- except ImportError:
19
- from langchain.schema import Document
20
- import time
21
-
22
-
23
- def create_test_documents():
24
- """创建测试文档"""
25
- return [
26
- Document(
27
- page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
28
- metadata={"source": "ai_intro.txt", "category": "AI基础"}
29
- ),
30
- Document(
31
- page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
32
- metadata={"source": "ml_basics.txt", "category": "机器学习"}
33
- ),
34
- Document(
35
- page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
36
- metadata={"source": "dl_guide.txt", "category": "深度学习"}
37
- ),
38
- Document(
39
- page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
40
- metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
41
- ),
42
- Document(
43
- page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
44
- metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
45
- ),
46
- Document(
47
- page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
48
- metadata={"source": "rl_basics.txt", "category": "强化学习"}
49
- ),
50
- Document(
51
- page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
52
- metadata={"source": "weather.txt", "category": "天气"}
53
- ),
54
- Document(
55
- page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
56
- metadata={"source": "blockchain.txt", "category": "区块链"}
57
- )
58
- ]
59
-
60
-
61
- def test_reranker_comparison():
62
- """比较不同重排器的效果"""
63
- print("🔍 重排器效果比较测试")
64
- print("=" * 60)
65
-
66
- # 创建测试数据
67
- query = "什么是人工智能和机器学习?"
68
- documents = create_test_documents()
69
-
70
- # 创建一个简单的嵌入模型(用于测试)
71
- try:
72
- from langchain_community.embeddings import HuggingFaceEmbeddings
73
- embeddings = HuggingFaceEmbeddings(
74
- model_name="sentence-transformers/all-MiniLM-L6-v2",
75
- model_kwargs={'device': 'cpu'}
76
- )
77
- print("✅ 成功加载嵌入模型")
78
- except Exception as e:
79
- print(f"❌ 嵌入模型加载失败: {e}")
80
- print("将使用基础重排器进行测试")
81
- embeddings = None
82
-
83
- # 测试不同的重排器
84
- rerankers = []
85
-
86
- # TF-IDF重排器
87
- rerankers.append(("TF-IDF", TFIDFReranker()))
88
-
89
- # BM25重排器
90
- rerankers.append(("BM25", BM25Reranker()))
91
-
92
- if embeddings:
93
- # 语义重排器
94
- rerankers.append(("语义相似度", SemanticReranker(embeddings)))
95
-
96
- # 混合重排器
97
- rerankers.append(("混合策略", HybridReranker(embeddings)))
98
-
99
- # 多样性重排器
100
- rerankers.append(("多样性优化", DiversityReranker(embeddings)))
101
-
102
- # 执行测试
103
- for name, reranker in rerankers:
104
- print(f"\n📊 {name} 重排结果:")
105
- print("-" * 40)
106
-
107
- start_time = time.time()
108
- try:
109
- results = reranker.rerank(query, documents, top_k=5)
110
- end_time = time.time()
111
-
112
- print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
113
-
114
- for i, (doc, score) in enumerate(results, 1):
115
- content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
116
- category = doc.metadata.get('category', '未知')
117
- print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
118
-
119
- except Exception as e:
120
- print(f"❌ 重排失败: {e}")
121
-
122
-
123
- def test_reranking_with_embeddings():
124
- """测试带嵌入的重排功能"""
125
- print("\n\n🧠 嵌入模型重排测试")
126
- print("=" * 60)
127
-
128
- try:
129
- # 创建文档处理器
130
- processor = DocumentProcessor()
131
-
132
- # 创建测试文档
133
- test_docs = create_test_documents()
134
-
135
- # 测试查询
136
- queries = [
137
- "人工智能的定义是什么?",
138
- "机器学习和深度学习的区别",
139
- "自然语言处理的应用",
140
- "今天天气怎么样?"
141
- ]
142
-
143
- for query in queries:
144
- print(f"\n🔍 查询: {query}")
145
- print("-" * 30)
146
-
147
- if processor.reranker:
148
- # 使用重排功能
149
- results = processor.reranker.rerank(query, test_docs, top_k=3)
150
-
151
- for i, (doc, score) in enumerate(results, 1):
152
- content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
153
- category = doc.metadata.get('category', '未知')
154
- print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
155
- else:
156
- print("❌ 重排器未初始化")
157
-
158
- except Exception as e:
159
- print(f"❌ 测试失败: {e}")
160
-
161
-
162
- def test_performance_comparison():
163
- """性能对比测试"""
164
- print("\n\n⚡ 性能对比测试")
165
- print("=" * 60)
166
-
167
- documents = create_test_documents() * 10 # 增加文档数量
168
- query = "人工智能技术的发展趋势"
169
-
170
- # 测试不同重排器的性能
171
- rerankers_config = [
172
- ("无重排", None),
173
- ("TF-IDF", TFIDFReranker()),
174
- ("BM25", BM25Reranker())
175
- ]
176
-
177
- for name, reranker in rerankers_config:
178
- times = []
179
-
180
- # 多次测试取平均值
181
- for _ in range(5):
182
- start_time = time.time()
183
-
184
- if reranker:
185
- results = reranker.rerank(query, documents, top_k=5)
186
- else:
187
- # 模拟无重排的情况
188
- results = documents[:5]
189
-
190
- end_time = time.time()
191
- times.append((end_time - start_time) * 1000)
192
-
193
- avg_time = sum(times) / len(times)
194
- print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")
195
-
196
-
197
- def main():
198
- """主测试函数"""
199
- print("🚀 向量重排功能综合测试")
200
- print("=" * 80)
201
-
202
- try:
203
- # 基础重排器比较
204
- test_reranker_comparison()
205
-
206
- # 嵌入模型重排测试
207
- test_reranking_with_embeddings()
208
-
209
- # 性能对比测试
210
- test_performance_comparison()
211
-
212
- print("\n\n✅ 所有测试完成!")
213
- print("=" * 80)
214
-
215
- except KeyboardInterrupt:
216
- print("\n❌ 测试被用户中断")
217
- except Exception as e:
218
- print(f"\n❌ 测试过程中发生错误: {e}")
219
- import traceback
220
- traceback.print_exc()
221
-
222
-
223
- if __name__ == "__main__":
224
- main()