# llm_router.py - NOVITA AI API ONLY import logging import asyncio from typing import Dict, Optional from .models_config import LLM_CONFIG from .config import get_settings # Import OpenAI client for Novita AI API try: from openai import OpenAI OPENAI_AVAILABLE = True except ImportError: OPENAI_AVAILABLE = False logger = logging.getLogger(__name__) logger.error("openai package not available - Novita AI API requires openai package") logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token=None, use_local_models: bool = False): """ Initialize LLM Router with Novita AI API only. Args: hf_token: Not used (kept for backward compatibility) use_local_models: Must be False (local models disabled) """ if use_local_models: raise ValueError("Local models are disabled. Only Novita AI API is supported.") self.settings = get_settings() self.novita_client = None # Validate OpenAI package if not OPENAI_AVAILABLE: raise ImportError( "openai package is required for Novita AI API. " "Install it with: pip install openai>=1.0.0" ) # Validate API key if not self.settings.novita_api_key: raise ValueError( "NOVITA_API_KEY is required. " "Set it in environment variables or .env file" ) # Initialize Novita AI client try: self.novita_client = OpenAI( base_url=self.settings.novita_base_url, api_key=self.settings.novita_api_key, ) logger.info("Novita AI API client initialized") logger.info(f"Base URL: {self.settings.novita_base_url}") logger.info(f"Model: {self.settings.novita_model}") except Exception as e: logger.error(f"Failed to initialize Novita AI client: {e}") raise RuntimeError(f"Could not initialize Novita AI API client: {e}") from e async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Route inference to Novita AI API. Args: task_type: Type of task (general_reasoning, intent_classification, etc.) prompt: Input prompt **kwargs: Additional parameters (max_tokens, temperature, etc.) Returns: Generated text response """ logger.info(f"Routing inference to Novita AI API for task: {task_type}") if not self.novita_client: raise RuntimeError("Novita AI client not initialized") try: # Handle embedding generation (may need special handling) if task_type == "embedding_generation": logger.warning("Embedding generation via Novita API may require special implementation") # For now, use chat completion (may need adjustment based on Novita API capabilities) result = await self._call_novita_api(task_type, prompt, **kwargs) else: result = await self._call_novita_api(task_type, prompt, **kwargs) if result is None: logger.error(f"Novita AI API returned None for task: {task_type}") raise RuntimeError(f"Inference failed for task: {task_type}") logger.info(f"Inference complete for {task_type} (Novita AI API)") return result except Exception as e: logger.error(f"Novita AI API inference failed: {e}", exc_info=True) raise RuntimeError( f"Inference failed for task: {task_type}. " f"Novita AI API error: {e}" ) from e async def _call_novita_api(self, task_type: str, prompt: str, **kwargs) -> Optional[str]: """Call Novita AI API for inference.""" if not self.novita_client: return None # Get model config model_config = self._select_model(task_type) model_name = kwargs.get('model', self.settings.novita_model) # Get optimized parameters requested_max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 4096)) temperature = kwargs.get('temperature', model_config.get('temperature', self.settings.deepseek_r1_temperature)) top_p = kwargs.get('top_p', model_config.get('top_p', 0.95)) stream = kwargs.get('stream', False) # Format prompt according to DeepSeek-R1 best practices formatted_prompt = self._format_deepseek_r1_prompt(prompt, task_type, model_config) # IMPORTANT: Calculate safe max_tokens based on input size max_tokens = self._calculate_safe_max_tokens(formatted_prompt, requested_max_tokens) # IMPORTANT: No system prompt - all instructions in user prompt messages = [{"role": "user", "content": formatted_prompt}] # Build request parameters request_params = { "model": model_name, "messages": messages, "stream": stream, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } try: if stream: # Handle streaming response response_text = "" stream_response = self.novita_client.chat.completions.create(**request_params) for chunk in stream_response: if chunk.choices and len(chunk.choices) > 0: delta = chunk.choices[0].delta if delta and delta.content: response_text += delta.content # Clean up reasoning tags if present response_text = self._clean_reasoning_tags(response_text) logger.info(f"Novita AI API generated response (length: {len(response_text)})") return response_text else: # Handle non-streaming response response = self.novita_client.chat.completions.create(**request_params) if response.choices and len(response.choices) > 0: result = response.choices[0].message.content # Clean up reasoning tags if present result = self._clean_reasoning_tags(result) logger.info(f"Novita AI API generated response (length: {len(result)})") return result else: logger.error("Novita AI API returned empty response") return None except Exception as e: logger.error(f"Error calling Novita AI API: {e}", exc_info=True) raise def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int: """ Calculate safe max_tokens based on input token count and model context window. Args: prompt: Input prompt text requested_max_tokens: Desired max_tokens value Returns: int: Adjusted max_tokens that fits within context window """ # Estimate input tokens (rough: 1 token ≈ 4 characters) # For more accuracy, you could use tiktoken if available input_tokens = len(prompt) // 4 # Get model context window context_window = self.settings.novita_model_context_window # Reserve minimum 100 tokens for safety margin available_tokens = context_window - input_tokens - 100 # Use the smaller of requested or available safe_max_tokens = min(requested_max_tokens, available_tokens) # Ensure minimum of 50 tokens for output safe_max_tokens = max(50, safe_max_tokens) if safe_max_tokens < requested_max_tokens: logger.warning( f"Reduced max_tokens from {requested_max_tokens} to {safe_max_tokens} " f"(input: ~{input_tokens} tokens, context window: {context_window} tokens)" ) return safe_max_tokens def _format_deepseek_r1_prompt(self, prompt: str, task_type: str, model_config: dict) -> str: """ Format prompt according to DeepSeek-R1 best practices: - No system prompt (all instructions in user prompt) - Force reasoning trigger for reasoning tasks - Add math directive for mathematical problems """ formatted_prompt = prompt # Check if we should force reasoning prefix force_reasoning = ( self.settings.deepseek_r1_force_reasoning and model_config.get("force_reasoning_prefix", False) ) if force_reasoning: # Force model to start with reasoning trigger formatted_prompt = f"``\n\n{formatted_prompt}" # Add math directive for mathematical problems if self._is_math_query(prompt): math_directive = "Please reason step by step, and put your final answer within \\boxed{}." formatted_prompt = f"{formatted_prompt}\n\n{math_directive}" return formatted_prompt def _is_math_query(self, prompt: str) -> bool: """Detect if query is mathematical""" math_keywords = [ "solve", "calculate", "compute", "equation", "formula", "mathematical", "algebra", "geometry", "calculus", "integral", "derivative", "theorem", "proof", "problem" ] prompt_lower = prompt.lower() return any(keyword in prompt_lower for keyword in math_keywords) def _clean_reasoning_tags(self, text: str) -> str: """Clean up reasoning tags from response""" text = text.replace("``", "").replace("``", "") text = text.strip() return text def _select_model(self, task_type: str) -> dict: """Select model configuration based on task type""" model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def get_available_models(self): """Get list of available models (Novita AI only)""" return ["Novita AI API - DeepSeek-R1-Distill-Qwen-7B"] async def health_check(self): """Perform health check on Novita AI API""" try: # Test API with a simple request test_response = self.novita_client.chat.completions.create( model=self.settings.novita_model, messages=[{"role": "user", "content": "test"}], max_tokens=5 ) return { "provider": "novita_api", "status": "healthy", "model": self.settings.novita_model, "base_url": self.settings.novita_base_url } except Exception as e: logger.error(f"Health check failed: {e}") return { "provider": "novita_api", "status": "unhealthy", "error": str(e) } def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None, user_input: Optional[str] = None) -> str: """ Smart context windowing with user input priority. User input is NEVER truncated - context is reduced to fit. Args: raw_context: Context dictionary max_tokens: Optional override (uses config default if None) user_input: Optional explicit user input (takes priority over raw_context['user_input']) """ # Use config budget if not provided if max_tokens is None: max_tokens = self.settings.context_preparation_budget # Get user input (explicit parameter takes priority) actual_user_input = user_input or raw_context.get('user_input', '') # Calculate user input tokens (simple estimation: 1 token ≈ 4 chars) user_input_tokens = len(actual_user_input) // 4 # Ensure user input fits within dedicated budget user_input_max = self.settings.user_input_max_tokens if user_input_tokens > user_input_max: logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating") max_chars = user_input_max * 4 actual_user_input = actual_user_input[:max_chars - 3] + "..." user_input_tokens = user_input_max # Reserve space for user input (it has highest priority) remaining_tokens = max_tokens - user_input_tokens if remaining_tokens < 0: logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})") remaining_tokens = 0 logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}") # Priority order for context elements (user input already handled) priority_elements = [ ('recent_interactions', 0.8), ('user_preferences', 0.6), ('session_summary', 0.4), ('historical_context', 0.2) ] formatted_context = [] total_tokens = user_input_tokens # Start with user input tokens # Add user input first (unconditionally, never truncated) if actual_user_input: formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}") # Now add context elements within remaining budget for element, priority in priority_elements: element_key_map = { 'recent_interactions': raw_context.get('interaction_contexts', []), 'user_preferences': raw_context.get('preferences', {}), 'session_summary': raw_context.get('session_context', {}), 'historical_context': raw_context.get('user_context', '') } content = element_key_map.get(element, '') # Convert to string if needed if isinstance(content, dict): content = str(content) elif isinstance(content, list): content = "\n".join([str(item) for item in content[:10]]) if not content: continue # Estimate tokens (simple: 1 token ≈ 4 chars) tokens = len(content) // 4 if total_tokens + tokens <= max_tokens: formatted_context.append(f"=== {element.upper()} ===\n{content}") total_tokens += tokens elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed available = max_tokens - total_tokens if available > 100: # Only truncate if we have meaningful space truncated = self._truncate_to_tokens(content, available) formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") total_tokens += available break logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})") return "\n\n".join(formatted_context) def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: """Truncate content to fit within token limit""" # Simple character-based truncation (1 token ≈ 4 chars) max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars - 3] + "..."