JatsTheAIGen commited on
Commit
13fa6c4
·
1 Parent(s): 9959ea9

Fix: BitsAndBytes compatibility and error handling

Browse files

CRITICAL FIXES:
- Fixed BitsAndBytes kernel registration error handling
- Distinguish between bitsandbytes errors and gated repository errors
- Automatic fallback to loading without quantization if bitsandbytes fails
- Changed fallback model to microsoft/Phi-3-mini-4k-instruct (verified non-gated)

Changes:
- src/local_model_loader.py:
- Better error detection for bitsandbytes vs gated repo errors
- Automatic fallback to no quantization if bitsandbytes fails
- Improved error messages to distinguish error types

- src/llm_router.py:
- Added catch for bitsandbytes errors at router level
- Automatic retry without quantization on bitsandbytes failures

- src/models_config.py:
- Changed fallback from mistralai/Mistral-7B-Instruct-v0.2 to microsoft/Phi-3-mini-4k-instruct
- Phi-3-mini is verified non-gated and smaller (3.8B vs 7B)

Fixes:
- RuntimeError: int8_mm_dequant kernel registration conflict
- ModuleNotFoundError: validate_bnb_backend_availability
- False positive 'gated repository' errors from bitsandbytes failures

Now properly handles:
- BitsAndBytes compatibility issues → fallback to no quantization
- Actual gated repository errors → use fallback model
- Both errors → clear error messages

src/llm_router.py CHANGED
@@ -169,6 +169,30 @@ class LLMRouter:
169
  raise
170
  else:
171
  raise RuntimeError(f"Model {model_id} is a gated repository and no fallback available") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  # Format as chat messages if needed
174
  messages = [{"role": "user", "content": prompt}]
 
169
  raise
170
  else:
171
  raise RuntimeError(f"Model {model_id} is a gated repository and no fallback available") from e
172
+ except (RuntimeError, ModuleNotFoundError, ImportError) as e:
173
+ # Check if this is a bitsandbytes error (not a gated repo error)
174
+ error_str = str(e).lower()
175
+ if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str or "validate_bnb_backend" in error_str:
176
+ logger.warning(f"⚠ BitsAndBytes compatibility issue detected: {e}")
177
+ logger.warning(f"⚠ Model {model_id} will be loaded without quantization")
178
+ # Retry without quantization
179
+ try:
180
+ # Disable quantization for this attempt
181
+ fallback_config = model_config.copy()
182
+ fallback_config["use_4bit_quantization"] = False
183
+ fallback_config["use_8bit_quantization"] = False
184
+ return await self._call_local_model(
185
+ fallback_config,
186
+ prompt,
187
+ task_type,
188
+ **kwargs
189
+ )
190
+ except Exception as retry_error:
191
+ logger.error(f"Failed to load model even without quantization: {retry_error}")
192
+ raise RuntimeError(f"Model loading failed: {retry_error}") from retry_error
193
+ else:
194
+ # Not a bitsandbytes error, re-raise
195
+ raise
196
 
197
  # Format as chat messages if needed
198
  messages = [{"role": "user", "content": prompt}]
src/local_model_loader.py CHANGED
@@ -110,6 +110,7 @@ class LocalModelLoader:
110
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
111
 
112
  # Load tokenizer with cache directory
 
113
  try:
114
  tokenizer = AutoTokenizer.from_pretrained(
115
  base_model_id,
@@ -117,15 +118,27 @@ class LocalModelLoader:
117
  token=self.hf_token if self.hf_token else None,
118
  trust_remote_code=True
119
  )
120
- except GatedRepoError as e:
121
- logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
122
- logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
123
- logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
124
- logger.error(f" Error details: {e}")
125
- raise GatedRepoError(
126
- f"Cannot access gated repository {base_model_id}. "
127
- f"Visit https://huggingface.co/{base_model_id} to request access."
128
- ) from e
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  # Determine quantization config
131
  if load_in_4bit and self.device == "cuda":
@@ -151,46 +164,92 @@ class LocalModelLoader:
151
  quantization_config = None
152
 
153
  # Load model with GPU optimization and cache directory
154
- try:
155
- load_kwargs = {
156
- "cache_dir": self.cache_dir,
157
- "token": self.hf_token if self.hf_token else None,
158
- "trust_remote_code": True
159
- }
160
-
161
- if self.device == "cuda":
162
- load_kwargs.update({
163
- "device_map": "auto", # Automatically uses GPU
164
- "torch_dtype": torch.float16, # Use FP16 for memory efficiency
165
- })
166
- if quantization_config:
167
- if isinstance(quantization_config, dict):
168
- load_kwargs.update(quantization_config)
169
- else:
170
- load_kwargs["quantization_config"] = quantization_config
 
 
 
 
 
 
171
 
172
  model = AutoModelForCausalLM.from_pretrained(
173
  base_model_id,
174
  **load_kwargs
175
  )
176
- else:
177
- load_kwargs.update({
178
- "torch_dtype": torch.float32,
179
- })
180
- model = AutoModelForCausalLM.from_pretrained(
181
- base_model_id,
182
- **load_kwargs
183
- )
184
- model = model.to(self.device)
185
- except GatedRepoError as e:
186
- logger.error(f" Gated Repository Error: Cannot access gated repo {base_model_id}")
187
- logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
188
- logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
189
- logger.error(f" Error details: {e}")
190
- raise GatedRepoError(
191
- f"Cannot access gated repository {base_model_id}. "
192
- f"Visit https://huggingface.co/{base_model_id} to request access."
193
- ) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Ensure padding token is set
196
  if tokenizer.pad_token is None:
 
110
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
111
 
112
  # Load tokenizer with cache directory
113
+ # This will fail with actual GatedRepoError if model is gated
114
  try:
115
  tokenizer = AutoTokenizer.from_pretrained(
116
  base_model_id,
 
118
  token=self.hf_token if self.hf_token else None,
119
  trust_remote_code=True
120
  )
121
+ except Exception as e:
122
+ # Check if this is actually a gated repo error
123
+ error_str = str(e).lower()
124
+ if "gated" in error_str or "authorized" in error_str or "access" in error_str:
125
+ # This might be a gated repo error
126
+ try:
127
+ from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError
128
+ if isinstance(e, RealGatedRepoError):
129
+ logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
130
+ logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
131
+ logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
132
+ logger.error(f" Error details: {e}")
133
+ raise RealGatedRepoError(
134
+ f"Cannot access gated repository {base_model_id}. "
135
+ f"Visit https://huggingface.co/{base_model_id} to request access."
136
+ ) from e
137
+ except ImportError:
138
+ pass
139
+
140
+ # If it's not a gated repo error, re-raise as-is
141
+ raise
142
 
143
  # Determine quantization config
144
  if load_in_4bit and self.device == "cuda":
 
164
  quantization_config = None
165
 
166
  # Load model with GPU optimization and cache directory
167
+ # Try with quantization first, fallback to no quantization if bitsandbytes fails
168
+ load_kwargs = {
169
+ "cache_dir": self.cache_dir,
170
+ "token": self.hf_token if self.hf_token else None,
171
+ "trust_remote_code": True
172
+ }
173
+
174
+ if self.device == "cuda":
175
+ load_kwargs.update({
176
+ "device_map": "auto", # Automatically uses GPU
177
+ "torch_dtype": torch.float16, # Use FP16 for memory efficiency
178
+ })
179
+
180
+ # Try loading with quantization first
181
+ model = None
182
+ quantization_failed = False
183
+
184
+ if quantization_config and self.device == "cuda":
185
+ try:
186
+ if isinstance(quantization_config, dict):
187
+ load_kwargs.update(quantization_config)
188
+ else:
189
+ load_kwargs["quantization_config"] = quantization_config
190
 
191
  model = AutoModelForCausalLM.from_pretrained(
192
  base_model_id,
193
  **load_kwargs
194
  )
195
+ logger.info("✓ Model loaded with quantization")
196
+ except (RuntimeError, ModuleNotFoundError, ImportError) as e:
197
+ error_str = str(e).lower()
198
+ # Check if error is related to bitsandbytes
199
+ if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str or "validate_bnb_backend" in error_str:
200
+ logger.warning(f"⚠ BitsAndBytes error detected: {e}")
201
+ logger.warning("⚠ Falling back to loading without quantization")
202
+ quantization_failed = True
203
+ # Remove quantization config and retry
204
+ load_kwargs.pop("quantization_config", None)
205
+ load_kwargs.pop("load_in_8bit", None)
206
+ load_kwargs.pop("load_in_4bit", None)
207
+ else:
208
+ # Re-raise if it's not a bitsandbytes error
209
+ raise
210
+
211
+ # If quantization failed or not using quantization, load without it
212
+ if model is None:
213
+ try:
214
+ if self.device == "cuda":
215
+ model = AutoModelForCausalLM.from_pretrained(
216
+ base_model_id,
217
+ **load_kwargs
218
+ )
219
+ else:
220
+ load_kwargs.update({
221
+ "torch_dtype": torch.float32,
222
+ })
223
+ model = AutoModelForCausalLM.from_pretrained(
224
+ base_model_id,
225
+ **load_kwargs
226
+ )
227
+ model = model.to(self.device)
228
+ except Exception as e:
229
+ # Check if this is a gated repo error (not bitsandbytes)
230
+ error_str = str(e).lower()
231
+ if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str:
232
+ # BitsAndBytes error - should have been caught earlier
233
+ logger.error(f"❌ Unexpected BitsAndBytes error: {e}")
234
+ raise RuntimeError(f"BitsAndBytes compatibility issue: {e}") from e
235
+
236
+ # Check for actual gated repo error
237
+ try:
238
+ from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError
239
+ if isinstance(e, RealGatedRepoError) or "gated" in error_str or "authorized" in error_str:
240
+ logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
241
+ logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
242
+ logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
243
+ logger.error(f" Error details: {e}")
244
+ raise RealGatedRepoError(
245
+ f"Cannot access gated repository {base_model_id}. "
246
+ f"Visit https://huggingface.co/{base_model_id} to request access."
247
+ ) from e
248
+ except ImportError:
249
+ pass
250
+
251
+ # Re-raise other errors as-is
252
+ raise
253
 
254
  # Ensure padding token is set
255
  if tokenizer.pad_token is None:
src/models_config.py CHANGED
@@ -9,7 +9,7 @@ LLM_CONFIG = {
9
  "task": "general_reasoning",
10
  "max_tokens": 8000, # Reduced from 10000
11
  "temperature": 0.7,
12
- "fallback": "mistralai/Mistral-7B-Instruct-v0.2", # Non-gated fallback model
13
  "is_chat_model": True,
14
  "use_4bit_quantization": True, # Enable 4-bit quantization for 16GB T4
15
  "use_8bit_quantization": False
@@ -29,7 +29,7 @@ LLM_CONFIG = {
29
  "latency_target": "<100ms",
30
  "is_chat_model": True,
31
  "use_4bit_quantization": True,
32
- "fallback": "mistralai/Mistral-7B-Instruct-v0.2" # Non-gated fallback
33
  },
34
  "safety_checker": {
35
  "model_id": "Qwen/Qwen2.5-7B-Instruct", # Same model for all text tasks
@@ -38,7 +38,7 @@ LLM_CONFIG = {
38
  "purpose": "bias_detection",
39
  "is_chat_model": True,
40
  "use_4bit_quantization": True,
41
- "fallback": "mistralai/Mistral-7B-Instruct-v0.2" # Non-gated fallback
42
  }
43
  },
44
  "routing_logic": {
 
9
  "task": "general_reasoning",
10
  "max_tokens": 8000, # Reduced from 10000
11
  "temperature": 0.7,
12
+ "fallback": "microsoft/Phi-3-mini-4k-instruct", # Non-gated fallback model (3.8B, verified non-gated)
13
  "is_chat_model": True,
14
  "use_4bit_quantization": True, # Enable 4-bit quantization for 16GB T4
15
  "use_8bit_quantization": False
 
29
  "latency_target": "<100ms",
30
  "is_chat_model": True,
31
  "use_4bit_quantization": True,
32
+ "fallback": "microsoft/Phi-3-mini-4k-instruct" # Non-gated fallback (3.8B, verified non-gated)
33
  },
34
  "safety_checker": {
35
  "model_id": "Qwen/Qwen2.5-7B-Instruct", # Same model for all text tasks
 
38
  "purpose": "bias_detection",
39
  "is_chat_model": True,
40
  "use_4bit_quantization": True,
41
+ "fallback": "microsoft/Phi-3-mini-4k-instruct" # Non-gated fallback (3.8B, verified non-gated)
42
  }
43
  },
44
  "routing_logic": {