hamxaameer commited on
Commit
dab6cfd
Β·
verified Β·
1 Parent(s): 06dde32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -105
app.py CHANGED
@@ -61,12 +61,11 @@ CONFIG = {
61
  "max_tokens": 600, # Allow natural length responses
62
  }
63
 
64
- # Local LLM configuration for Hugging Face Spaces
65
- # TinyLlama: 1.1B parameters, fast on CPU, reliable generation
66
- # Alternative: google/flan-t5-base (smaller, faster)
67
- LOCAL_LLM_MODEL = os.environ.get("LOCAL_LLM_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
68
- USE_8BIT_QUANTIZATION = False # Not needed for TinyLlama
69
- USE_REMOTE_LLM = False
70
 
71
  # Natural flow mode: No word limits, let model decide length
72
  MAX_CONTEXT_LENGTH = 400 # Reduced for faster generation
@@ -95,14 +94,15 @@ if HF_INFERENCE_API_KEY:
95
  # ============================================================================
96
 
97
  def initialize_llm():
98
- """Initialize TinyLlama model locally with CPU optimizations.
99
 
100
- TinyLlama is fast, reliable, and works well on CPU without device issues.
 
101
  """
102
- global LOCAL_LLM_MODEL, USE_8BIT_QUANTIZATION
103
 
104
- logger.info(f"πŸ”„ Initializing local LLM: {LOCAL_LLM_MODEL}")
105
- logger.info(" Using CPU-optimized configuration for Hugging Face Spaces")
106
 
107
  try:
108
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -111,102 +111,71 @@ def initialize_llm():
111
  device = "cuda" if torch.cuda.is_available() else "cpu"
112
  logger.info(f" Target device: {device}")
113
 
 
 
114
  # Load tokenizer
115
  logger.info(" Loading tokenizer...")
116
- tokenizer = AutoTokenizer.from_pretrained(
117
- LOCAL_LLM_MODEL,
118
- trust_remote_code=True
119
- )
120
 
121
- # Configure tokenizer
122
- if tokenizer.pad_token is None:
123
- tokenizer.pad_token = tokenizer.eos_token
124
- if tokenizer.pad_token_id is None:
125
- tokenizer.pad_token_id = tokenizer.eos_token_id
126
-
127
- logger.info(f" Tokenizer ready: {len(tokenizer)} tokens")
128
-
129
- # Load model - simple CPU configuration
130
- logger.info(" Loading model (20-40 seconds)...")
131
- model = AutoModelForCausalLM.from_pretrained(
132
  LOCAL_LLM_MODEL,
133
- trust_remote_code=True,
134
- torch_dtype=torch.float32,
135
- low_cpu_mem_usage=True
136
  )
137
 
138
- # Move to CPU explicitly
139
  model = model.to('cpu')
 
140
 
141
- # Apply advanced optimizations for faster inference
142
- if hasattr(model, 'config'):
143
- # Reduce attention heads computation for speed
144
- model.config.use_cache = True # Enable KV cache for faster generation
145
- model.config.output_attentions = False
146
- model.config.output_hidden_states = False
147
-
148
- # Move to eval mode to disable dropout and save memory
149
  model.eval()
150
-
151
- # Skip torch.compile - can cause issues on Hugging Face Spaces
152
- logger.info(" Model ready for inference")
153
 
154
  # Store model and tokenizer directly for faster inference
155
  # We'll use direct generation instead of pipeline
156
  logger.info(" Configuring direct model inference (faster than pipeline)...")
157
 
158
- # Create a simple wrapper that mimics pipeline interface
159
- class FastLLMGenerator:
160
  def __init__(self, model, tokenizer):
161
  self.model = model
162
  self.tokenizer = tokenizer
163
 
164
- def __call__(self, prompt, max_new_tokens=150, temperature=0.7, top_p=0.9,
165
- do_sample=True, repetition_penalty=1.1, **kwargs):
166
- """Direct generation - faster and more reliable"""
167
  try:
168
- # Tokenize
169
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
170
- input_ids = inputs["input_ids"].to('cpu')
171
- attention_mask = inputs.get("attention_mask", None)
172
- if attention_mask is not None:
173
- attention_mask = attention_mask.to('cpu')
174
 
175
- # Generate
176
  with torch.no_grad():
177
  outputs = self.model.generate(
178
- input_ids,
179
- attention_mask=attention_mask,
180
  max_new_tokens=max_new_tokens,
181
- temperature=temperature if do_sample else 1.0,
182
- top_p=top_p if do_sample else 1.0,
183
- do_sample=do_sample,
184
- repetition_penalty=repetition_penalty,
185
- pad_token_id=self.tokenizer.pad_token_id,
186
- eos_token_id=self.tokenizer.eos_token_id
187
  )
188
 
189
- # Decode only the new tokens
190
- generated_ids = outputs[0][input_ids.shape[1]:]
191
- generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
192
-
193
  return [{"generated_text": generated_text.strip()}]
194
 
195
  except Exception as e:
196
  logger.error(f"Generation error: {e}")
197
- import traceback
198
- logger.error(traceback.format_exc())
199
  return [{"generated_text": ""}]
200
 
201
- llm_client = FastLLMGenerator(model, tokenizer)
202
- llm_client.tokenizer = tokenizer # Add tokenizer reference for compatibility
203
 
204
  CONFIG["llm_model"] = LOCAL_LLM_MODEL
205
- CONFIG["model_type"] = "tinyllama_local"
206
 
207
- logger.info(f"βœ… LLM initialized successfully: {LOCAL_LLM_MODEL}")
208
- logger.info(f" Model size: 1.1B parameters")
209
- logger.info(f" Expected speed: 5-15 seconds per response on CPU")
210
 
211
  return llm_client
212
 
@@ -222,34 +191,30 @@ def initialize_llm():
222
  raise Exception(f"Failed to initialize LLM: {str(e)}")
223
 
224
 
225
- def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9) -> str:
226
- """Call the Hugging Face Inference API for remote generation. Requires
227
- `HF_INFERENCE_API_KEY` env var to be set and a model name in
228
- `REMOTE_LLM_MODEL`.
229
 
230
- PHI models work best with clear instruction formatting. This function
231
- handles both the standard HF Inference API and PHI-specific response parsing.
232
  """
233
  if not HF_INFERENCE_API_KEY:
234
  raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
235
 
236
- # Use the HF Inference API endpoint (not router for better PHI compatibility)
237
  api_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
238
  headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
239
 
240
- # PHI models prefer simple parameters; avoid return_full_text which can cause issues
241
  payload = {
242
  "inputs": prompt,
243
  "parameters": {
244
  "max_new_tokens": max_new_tokens,
245
  "temperature": temperature,
246
  "top_p": top_p,
247
- "do_sample": True,
248
- "repetition_penalty": 1.1
249
  }
250
  }
251
 
252
- logger.info(f" β†’ Remote PHI inference to {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
253
  try:
254
  r = requests.post(api_url, headers=headers, json=payload, timeout=90)
255
  except Exception as e:
@@ -277,30 +242,26 @@ def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float =
277
  logger.error(f" βœ— Remote inference returned error: {result.get('error')}")
278
  return ""
279
 
280
- # Parse the generated text from various response formats
281
  generated_text = ""
282
 
283
  if isinstance(result, list) and result:
284
- # HF Inference API returns [{"generated_text": "..."}]
285
  first = result[0]
286
  if isinstance(first, dict):
287
  generated_text = first.get("generated_text", "")
288
  else:
289
  generated_text = str(first)
290
- elif isinstance(result, dict) and "generated_text" in result:
291
- generated_text = result["generated_text"]
292
  else:
293
  generated_text = str(result)
294
-
295
- # Clean up: PHI may return the prompt + completion, extract only new text
296
- generated_text = generated_text.strip()
297
 
298
- # If the response contains the original prompt, extract only the new completion
 
299
  if prompt in generated_text:
300
- # Find where the prompt ends and new generation begins
301
- prompt_end = generated_text.find(prompt) + len(prompt)
302
- generated_text = generated_text[prompt_end:].strip()
303
 
 
304
  return generated_text
305
 
306
  def initialize_embeddings():
@@ -714,7 +675,7 @@ def generate_llm_answer(
714
  # Ultra-simple prompt
715
  formatted_prompt = f"{prompt}\n\nAnswer:"
716
 
717
- logger.info(f" β†’ Generating with TinyLlama (max_tokens={max_new_tokens})")
718
 
719
  # MINIMAL settings - most restrictive for speed
720
  out = llm_client(
@@ -777,20 +738,16 @@ def generate_llm_answer(
777
 
778
  A:"""
779
 
780
- # AGGRESSIVE speed optimization
781
  if attempt == 1:
782
- temperature = 0.6 # Lower = faster
783
- max_new_tokens = 150 # Much shorter
784
- top_p = 0.85
785
- repetition_penalty = 1.2
786
  else:
 
787
  temperature = 0.7
788
- max_new_tokens = 180
789
- top_p = 0.9
790
- repetition_penalty = 1.25
791
 
792
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
793
- initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
794
  response = (initial_output or '').strip()
795
 
796
  # Basic sanity checks
 
61
  "max_tokens": 600, # Allow natural length responses
62
  }
63
 
64
+ # LLM Configuration - LOCAL ONLY
65
+ # Using Google Flan-T5: Fast on CPU, reliable, no timeouts
66
+ LOCAL_LLM_MODEL = os.environ.get("LOCAL_LLM_MODEL", "google/flan-t5-base")
67
+ USE_8BIT_QUANTIZATION = False
68
+ USE_REMOTE_LLM = False # LOCAL ONLY
 
69
 
70
  # Natural flow mode: No word limits, let model decide length
71
  MAX_CONTEXT_LENGTH = 400 # Reduced for faster generation
 
94
  # ============================================================================
95
 
96
  def initialize_llm():
97
+ """Initialize Flan-T5 for fast local CPU generation.
98
 
99
+ Flan-T5 is an encoder-decoder model optimized for instruction following.
100
+ Much faster than decoder-only models like TinyLlama on CPU.
101
  """
102
+ global LOCAL_LLM_MODEL
103
 
104
+ logger.info(f"πŸ”„ Initializing Flan-T5: {LOCAL_LLM_MODEL}")
105
+ logger.info(" Optimized for fast CPU inference")
106
 
107
  try:
108
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
111
  device = "cuda" if torch.cuda.is_available() else "cpu"
112
  logger.info(f" Target device: {device}")
113
 
114
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
115
+
116
  # Load tokenizer
117
  logger.info(" Loading tokenizer...")
118
+ tokenizer = T5Tokenizer.from_pretrained(LOCAL_LLM_MODEL)
119
+ logger.info(f" Tokenizer ready")
 
 
120
 
121
+ # Load model - Flan-T5 is much lighter
122
+ logger.info(" Loading model (10-20 seconds)...")
123
+ model = T5ForConditionalGeneration.from_pretrained(
 
 
 
 
 
 
 
 
124
  LOCAL_LLM_MODEL,
125
+ torch_dtype=torch.float32
 
 
126
  )
127
 
 
128
  model = model.to('cpu')
129
+ logger.info(" Model loaded on CPU")
130
 
131
+ # Optimize for inference
 
 
 
 
 
 
 
132
  model.eval()
133
+ logger.info(" Model ready")
 
 
134
 
135
  # Store model and tokenizer directly for faster inference
136
  # We'll use direct generation instead of pipeline
137
  logger.info(" Configuring direct model inference (faster than pipeline)...")
138
 
139
+ # Flan-T5 generator - simple and fast
140
+ class FlanT5Generator:
141
  def __init__(self, model, tokenizer):
142
  self.model = model
143
  self.tokenizer = tokenizer
144
 
145
+ def __call__(self, prompt, max_new_tokens=128, temperature=0.7, **kwargs):
146
+ """Generate with Flan-T5 - fast on CPU"""
 
147
  try:
148
+ # Tokenize input
149
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
150
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
 
 
 
151
 
152
+ # Generate - Flan-T5 is fast even on CPU
153
  with torch.no_grad():
154
  outputs = self.model.generate(
155
+ **inputs,
 
156
  max_new_tokens=max_new_tokens,
157
+ num_beams=2, # Beam search for quality
158
+ early_stopping=True,
159
+ no_repeat_ngram_size=3
 
 
 
160
  )
161
 
162
+ # Decode
163
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
164
  return [{"generated_text": generated_text.strip()}]
165
 
166
  except Exception as e:
167
  logger.error(f"Generation error: {e}")
 
 
168
  return [{"generated_text": ""}]
169
 
170
+ llm_client = FlanT5Generator(model, tokenizer)
171
+ llm_client.tokenizer = tokenizer
172
 
173
  CONFIG["llm_model"] = LOCAL_LLM_MODEL
174
+ CONFIG["model_type"] = "flan_t5_local"
175
 
176
+ logger.info(f"βœ… Flan-T5 initialized: {LOCAL_LLM_MODEL}")
177
+ logger.info(f" Size: ~250M parameters (base model)")
178
+ logger.info(f" Speed: 3-8 seconds per response")
179
 
180
  return llm_client
181
 
 
191
  raise Exception(f"Failed to initialize LLM: {str(e)}")
192
 
193
 
194
+ def remote_generate(prompt: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9) -> str:
195
+ """Call Hugging Face Inference API - fast and reliable.
 
 
196
 
197
+ Uses Qwen2.5 model optimized for fast inference.
 
198
  """
199
  if not HF_INFERENCE_API_KEY:
200
  raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
201
 
202
+ # Use Inference API
203
  api_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
204
  headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
205
 
206
+ # Simple parameters for fast inference
207
  payload = {
208
  "inputs": prompt,
209
  "parameters": {
210
  "max_new_tokens": max_new_tokens,
211
  "temperature": temperature,
212
  "top_p": top_p,
213
+ "return_full_text": False
 
214
  }
215
  }
216
 
217
+ logger.info(f" β†’ Remote inference (tokens={max_new_tokens})")
218
  try:
219
  r = requests.post(api_url, headers=headers, json=payload, timeout=90)
220
  except Exception as e:
 
242
  logger.error(f" βœ— Remote inference returned error: {result.get('error')}")
243
  return ""
244
 
245
+ # Extract generated text
246
  generated_text = ""
247
 
248
  if isinstance(result, list) and result:
 
249
  first = result[0]
250
  if isinstance(first, dict):
251
  generated_text = first.get("generated_text", "")
252
  else:
253
  generated_text = str(first)
254
+ elif isinstance(result, dict):
255
+ generated_text = result.get("generated_text", str(result))
256
  else:
257
  generated_text = str(result)
 
 
 
258
 
259
+ # Clean up
260
+ generated_text = generated_text.strip()
261
  if prompt in generated_text:
262
+ generated_text = generated_text.replace(prompt, "").strip()
 
 
263
 
264
+ logger.info(f" βœ… Generated {len(generated_text.split())} words remotely")
265
  return generated_text
266
 
267
  def initialize_embeddings():
 
675
  # Ultra-simple prompt
676
  formatted_prompt = f"{prompt}\n\nAnswer:"
677
 
678
+ logger.info(f" β†’ Generating with Flan-T5 (max_tokens={max_new_tokens})")
679
 
680
  # MINIMAL settings - most restrictive for speed
681
  out = llm_client(
 
738
 
739
  A:"""
740
 
741
+ # Flan-T5 optimized parameters
742
  if attempt == 1:
743
+ max_new_tokens = 128 # Flan-T5 is concise
744
+ temperature = 0.7
 
 
745
  else:
746
+ max_new_tokens = 150
747
  temperature = 0.7
 
 
 
748
 
749
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
750
+ initial_output = call_model(base_prompt, max_new_tokens, temperature)
751
  response = (initial_output or '').strip()
752
 
753
  # Basic sanity checks