Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from dataclasses import dataclass | |
| from accelerate import PartialState | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | |
| from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format | |
| from kto_dataset_processor import process_feel_dataset, SupportedLanguages | |
| from datetime import datetime | |
| import wandb | |
| from enum import Enum | |
| from typing import Optional | |
| from pathlib import Path | |
| # PEFT library: attach and load adapters | |
| from peft import get_peft_model, PeftModel | |
| #################################### | |
| # CONFIGURATION | |
| #################################### | |
| class ScriptArguments: | |
| """ | |
| Configuration for the script. | |
| """ | |
| process_dataset_func: callable = process_feel_dataset | |
| checkpoint_path: str = None | |
| push_to_hub: bool = True | |
| language: str = "English" # Default to English | |
| def __post_init__(self): | |
| """Validate the language after initialization""" | |
| try: | |
| # This will raise ValueError if language is not in the enum | |
| SupportedLanguages(self.language) | |
| except ValueError: | |
| supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages]) | |
| raise ValueError( | |
| f"Invalid language: '{self.language}'\n" | |
| f"Supported languages are:\n- {supported_langs}" | |
| ) | |
| class ModelArguments(ModelConfig): | |
| """ | |
| Configuration for the model. | |
| """ | |
| model_name: str = "CohereForAI/aya-expanse-8b" | |
| use_peft: bool = True | |
| lora_target_modules: str = "all-linear" | |
| lora_r: int = 16 | |
| lora_alpha: int = 16 | |
| trust_remote_code: bool = True | |
| class TrainingArguments(KTOConfig): | |
| """ | |
| Configuration for the KTO trainer. | |
| """ | |
| output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | |
| num_train_epochs: int = 1 | |
| per_device_train_batch_size: int = 4 | |
| learning_rate: float = 5e-7 | |
| lr_scheduler_type: str = "cosine" | |
| gradient_accumulation_steps: int = 1 | |
| logging_steps: int = 10 | |
| eval_steps: int = 500 | |
| warmup_ratio: float = 0.1 | |
| bf16: bool = True | |
| logging_first_step: bool = True | |
| # Initialize configurations | |
| script_args = ScriptArguments() | |
| training_args = TrainingArguments() | |
| model_args = ModelArguments() | |
| #################################### | |
| # HELPER FUNCTIONS | |
| #################################### | |
| def load_model_and_tokenizer(model_args): | |
| """ | |
| Load the base model and tokenizer from the Hugging Face Hub. | |
| """ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=model_args.trust_remote_code, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=model_args.trust_remote_code | |
| ) | |
| # Set pad token if it is missing | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Setup chat format if not available on the tokenizer | |
| if not getattr(tokenizer, "chat_template", None): | |
| model, tokenizer = setup_chat_format(model, tokenizer) | |
| return model, tokenizer | |
| def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path: | |
| """ | |
| Generate standardized adapter path. | |
| If timestamp is None, returns the base language directory. | |
| Otherwise, returns specific adapter version path. | |
| Format: adapters/{model_name}/{language}/version_{timestamp} | |
| """ | |
| # Clean model name (remove slashes, etc.) | |
| clean_model_name = model_name.replace('/', '_') | |
| base_path = Path("adapters") / clean_model_name / language | |
| if timestamp: | |
| return base_path / f"version_{timestamp}" | |
| return base_path | |
| def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]: | |
| """ | |
| Load the most recent adapter for given model and language. | |
| Returns: (loaded_model, timestamp of loaded adapter) | |
| """ | |
| adapter_base = get_adapter_path(model_name, language) | |
| if not adapter_base.exists(): | |
| return None, None | |
| # Get all version directories and sort by timestamp | |
| versions = sorted( | |
| [d for d in adapter_base.glob("version_*")], | |
| key=lambda x: x.name, | |
| reverse=True | |
| ) | |
| if not versions: | |
| return None, None | |
| latest_version = versions[0] | |
| timestamp = latest_version.name.replace("version_", "") | |
| model = PeftModel.from_pretrained(model, latest_version, is_trainable=True) | |
| return model, timestamp | |
| #################################### | |
| # MAIN LOGIC | |
| #################################### | |
| def main(): | |
| # Initialize wandb for logging | |
| wandb.init(project="kto") | |
| # Get timestamp at start of training | |
| training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
| print("Loading base model and tokenizer...") | |
| model, tokenizer = load_model_and_tokenizer(model_args) | |
| ref_model, _ = load_model_and_tokenizer(model_args) | |
| print("Models and tokenizer loaded.") | |
| # Load existing adapter or create new one | |
| loaded_model, previous_timestamp = load_latest_adapter( | |
| model, | |
| model_args.model_name, | |
| script_args.language | |
| ) | |
| if loaded_model is not None: | |
| model = loaded_model | |
| print(f"Loaded existing adapter trained at {previous_timestamp}") | |
| else: | |
| # Initialize new LoRA adapter | |
| peft_config = get_peft_config(model_args) | |
| model = get_peft_model(model, peft_config) | |
| print("Initialized new adapter") | |
| # ----------------------------- | |
| # Data Preparation and Training | |
| # ----------------------------- | |
| print("Processing dataset...") | |
| dataset = script_args.process_dataset_func(script_args.language) | |
| print("Dataset processed.") | |
| print("Initializing trainer...") | |
| trainer = KTOTrainer( | |
| model=model, | |
| ref_model=ref_model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| # Training | |
| print("Starting training...") | |
| trainer.train() | |
| print("Training completed.") | |
| # Evaluation | |
| print("Evaluating model...") | |
| metrics = trainer.evaluate() | |
| print(f"Metrics: {metrics}") | |
| trainer.log_metrics("eval", metrics) | |
| trainer.save_metrics("eval", metrics) | |
| # Log metrics to wandb | |
| wandb.log({ | |
| "epoch": metrics.get("epoch"), | |
| "grad_norm": metrics.get("grad_norm"), | |
| "kl": metrics.get("kl"), | |
| "learning_rate": metrics.get("learning_rate"), | |
| "logits/chosen": metrics.get("logits/chosen"), | |
| "logits/rejected": metrics.get("logits/rejected"), | |
| "logps/chosen": metrics.get("logps/chosen"), | |
| "logps/rejected": metrics.get("logps/rejected"), | |
| "loss": metrics.get("loss"), | |
| "rewards/chosen": metrics.get("rewards/chosen"), | |
| "rewards/margins": metrics.get("rewards/margins"), | |
| "rewards/rejected": metrics.get("rewards/rejected"), | |
| "step": metrics.get("step") | |
| }) | |
| # Save the adapter | |
| adapter_path = get_adapter_path( | |
| model_args.model_name, | |
| script_args.language, | |
| training_timestamp | |
| ) | |
| adapter_path.parent.mkdir(parents=True, exist_ok=True) | |
| print(f"Saving adapter to: {adapter_path}") | |
| model.save_pretrained(adapter_path) | |
| # Save metadata | |
| metadata = AdapterMetadata( | |
| training_timestamp=training_timestamp, | |
| model_name=model_args.model_name, | |
| language=script_args.language, | |
| ) | |
| metadata.save(adapter_path / "metadata.json") | |
| if script_args.push_to_hub: | |
| repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}" | |
| print(f"Pushing adapter to Hugging Face Hub at {repo_id}...") | |
| model.push_to_hub(repo_id=repo_id) | |
| print("Process completed.") | |
| # Finish wandb run | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |