RetroHackerTerminalUI / model_utls.py
Canstralian's picture
Create model_utls.py
dae0b2e verified
"""Utility functions for handling HuggingFace models."""
import logging
from typing import Optional, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.pipelines.text_generation import TextGenerationPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
"""Initialize model manager."""
self.loaded_models: Dict[str, TextGenerationPipeline] = {}
logger.info("Model manager initialized")
def load_model(self, model_name: str) -> Optional[TextGenerationPipeline]:
"""Load model with minimal configuration."""
if model_name in self.loaded_models:
logger.info(f"Using cached model: {model_name}")
return self.loaded_models[model_name]
try:
logger.info(f"Loading model: {model_name}")
# Load tokenizer with basic settings
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
logger.warning("Pad token not found; setting pad_token to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
# Load model with minimal settings
model = AutoModelForCausalLM.from_pretrained(
model_name,
pad_token_id=tokenizer.pad_token_id
)
# Create text-generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
self.loaded_models[model_name] = pipe
logger.info(f"Successfully loaded model: {model_name}")
return pipe
except OSError as e:
logger.error(f"Model files not found for {model_name}: {str(e)}")
except ValueError as e:
logger.error(f"Invalid configuration for model {model_name}: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while loading model {model_name}: {str(e)}")
return None
def generate_response(self, model_name: str, prompt: str, max_new_tokens: int = 50) -> str:
"""Generate response with basic error handling."""
try:
pipe = self.load_model(model_name)
if pipe is None:
return "Error: Model loading failed"
# Generate response
response = pipe(prompt, max_new_tokens=max_new_tokens, return_full_text=False)
if not response or 'generated_text' not in response[0]:
logger.warning("Unexpected response format from pipeline.")
return "Error: Failed to generate text"
return response[0]['generated_text']
except KeyError as e:
logger.error(f"Key error during text generation: {str(e)}")
return "Error: Missing expected keys in the response"
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"