""" Zero-Day Exploit Scanner & Fixer Model ======================================= Fine-tunes Qwen2.5-Coder-7B-Instruct on vulnerability detection + repair data using SFT with QLoRA (4-bit quantization). Datasets: MegaVul (C/C++ CVE pairs) + TitanVul (multi-lang) + CleanVul (filtered) Base model: Qwen/Qwen2.5-Coder-7B-Instruct Method: QLoRA SFT with structured instruction format References: - R2Vul (arxiv:2504.04699): structured reasoning for vuln detection - MSIVD (arxiv:2406.05892): multi-task instruction tuning - SecRepair (arxiv:2401.03374): combined detection + repair - SecureCode: QLoRA r=16, alpha=32, lr=2e-4, 3 epochs Usage: pip install transformers trl torch datasets trackio accelerate peft bitsandbytes python train.py Hardware: Requires GPU with 24GB+ VRAM (A10G, A100, etc.) Estimated time: ~6-8 hours on A10G for 3 epochs on ~90K samples """ import os import torch import trackio from datasets import load_dataset, concatenate_datasets, Dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig from transformers import BitsAndBytesConfig # ============================================================================ # Configuration # ============================================================================ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" OUTPUT_DIR = "./vuln-scanner-fixer" HUB_MODEL_ID = "jacobmahon/zero-day-exploit-scanner-fixer" # Training hyperparameters (paper-backed from SecureCode + SecRepair + MSIVD) LEARNING_RATE = 2e-4 # Higher LR for LoRA (SecureCode: 2e-4) NUM_EPOCHS = 3 # SecureCode: 3 epochs BATCH_SIZE = 2 # Per device GRAD_ACCUM = 8 # Effective batch = 16 MAX_LENGTH = 2048 # Vuln functions can be long WARMUP_STEPS = 100 LORA_R = 16 # SecureCode: rank=16 LORA_ALPHA = 32 # SecureCode: alpha=32 LORA_DROPOUT = 0.05 # ============================================================================ # Trackio Monitoring # ============================================================================ trackio.init( project="zero-day-exploit-scanner", name="qwen2.5-coder-7b-vulnfix-qlora", ) # ============================================================================ # Dataset Preparation # ============================================================================ print("=" * 60) print("Loading datasets...") print("=" * 60) # 1. MegaVul - C/C++ vulnerability pairs with CVE/CWE labels (17K) megavul = load_dataset("hitoshura25/megavul", split="train") print(f"MegaVul: {len(megavul)} samples") # 2. TitanVul - Multi-language aggregated vulnerability pairs (38K) titanvul = load_dataset("yikun-li/TitanVul", split="train") print(f"TitanVul: {len(titanvul)} samples") # 3. CleanVul - Filtered high-quality vulnerability pairs cleanvul = load_dataset("yikun-li/CleanVul", split="train") print(f"CleanVul raw: {len(cleanvul)} samples") # Filter CleanVul: score >= 1 means likely real vulnerability fix cleanvul = cleanvul.filter(lambda x: x["vulnerability_score"] is not None and x["vulnerability_score"] >= 1) print(f"CleanVul filtered (score >= 1): {len(cleanvul)} samples") # ============================================================================ # Format datasets into unified instruction format # ============================================================================ SYSTEM_PROMPT = """You are a world-class security expert specializing in zero-day vulnerability detection and remediation. When given code, you will: 1. SCAN: Determine if the code contains a security vulnerability 2. IDENTIFY: If vulnerable, identify the CWE type and CVE ID if known 3. EXPLAIN: Provide a clear explanation of the vulnerability mechanism, attack vector, and potential impact 4. FIX: Provide the corrected code that patches the vulnerability Always respond in the following structured format: ## SCAN RESULT [VULNERABLE / SAFE] ## VULNERABILITY DETAILS - **CWE**: [CWE ID and name] - **CVE**: [CVE ID if known, otherwise "N/A"] - **Severity**: [CRITICAL / HIGH / MEDIUM / LOW] ## EXPLANATION [Detailed explanation of the vulnerability] ## VULNERABLE LINES [Specific lines or patterns that are vulnerable] ## FIXED CODE ``` [Corrected code] ``` ## FIX EXPLANATION [What was changed and why]""" def format_megavul(example): """Format MegaVul samples into instruction format.""" vuln_code = example.get("vulnerable_code") fixed_code = example.get("fixed_code") cwe_id = example.get("cwe_id", "Unknown") cve_id = example.get("cve_id", "N/A") cve_desc = example.get("cve_description", "") commit_msg = example.get("commit_message", "") language = example.get("language", "C") if not vuln_code or not fixed_code: return {"messages": None} explanation = cve_desc if cve_desc else commit_msg if not explanation: explanation = f"Security vulnerability identified as {cwe_id}" assistant_response = f"""## SCAN RESULT VULNERABLE ## VULNERABILITY DETAILS - **CWE**: {cwe_id} - **CVE**: {cve_id} - **Severity**: HIGH ## EXPLANATION {explanation} ## FIX EXPLANATION {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} ## FIXED CODE ```{language.lower()} {fixed_code} ```""" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{vuln_code}\n```"}, {"role": "assistant", "content": assistant_response}, ] return {"messages": messages} def format_titanvul(example): """Format TitanVul samples into instruction format.""" func_before = example.get("func_before") func_after = example.get("func_after") cwe_id = example.get("cwe_id", "Unknown") cve_id = example.get("cve_id", "N/A") cve_desc = example.get("cve_description", "") commit_msg = example.get("commit_message", "") extension = example.get("extension", "") if not func_before or not func_after: return {"messages": None} lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", "py": "Python", "js": "JavaScript", "ts": "TypeScript", "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} language = lang_map.get(extension, extension if extension else "code") explanation = cve_desc if cve_desc else commit_msg if not explanation: explanation = f"Security vulnerability identified as {cwe_id}" if cwe_id else "Security vulnerability detected in code" cwe_display = cwe_id if cwe_id else "Unknown" cve_display = cve_id if cve_id else "N/A" assistant_response = f"""## SCAN RESULT VULNERABLE ## VULNERABILITY DETAILS - **CWE**: {cwe_display} - **CVE**: {cve_display} - **Severity**: HIGH ## EXPLANATION {explanation} ## FIX EXPLANATION {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} ## FIXED CODE ```{language.lower()} {func_after} ```""" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"}, {"role": "assistant", "content": assistant_response}, ] return {"messages": messages} def format_cleanvul(example): """Format CleanVul samples into instruction format.""" func_before = example.get("func_before") func_after = example.get("func_after") cwe_id = example.get("cwe_id") cve_id = example.get("cve_id", "N/A") commit_msg = example.get("commit_msg", "") extension = example.get("extension", "") if not func_before or not func_after: return {"messages": None} lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", "py": "Python", "js": "JavaScript", "ts": "TypeScript", "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} language = lang_map.get(extension, extension if extension else "code") if isinstance(cwe_id, list) and len(cwe_id) > 0: cwe_display = ", ".join(cwe_id) elif cwe_id: cwe_display = str(cwe_id) else: cwe_display = "Unknown" cve_display = cve_id if cve_id else "N/A" explanation = commit_msg if commit_msg else f"Security vulnerability ({cwe_display}) detected and patched." assistant_response = f"""## SCAN RESULT VULNERABLE ## VULNERABILITY DETAILS - **CWE**: {cwe_display} - **CVE**: {cve_display} - **Severity**: HIGH ## EXPLANATION {explanation} ## FIX EXPLANATION {commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."} ## FIXED CODE ```{language.lower()} {func_after} ```""" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"}, {"role": "assistant", "content": assistant_response}, ] return {"messages": messages} print("\nFormatting datasets into instruction format...") megavul_formatted = megavul.map(format_megavul, remove_columns=megavul.column_names, num_proc=4) megavul_formatted = megavul_formatted.filter(lambda x: x["messages"] is not None) print(f"MegaVul formatted: {len(megavul_formatted)} samples") titanvul_formatted = titanvul.map(format_titanvul, remove_columns=titanvul.column_names, num_proc=4) titanvul_formatted = titanvul_formatted.filter(lambda x: x["messages"] is not None) print(f"TitanVul formatted: {len(titanvul_formatted)} samples") cleanvul_formatted = cleanvul.map(format_cleanvul, remove_columns=cleanvul.column_names, num_proc=4) cleanvul_formatted = cleanvul_formatted.filter(lambda x: x["messages"] is not None) print(f"CleanVul formatted: {len(cleanvul_formatted)} samples") # Combine all datasets combined = concatenate_datasets([megavul_formatted, titanvul_formatted, cleanvul_formatted]) combined = combined.shuffle(seed=42) print(f"\nTotal combined dataset: {len(combined)} samples") # ============================================================================ # Add SAFE code examples to reduce false positives # ============================================================================ print("\nCreating safe code examples (negative samples)...") def create_safe_sample(example): """Use fixed code as safe example to reduce false positives.""" func_after = example.get("func_after") extension = example.get("extension", "") if not func_after: return {"messages": None} lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java", "py": "Python", "js": "JavaScript", "ts": "TypeScript", "rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"} language = lang_map.get(extension, extension if extension else "code") assistant_response = """## SCAN RESULT SAFE ## VULNERABILITY DETAILS No security vulnerabilities detected in this code. ## EXPLANATION The code follows secure coding practices. No known vulnerability patterns (buffer overflows, injection flaws, authentication bypasses, race conditions, etc.) were identified in this code segment.""" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_after}\n```"}, {"role": "assistant", "content": assistant_response}, ] return {"messages": messages} # Take subset of TitanVul fixed code as safe examples (~15% of vulnerable samples) safe_subset = titanvul.shuffle(seed=123).select(range(min(12000, len(titanvul)))) safe_formatted = safe_subset.map(create_safe_sample, remove_columns=titanvul.column_names, num_proc=4) safe_formatted = safe_formatted.filter(lambda x: x["messages"] is not None) print(f"Safe (negative) samples: {len(safe_formatted)}") # Combine all data all_data = concatenate_datasets([combined, safe_formatted]).shuffle(seed=42) print(f"Total dataset (vuln + safe): {len(all_data)} samples") # Create train/eval split (95/5) split = all_data.train_test_split(test_size=0.05, seed=42) train_dataset = split["train"] eval_dataset = split["test"] print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") # ============================================================================ # Model + QLoRA Setup # ============================================================================ print("\n" + "=" * 60) print("Setting up model with QLoRA...") print("=" * 60) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) peft_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) # ============================================================================ # Training Configuration # ============================================================================ training_args = SFTConfig( output_dir=OUTPUT_DIR, hub_model_id=HUB_MODEL_ID, push_to_hub=True, # Training hyperparameters num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_steps=WARMUP_STEPS, # Precision & memory bf16=True, gradient_checkpointing=True, max_length=MAX_LENGTH, # Logging logging_steps=10, logging_first_step=True, disable_tqdm=True, report_to="none", # Evaluation eval_strategy="steps", eval_steps=500, # Saving save_strategy="steps", save_steps=500, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="eval_loss", # SFT specific packing=False, assistant_only_loss=True, # Model loading kwargs model_init_kwargs={ "quantization_config": bnb_config, "torch_dtype": torch.bfloat16, "attn_implementation": "eager", }, ) # ============================================================================ # Initialize Trainer # ============================================================================ print("Initializing SFTTrainer...") trainer = SFTTrainer( model=MODEL_NAME, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, ) print(f"Model parameters: {trainer.model.num_parameters():,}") print(f"Trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}") # ============================================================================ # Train! # ============================================================================ print("\n" + "=" * 60) print("Starting training...") print("=" * 60) trainer.train() # ============================================================================ # Save & Push to Hub # ============================================================================ print("\n" + "=" * 60) print("Saving model and pushing to Hub...") print("=" * 60) trainer.save_model() trainer.push_to_hub(commit_message="Train zero-day exploit scanner & fixer (QLoRA on Qwen2.5-Coder-7B)") print("\n" + "=" * 60) print(f"Model saved to Hub: https://huggingface.co/{HUB_MODEL_ID}") print("=" * 60) trackio.log({"status": "training_complete"})