Text Generation
PEFT
English
security
vulnerability-detection
code-repair
zero-day
exploit-scanner
cybersecurity
sft
qlora
Instructions to use jacobmahon/zero-day-exploit-scanner-fixer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use jacobmahon/zero-day-exploit-scanner-fixer with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| """ | |
| Zero-Day Exploit Scanner & Fixer Model | |
| ======================================= | |
| Fine-tunes Qwen2.5-Coder-7B-Instruct on vulnerability detection + repair data | |
| using SFT with QLoRA (4-bit quantization). | |
| Datasets: MegaVul (C/C++ CVE pairs) + TitanVul (multi-lang) + CleanVul (filtered) | |
| Base model: Qwen/Qwen2.5-Coder-7B-Instruct | |
| Method: QLoRA SFT with structured instruction format | |
| References: | |
| - R2Vul (arxiv:2504.04699): structured reasoning for vuln detection | |
| - MSIVD (arxiv:2406.05892): multi-task instruction tuning | |
| - SecRepair (arxiv:2401.03374): combined detection + repair | |
| - SecureCode: QLoRA r=16, alpha=32, lr=2e-4, 3 epochs | |
| Usage: | |
| pip install transformers trl torch datasets trackio accelerate peft bitsandbytes | |
| python train.py | |
| Hardware: Requires GPU with 24GB+ VRAM (A10G, A100, etc.) | |
| Estimated time: ~6-8 hours on A10G for 3 epochs on ~90K samples | |
| """ | |
| import os | |
| import torch | |
| import trackio | |
| from datasets import load_dataset, concatenate_datasets, Dataset | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig | |
| from transformers import BitsAndBytesConfig | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| OUTPUT_DIR = "./vuln-scanner-fixer" | |
| HUB_MODEL_ID = "jacobmahon/zero-day-exploit-scanner-fixer" | |
| # Training hyperparameters (paper-backed from SecureCode + SecRepair + MSIVD) | |
| LEARNING_RATE = 2e-4 # Higher LR for LoRA (SecureCode: 2e-4) | |
| NUM_EPOCHS = 3 # SecureCode: 3 epochs | |
| BATCH_SIZE = 2 # Per device | |
| GRAD_ACCUM = 8 # Effective batch = 16 | |
| MAX_LENGTH = 2048 # Vuln functions can be long | |
| WARMUP_STEPS = 100 | |
| LORA_R = 16 # SecureCode: rank=16 | |
| LORA_ALPHA = 32 # SecureCode: alpha=32 | |
| LORA_DROPOUT = 0.05 | |
| # ============================================================================ | |
| # Trackio Monitoring | |
| # ============================================================================ | |
| trackio.init( | |
| project="zero-day-exploit-scanner", | |
| name="qwen2.5-coder-7b-vulnfix-qlora", | |
| ) | |
| # ============================================================================ | |
| # Dataset Preparation | |
| # ============================================================================ | |
| print("=" * 60) | |
| print("Loading datasets...") | |
| print("=" * 60) | |
| # 1. MegaVul - C/C++ vulnerability pairs with CVE/CWE labels (17K) | |
| megavul = load_dataset("hitoshura25/megavul", split="train") | |
| print(f"MegaVul: {len(megavul)} samples") | |
| # 2. TitanVul - Multi-language aggregated vulnerability pairs (38K) | |
| titanvul = load_dataset("yikun-li/TitanVul", split="train") | |
| print(f"TitanVul: {len(titanvul)} samples") | |
| # 3. CleanVul - Filtered high-quality vulnerability pairs | |
| cleanvul = load_dataset("yikun-li/CleanVul", split="train") | |
| print(f"CleanVul raw: {len(cleanvul)} samples") | |
| # Filter CleanVul: score >= 1 means likely real vulnerability fix | |
| cleanvul = cleanvul.filter(lambda x: x["vulnerability_score"] is not None and x["vulnerability_score"] >= 1) | |
| print(f"CleanVul filtered (score >= 1): {len(cleanvul)} samples") | |
| # ============================================================================ | |
| # Format datasets into unified instruction format | |
| # ============================================================================ | |
| SYSTEM_PROMPT = """You are a world-class security expert specializing in zero-day vulnerability detection and remediation. When given code, you will: | |
| 1. SCAN: Determine if the code contains a security vulnerability | |
| 2. IDENTIFY: If vulnerable, identify the CWE type and CVE ID if known | |
| 3. EXPLAIN: Provide a clear explanation of the vulnerability mechanism, attack vector, and potential impact | |
| 4. FIX: Provide the corrected code that patches the vulnerability | |
| Always respond in the following structured format: | |
| ## SCAN RESULT | |
| [VULNERABLE / SAFE] | |
| ## VULNERABILITY DETAILS | |
| - **CWE**: [CWE ID and name] | |
| - **CVE**: [CVE ID if known, otherwise "N/A"] | |
| - **Severity**: [CRITICAL / HIGH / MEDIUM / LOW] | |
| ## EXPLANATION | |
| [Detailed explanation of the vulnerability] | |
| ## VULNERABLE LINES | |
| [Specific lines or patterns that are vulnerable] | |
| ## FIXED CODE | |
| ``` | |
| [Corrected code] | |
| ``` | |
| ## FIX EXPLANATION | |
| [What was changed and why]""" | |
| def format_megavul(example): | |
| """Format MegaVul samples into instruction format.""" | |
| vuln_code = example.get("vulnerable_code") | |
| fixed_code = example.get("fixed_code") | |
| cwe_id = example.get("cwe_id", "Unknown") | |
| cve_id = example.get("cve_id", "N/A") | |
| cve_desc = example.get("cve_description", "") | |
| commit_msg = example.get("commit_message", "") | |
| language = example.get("language", "C") | |
| if not vuln_code or not fixed_code: | |
| return {"messages": None} | |
| explanation = cve_desc if cve_desc else commit_msg | |
| if not explanation: | |
| explanation = f"Security vulnerability identified as {cwe_id}" | |
| assistant_response = f"""## SCAN RESULT | |
| VULNERABLE | |
| ## VULNERABILITY DETAILS | |
| - **CWE**: {cwe_id} | |
| - **CVE**: {cve_id} | |
| - **Severity**: HIGH | |
| ## EXPLANATION | |
| {explanation} | |
| ## FIX EXPLANATION | |
| {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} | |
| ## FIXED CODE | |
| ```{language.lower()} | |
| {fixed_code} | |
| ```""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{vuln_code}\n```"}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return {"messages": messages} | |
| def format_titanvul(example): | |
| """Format TitanVul samples into instruction format.""" | |
| func_before = example.get("func_before") | |
| func_after = example.get("func_after") | |
| cwe_id = example.get("cwe_id", "Unknown") | |
| cve_id = example.get("cve_id", "N/A") | |
| cve_desc = example.get("cve_description", "") | |
| commit_msg = example.get("commit_message", "") | |
| extension = example.get("extension", "") | |
| if not func_before or not func_after: | |
| return {"messages": None} | |
| lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", | |
| "py": "Python", "js": "JavaScript", "ts": "TypeScript", | |
| "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} | |
| language = lang_map.get(extension, extension if extension else "code") | |
| explanation = cve_desc if cve_desc else commit_msg | |
| if not explanation: | |
| explanation = f"Security vulnerability identified as {cwe_id}" if cwe_id else "Security vulnerability detected in code" | |
| cwe_display = cwe_id if cwe_id else "Unknown" | |
| cve_display = cve_id if cve_id else "N/A" | |
| assistant_response = f"""## SCAN RESULT | |
| VULNERABLE | |
| ## VULNERABILITY DETAILS | |
| - **CWE**: {cwe_display} | |
| - **CVE**: {cve_display} | |
| - **Severity**: HIGH | |
| ## EXPLANATION | |
| {explanation} | |
| ## FIX EXPLANATION | |
| {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} | |
| ## FIXED CODE | |
| ```{language.lower()} | |
| {func_after} | |
| ```""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return {"messages": messages} | |
| def format_cleanvul(example): | |
| """Format CleanVul samples into instruction format.""" | |
| func_before = example.get("func_before") | |
| func_after = example.get("func_after") | |
| cwe_id = example.get("cwe_id") | |
| cve_id = example.get("cve_id", "N/A") | |
| commit_msg = example.get("commit_msg", "") | |
| extension = example.get("extension", "") | |
| if not func_before or not func_after: | |
| return {"messages": None} | |
| lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", | |
| "py": "Python", "js": "JavaScript", "ts": "TypeScript", | |
| "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} | |
| language = lang_map.get(extension, extension if extension else "code") | |
| if isinstance(cwe_id, list) and len(cwe_id) > 0: | |
| cwe_display = ", ".join(cwe_id) | |
| elif cwe_id: | |
| cwe_display = str(cwe_id) | |
| else: | |
| cwe_display = "Unknown" | |
| cve_display = cve_id if cve_id else "N/A" | |
| explanation = commit_msg if commit_msg else f"Security vulnerability ({cwe_display}) detected and patched." | |
| assistant_response = f"""## SCAN RESULT | |
| VULNERABLE | |
| ## VULNERABILITY DETAILS | |
| - **CWE**: {cwe_display} | |
| - **CVE**: {cve_display} | |
| - **Severity**: HIGH | |
| ## EXPLANATION | |
| {explanation} | |
| ## FIX EXPLANATION | |
| {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} | |
| ## FIXED CODE | |
| ```{language.lower()} | |
| {func_after} | |
| ```""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return {"messages": messages} | |
| print("\nFormatting datasets into instruction format...") | |
| megavul_formatted = megavul.map(format_megavul, remove_columns=megavul.column_names, num_proc=4) | |
| megavul_formatted = megavul_formatted.filter(lambda x: x["messages"] is not None) | |
| print(f"MegaVul formatted: {len(megavul_formatted)} samples") | |
| titanvul_formatted = titanvul.map(format_titanvul, remove_columns=titanvul.column_names, num_proc=4) | |
| titanvul_formatted = titanvul_formatted.filter(lambda x: x["messages"] is not None) | |
| print(f"TitanVul formatted: {len(titanvul_formatted)} samples") | |
| cleanvul_formatted = cleanvul.map(format_cleanvul, remove_columns=cleanvul.column_names, num_proc=4) | |
| cleanvul_formatted = cleanvul_formatted.filter(lambda x: x["messages"] is not None) | |
| print(f"CleanVul formatted: {len(cleanvul_formatted)} samples") | |
| # Combine all datasets | |
| combined = concatenate_datasets([megavul_formatted, titanvul_formatted, cleanvul_formatted]) | |
| combined = combined.shuffle(seed=42) | |
| print(f"\nTotal combined dataset: {len(combined)} samples") | |
| # ============================================================================ | |
| # Add SAFE code examples to reduce false positives | |
| # ============================================================================ | |
| print("\nCreating safe code examples (negative samples)...") | |
| def create_safe_sample(example): | |
| """Use fixed code as safe example to reduce false positives.""" | |
| func_after = example.get("func_after") | |
| extension = example.get("extension", "") | |
| if not func_after: | |
| return {"messages": None} | |
| lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", | |
| "py": "Python", "js": "JavaScript", "ts": "TypeScript", | |
| "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} | |
| language = lang_map.get(extension, extension if extension else "code") | |
| assistant_response = """## SCAN RESULT | |
| SAFE | |
| ## VULNERABILITY DETAILS | |
| No security vulnerabilities detected in this code. | |
| ## EXPLANATION | |
| The code follows secure coding practices. No known vulnerability patterns (buffer overflows, injection flaws, authentication bypasses, race conditions, etc.) were identified in this code segment.""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_after}\n```"}, | |
| {"role": "assistant", "content": assistant_response}, | |
| ] | |
| return {"messages": messages} | |
| # Take subset of TitanVul fixed code as safe examples (~15% of vulnerable samples) | |
| safe_subset = titanvul.shuffle(seed=123).select(range(min(12000, len(titanvul)))) | |
| safe_formatted = safe_subset.map(create_safe_sample, remove_columns=titanvul.column_names, num_proc=4) | |
| safe_formatted = safe_formatted.filter(lambda x: x["messages"] is not None) | |
| print(f"Safe (negative) samples: {len(safe_formatted)}") | |
| # Combine all data | |
| all_data = concatenate_datasets([combined, safe_formatted]).shuffle(seed=42) | |
| print(f"Total dataset (vuln + safe): {len(all_data)} samples") | |
| # Create train/eval split (95/5) | |
| split = all_data.train_test_split(test_size=0.05, seed=42) | |
| train_dataset = split["train"] | |
| eval_dataset = split["test"] | |
| print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") | |
| # ============================================================================ | |
| # Model + QLoRA Setup | |
| # ============================================================================ | |
| print("\n" + "=" * 60) | |
| print("Setting up model with QLoRA...") | |
| print("=" * 60) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| peft_config = LoraConfig( | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| lora_dropout=LORA_DROPOUT, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| ) | |
| # ============================================================================ | |
| # Training Configuration | |
| # ============================================================================ | |
| training_args = SFTConfig( | |
| output_dir=OUTPUT_DIR, | |
| hub_model_id=HUB_MODEL_ID, | |
| push_to_hub=True, | |
| # Training hyperparameters | |
| num_train_epochs=NUM_EPOCHS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| learning_rate=LEARNING_RATE, | |
| lr_scheduler_type="cosine", | |
| warmup_steps=WARMUP_STEPS, | |
| # Precision & memory | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| max_length=MAX_LENGTH, | |
| # Logging | |
| logging_steps=10, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to="none", | |
| # Evaluation | |
| eval_strategy="steps", | |
| eval_steps=500, | |
| # Saving | |
| save_strategy="steps", | |
| save_steps=500, | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| # SFT specific | |
| packing=False, | |
| assistant_only_loss=True, | |
| # Model loading kwargs | |
| model_init_kwargs={ | |
| "quantization_config": bnb_config, | |
| "torch_dtype": torch.bfloat16, | |
| "attn_implementation": "eager", | |
| }, | |
| ) | |
| # ============================================================================ | |
| # Initialize Trainer | |
| # ============================================================================ | |
| print("Initializing SFTTrainer...") | |
| trainer = SFTTrainer( | |
| model=MODEL_NAME, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| peft_config=peft_config, | |
| ) | |
| print(f"Model parameters: {trainer.model.num_parameters():,}") | |
| print(f"Trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}") | |
| # ============================================================================ | |
| # Train! | |
| # ============================================================================ | |
| print("\n" + "=" * 60) | |
| print("Starting training...") | |
| print("=" * 60) | |
| trainer.train() | |
| # ============================================================================ | |
| # Save & Push to Hub | |
| # ============================================================================ | |
| print("\n" + "=" * 60) | |
| print("Saving model and pushing to Hub...") | |
| print("=" * 60) | |
| trainer.save_model() | |
| trainer.push_to_hub(commit_message="Train zero-day exploit scanner & fixer (QLoRA on Qwen2.5-Coder-7B)") | |
| print("\n" + "=" * 60) | |
| print(f"Model saved to Hub: https://huggingface.co/{HUB_MODEL_ID}") | |
| print("=" * 60) | |
| trackio.log({"status": "training_complete"}) | |