|
|
""" |
|
|
Semantic retrieval functionality for RAG system. |
|
|
|
|
|
Handles document retrieval, filtering, and relevance scoring. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
import numpy as np |
|
|
|
|
|
from .qdrant_client import QdrantManager |
|
|
from .embeddings import EmbeddingGenerator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class RetrievalEngine: |
|
|
"""Handles semantic retrieval of documents from vector store.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
qdrant_manager: QdrantManager, |
|
|
embedder: EmbeddingGenerator, |
|
|
default_k: int = 5, |
|
|
score_threshold: float = 0.3, |
|
|
max_context_tokens: int = 4000 |
|
|
): |
|
|
self.qdrant_manager = qdrant_manager |
|
|
self.embedder = embedder |
|
|
self.default_k = default_k |
|
|
self.score_threshold = score_threshold |
|
|
self.max_context_tokens = max_context_tokens |
|
|
|
|
|
async def retrieve( |
|
|
self, |
|
|
query: str, |
|
|
k: Optional[int] = None, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
rerank: bool = True |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieve relevant documents for a query. |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
k: Number of documents to retrieve |
|
|
filters: Metadata filters |
|
|
rerank: Whether to apply reranking |
|
|
|
|
|
Returns: |
|
|
List of retrieved documents with scores |
|
|
""" |
|
|
k = k or self.default_k |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Generating embedding for query: {query[:100]}...") |
|
|
embedding_result = await self.embedder.generate_embedding(query) |
|
|
query_embedding = embedding_result["embedding"] |
|
|
|
|
|
|
|
|
logger.info(f"Retrieving {k} documents from vector store...") |
|
|
results = await self.qdrant_manager.search_similar( |
|
|
query_embedding=query_embedding, |
|
|
limit=k * 2 if rerank else k, |
|
|
score_threshold=self.score_threshold, |
|
|
filters=filters |
|
|
) |
|
|
|
|
|
if not results: |
|
|
logger.warning("No documents retrieved") |
|
|
return [] |
|
|
|
|
|
|
|
|
if rerank and len(results) > k: |
|
|
results = await self._rerank_results(query, results, k) |
|
|
|
|
|
|
|
|
results = results[:k] |
|
|
|
|
|
|
|
|
for i, result in enumerate(results): |
|
|
result["rank"] = i + 1 |
|
|
result["retrieval_method"] = "semantic" |
|
|
|
|
|
logger.info(f"Retrieved {len(results)} documents") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Retrieval failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def retrieve_with_context( |
|
|
self, |
|
|
query: str, |
|
|
max_tokens: Optional[int] = None, |
|
|
filters: Optional[Dict[str, Any]] = None |
|
|
) -> Tuple[List[Dict[str, Any]], str]: |
|
|
""" |
|
|
Retrieve documents and format as context string within token limit. |
|
|
|
|
|
Returns: |
|
|
Tuple of (retrieved_documents, context_string) |
|
|
""" |
|
|
max_tokens = max_tokens or self.max_context_tokens |
|
|
|
|
|
|
|
|
results = await self.retrieve( |
|
|
query=query, |
|
|
k=10, |
|
|
filters=filters, |
|
|
rerank=True |
|
|
) |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
current_tokens = 0 |
|
|
selected_docs = [] |
|
|
|
|
|
for doc in results: |
|
|
|
|
|
doc_text = self._format_document(doc) |
|
|
doc_tokens = self.embedder.get_token_count(doc_text) |
|
|
|
|
|
|
|
|
if current_tokens + doc_tokens > max_tokens: |
|
|
if not context_parts: |
|
|
context_parts.append(doc_text[:max_tokens]) |
|
|
selected_docs.append({**doc, "content": doc["content"][:max_tokens]}) |
|
|
break |
|
|
|
|
|
context_parts.append(doc_text) |
|
|
selected_docs.append(doc) |
|
|
current_tokens += doc_tokens |
|
|
|
|
|
context_string = "\n\n".join(context_parts) |
|
|
|
|
|
return selected_docs, context_string |
|
|
|
|
|
async def _rerank_results( |
|
|
self, |
|
|
query: str, |
|
|
results: List[Dict[str, Any]], |
|
|
k: int |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Rerank retrieval results using cross-encoder or other methods. |
|
|
|
|
|
For now, implements a simple keyword-based reranking. |
|
|
Production systems might use a cross-encoder model. |
|
|
""" |
|
|
logger.info(f"Reranking {len(results)} results to top {k}") |
|
|
|
|
|
|
|
|
query_terms = set(query.lower().split()) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
content = result.get("content", "").lower() |
|
|
|
|
|
|
|
|
term_matches = sum(1 for term in query_terms if term in content) |
|
|
term_coverage = term_matches / len(query_terms) if query_terms else 0 |
|
|
|
|
|
|
|
|
length_penalty = min(1.0, 500 / len(content)) |
|
|
|
|
|
|
|
|
semantic_score = result.get("score", 0.0) |
|
|
rerank_score = ( |
|
|
0.7 * semantic_score + |
|
|
0.2 * term_coverage + |
|
|
0.1 * length_penalty |
|
|
) |
|
|
|
|
|
result["rerank_score"] = rerank_score |
|
|
result["term_coverage"] = term_coverage |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) |
|
|
|
|
|
return results[:k] |
|
|
|
|
|
def _format_document(self, doc: Dict[str, Any]) -> str: |
|
|
"""Format a document for context.""" |
|
|
metadata = doc.get("metadata", {}) |
|
|
|
|
|
|
|
|
parts = [] |
|
|
if metadata.get("chapter"): |
|
|
parts.append(f"Chapter: {metadata['chapter']}") |
|
|
if metadata.get("section"): |
|
|
parts.append(f"Section: {metadata['section']}") |
|
|
if metadata.get("subsection"): |
|
|
parts.append(f"Subsection: {metadata['subsection']}") |
|
|
|
|
|
citation = " - ".join(parts) if parts else "Source" |
|
|
|
|
|
|
|
|
formatted = f"[{citation}]\n{doc.get('content', '')}" |
|
|
|
|
|
return formatted |
|
|
|
|
|
async def hybrid_search( |
|
|
self, |
|
|
query: str, |
|
|
k: int = 5, |
|
|
semantic_weight: float = 0.7, |
|
|
filters: Optional[Dict[str, Any]] = None |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Perform hybrid search combining semantic and keyword search. |
|
|
|
|
|
For now, focuses on semantic search. Keyword search would need |
|
|
additional indexing infrastructure. |
|
|
""" |
|
|
|
|
|
semantic_results = await self.retrieve( |
|
|
query=query, |
|
|
k=k, |
|
|
filters=filters, |
|
|
rerank=True |
|
|
) |
|
|
|
|
|
|
|
|
for result in semantic_results: |
|
|
result["final_score"] = result.get("score", 0.0) * semantic_weight |
|
|
result["search_type"] = "hybrid" |
|
|
|
|
|
return semantic_results |
|
|
|
|
|
async def get_similar_documents( |
|
|
self, |
|
|
document_id: str, |
|
|
k: int = 5 |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Find documents similar to a given document. |
|
|
|
|
|
This would require storing document vectors for similarity search. |
|
|
""" |
|
|
|
|
|
|
|
|
logger.warning("Similar documents search not fully implemented") |
|
|
return [] |
|
|
|
|
|
def explain_retrieval( |
|
|
self, |
|
|
query: str, |
|
|
results: List[Dict[str, Any]] |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Explain why documents were retrieved for debugging/analysis. |
|
|
|
|
|
Returns explanation of retrieval process and scoring. |
|
|
""" |
|
|
explanation = { |
|
|
"query": query, |
|
|
"retrieved_count": len(results), |
|
|
"score_threshold": self.score_threshold, |
|
|
"reranking_applied": any("rerank_score" in r for r in results), |
|
|
"results": [] |
|
|
} |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
result_explanation = { |
|
|
"rank": i + 1, |
|
|
"document_id": result.get("chunk_id"), |
|
|
"semantic_score": result.get("score", 0.0), |
|
|
"rerank_score": result.get("rerank_score"), |
|
|
"term_coverage": result.get("term_coverage", 0.0), |
|
|
"metadata": result.get("metadata", {}) |
|
|
} |
|
|
explanation["results"].append(result_explanation) |
|
|
|
|
|
return explanation |