|
|
""" |
|
|
Chat functionality for RAG system. |
|
|
|
|
|
Handles conversation context, retrieval, generation, and streaming responses. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import asyncio |
|
|
import logging |
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
|
|
|
import openai |
|
|
from openai import AsyncOpenAI |
|
|
import tiktoken |
|
|
|
|
|
from .qdrant_client import QdrantManager |
|
|
from .embeddings import EmbeddingGenerator |
|
|
from .retrieval import RetrievalEngine |
|
|
from .models import ( |
|
|
Message, MessageRole, ConversationContext, Citation, |
|
|
ChatRequest, ChatResponse |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ChatHandler: |
|
|
"""Handles chat functionality with RAG retrieval and streaming responses.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
qdrant_manager: QdrantManager, |
|
|
openai_api_key: str, |
|
|
model: str = "gpt-4.1-nano", |
|
|
embedding_model: str = "text-embedding-3-small", |
|
|
max_context_messages: int = 3, |
|
|
context_window_size: int = 4000, |
|
|
max_retries: int = 3 |
|
|
): |
|
|
self.qdrant_manager = qdrant_manager |
|
|
self.model = model |
|
|
self.embedding_model = embedding_model |
|
|
self.max_context_messages = max_context_messages |
|
|
self.context_window_size = context_window_size |
|
|
self.max_retries = max_retries |
|
|
|
|
|
|
|
|
self.openai_client = AsyncOpenAI(api_key=openai_api_key) |
|
|
self.embedder = EmbeddingGenerator( |
|
|
api_key=openai_api_key, |
|
|
model=embedding_model |
|
|
) |
|
|
self.encoding = tiktoken.get_encoding("cl100k_base") |
|
|
|
|
|
|
|
|
self.retrieval_engine = RetrievalEngine( |
|
|
qdrant_manager=qdrant_manager, |
|
|
embedder=self.embedder, |
|
|
score_threshold=0.5, |
|
|
enable_mmr=True, |
|
|
mmr_lambda=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
self.conversations: Dict[str, ConversationContext] = {} |
|
|
|
|
|
def get_adaptive_threshold(self, query_length: int, result_count: int) -> float: |
|
|
""" |
|
|
Get adaptive similarity threshold based on query characteristics. |
|
|
|
|
|
Args: |
|
|
query_length: Length of the query in characters |
|
|
result_count: Number of results found in initial search |
|
|
|
|
|
Returns: |
|
|
Adaptive threshold value |
|
|
""" |
|
|
base_threshold = 0.5 |
|
|
|
|
|
|
|
|
if query_length > 100: |
|
|
return max(0.5, base_threshold - 0.2) |
|
|
|
|
|
|
|
|
if result_count > 20: |
|
|
return min(0.9, base_threshold + 0.2) |
|
|
|
|
|
|
|
|
if result_count < 3: |
|
|
return max(0.5, base_threshold - 0.1) |
|
|
|
|
|
return base_threshold |
|
|
|
|
|
async def stream_chat( |
|
|
self, |
|
|
query: str, |
|
|
session_id: Optional[str] = None, |
|
|
k: int = 5, |
|
|
context_window: Optional[int] = None, |
|
|
filters: Optional[Dict[str, Any]] = None |
|
|
) -> AsyncGenerator[str, None]: |
|
|
""" |
|
|
Stream chat response with Server-Sent Events. |
|
|
|
|
|
Yields JSON-formatted SSE messages. |
|
|
""" |
|
|
start_time = datetime.utcnow() |
|
|
|
|
|
try: |
|
|
|
|
|
if not session_id: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
query_lower = query.strip().lower() |
|
|
greetings = ['hi', 'hello', 'hey', 'yo', 'sup', 'greetings', 'good morning', 'good afternoon', 'good evening', 'assalamualikum', 'salam', 'assalam o alaikum'] |
|
|
|
|
|
if query_lower in greetings or len(query.strip()) <= 2: |
|
|
|
|
|
greeting_responses = [ |
|
|
"Hello! I'm here to help you learn about Physical AI and Humanoid Robotics. What would you like to know?", |
|
|
"Hi there! I can help you with questions about humanoid robots and physical AI. What topic interests you?", |
|
|
"Hey! I'm your AI assistant for the Physical AI & Humanoid Robotics book. How can I assist you today?", |
|
|
"Greetings! Feel free to ask me anything about humanoid robotics, AI, or the content of this book.", |
|
|
"Wa Alaikum Assalam! I'm happy to help you with Physical AI and Humanoid Robotics topics. What would you like to explore?" |
|
|
] |
|
|
|
|
|
import random |
|
|
response_text = random.choice(greeting_responses) |
|
|
|
|
|
|
|
|
yield self._format_sse_message({ |
|
|
"type": "start", |
|
|
"session_id": session_id, |
|
|
"sources": [], |
|
|
"retrieved_docs": 0 |
|
|
}) |
|
|
|
|
|
|
|
|
words = response_text.split() |
|
|
for word in words: |
|
|
yield self._format_sse_message({ |
|
|
"type": "chunk", |
|
|
"content": word + " " |
|
|
}) |
|
|
await asyncio.sleep(0.05) |
|
|
|
|
|
yield self._format_sse_message({ |
|
|
"type": "done", |
|
|
"session_id": session_id, |
|
|
"response_time": 0.1, |
|
|
"tokens_used": self.count_tokens(response_text) |
|
|
}) |
|
|
return |
|
|
|
|
|
|
|
|
context = self._get_or_create_context(session_id) |
|
|
|
|
|
|
|
|
user_message = Message( |
|
|
id=str(uuid.uuid4()), |
|
|
role=MessageRole.USER, |
|
|
content=query, |
|
|
token_count=self.count_tokens(query) |
|
|
) |
|
|
context.add_message(user_message) |
|
|
|
|
|
|
|
|
logger.info(f"Retrieving {k} relevant documents...") |
|
|
|
|
|
|
|
|
retrieved_docs = await self.retrieval_engine.retrieve( |
|
|
query=query, |
|
|
k=k * 3, |
|
|
filters=filters, |
|
|
exclude_templates=True, |
|
|
use_mmr=True |
|
|
) |
|
|
|
|
|
|
|
|
retrieved_docs = retrieved_docs[:k] |
|
|
|
|
|
|
|
|
if not retrieved_docs: |
|
|
|
|
|
if len(query.strip()) < 20: |
|
|
logger.info(f"Short query with no results, retrying with lower threshold...") |
|
|
retrieved_docs = await self.retrieval_engine.retrieve( |
|
|
query=query, |
|
|
k=k, |
|
|
filters=filters, |
|
|
exclude_templates=True, |
|
|
use_mmr=False |
|
|
) |
|
|
retrieved_docs = retrieved_docs[:k] |
|
|
|
|
|
|
|
|
if not retrieved_docs: |
|
|
logger.info(f"No content found for query: {query[:100]}...") |
|
|
|
|
|
|
|
|
no_content_response = ( |
|
|
"I couldn't find specific information about that topic in the book. " |
|
|
"This book covers Physical AI & Humanoid Robotics. Try asking about:\n" |
|
|
"• Introduction to physical AI\n" |
|
|
"• Types of humanoid robots\n" |
|
|
"• AI control systems\n" |
|
|
"• Robot locomotion\n" |
|
|
"• Specific chapters or sections" |
|
|
) |
|
|
|
|
|
|
|
|
words = no_content_response.split() |
|
|
for word in words: |
|
|
yield self._format_sse_message({ |
|
|
"type": "chunk", |
|
|
"content": word + " " |
|
|
}) |
|
|
await asyncio.sleep(0.05) |
|
|
|
|
|
yield self._format_sse_message({ |
|
|
"type": "done", |
|
|
"session_id": session_id, |
|
|
"response_time": 0.1, |
|
|
"tokens_used": self.count_tokens(no_content_response), |
|
|
"no_results": True |
|
|
}) |
|
|
return |
|
|
|
|
|
|
|
|
logger.info( |
|
|
"Retrieval metrics - query_length=%d, retrieved_count=%d, threshold=%.2f, session_id=%s", |
|
|
len(query), |
|
|
len(retrieved_docs), |
|
|
self.retrieval_engine.score_threshold, |
|
|
session_id |
|
|
) |
|
|
|
|
|
|
|
|
scores = [result["similarity_score"] for result in retrieved_docs] |
|
|
if scores: |
|
|
logger.info( |
|
|
"Similarity scores - min=%.3f, max=%.3f, avg=%.3f, count=%d", |
|
|
min(scores), |
|
|
max(scores), |
|
|
sum(scores) / len(scores), |
|
|
len(scores) |
|
|
) |
|
|
|
|
|
|
|
|
citations = [] |
|
|
source_context = [] |
|
|
|
|
|
for i, result in enumerate(retrieved_docs): |
|
|
chunk = result["chunk"] |
|
|
metadata = chunk.metadata |
|
|
|
|
|
citation = Citation( |
|
|
id=str(uuid.uuid4()), |
|
|
chunk_id=chunk.id, |
|
|
document_id=metadata.get("document_id", ""), |
|
|
text_snippet=chunk.content[:200] + "...", |
|
|
relevance_score=result["similarity_score"], |
|
|
chapter=metadata.get("chapter"), |
|
|
section=metadata.get("section_header") or metadata.get("section"), |
|
|
url=metadata.get("url"), |
|
|
confidence=result["similarity_score"] |
|
|
) |
|
|
citations.append(citation) |
|
|
|
|
|
|
|
|
source_text = chunk.content |
|
|
if source_text: |
|
|
source_url = metadata.get("url", "") |
|
|
url_info = f" (URL: {source_url})" if source_url else "" |
|
|
source_context.append(f"[Source {i+1}]{url_info}: {source_text}") |
|
|
|
|
|
|
|
|
context_messages = self._build_context_messages( |
|
|
context, |
|
|
source_context, |
|
|
context_window or self.context_window_size |
|
|
) |
|
|
|
|
|
|
|
|
yield self._format_sse_message({ |
|
|
"type": "start", |
|
|
"session_id": session_id, |
|
|
"sources": [citation.to_markdown() for citation in citations], |
|
|
"retrieved_docs": len(retrieved_docs) |
|
|
}) |
|
|
|
|
|
|
|
|
logger.info("Generating streaming response...") |
|
|
full_response = "" |
|
|
|
|
|
stream = await self.openai_client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=context_messages, |
|
|
stream=True, |
|
|
max_completion_tokens=1000 |
|
|
) |
|
|
|
|
|
async for chunk in stream: |
|
|
if chunk.choices and chunk.choices[0].delta: |
|
|
content = chunk.choices[0].delta.content |
|
|
if content: |
|
|
full_response += content |
|
|
yield self._format_sse_message({ |
|
|
"type": "chunk", |
|
|
"content": content |
|
|
}) |
|
|
|
|
|
|
|
|
assistant_message = Message( |
|
|
id=str(uuid.uuid4()), |
|
|
role=MessageRole.ASSISTANT, |
|
|
content=full_response, |
|
|
token_count=self.count_tokens(full_response), |
|
|
citations=[citation.id for citation in citations] |
|
|
) |
|
|
context.add_message(assistant_message) |
|
|
|
|
|
|
|
|
response_time = (datetime.utcnow() - start_time).total_seconds() |
|
|
yield self._format_sse_message({ |
|
|
"type": "done", |
|
|
"session_id": session_id, |
|
|
"response_time": response_time, |
|
|
"tokens_used": user_message.token_count + assistant_message.token_count |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Chat streaming failed: {str(e)}", exc_info=True) |
|
|
yield self._format_sse_message({ |
|
|
"type": "error", |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
async def chat( |
|
|
self, |
|
|
query: str, |
|
|
session_id: Optional[str] = None, |
|
|
k: int = 5, |
|
|
context_window: Optional[int] = None, |
|
|
filters: Optional[Dict[str, Any]] = None |
|
|
) -> ChatResponse: |
|
|
""" |
|
|
Non-streaming chat response. |
|
|
|
|
|
Returns complete response with citations. |
|
|
""" |
|
|
start_time = datetime.utcnow() |
|
|
|
|
|
try: |
|
|
|
|
|
if not session_id: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
query_lower = query.strip().lower() |
|
|
greetings = ['hi', 'hello', 'hey', 'yo', 'sup', 'greetings', 'good morning', 'good afternoon', 'good evening', 'assalamualikum', 'salam', 'assalam o alaikum'] |
|
|
|
|
|
if query_lower in greetings or len(query.strip()) <= 2: |
|
|
|
|
|
greeting_responses = [ |
|
|
"Hello! I'm here to help you learn about Physical AI and Humanoid Robotics. What would you like to know?", |
|
|
"Hi there! I can help you with questions about humanoid robots and physical AI. What topic interests you?", |
|
|
"Hey! I'm your AI assistant for the Physical AI & Humanoid Robotics book. How can I assist you today?", |
|
|
"Greetings! Feel free to ask me anything about humanoid robotics, AI, or the content of this book.", |
|
|
"Wa Alaikum Assalam! I'm happy to help you with Physical AI and Humanoid Robotics topics. What would you like to explore?" |
|
|
] |
|
|
|
|
|
import random |
|
|
answer = random.choice(greeting_responses) |
|
|
|
|
|
response_time = (datetime.utcnow() - start_time).total_seconds() |
|
|
|
|
|
|
|
|
greeting_response = { |
|
|
"type": "final", |
|
|
"answer": answer, |
|
|
"sources": [], |
|
|
"session_id": session_id, |
|
|
"query": query, |
|
|
"response_time": response_time, |
|
|
"tokens_used": self.count_tokens(answer), |
|
|
"context_used": False, |
|
|
"model": self.model, |
|
|
"has_context": False |
|
|
} |
|
|
yield f"data: {json.dumps(greeting_response)}\n\n" |
|
|
yield f"data: [DONE]\n\n" |
|
|
return |
|
|
|
|
|
|
|
|
context = self._get_or_create_context(session_id) |
|
|
|
|
|
|
|
|
user_message = Message( |
|
|
id=str(uuid.uuid4()), |
|
|
role=MessageRole.USER, |
|
|
content=query, |
|
|
token_count=self.count_tokens(query) |
|
|
) |
|
|
context.add_message(user_message) |
|
|
|
|
|
|
|
|
logger.info(f"Retrieving {k} relevant documents...") |
|
|
retrieved_docs = await self.retrieval_engine.retrieve( |
|
|
query=query, |
|
|
k=k * 3, |
|
|
filters=filters, |
|
|
exclude_templates=True, |
|
|
use_mmr=True |
|
|
) |
|
|
|
|
|
|
|
|
retrieved_docs = retrieved_docs[:k] |
|
|
|
|
|
|
|
|
if not retrieved_docs: |
|
|
|
|
|
if len(query.strip()) < 20: |
|
|
logger.info(f"Short query with no results, retrying with lower threshold...") |
|
|
retrieved_docs = await self.retrieval_engine.retrieve( |
|
|
query=query, |
|
|
k=k, |
|
|
filters=filters, |
|
|
exclude_templates=True, |
|
|
use_mmr=False |
|
|
) |
|
|
retrieved_docs = retrieved_docs[:k] |
|
|
|
|
|
|
|
|
if not retrieved_docs: |
|
|
logger.info(f"No content found for query: {query[:100]}...") |
|
|
|
|
|
|
|
|
no_content_response = ( |
|
|
"I couldn't find specific information about that topic in the book. " |
|
|
"This book covers Physical AI & Humanoid Robotics. Try asking about:\n" |
|
|
"• Introduction to physical AI\n" |
|
|
"• Types of humanoid robots\n" |
|
|
"• AI control systems\n" |
|
|
"• Robot locomotion\n" |
|
|
"• Specific chapters or sections" |
|
|
) |
|
|
|
|
|
|
|
|
words = no_content_response.split() |
|
|
for word in words: |
|
|
yield self._format_sse_message({ |
|
|
"type": "chunk", |
|
|
"content": word + " " |
|
|
}) |
|
|
await asyncio.sleep(0.05) |
|
|
|
|
|
yield self._format_sse_message({ |
|
|
"type": "done", |
|
|
"session_id": session_id, |
|
|
"response_time": 0.1, |
|
|
"tokens_used": self.count_tokens(no_content_response), |
|
|
"no_results": True |
|
|
}) |
|
|
return |
|
|
|
|
|
|
|
|
logger.info( |
|
|
"Retrieval metrics - query_length=%d, retrieved_count=%d, threshold=%.2f, session_id=%s", |
|
|
len(query), |
|
|
len(retrieved_docs), |
|
|
self.retrieval_engine.score_threshold, |
|
|
session_id |
|
|
) |
|
|
|
|
|
|
|
|
scores = [result["similarity_score"] for result in retrieved_docs] |
|
|
if scores: |
|
|
logger.info( |
|
|
"Similarity scores - min=%.3f, max=%.3f, avg=%.3f, count=%d", |
|
|
min(scores), |
|
|
max(scores), |
|
|
sum(scores) / len(scores), |
|
|
len(scores) |
|
|
) |
|
|
|
|
|
|
|
|
citations = [] |
|
|
source_context = [] |
|
|
|
|
|
for result in retrieved_docs: |
|
|
chunk = result["chunk"] |
|
|
metadata = chunk.metadata |
|
|
|
|
|
citation = Citation( |
|
|
id=str(uuid.uuid4()), |
|
|
chunk_id=chunk.id, |
|
|
document_id=metadata.get("document_id", ""), |
|
|
text_snippet=chunk.content[:200] + "...", |
|
|
relevance_score=result["similarity_score"], |
|
|
chapter=metadata.get("chapter"), |
|
|
section=metadata.get("section_header") or metadata.get("section"), |
|
|
confidence=result["similarity_score"] |
|
|
) |
|
|
citations.append(citation) |
|
|
|
|
|
|
|
|
source_text = chunk.content |
|
|
if source_text: |
|
|
source_context.append(f"[Source]: {source_text}") |
|
|
|
|
|
|
|
|
context_messages = self._build_context_messages( |
|
|
context, |
|
|
source_context, |
|
|
context_window or self.context_window_size |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Generating response...") |
|
|
response = await self.openai_client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=context_messages, |
|
|
max_completion_tokens=1000 |
|
|
) |
|
|
|
|
|
answer = response.choices[0].message.content |
|
|
tokens_used = response.usage.total_tokens if response.usage else 0 |
|
|
|
|
|
|
|
|
assistant_message = Message( |
|
|
id=str(uuid.uuid4()), |
|
|
role=MessageRole.ASSISTANT, |
|
|
content=answer, |
|
|
token_count=self.count_tokens(answer), |
|
|
citations=[citation.id for citation in citations] |
|
|
) |
|
|
context.add_message(assistant_message) |
|
|
|
|
|
|
|
|
response_time = (datetime.utcnow() - start_time).total_seconds() |
|
|
|
|
|
|
|
|
def serialize_citation(citation): |
|
|
"""Convert Citation object to JSON-serializable dict.""" |
|
|
return { |
|
|
"id": getattr(citation, 'id', ''), |
|
|
"chunk_id": getattr(citation, 'chunk_id', ''), |
|
|
"document_id": getattr(citation, 'document_id', ''), |
|
|
"text_snippet": getattr(citation, 'text_snippet', ''), |
|
|
"relevance_score": getattr(citation, 'relevance_score', 0), |
|
|
"chapter": getattr(citation, 'chapter', ''), |
|
|
"section": getattr(citation, 'section', ''), |
|
|
"confidence": getattr(citation, 'confidence', 0) |
|
|
} |
|
|
|
|
|
|
|
|
final_response = { |
|
|
"type": "final", |
|
|
"answer": answer, |
|
|
"sources": [serialize_citation(citation) for citation in citations], |
|
|
"session_id": session_id, |
|
|
"query": query, |
|
|
"response_time": response_time, |
|
|
"tokens_used": tokens_used, |
|
|
"model": self.model |
|
|
} |
|
|
yield f"data: {json.dumps(final_response)}\n\n" |
|
|
yield f"data: [DONE]\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Chat failed: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def _get_or_create_context(self, session_id: str) -> ConversationContext: |
|
|
"""Get existing conversation context or create new one.""" |
|
|
if session_id not in self.conversations: |
|
|
self.conversations[session_id] = ConversationContext( |
|
|
session_id=session_id, |
|
|
max_messages=self.max_context_messages, |
|
|
messages=[ |
|
|
Message( |
|
|
id=str(uuid.uuid4()), |
|
|
role=MessageRole.SYSTEM, |
|
|
content=( |
|
|
"You are an AI assistant for the book 'Physical AI and Humanoid Robotics'. " |
|
|
"This book covers topics including physical AI systems, humanoid robots, " |
|
|
"robot sensing, actuation mechanisms, and the convergence of AI with robotics. " |
|
|
"Provide accurate, detailed answers based on the provided book content. " |
|
|
"Always cite your sources. Use the format [Chapter - Section](URL) if a URL is provided in the context; otherwise, use [Chapter - Section]. " |
|
|
"If users ask about topics outside this book (other books, movies, general knowledge), " |
|
|
"politely explain: 'I can only provide information about Physical AI, humanoid robots, " |
|
|
"and the specific topics covered in this book.' " |
|
|
"If the book context doesn't contain relevant information, say so clearly." |
|
|
), |
|
|
token_count=self.count_tokens( |
|
|
"You are an AI assistant for the book 'Physical AI and Humanoid Robotics'. " |
|
|
"This book covers topics including physical AI systems, humanoid robots, " |
|
|
"robot sensing, actuation mechanisms, and the convergence of AI with robotics. " |
|
|
"Provide accurate, detailed answers based on the provided book content. " |
|
|
"Always cite your sources. Use the format [Chapter - Section](URL) if a URL is provided in the context; otherwise, use [Chapter - Section]. " |
|
|
"If users ask about topics outside this book (other books, movies, general knowledge), " |
|
|
"politely explain: 'I can only provide information about Physical AI, humanoid robots, " |
|
|
"and the specific topics covered in this book.' " |
|
|
"If the book context doesn't contain relevant information, say so clearly." |
|
|
) |
|
|
) |
|
|
] |
|
|
) |
|
|
return self.conversations[session_id] |
|
|
|
|
|
def _build_context_messages( |
|
|
self, |
|
|
context: ConversationContext, |
|
|
source_texts: List[str], |
|
|
max_tokens: int |
|
|
) -> List[Dict[str, str]]: |
|
|
"""Build context messages for OpenAI API.""" |
|
|
messages = [] |
|
|
current_tokens = 0 |
|
|
|
|
|
|
|
|
system_msg = context.messages[0] if context.messages else None |
|
|
if system_msg: |
|
|
messages.append({ |
|
|
"role": system_msg.role.value, |
|
|
"content": system_msg.content |
|
|
}) |
|
|
current_tokens += system_msg.token_count |
|
|
|
|
|
|
|
|
if source_texts: |
|
|
context_content = "\n\n".join(source_texts) |
|
|
context_message = { |
|
|
"role": "system", |
|
|
"content": f"Context from the book:\n\n{context_content}" |
|
|
} |
|
|
context_tokens = self.count_tokens(context_content) |
|
|
|
|
|
|
|
|
if current_tokens + context_tokens < max_tokens * 0.6: |
|
|
messages.append(context_message) |
|
|
current_tokens += context_tokens |
|
|
|
|
|
|
|
|
for msg in context.get_context_messages(): |
|
|
if msg.role != MessageRole.SYSTEM: |
|
|
msg_tokens = msg.token_count |
|
|
|
|
|
|
|
|
if current_tokens + msg_tokens < max_tokens * 0.9: |
|
|
messages.append({ |
|
|
"role": msg.role.value, |
|
|
"content": msg.content |
|
|
}) |
|
|
current_tokens += msg_tokens |
|
|
else: |
|
|
|
|
|
break |
|
|
|
|
|
return messages |
|
|
|
|
|
def _format_sse_message(self, data: Dict[str, Any]) -> str: |
|
|
"""Format message for Server-Sent Events.""" |
|
|
return f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
"""Count tokens in text using tiktoken.""" |
|
|
return len(self.encoding.encode(text)) |
|
|
|
|
|
async def clear_context(self, session_id: str): |
|
|
"""Clear conversation context for a session.""" |
|
|
if session_id in self.conversations: |
|
|
del self.conversations[session_id] |
|
|
logger.info(f"Cleared context for session: {session_id}") |
|
|
|
|
|
async def get_context(self, session_id: str) -> Optional[ConversationContext]: |
|
|
"""Get conversation context for a session.""" |
|
|
return self.conversations.get(session_id) |
|
|
|
|
|
async def close(self): |
|
|
"""Close clients and cleanup.""" |
|
|
if self.openai_client: |
|
|
await self.openai_client.close() |
|
|
if self.embedder: |
|
|
await self.embedder.close() |