hamxaameer commited on
Commit
7eb2f2d
Β·
verified Β·
1 Parent(s): 2e58050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -60
app.py CHANGED
@@ -40,9 +40,9 @@ CONFIG = {
40
  "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
41
  "llm_model": None,
42
  "vector_store_path": ".",
43
- "top_k": 10, # Reduced for faster retrieval
44
- "temperature": 0.75,
45
- "max_tokens": 300, # Reduced for faster generation
46
  }
47
 
48
  # Local PHI model configuration for Hugging Face Spaces
@@ -52,9 +52,11 @@ LOCAL_PHI_MODEL = os.environ.get("LOCAL_PHI_MODEL", "microsoft/phi-2")
52
  USE_8BIT_QUANTIZATION = True # Reduces memory usage by ~50%
53
  USE_REMOTE_LLM = False
54
 
55
- # Generation optimization for speed
56
- MAX_CONTEXT_LENGTH = 800 # Reduce context to speed up generation
57
- TARGET_ANSWER_WORDS = 280 # Shorter target for faster responses
 
 
58
 
59
  # Prefer the environment variable, but also allow a local token file for users
60
  # who don't know how to set env vars. Create a file named `hf_token.txt` in the
@@ -125,16 +127,31 @@ def initialize_llm():
125
  logger.warning(f" 8-bit quantization unavailable: {quant_error}")
126
  logger.info(" Falling back to float32 (will use more memory)")
127
 
128
- # Load the model
129
  logger.info(" Loading PHI model (this may take 30-60 seconds)...")
130
  model = AutoModelForCausalLM.from_pretrained(
131
  LOCAL_PHI_MODEL,
132
  **model_kwargs
133
  )
134
 
 
 
 
 
 
 
 
135
  # Move to eval mode to disable dropout and save memory
136
  model.eval()
137
 
 
 
 
 
 
 
 
 
138
  # Create pipeline for generation
139
  # NOTE: When using accelerate/quantization, do NOT specify device parameter
140
  logger.info(" Creating text-generation pipeline...")
@@ -142,8 +159,9 @@ def initialize_llm():
142
  "text-generation",
143
  model=model,
144
  tokenizer=tokenizer,
145
- max_new_tokens=512,
146
- pad_token_id=tokenizer.eos_token_id
 
147
  )
148
 
149
  CONFIG["llm_model"] = LOCAL_PHI_MODEL
@@ -568,14 +586,19 @@ Enhanced answer:
568
  def retrieve_knowledge_langchain(
569
  query: str,
570
  vectorstore,
571
- top_k: int = 15
572
  ) -> Tuple[List[Document], float]:
573
  logger.info(f"πŸ” Retrieving knowledge for: '{query}'")
574
 
575
- query_variants = [
576
- query,
577
- f"fashion advice clothing outfit style for {query}",
578
- ]
 
 
 
 
 
579
 
580
  all_docs = []
581
 
@@ -645,26 +668,33 @@ def generate_llm_answer(
645
  scored_docs.sort(key=lambda x: x[1], reverse=True)
646
  top_docs = [doc[0] for doc in scored_docs[:8]]
647
 
 
648
  context_parts = []
649
- for doc in top_docs:
650
  content = doc.page_content.strip()
651
- if len(content) > 300:
652
- content = content[:300] + "..."
653
  context_parts.append(content)
654
 
655
- context_text = "\n\n".join(context_parts)
656
 
657
- # Optimized for speed: shorter context, shorter target, fewer iterations
658
- # This significantly reduces generation time on CPU
659
- target_min_words = 250
660
- target_max_words = 350
661
- chunk_target_words = 120
662
- max_iterations = 2
 
 
 
 
 
 
663
 
664
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
665
  logger.info(f" β†’ PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
666
  try:
667
- # Call local PHI model (causal LM)
668
  out = llm_client(
669
  prompt,
670
  max_new_tokens=max_new_tokens,
@@ -674,7 +704,10 @@ def generate_llm_answer(
674
  repetition_penalty=repetition_penalty,
675
  num_return_sequences=1,
676
  pad_token_id=llm_client.tokenizer.eos_token_id,
677
- eos_token_id=llm_client.tokenizer.eos_token_id
 
 
 
678
  )
679
 
680
  # Extract generated text from pipeline output
@@ -694,30 +727,24 @@ def generate_llm_answer(
694
  logger.error(f" βœ— PHI model call error: {e}")
695
  return ''
696
 
697
- # Build initial prompt - optimized for speed with shorter context
698
- base_prompt = f"""Answer this fashion question with practical advice in ~{target_min_words} words.
699
 
700
- Question: {query}
701
 
702
- Key information:
703
- {context_text[:600]}
704
 
705
- Provide a clear, helpful answer with specific recommendations.
706
-
707
- Answer:
708
- """
709
-
710
- # Optimized parameters for faster CPU generation
711
  if attempt == 1:
712
- temperature = 0.75
713
- max_new_tokens = 400 # Reduced for speed
714
- top_p = 0.90
715
- repetition_penalty = 1.1
716
  else:
717
- temperature = 0.85
718
- max_new_tokens = 500
719
- top_p = 0.92
720
- repetition_penalty = 1.12
721
 
722
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
723
  response = (initial_output or '').strip()
@@ -730,6 +757,14 @@ Answer:
730
  words = response.split()
731
  word_count = len(words)
732
 
 
 
 
 
 
 
 
 
733
  # If single-shot succeeded, validate length and return
734
  if word_count >= target_min_words:
735
  if word_count > target_max_words:
@@ -738,6 +773,15 @@ Answer:
738
  logger.info(f" βœ… Single-shot generated {word_count} words")
739
  return response
740
 
 
 
 
 
 
 
 
 
 
741
  # Otherwise, try iterative continuation to build up to the target
742
  accumulated = response
743
  prev_word_count = word_count
@@ -823,30 +867,37 @@ def generate_answer_langchain(
823
  if not retrieved_docs:
824
  return "I couldn't find relevant information to answer your question."
825
 
 
 
 
 
826
  llm_answer = None
827
- for attempt in range(1, 3):
828
- logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/2")
829
  llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt)
830
 
831
  if llm_answer:
832
  logger.info(f" βœ… LLM answer generated successfully")
833
  break
834
  else:
835
- logger.warning(f" β†’ Attempt {attempt}/2 failed, retrying...")
 
836
 
837
  if not llm_answer:
838
- logger.error(f" βœ— All 2 LLM attempts failed")
839
- # Try scaffold-and-polish as a fallback strategy
840
- try:
841
- logger.info(" β†’ Attempting scaffold-and-polish using PHI model")
842
- polished = scaffold_and_polish(query, retrieved_docs, llm_client)
843
- if polished:
844
- logger.info(" βœ… Scaffold-and-polish produced an answer")
845
- return polished
846
- except Exception as e:
847
- logger.error(f" βœ— Scaffold-and-polish error: {e}")
848
-
849
- # Final fallback: extractive templated answer (guaranteed deterministic)
 
 
850
  try:
851
  logger.info(" β†’ Using extractive fallback generator")
852
  fallback = generate_extractive_answer(query, retrieved_docs)
 
40
  "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
41
  "llm_model": None,
42
  "vector_store_path": ".",
43
+ "top_k": 8, # Minimal retrieval for speed
44
+ "temperature": 0.85, # Higher for faster sampling
45
+ "max_tokens": 280, # Aggressive reduction
46
  }
47
 
48
  # Local PHI model configuration for Hugging Face Spaces
 
52
  USE_8BIT_QUANTIZATION = True # Reduces memory usage by ~50%
53
  USE_REMOTE_LLM = False
54
 
55
+ # Advanced optimization settings for FAST generation
56
+ MAX_CONTEXT_LENGTH = 500 # Minimal context for speed
57
+ TARGET_ANSWER_WORDS = 220 # Shorter answers = faster generation
58
+ USE_CACHING = True # Cache model outputs for repeated patterns
59
+ ENABLE_FAST_MODE = True # Skip iterative generation, use single-shot only
60
 
61
  # Prefer the environment variable, but also allow a local token file for users
62
  # who don't know how to set env vars. Create a file named `hf_token.txt` in the
 
127
  logger.warning(f" 8-bit quantization unavailable: {quant_error}")
128
  logger.info(" Falling back to float32 (will use more memory)")
129
 
130
+ # Load the model with optimization
131
  logger.info(" Loading PHI model (this may take 30-60 seconds)...")
132
  model = AutoModelForCausalLM.from_pretrained(
133
  LOCAL_PHI_MODEL,
134
  **model_kwargs
135
  )
136
 
137
+ # Apply advanced optimizations for faster inference
138
+ if hasattr(model, 'config'):
139
+ # Reduce attention heads computation for speed
140
+ model.config.use_cache = True # Enable KV cache for faster generation
141
+ model.config.output_attentions = False
142
+ model.config.output_hidden_states = False
143
+
144
  # Move to eval mode to disable dropout and save memory
145
  model.eval()
146
 
147
+ # Advanced: Try to optimize with torch.compile (PyTorch 2.0+)
148
+ try:
149
+ if hasattr(torch, 'compile') and not USE_8BIT_QUANTIZATION:
150
+ logger.info(" Applying torch.compile for faster inference...")
151
+ model = torch.compile(model, mode="reduce-overhead")
152
+ except Exception as compile_error:
153
+ logger.info(f" Torch compile not available or failed: {compile_error}")
154
+
155
  # Create pipeline for generation
156
  # NOTE: When using accelerate/quantization, do NOT specify device parameter
157
  logger.info(" Creating text-generation pipeline...")
 
159
  "text-generation",
160
  model=model,
161
  tokenizer=tokenizer,
162
+ max_new_tokens=280, # Default optimized value
163
+ pad_token_id=tokenizer.eos_token_id,
164
+ batch_size=1 # Single batch for optimal CPU performance
165
  )
166
 
167
  CONFIG["llm_model"] = LOCAL_PHI_MODEL
 
586
  def retrieve_knowledge_langchain(
587
  query: str,
588
  vectorstore,
589
+ top_k: int = 8
590
  ) -> Tuple[List[Document], float]:
591
  logger.info(f"πŸ” Retrieving knowledge for: '{query}'")
592
 
593
+ # Fast mode: single query only (no variants)
594
+ global ENABLE_FAST_MODE
595
+ if ENABLE_FAST_MODE:
596
+ query_variants = [query]
597
+ else:
598
+ query_variants = [
599
+ query,
600
+ f"fashion advice clothing outfit style for {query}",
601
+ ]
602
 
603
  all_docs = []
604
 
 
668
  scored_docs.sort(key=lambda x: x[1], reverse=True)
669
  top_docs = [doc[0] for doc in scored_docs[:8]]
670
 
671
+ # Ultra-fast context preparation: only use top 4 docs, very short snippets
672
  context_parts = []
673
+ for doc in top_docs[:4]: # Reduced from 8 to 4
674
  content = doc.page_content.strip()
675
+ if len(content) > 200: # Reduced from 300 to 200
676
+ content = content[:200] + "..."
677
  context_parts.append(content)
678
 
679
+ context_text = "\n".join(context_parts) # Single newline instead of double
680
 
681
+ # Ultra-fast mode: minimal words, no iterations
682
+ global ENABLE_FAST_MODE
683
+ if ENABLE_FAST_MODE:
684
+ target_min_words = 180 # Much shorter
685
+ target_max_words = 280
686
+ chunk_target_words = 0 # No continuations
687
+ max_iterations = 0 # No iterations
688
+ else:
689
+ target_min_words = 250
690
+ target_max_words = 350
691
+ chunk_target_words = 120
692
+ max_iterations = 2
693
 
694
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
695
  logger.info(f" β†’ PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
696
  try:
697
+ # Call local PHI model with speed optimizations
698
  out = llm_client(
699
  prompt,
700
  max_new_tokens=max_new_tokens,
 
704
  repetition_penalty=repetition_penalty,
705
  num_return_sequences=1,
706
  pad_token_id=llm_client.tokenizer.eos_token_id,
707
+ eos_token_id=llm_client.tokenizer.eos_token_id,
708
+ num_beams=1, # Greedy/sampling is faster than beam search
709
+ early_stopping=True, # Stop as soon as EOS is generated
710
+ use_cache=True # Use KV cache for speed
711
  )
712
 
713
  # Extract generated text from pipeline output
 
727
  logger.error(f" βœ— PHI model call error: {e}")
728
  return ''
729
 
730
+ # Ultra-compact prompt for maximum speed
731
+ base_prompt = f"""Q: {query}
732
 
733
+ Context: {context_text[:400]}
734
 
735
+ A:"""
 
736
 
737
+ # Aggressive speed optimization: fewer tokens, higher temperature for faster sampling
 
 
 
 
 
738
  if attempt == 1:
739
+ temperature = 0.85 # Higher = faster sampling
740
+ max_new_tokens = 280 # Reduced significantly
741
+ top_p = 0.88
742
+ repetition_penalty = 1.08
743
  else:
744
+ temperature = 0.90
745
+ max_new_tokens = 320
746
+ top_p = 0.90
747
+ repetition_penalty = 1.10
748
 
749
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
750
  response = (initial_output or '').strip()
 
757
  words = response.split()
758
  word_count = len(words)
759
 
760
+ # Fast mode: accept shorter answers immediately
761
+ if ENABLE_FAST_MODE and word_count >= 150:
762
+ if word_count > target_max_words:
763
+ response = ' '.join(words[:target_max_words]) + '...'
764
+ word_count = target_max_words
765
+ logger.info(f" βœ… Fast-mode generated {word_count} words")
766
+ return response
767
+
768
  # If single-shot succeeded, validate length and return
769
  if word_count >= target_min_words:
770
  if word_count > target_max_words:
 
773
  logger.info(f" βœ… Single-shot generated {word_count} words")
774
  return response
775
 
776
+ # Skip iterations in fast mode
777
+ if ENABLE_FAST_MODE or max_iterations == 0:
778
+ if word_count >= 120: # Accept even shorter in fast mode
779
+ logger.info(f" βœ… Fast-mode accepted {word_count} words")
780
+ return response
781
+ # If too short, return None to trigger fallback
782
+ logger.warning(f" βœ— Output too short ({word_count} words), trying fallback")
783
+ return None
784
+
785
  # Otherwise, try iterative continuation to build up to the target
786
  accumulated = response
787
  prev_word_count = word_count
 
867
  if not retrieved_docs:
868
  return "I couldn't find relevant information to answer your question."
869
 
870
+ # Fast mode: single attempt only
871
+ global ENABLE_FAST_MODE
872
+ max_attempts = 1 if ENABLE_FAST_MODE else 2
873
+
874
  llm_answer = None
875
+ for attempt in range(1, max_attempts + 1):
876
+ logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/{max_attempts}")
877
  llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt)
878
 
879
  if llm_answer:
880
  logger.info(f" βœ… LLM answer generated successfully")
881
  break
882
  else:
883
+ if attempt < max_attempts:
884
+ logger.warning(f" β†’ Attempt {attempt}/{max_attempts} failed, retrying...")
885
 
886
  if not llm_answer:
887
+ logger.error(f" βœ— All {max_attempts} LLM attempts failed")
888
+
889
+ # In fast mode, skip scaffold-and-polish and go straight to extractive
890
+ if not ENABLE_FAST_MODE:
891
+ try:
892
+ logger.info(" β†’ Attempting scaffold-and-polish using PHI model")
893
+ polished = scaffold_and_polish(query, retrieved_docs, llm_client)
894
+ if polished:
895
+ logger.info(" βœ… Scaffold-and-polish produced an answer")
896
+ return polished
897
+ except Exception as e:
898
+ logger.error(f" βœ— Scaffold-and-polish error: {e}")
899
+
900
+ # Final fallback: extractive templated answer (guaranteed deterministic & FAST)
901
  try:
902
  logger.info(" β†’ Using extractive fallback generator")
903
  fallback = generate_extractive_answer(query, retrieved_docs)