jacobmahon's picture
Add training script for zero-day exploit scanner & fixer model
e664669 verified
"""
Zero-Day Exploit Scanner & Fixer Model
=======================================
Fine-tunes Qwen2.5-Coder-7B-Instruct on vulnerability detection + repair data
using SFT with QLoRA (4-bit quantization).
Datasets: MegaVul (C/C++ CVE pairs) + TitanVul (multi-lang) + CleanVul (filtered)
Base model: Qwen/Qwen2.5-Coder-7B-Instruct
Method: QLoRA SFT with structured instruction format
References:
- R2Vul (arxiv:2504.04699): structured reasoning for vuln detection
- MSIVD (arxiv:2406.05892): multi-task instruction tuning
- SecRepair (arxiv:2401.03374): combined detection + repair
- SecureCode: QLoRA r=16, alpha=32, lr=2e-4, 3 epochs
Usage:
pip install transformers trl torch datasets trackio accelerate peft bitsandbytes
python train.py
Hardware: Requires GPU with 24GB+ VRAM (A10G, A100, etc.)
Estimated time: ~6-8 hours on A10G for 3 epochs on ~90K samples
"""
import os
import torch
import trackio
from datasets import load_dataset, concatenate_datasets, Dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import BitsAndBytesConfig
# ============================================================================
# Configuration
# ============================================================================
MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
OUTPUT_DIR = "./vuln-scanner-fixer"
HUB_MODEL_ID = "jacobmahon/zero-day-exploit-scanner-fixer"
# Training hyperparameters (paper-backed from SecureCode + SecRepair + MSIVD)
LEARNING_RATE = 2e-4 # Higher LR for LoRA (SecureCode: 2e-4)
NUM_EPOCHS = 3 # SecureCode: 3 epochs
BATCH_SIZE = 2 # Per device
GRAD_ACCUM = 8 # Effective batch = 16
MAX_LENGTH = 2048 # Vuln functions can be long
WARMUP_STEPS = 100
LORA_R = 16 # SecureCode: rank=16
LORA_ALPHA = 32 # SecureCode: alpha=32
LORA_DROPOUT = 0.05
# ============================================================================
# Trackio Monitoring
# ============================================================================
trackio.init(
project="zero-day-exploit-scanner",
name="qwen2.5-coder-7b-vulnfix-qlora",
)
# ============================================================================
# Dataset Preparation
# ============================================================================
print("=" * 60)
print("Loading datasets...")
print("=" * 60)
# 1. MegaVul - C/C++ vulnerability pairs with CVE/CWE labels (17K)
megavul = load_dataset("hitoshura25/megavul", split="train")
print(f"MegaVul: {len(megavul)} samples")
# 2. TitanVul - Multi-language aggregated vulnerability pairs (38K)
titanvul = load_dataset("yikun-li/TitanVul", split="train")
print(f"TitanVul: {len(titanvul)} samples")
# 3. CleanVul - Filtered high-quality vulnerability pairs
cleanvul = load_dataset("yikun-li/CleanVul", split="train")
print(f"CleanVul raw: {len(cleanvul)} samples")
# Filter CleanVul: score >= 1 means likely real vulnerability fix
cleanvul = cleanvul.filter(lambda x: x["vulnerability_score"] is not None and x["vulnerability_score"] >= 1)
print(f"CleanVul filtered (score >= 1): {len(cleanvul)} samples")
# ============================================================================
# Format datasets into unified instruction format
# ============================================================================
SYSTEM_PROMPT = """You are a world-class security expert specializing in zero-day vulnerability detection and remediation. When given code, you will:
1. SCAN: Determine if the code contains a security vulnerability
2. IDENTIFY: If vulnerable, identify the CWE type and CVE ID if known
3. EXPLAIN: Provide a clear explanation of the vulnerability mechanism, attack vector, and potential impact
4. FIX: Provide the corrected code that patches the vulnerability
Always respond in the following structured format:
## SCAN RESULT
[VULNERABLE / SAFE]
## VULNERABILITY DETAILS
- **CWE**: [CWE ID and name]
- **CVE**: [CVE ID if known, otherwise "N/A"]
- **Severity**: [CRITICAL / HIGH / MEDIUM / LOW]
## EXPLANATION
[Detailed explanation of the vulnerability]
## VULNERABLE LINES
[Specific lines or patterns that are vulnerable]
## FIXED CODE
```
[Corrected code]
```
## FIX EXPLANATION
[What was changed and why]"""
def format_megavul(example):
"""Format MegaVul samples into instruction format."""
vuln_code = example.get("vulnerable_code")
fixed_code = example.get("fixed_code")
cwe_id = example.get("cwe_id", "Unknown")
cve_id = example.get("cve_id", "N/A")
cve_desc = example.get("cve_description", "")
commit_msg = example.get("commit_message", "")
language = example.get("language", "C")
if not vuln_code or not fixed_code:
return {"messages": None}
explanation = cve_desc if cve_desc else commit_msg
if not explanation:
explanation = f"Security vulnerability identified as {cwe_id}"
assistant_response = f"""## SCAN RESULT
VULNERABLE
## VULNERABILITY DETAILS
- **CWE**: {cwe_id}
- **CVE**: {cve_id}
- **Severity**: HIGH
## EXPLANATION
{explanation}
## FIX EXPLANATION
{commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."}
## FIXED CODE
```{language.lower()}
{fixed_code}
```"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{vuln_code}\n```"},
{"role": "assistant", "content": assistant_response},
]
return {"messages": messages}
def format_titanvul(example):
"""Format TitanVul samples into instruction format."""
func_before = example.get("func_before")
func_after = example.get("func_after")
cwe_id = example.get("cwe_id", "Unknown")
cve_id = example.get("cve_id", "N/A")
cve_desc = example.get("cve_description", "")
commit_msg = example.get("commit_message", "")
extension = example.get("extension", "")
if not func_before or not func_after:
return {"messages": None}
lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java",
"py": "Python", "js": "JavaScript", "ts": "TypeScript",
"rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"}
language = lang_map.get(extension, extension if extension else "code")
explanation = cve_desc if cve_desc else commit_msg
if not explanation:
explanation = f"Security vulnerability identified as {cwe_id}" if cwe_id else "Security vulnerability detected in code"
cwe_display = cwe_id if cwe_id else "Unknown"
cve_display = cve_id if cve_id else "N/A"
assistant_response = f"""## SCAN RESULT
VULNERABLE
## VULNERABILITY DETAILS
- **CWE**: {cwe_display}
- **CVE**: {cve_display}
- **Severity**: HIGH
## EXPLANATION
{explanation}
## FIX EXPLANATION
{commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."}
## FIXED CODE
```{language.lower()}
{func_after}
```"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"},
{"role": "assistant", "content": assistant_response},
]
return {"messages": messages}
def format_cleanvul(example):
"""Format CleanVul samples into instruction format."""
func_before = example.get("func_before")
func_after = example.get("func_after")
cwe_id = example.get("cwe_id")
cve_id = example.get("cve_id", "N/A")
commit_msg = example.get("commit_msg", "")
extension = example.get("extension", "")
if not func_before or not func_after:
return {"messages": None}
lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java",
"py": "Python", "js": "JavaScript", "ts": "TypeScript",
"rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"}
language = lang_map.get(extension, extension if extension else "code")
if isinstance(cwe_id, list) and len(cwe_id) > 0:
cwe_display = ", ".join(cwe_id)
elif cwe_id:
cwe_display = str(cwe_id)
else:
cwe_display = "Unknown"
cve_display = cve_id if cve_id else "N/A"
explanation = commit_msg if commit_msg else f"Security vulnerability ({cwe_display}) detected and patched."
assistant_response = f"""## SCAN RESULT
VULNERABLE
## VULNERABILITY DETAILS
- **CWE**: {cwe_display}
- **CVE**: {cve_display}
- **Severity**: HIGH
## EXPLANATION
{explanation}
## FIX EXPLANATION
{commit_msg if commit_msg else "Applied security patch to remediate the vulnerability."}
## FIXED CODE
```{language.lower()}
{func_after}
```"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_before}\n```"},
{"role": "assistant", "content": assistant_response},
]
return {"messages": messages}
print("\nFormatting datasets into instruction format...")
megavul_formatted = megavul.map(format_megavul, remove_columns=megavul.column_names, num_proc=4)
megavul_formatted = megavul_formatted.filter(lambda x: x["messages"] is not None)
print(f"MegaVul formatted: {len(megavul_formatted)} samples")
titanvul_formatted = titanvul.map(format_titanvul, remove_columns=titanvul.column_names, num_proc=4)
titanvul_formatted = titanvul_formatted.filter(lambda x: x["messages"] is not None)
print(f"TitanVul formatted: {len(titanvul_formatted)} samples")
cleanvul_formatted = cleanvul.map(format_cleanvul, remove_columns=cleanvul.column_names, num_proc=4)
cleanvul_formatted = cleanvul_formatted.filter(lambda x: x["messages"] is not None)
print(f"CleanVul formatted: {len(cleanvul_formatted)} samples")
# Combine all datasets
combined = concatenate_datasets([megavul_formatted, titanvul_formatted, cleanvul_formatted])
combined = combined.shuffle(seed=42)
print(f"\nTotal combined dataset: {len(combined)} samples")
# ============================================================================
# Add SAFE code examples to reduce false positives
# ============================================================================
print("\nCreating safe code examples (negative samples)...")
def create_safe_sample(example):
"""Use fixed code as safe example to reduce false positives."""
func_after = example.get("func_after")
extension = example.get("extension", "")
if not func_after:
return {"messages": None}
lang_map = {"c": "C", "cpp": "C++", "h": "C", "java": "Java",
"py": "Python", "js": "JavaScript", "ts": "TypeScript",
"rb": "Ruby", "go": "Go", "rs": "Rust", "php": "PHP"}
language = lang_map.get(extension, extension if extension else "code")
assistant_response = """## SCAN RESULT
SAFE
## VULNERABILITY DETAILS
No security vulnerabilities detected in this code.
## EXPLANATION
The code follows secure coding practices. No known vulnerability patterns (buffer overflows, injection flaws, authentication bypasses, race conditions, etc.) were identified in this code segment."""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{func_after}\n```"},
{"role": "assistant", "content": assistant_response},
]
return {"messages": messages}
# Take subset of TitanVul fixed code as safe examples (~15% of vulnerable samples)
safe_subset = titanvul.shuffle(seed=123).select(range(min(12000, len(titanvul))))
safe_formatted = safe_subset.map(create_safe_sample, remove_columns=titanvul.column_names, num_proc=4)
safe_formatted = safe_formatted.filter(lambda x: x["messages"] is not None)
print(f"Safe (negative) samples: {len(safe_formatted)}")
# Combine all data
all_data = concatenate_datasets([combined, safe_formatted]).shuffle(seed=42)
print(f"Total dataset (vuln + safe): {len(all_data)} samples")
# Create train/eval split (95/5)
split = all_data.train_test_split(test_size=0.05, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# ============================================================================
# Model + QLoRA Setup
# ============================================================================
print("\n" + "=" * 60)
print("Setting up model with QLoRA...")
print("=" * 60)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
peft_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# ============================================================================
# Training Configuration
# ============================================================================
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
# Training hyperparameters
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_steps=WARMUP_STEPS,
# Precision & memory
bf16=True,
gradient_checkpointing=True,
max_length=MAX_LENGTH,
# Logging
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
report_to="none",
# Evaluation
eval_strategy="steps",
eval_steps=500,
# Saving
save_strategy="steps",
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
# SFT specific
packing=False,
assistant_only_loss=True,
# Model loading kwargs
model_init_kwargs={
"quantization_config": bnb_config,
"torch_dtype": torch.bfloat16,
"attn_implementation": "eager",
},
)
# ============================================================================
# Initialize Trainer
# ============================================================================
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=MODEL_NAME,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
print(f"Model parameters: {trainer.model.num_parameters():,}")
print(f"Trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}")
# ============================================================================
# Train!
# ============================================================================
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
trainer.train()
# ============================================================================
# Save & Push to Hub
# ============================================================================
print("\n" + "=" * 60)
print("Saving model and pushing to Hub...")
print("=" * 60)
trainer.save_model()
trainer.push_to_hub(commit_message="Train zero-day exploit scanner & fixer (QLoRA on Qwen2.5-Coder-7B)")
print("\n" + "=" * 60)
print(f"Model saved to Hub: https://huggingface.co/{HUB_MODEL_ID}")
print("=" * 60)
trackio.log({"status": "training_complete"})