diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..6aad062a9ead2d88a859379d5932cd3dc78a776b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +llamafactory/extras/__pycache__/constants.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index c4747101b23ead88704837d275618fe368198a4d..6cf72efb35b793371fac42a092678981ab5317a2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,6 @@ --- -title: App -emoji: 🏢 -colorFrom: blue -colorTo: purple +title: app +app_file: webui.py sdk: gradio -sdk_version: 5.49.1 -app_file: app.py -pinned: false +sdk_version: 5.45.0 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..61215459ed91c6fa529a719cb9dac57223754d2e --- /dev/null +++ b/api.py @@ -0,0 +1,33 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import uvicorn + +from llamafactory.api.app import create_app +from llamafactory.chat import ChatModel + + +def main(): + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.getenv("API_HOST", "0.0.0.0") + api_port = int(os.getenv("API_PORT", "8000")) + print(f"Visit http://localhost:{api_port}/docs for API document.") + uvicorn.run(app, host=api_host, port=api_port) + + +if __name__ == "__main__": + main() diff --git a/llamafactory.egg-info/PKG-INFO b/llamafactory.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..b0a26490c345a22dc3e831d6325c1d3de71902c4 --- /dev/null +++ b/llamafactory.egg-info/PKG-INFO @@ -0,0 +1,1124 @@ +Metadata-Version: 2.4 +Name: llamafactory +Version: 0.9.4.dev0 +Summary: Unified Efficient Fine-Tuning of 100+ LLMs +Home-page: https://github.com/hiyouga/LLaMA-Factory +Author: hiyouga +Author-email: hiyouga@buaa.edu.cn +License: Apache 2.0 License +Keywords: AI,LLM,GPT,ChatGPT,Llama,Transformer,DeepSeek,Pytorch +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Requires-Python: >=3.9.0 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: transformers!=4.52.0,<=4.56.2,>=4.49.0; python_version < "3.10" +Requires-Dist: transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0; python_version >= "3.10" +Requires-Dist: datasets<=4.0.0,>=2.16.0 +Requires-Dist: accelerate<=1.11.0,>=1.3.0 +Requires-Dist: peft<=0.17.1,>=0.14.0 +Requires-Dist: trl<=0.9.6,>=0.8.6 +Requires-Dist: gradio<=5.45.0,>=4.38.0 +Requires-Dist: matplotlib>=3.7.0 +Requires-Dist: tyro<0.9.0 +Requires-Dist: einops +Requires-Dist: numpy<2.0.0 +Requires-Dist: pandas>=2.0.0 +Requires-Dist: scipy +Requires-Dist: sentencepiece +Requires-Dist: tiktoken +Requires-Dist: modelscope>=1.14.0 +Requires-Dist: hf-transfer +Requires-Dist: safetensors<=0.5.3 +Requires-Dist: fire +Requires-Dist: omegaconf +Requires-Dist: packaging +Requires-Dist: protobuf +Requires-Dist: pyyaml +Requires-Dist: pydantic<=2.10.6 +Requires-Dist: uvicorn +Requires-Dist: fastapi +Requires-Dist: sse-starlette +Requires-Dist: av +Requires-Dist: librosa +Requires-Dist: propcache!=0.4.0 +Provides-Extra: torch +Requires-Dist: torch>=2.0.0; extra == "torch" +Requires-Dist: torchvision>=0.15.0; extra == "torch" +Provides-Extra: torch-npu +Requires-Dist: torch==2.7.1; extra == "torch-npu" +Requires-Dist: torch-npu==2.7.1; extra == "torch-npu" +Requires-Dist: torchvision==0.22.1; extra == "torch-npu" +Requires-Dist: decorator; extra == "torch-npu" +Provides-Extra: metrics +Requires-Dist: nltk; extra == "metrics" +Requires-Dist: jieba; extra == "metrics" +Requires-Dist: rouge-chinese; extra == "metrics" +Provides-Extra: deepspeed +Requires-Dist: deepspeed<=0.16.9,>=0.10.0; extra == "deepspeed" +Provides-Extra: liger-kernel +Requires-Dist: liger-kernel>=0.5.5; extra == "liger-kernel" +Provides-Extra: bitsandbytes +Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes" +Provides-Extra: hqq +Requires-Dist: hqq; extra == "hqq" +Provides-Extra: eetq +Requires-Dist: eetq; extra == "eetq" +Provides-Extra: gptq +Requires-Dist: optimum>=1.24.0; extra == "gptq" +Requires-Dist: gptqmodel>=2.0.0; extra == "gptq" +Provides-Extra: aqlm +Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm" +Provides-Extra: vllm +Requires-Dist: vllm<=0.11.0,>=0.4.3; extra == "vllm" +Provides-Extra: sglang +Requires-Dist: sglang[srt]>=0.4.5; extra == "sglang" +Requires-Dist: transformers==4.51.1; extra == "sglang" +Provides-Extra: galore +Requires-Dist: galore-torch; extra == "galore" +Provides-Extra: apollo +Requires-Dist: apollo-torch; extra == "apollo" +Provides-Extra: badam +Requires-Dist: badam>=1.2.1; extra == "badam" +Provides-Extra: adam-mini +Requires-Dist: adam-mini; extra == "adam-mini" +Provides-Extra: minicpm-v +Requires-Dist: soundfile; extra == "minicpm-v" +Requires-Dist: torchvision; extra == "minicpm-v" +Requires-Dist: torchaudio; extra == "minicpm-v" +Requires-Dist: vector_quantize_pytorch; extra == "minicpm-v" +Requires-Dist: vocos; extra == "minicpm-v" +Requires-Dist: msgpack; extra == "minicpm-v" +Requires-Dist: referencing; extra == "minicpm-v" +Requires-Dist: jsonschema_specifications; extra == "minicpm-v" +Provides-Extra: openmind +Requires-Dist: openmind; extra == "openmind" +Provides-Extra: swanlab +Requires-Dist: swanlab; extra == "swanlab" +Provides-Extra: fp8 +Requires-Dist: torchao>=0.8.0; extra == "fp8" +Requires-Dist: accelerate>=1.10.0; extra == "fp8" +Provides-Extra: fp8-te +Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-te" +Requires-Dist: accelerate>=1.10.0; extra == "fp8-te" +Provides-Extra: fp8-all +Requires-Dist: torchao>=0.8.0; extra == "fp8-all" +Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-all" +Requires-Dist: accelerate>=1.10.0; extra == "fp8-all" +Provides-Extra: dev +Requires-Dist: pre-commit; extra == "dev" +Requires-Dist: ruff; extra == "dev" +Requires-Dist: pytest; extra == "dev" +Requires-Dist: build; extra == "dev" +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: license-file +Dynamic: provides-extra +Dynamic: requires-dist +Dynamic: requires-python +Dynamic: summary + +![# LLaMA Factory](assets/logo.png) + +[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) +[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) +[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) +[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) +[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) +[![Citation](https://img.shields.io/badge/citation-1000+-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags) + +[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) +[![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK) +[![WeChat](https://img.shields.io/badge/WeChat-User%20Group-blue?logo=wechat)](https://github.com/hiyouga/llamafactory-community) +[![Blog](https://img.shields.io/badge/Hugo-Official%20Blog-blue?logo=hugo)](https://blog.llamafactory.net/en/) + +[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) +[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) +[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory) +[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory) +[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) +[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) +[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47) + +### Used by [Amazon](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/), [NVIDIA](https://developer.nvidia.com/rtx/ai-toolkit), [Aliyun](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory), etc. + +
+ +### Supporters ❤️ + +|
Warp sponsorship
Warp, the agentic terminal for developers
Available for MacOS, Linux, & Windows | SerpAPI sponsorship | +| ---- | ---- | + +---- + +### Easily fine-tune 100+ large language models with zero-code [CLI](#quickstart) and [Web UI](#fine-tuning-with-llama-board-gui-powered-by-gradio) + +![GitHub Trend](https://trendshift.io/api/badge/repositories/4535) + +
+ +👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group. + +\[ English | [中文](README_zh.md) \] + +**Fine-tuning a large language model can be easy as...** + +https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e + +Start local training: +- Please refer to [usage](#getting-started) + +Start cloud training: +- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing +- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory +- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory +- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory + +Read technical notes: +- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/ +- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html +- **Official Blog**: https://blog.llamafactory.net/en/ +- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory + +> [!NOTE] +> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them. + +## Table of Contents + +- [Features](#features) +- [Blogs](#blogs) +- [Changelog](#changelog) +- [Supported Models](#supported-models) +- [Supported Training Approaches](#supported-training-approaches) +- [Provided Datasets](#provided-datasets) +- [Requirement](#requirement) +- [Getting Started](#getting-started) + - [Installation](#installation) + - [Data Preparation](#data-preparation) + - [Quickstart](#quickstart) + - [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio) + - [LLaMA Factory Online](#llama-factory-online) + - [Build Docker](#build-docker) + - [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm) + - [Download from ModelScope Hub](#download-from-modelscope-hub) + - [Download from Modelers Hub](#download-from-modelers-hub) + - [Use W&B Logger](#use-wb-logger) + - [Use SwanLab Logger](#use-swanlab-logger) +- [Projects using LLaMA Factory](#projects-using-llama-factory) +- [License](#license) +- [Citation](#citation) +- [Acknowledgement](#acknowledgement) + +## Features + +- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. +- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. +- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. +- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. +- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. +- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. +- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. +- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang). + +### Day-N Support for Fine-Tuning Cutting-Edge Models + +| Support Date | Model Name | +| ------------ | -------------------------------------------------------------------- | +| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 | +| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 | + +## Blogs + +> [!TIP] +> Now we have a dedicated blog for LLaMA Factory! +> +> Website: https://blog.llamafactory.net/en/ + +- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English) +- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese) +- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese) +- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese) +- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English) + +
All Blogs + +- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese) +- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese) +- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese) +- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese) +- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese) +- [LLaMA Factory: Fine-tuning Llama3 for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese) + +
+ +## Changelog + +[25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started. + +[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage. + +[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started. + +[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started. + +
Full Changelog + +[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. + +[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family. + +[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR. + +[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started. + +[25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models. + +[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started. + +[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. + +[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. + +[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model. + +[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. + +[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. + +[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. + +[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** models. + +[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. + +[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. + +[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. + +[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. + +[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. + +[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. + +[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. + +[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. + +[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. + +[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training. + +[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. + +[24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR. + +[24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. + +[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. + +[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. + +[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion. + +[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. + +[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. + +[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. + +[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. + +[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. + +[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage. + +[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). + +[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage. + +[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! + +[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage. + +[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage. + +[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage. + +[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed. + +[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training. + +[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage. + +[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. + +[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`. + +[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. + +[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). + +[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage. + +[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune. + +[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention. + +[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage. + +[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. + +[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings. + +[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage. + +[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. + +[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details. + +[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. + +[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. + +[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details. + +[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**. + +[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage. + +
+ +> [!TIP] +> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again. + +## Supported Models + +| Model | Model size | Template | +| ----------------------------------------------------------------- | -------------------------------- | -------------------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | +| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | +| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | +| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n | +| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 | +| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v | +| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe | +| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | +| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt | +| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 | +| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan | +| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | +| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | +| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | +| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | +| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | +| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 | +| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | +| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | +| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo | +| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | +| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | +| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | +| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | +| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | +| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | +| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | +| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | +| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | +| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | +| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | +| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl | +| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder | +| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | + +> [!NOTE] +> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. +> +> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`. +> +> Remember to use the **SAME** template in training and inference. +> +> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check. +> +> \*\*: You need to install a specific version of `transformers` to use the corresponding model. + +Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported. + +You also can add a custom chat template to [template.py](src/llamafactory/data/template.py). + +## Supported Training Approaches + +| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | +| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | + +> [!TIP] +> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html). + +## Provided Datasets + +
Pre-training datasets + +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb) +- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) +- [CCI3-HQ (zh)](https://huggingface.co/datasets/BAAI/CCI3-HQ) +- [CCI3-Data (zh)](https://huggingface.co/datasets/BAAI/CCI3-Data) +- [CCI4.0-M2-Base-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Base-v1) +- [CCI4.0-M2-CoT-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-CoT-v1) +- [CCI4.0-M2-Extra-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Extra-v1) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
Supervised fine-tuning datasets + +- [Identity (en&zh)](data/identity.json) +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3) +- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) +- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) +- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [Infinity Instruct (zh)](https://huggingface.co/datasets/BAAI/Infinity-Instruct) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) +- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) +- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) +- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) +- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) +- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1) +- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) +- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT) +- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) +- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) +- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT) +- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) +- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions) +- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) +- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) +- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) +- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de) +- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de) +- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de) +- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de) +- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de) +- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de) + +
+ +
Preference datasets + +- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) +- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) +- [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P) +- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) +- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback) +- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) +- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) +- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k) + +
+ +Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands. + +```bash +pip install "huggingface_hub<1.0.0" +huggingface-cli login +``` + +## Requirement + +| Mandatory | Minimum | Recommend | +| ------------ | ------- | --------- | +| python | 3.9 | 3.10 | +| torch | 2.0.0 | 2.6.0 | +| torchvision | 0.15.0 | 0.21.0 | +| transformers | 4.49.0 | 4.50.0 | +| datasets | 2.16.0 | 3.2.0 | +| accelerate | 0.34.0 | 1.2.1 | +| peft | 0.14.0 | 0.15.1 | +| trl | 0.8.6 | 0.9.6 | + +| Optional | Minimum | Recommend | +| ------------ | ------- | --------- | +| CUDA | 11.6 | 12.2 | +| deepspeed | 0.10.0 | 0.16.4 | +| bitsandbytes | 0.39.0 | 0.43.1 | +| vllm | 0.4.3 | 0.8.2 | +| flash-attn | 2.5.6 | 2.7.2 | + +### Hardware Requirement + +\* *estimated* + +| Method | Bits | 7B | 14B | 30B | 70B | `x`B | +| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | +| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | +| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | +| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | +| QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | +| QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | +| QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | + +## Getting Started + +### Installation + +> [!IMPORTANT] +> Installation is mandatory. + +#### Install from Source + +```bash +git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git +cd LLaMA-Factory +pip install -e ".[torch,metrics]" --no-build-isolation +``` + +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev + +#### Install from Docker Image + +```bash +docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest +``` + +This image is built on Ubuntu 22.04 (x86\_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4. + +Find the pre-built images: https://hub.docker.com/r/hiyouga/llamafactory/tags + +Please refer to [build docker](#build-docker) to build the image yourself. + +
Setting up a virtual environment with uv + +Create an isolated Python environment with [uv](https://github.com/astral-sh/uv): + +```bash +uv sync --extra torch --extra metrics --prerelease=allow +``` + +Run LLaMA-Factory in the isolated environment: + +```bash +uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml +``` + +
+ +
For Windows users + +#### Install PyTorch + +You need to manually install the GPU version of PyTorch on the Windows platform. Please refer to the [official website](https://pytorch.org/get-started/locally/) and the following command to install PyTorch with CUDA support: + +```bash +pip uninstall torch torchvision torchaudio +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 +python -c "import torch; print(torch.cuda.is_available())" +``` + +If you see `True` then you have successfully installed PyTorch with CUDA support. + +Try `dataloader_num_workers: 0` if you encounter `Can't pickle local object` error. + +#### Install BitsAndBytes + +If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. + +```bash +pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl +``` + +#### Install Flash Attention-2 + +To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself. + +
+ +
For Ascend NPU users + +To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: + +```bash +# replace the url according to your CANN version and devices +# install CANN Toolkit +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run +bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install + +# install CANN Kernels +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run +bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install + +# set env variables +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +| Requirement | Minimum | Recommend | +| ------------ | ------- | -------------- | +| CANN | 8.0.RC1 | 8.0.0.alpha002 | +| torch | 2.1.0 | 2.4.0 | +| torch-npu | 2.1.0 | 2.4.0.post2 | +| deepspeed | 0.13.2 | 0.13.2 | +| vllm-ascend | - | 0.7.3 | + +Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. + +If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations. + +Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) + +#### Install BitsAndBytes + +To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps: + +1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x. + +```bash +# Install bitsandbytes from source +# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch +git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git +cd bitsandbytes/ + +# Install dependencies +pip install -r requirements-dev.txt + +# Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference +apt-get install -y build-essential cmake + +# Compile & install +cmake -DCOMPUTE_BACKEND=npu -S . +make +pip install . +``` + +2. Install transformers from the main branch. + +```bash +git clone -b main https://github.com/huggingface/transformers.git +cd transformers +pip install . +``` + +3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml). + +
+ +### Data Preparation + +Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace / ModelScope / Modelers hub, load the dataset in local disk, or specify a path to s3/gcs cloud storage. + +> [!NOTE] +> Please update `data/dataset_info.json` to use your custom dataset. + +You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, **[DataFlow](https://github.com/OpenDCAI/DataFlow)** and **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning. + +### Quickstart + +Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +``` + +See [examples/README.md](examples/README.md) for advanced usage (including distributed training). + +> [!TIP] +> Use `llamafactory-cli help` to show help information. +> +> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems. + +### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) + +```bash +llamafactory-cli webui +``` + +### LLaMA Factory Online + +Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory). + +### Build Docker + +For CUDA users: + +```bash +cd docker/docker-cuda/ +docker compose up -d +docker compose exec llamafactory bash +``` + +For Ascend NPU users: + +```bash +cd docker/docker-npu/ +docker compose up -d +docker compose exec llamafactory bash +``` + +For AMD ROCm users: + +```bash +cd docker/docker-rocm/ +docker compose up -d +docker compose exec llamafactory bash +``` + +
Build without Docker Compose + +For CUDA users: + +```bash +docker build -f ./docker/docker-cuda/Dockerfile \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + --build-arg EXTRAS=metrics \ + -t llamafactory:latest . + +docker run -dit --ipc=host --gpus=all \ + -p 7860:7860 \ + -p 8000:8000 \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash +``` + +For Ascend NPU users: + +```bash +docker build -f ./docker/docker-npu/Dockerfile \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + --build-arg EXTRAS=torch-npu,metrics \ + -t llamafactory:latest . + +docker run -dit --ipc=host \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -p 7860:7860 \ + -p 8000:8000 \ + --device /dev/davinci0 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash +``` + +For AMD ROCm users: + +```bash +docker build -f ./docker/docker-rocm/Dockerfile \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + --build-arg EXTRAS=metrics \ + -t llamafactory:latest . + +docker run -dit --ipc=host \ + -p 7860:7860 \ + -p 8000:8000 \ + --device /dev/kfd \ + --device /dev/dri \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash +``` + +
+ +
Use Docker volumes + +You can uncomment `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` in the Dockerfile to use data volumes. + +When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` argument to mount the local directory to the container. The following data volumes are available. + +- `hf_cache`: Utilize Hugging Face cache on the host machine. +- `shared_data`: The directionary to store datasets on the host machine. +- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine. + +
+ +### Deploy with OpenAI-style API and vLLM + +```bash +API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true +``` + +> [!TIP] +> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document. +> +> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py) + +### Download from ModelScope Hub + +If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. + +```bash +export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows +``` + +Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. + +### Download from Modelers Hub + +You can also use Modelers Hub to download models and datasets. + +```bash +export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows +``` + +Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`. + +### Use W&B Logger + +To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. + +```yaml +report_to: wandb +run_name: test_run # optional +``` + +Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account. + +### Use SwanLab Logger + +To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files. + +```yaml +use_swanlab: true +swanlab_run_name: test_run # optional +``` + +When launching training tasks, you can log in to SwanLab in three ways: + +1. Add `swanlab_api_key=` to the yaml file, and set it to your [API key](https://swanlab.cn/settings). +2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings). +3. Use the `swanlab login` command to complete the login. + +## Projects using LLaMA Factory + +If you have a project that should be incorporated, please contact via email or create a pull request. + +
Click to show + +1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) +1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) +1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) +1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) +1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) +1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) +1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) +1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904) +1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625) +1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176) +1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187) +1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746) +1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801) +1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809) +1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819) +1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) +1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) +1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) +1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) +1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) +1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073) +1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541) +1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) +1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) +1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443) +1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604) +1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827) +1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167) +1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) +1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084) +1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836) +1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581) +1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215) +1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621) +1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140) +1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760) +1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378) +1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055) +1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739) +1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816) +1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215) +1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30) +1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380) +1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106) +1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136) +1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496) +1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688) +1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955) +1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973) +1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115) +1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815) +1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099) +1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173) +1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074) +1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408) +1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546) +1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) +1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) +1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) +1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) +1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949) +1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365) +1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470) +1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129) +1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044) +1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756) +1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/) +1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561) +1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637) +1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535) +1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705) +1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137) +1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf) +1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11) +1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23) +1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693) +1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168) +1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/) +1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072) +1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611) +1. Zhang et al. CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling. ACL 2024. [[paper]](https://aclanthology.org/2024.findings-acl.830.pdf) +1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. +1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. +1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. +1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. +1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. +1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) +1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. +1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. +1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. +1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. +1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357) +1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention. +1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost. +1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs. +1. **[EmoLLM](https://github.com/SmartFlowAI/EmoLLM)**: A project about large language models (LLMs) and mental health. +
+ +## License + +This repository is licensed under the [Apache-2.0 License](LICENSE). + +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) + +## Citation + +If this work is helpful, please kindly cite as: + +```bibtex +@inproceedings{zheng2024llamafactory, + title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, + author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma}, + booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, + address={Bangkok, Thailand}, + publisher={Association for Computational Linguistics}, + year={2024}, + url={http://arxiv.org/abs/2403.13372} +} +``` + +## Acknowledgement + +This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works. + +## Star History + +![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date) diff --git a/llamafactory.egg-info/SOURCES.txt b/llamafactory.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb4a46244940cfc2d52cfdfa74064bcb07bf2aa8 --- /dev/null +++ b/llamafactory.egg-info/SOURCES.txt @@ -0,0 +1,178 @@ +LICENSE +MANIFEST.in +README.md +pyproject.toml +requirements.txt +setup.py +src/llamafactory/__init__.py +src/llamafactory/cli.py +src/llamafactory/launcher.py +src/llamafactory.egg-info/PKG-INFO +src/llamafactory.egg-info/SOURCES.txt +src/llamafactory.egg-info/dependency_links.txt +src/llamafactory.egg-info/entry_points.txt +src/llamafactory.egg-info/requires.txt +src/llamafactory.egg-info/top_level.txt +src/llamafactory/api/__init__.py +src/llamafactory/api/app.py +src/llamafactory/api/chat.py +src/llamafactory/api/common.py +src/llamafactory/api/protocol.py +src/llamafactory/chat/__init__.py +src/llamafactory/chat/base_engine.py +src/llamafactory/chat/chat_model.py +src/llamafactory/chat/hf_engine.py +src/llamafactory/chat/kt_engine.py +src/llamafactory/chat/sglang_engine.py +src/llamafactory/chat/vllm_engine.py +src/llamafactory/data/__init__.py +src/llamafactory/data/collator.py +src/llamafactory/data/converter.py +src/llamafactory/data/data_utils.py +src/llamafactory/data/formatter.py +src/llamafactory/data/loader.py +src/llamafactory/data/mm_plugin.py +src/llamafactory/data/parser.py +src/llamafactory/data/template.py +src/llamafactory/data/tool_utils.py +src/llamafactory/data/processor/__init__.py +src/llamafactory/data/processor/feedback.py +src/llamafactory/data/processor/pairwise.py +src/llamafactory/data/processor/pretrain.py +src/llamafactory/data/processor/processor_utils.py +src/llamafactory/data/processor/supervised.py +src/llamafactory/data/processor/unsupervised.py +src/llamafactory/eval/__init__.py +src/llamafactory/eval/evaluator.py +src/llamafactory/eval/template.py +src/llamafactory/extras/__init__.py +src/llamafactory/extras/constants.py +src/llamafactory/extras/env.py +src/llamafactory/extras/logging.py +src/llamafactory/extras/misc.py +src/llamafactory/extras/packages.py +src/llamafactory/extras/ploting.py +src/llamafactory/hparams/__init__.py +src/llamafactory/hparams/data_args.py +src/llamafactory/hparams/evaluation_args.py +src/llamafactory/hparams/finetuning_args.py +src/llamafactory/hparams/generating_args.py +src/llamafactory/hparams/model_args.py +src/llamafactory/hparams/parser.py +src/llamafactory/hparams/training_args.py +src/llamafactory/model/__init__.py +src/llamafactory/model/adapter.py +src/llamafactory/model/loader.py +src/llamafactory/model/patcher.py +src/llamafactory/model/model_utils/__init__.py +src/llamafactory/model/model_utils/attention.py +src/llamafactory/model/model_utils/checkpointing.py +src/llamafactory/model/model_utils/embedding.py +src/llamafactory/model/model_utils/ktransformers.py +src/llamafactory/model/model_utils/kv_cache.py +src/llamafactory/model/model_utils/liger_kernel.py +src/llamafactory/model/model_utils/longlora.py +src/llamafactory/model/model_utils/misc.py +src/llamafactory/model/model_utils/mod.py +src/llamafactory/model/model_utils/moe.py +src/llamafactory/model/model_utils/packing.py +src/llamafactory/model/model_utils/quantization.py +src/llamafactory/model/model_utils/rope.py +src/llamafactory/model/model_utils/unsloth.py +src/llamafactory/model/model_utils/valuehead.py +src/llamafactory/model/model_utils/visual.py +src/llamafactory/third_party/__init__.py +src/llamafactory/third_party/muon/__init__.py +src/llamafactory/third_party/muon/muon.py +src/llamafactory/train/__init__.py +src/llamafactory/train/callbacks.py +src/llamafactory/train/fp8_utils.py +src/llamafactory/train/test_utils.py +src/llamafactory/train/trainer_utils.py +src/llamafactory/train/tuner.py +src/llamafactory/train/dpo/__init__.py +src/llamafactory/train/dpo/trainer.py +src/llamafactory/train/dpo/workflow.py +src/llamafactory/train/ksft/__init__.py +src/llamafactory/train/ksft/workflow.py +src/llamafactory/train/kto/__init__.py +src/llamafactory/train/kto/trainer.py +src/llamafactory/train/kto/workflow.py +src/llamafactory/train/mca/__init__.py +src/llamafactory/train/mca/trainer.py +src/llamafactory/train/mca/workflow.py +src/llamafactory/train/ppo/__init__.py +src/llamafactory/train/ppo/ppo_utils.py +src/llamafactory/train/ppo/trainer.py +src/llamafactory/train/ppo/workflow.py +src/llamafactory/train/pt/__init__.py +src/llamafactory/train/pt/trainer.py +src/llamafactory/train/pt/workflow.py +src/llamafactory/train/rm/__init__.py +src/llamafactory/train/rm/metric.py +src/llamafactory/train/rm/trainer.py +src/llamafactory/train/rm/workflow.py +src/llamafactory/train/sft/__init__.py +src/llamafactory/train/sft/metric.py +src/llamafactory/train/sft/trainer.py +src/llamafactory/train/sft/workflow.py +src/llamafactory/v1/__init__.py +src/llamafactory/v1/launcher.py +src/llamafactory/v1/config/__init__.py +src/llamafactory/v1/config/data_args.py +src/llamafactory/v1/config/model_args.py +src/llamafactory/v1/config/parser.py +src/llamafactory/v1/config/sample_args.py +src/llamafactory/v1/config/training_args.py +src/llamafactory/v1/core/__init__.py +src/llamafactory/v1/core/base_trainer.py +src/llamafactory/v1/core/chat_sampler.py +src/llamafactory/v1/core/data_engine.py +src/llamafactory/v1/core/model_engine.py +src/llamafactory/v1/plugins/__init__.py +src/llamafactory/v1/plugins/data_plugins/__init__.py +src/llamafactory/v1/plugins/data_plugins/converter.py +src/llamafactory/v1/plugins/data_plugins/loader.py +src/llamafactory/v1/plugins/data_plugins/template.py +src/llamafactory/v1/plugins/model_plugins/__init__.py +src/llamafactory/v1/plugins/model_plugins/added_token.py +src/llamafactory/v1/plugins/model_plugins/peft.py +src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py +src/llamafactory/v1/plugins/model_plugins/kernels/constants.py +src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py +src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py +src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py +src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py +src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +src/llamafactory/v1/plugins/sampler_plugins/__init__.py +src/llamafactory/v1/plugins/sampler_plugins/vllm.py +src/llamafactory/v1/plugins/trainer_plugins/__init__.py +src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py +src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py +src/llamafactory/v1/trainers/__init__.py +src/llamafactory/v1/trainers/dpo_trainer.py +src/llamafactory/v1/trainers/rm_trainer.py +src/llamafactory/v1/trainers/sft_trainer.py +src/llamafactory/webui/__init__.py +src/llamafactory/webui/chatter.py +src/llamafactory/webui/common.py +src/llamafactory/webui/control.py +src/llamafactory/webui/css.py +src/llamafactory/webui/engine.py +src/llamafactory/webui/interface.py +src/llamafactory/webui/locales.py +src/llamafactory/webui/manager.py +src/llamafactory/webui/runner.py +src/llamafactory/webui/components/__init__.py +src/llamafactory/webui/components/chatbot.py +src/llamafactory/webui/components/data.py +src/llamafactory/webui/components/eval.py +src/llamafactory/webui/components/export.py +src/llamafactory/webui/components/footer.py +src/llamafactory/webui/components/infer.py +src/llamafactory/webui/components/top.py +src/llamafactory/webui/components/train.py \ No newline at end of file diff --git a/llamafactory.egg-info/dependency_links.txt b/llamafactory.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/llamafactory.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/llamafactory.egg-info/entry_points.txt b/llamafactory.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..bcafd0771096e3e65f55808b2ead2909c6f6691b --- /dev/null +++ b/llamafactory.egg-info/entry_points.txt @@ -0,0 +1,3 @@ +[console_scripts] +llamafactory-cli = llamafactory.cli:main +lmf = llamafactory.cli:main diff --git a/llamafactory.egg-info/requires.txt b/llamafactory.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..2fb93071d9a6df49004f3c31adc7c4e77bee4cd8 --- /dev/null +++ b/llamafactory.egg-info/requires.txt @@ -0,0 +1,125 @@ +datasets<=4.0.0,>=2.16.0 +accelerate<=1.11.0,>=1.3.0 +peft<=0.17.1,>=0.14.0 +trl<=0.9.6,>=0.8.6 +gradio<=5.45.0,>=4.38.0 +matplotlib>=3.7.0 +tyro<0.9.0 +einops +numpy<2.0.0 +pandas>=2.0.0 +scipy +sentencepiece +tiktoken +modelscope>=1.14.0 +hf-transfer +safetensors<=0.5.3 +fire +omegaconf +packaging +protobuf +pyyaml +pydantic<=2.10.6 +uvicorn +fastapi +sse-starlette +av +librosa +propcache!=0.4.0 + +[:python_version < "3.10"] +transformers!=4.52.0,<=4.56.2,>=4.49.0 + +[:python_version >= "3.10"] +transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0 + +[adam-mini] +adam-mini + +[apollo] +apollo-torch + +[aqlm] +aqlm[gpu]>=1.1.0 + +[badam] +badam>=1.2.1 + +[bitsandbytes] +bitsandbytes>=0.39.0 + +[deepspeed] +deepspeed<=0.16.9,>=0.10.0 + +[dev] +pre-commit +ruff +pytest +build + +[eetq] +eetq + +[fp8] +torchao>=0.8.0 +accelerate>=1.10.0 + +[fp8-all] +torchao>=0.8.0 +transformer_engine[pytorch]>=2.0.0 +accelerate>=1.10.0 + +[fp8-te] +transformer_engine[pytorch]>=2.0.0 +accelerate>=1.10.0 + +[galore] +galore-torch + +[gptq] +optimum>=1.24.0 +gptqmodel>=2.0.0 + +[hqq] +hqq + +[liger-kernel] +liger-kernel>=0.5.5 + +[metrics] +nltk +jieba +rouge-chinese + +[minicpm_v] +soundfile +torchvision +torchaudio +vector_quantize_pytorch +vocos +msgpack +referencing +jsonschema_specifications + +[openmind] +openmind + +[sglang] +sglang[srt]>=0.4.5 +transformers==4.51.1 + +[swanlab] +swanlab + +[torch] +torch>=2.0.0 +torchvision>=0.15.0 + +[torch-npu] +torch==2.7.1 +torch-npu==2.7.1 +torchvision==0.22.1 +decorator + +[vllm] +vllm<=0.11.0,>=0.4.3 diff --git a/llamafactory.egg-info/top_level.txt b/llamafactory.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6670a28d28bf8b6497d5cca0e63d1e13c1aee55 --- /dev/null +++ b/llamafactory.egg-info/top_level.txt @@ -0,0 +1 @@ +llamafactory diff --git a/llamafactory/__init__.py b/llamafactory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1567ef572714881cc464db25d3da3d08a460963 --- /dev/null +++ b/llamafactory/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Efficient fine-tuning of large language models. + +Level: + api, webui > chat, eval, train > data, model > hparams > extras + +Disable version checking: DISABLE_VERSION_CHECK=1 +Enable VRAM recording: RECORD_VRAM=1 +Force using torchrun: FORCE_TORCHRUN=1 +Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN +Use modelscope: USE_MODELSCOPE_HUB=1 +Use openmind: USE_OPENMIND_HUB=1 +""" + +from .extras.env import VERSION + + +__version__ = VERSION diff --git a/llamafactory/__pycache__/__init__.cpython-312.pyc b/llamafactory/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae97776ebf1e10b36d5f470d1c0de9c4f5b17797 Binary files /dev/null and b/llamafactory/__pycache__/__init__.cpython-312.pyc differ diff --git a/llamafactory/__pycache__/cli.cpython-312.pyc b/llamafactory/__pycache__/cli.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95db2f247b7b2fe6595a5f7cc774ce0f6f472c11 Binary files /dev/null and b/llamafactory/__pycache__/cli.cpython-312.pyc differ diff --git a/llamafactory/__pycache__/launcher.cpython-312.pyc b/llamafactory/__pycache__/launcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28aab819e043e9a2500c50c92f460fa7d87533cd Binary files /dev/null and b/llamafactory/__pycache__/launcher.cpython-312.pyc differ diff --git a/llamafactory/api/__init__.py b/llamafactory/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llamafactory/api/app.py b/llamafactory/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e0621d80b064f00970e8fb58909ec8656ba0fb6b --- /dev/null +++ b/llamafactory/api/app.py @@ -0,0 +1,133 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from contextlib import asynccontextmanager +from functools import partial +from typing import Annotated, Optional + +from ..chat import ChatModel +from ..extras.constants import EngineName +from ..extras.misc import torch_gc +from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available +from .chat import ( + create_chat_completion_response, + create_score_evaluation_response, + create_stream_chat_completion_response, +) +from .protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ModelCard, + ModelList, + ScoreEvaluationRequest, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.middleware.cors import CORSMiddleware + from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + + +if is_starlette_available(): + from sse_starlette import EventSourceResponse + + +if is_uvicorn_available(): + import uvicorn + + +async def sweeper() -> None: + while True: + torch_gc() + await asyncio.sleep(300) + + +@asynccontextmanager +async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory + if chat_model.engine.name == EngineName.HF: + asyncio.create_task(sweeper()) + + yield + torch_gc() + + +def create_app(chat_model: "ChatModel") -> "FastAPI": + root_path = os.getenv("FASTAPI_ROOT_PATH", "") + app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + api_key = os.getenv("API_KEY") + security = HTTPBearer(auto_error=False) + + async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): + if api_key and (auth is None or auth.credentials != api_key): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") + + @app.get( + "/v1/models", + response_model=ModelList, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def list_models(): + model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo")) + return ModelList(data=[model_card]) + + @app.post( + "/v1/chat/completions", + response_model=ChatCompletionResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_chat_completion(request: ChatCompletionRequest): + if not chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + if request.stream: + generate = create_stream_chat_completion_response(request, chat_model) + return EventSourceResponse(generate, media_type="text/event-stream", sep="\n") + else: + return await create_chat_completion_response(request, chat_model) + + @app.post( + "/v1/score/evaluation", + response_model=ScoreEvaluationResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_score_evaluation(request: ScoreEvaluationRequest): + if chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + return await create_score_evaluation_response(request, chat_model) + + return app + + +def run_api() -> None: + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.getenv("API_HOST", "0.0.0.0") + api_port = int(os.getenv("API_PORT", "8000")) + print(f"Visit http://localhost:{api_port}/docs for API document.") + uvicorn.run(app, host=api_host, port=api_port) diff --git a/llamafactory/api/chat.py b/llamafactory/api/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..93236c5ca865492f0c45e1f5ab56a389875350ea --- /dev/null +++ b/llamafactory/api/chat.py @@ -0,0 +1,291 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io +import json +import os +import re +import uuid +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Optional + +from ..data import Role as DataRole +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.misc import is_env_enabled +from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available +from .common import check_lfi_path, check_ssrf_url, dictify, jsonify +from .protocol import ( + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseUsage, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + Finish, + Function, + FunctionCall, + Role, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import HTTPException, status + + +if is_pillow_available(): + from PIL import Image + + +if is_requests_available(): + import requests + + +if TYPE_CHECKING: + from ..chat import ChatModel + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from .protocol import ChatCompletionRequest, ScoreEvaluationRequest + + +logger = logging.get_logger(__name__) +ROLE_MAPPING = { + Role.USER: DataRole.USER.value, + Role.ASSISTANT: DataRole.ASSISTANT.value, + Role.SYSTEM: DataRole.SYSTEM.value, + Role.FUNCTION: DataRole.FUNCTION.value, + Role.TOOL: DataRole.OBSERVATION.value, +} + + +def _process_request( + request: "ChatCompletionRequest", +) -> tuple[ + list[dict[str, str]], + Optional[str], + Optional[str], + Optional[list["ImageInput"]], + Optional[list["VideoInput"]], + Optional[list["AudioInput"]], +]: + if is_env_enabled("API_VERBOSE", "1"): + logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") + + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") + + if request.messages[0].role == Role.SYSTEM: + content = request.messages.pop(0).content + system = content[0].text if isinstance(content, list) else content + else: + system = None + + if len(request.messages) % 2 == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + + input_messages = [] + images, videos, audios = [], [], [] + for i, message in enumerate(request.messages): + if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + + if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): + tool_calls = [ + {"name": tool_call.function.name, "arguments": tool_call.function.arguments} + for tool_call in message.tool_calls + ] + content = json.dumps(tool_calls, ensure_ascii=False) + input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + elif isinstance(message.content, list): + text_content = "" + for input_item in message.content: + if input_item.type == "text": + text_content += input_item.text + elif input_item.type == "image_url": + text_content += IMAGE_PLACEHOLDER + image_url = input_item.image_url.url + if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image + image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(image_url): # local file + check_lfi_path(image_url) + image_stream = open(image_url, "rb") + else: # web uri + check_ssrf_url(image_url) + image_stream = requests.get(image_url, stream=True).raw + + images.append(Image.open(image_stream).convert("RGB")) + elif input_item.type == "video_url": + text_content += VIDEO_PLACEHOLDER + video_url = input_item.video_url.url + if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video + video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(video_url): # local file + check_lfi_path(video_url) + video_stream = video_url + else: # web uri + check_ssrf_url(video_url) + video_stream = requests.get(video_url, stream=True).raw + + videos.append(video_stream) + elif input_item.type == "audio_url": + text_content += AUDIO_PLACEHOLDER + audio_url = input_item.audio_url.url + if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio + audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) + elif os.path.isfile(audio_url): # local file + check_lfi_path(audio_url) + audio_stream = audio_url + else: # web uri + check_ssrf_url(audio_url) + audio_stream = requests.get(audio_url, stream=True).raw + + audios.append(audio_stream) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}." + ) + + input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) + else: + input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) + + tool_list = request.tools + if isinstance(tool_list, list) and len(tool_list): + try: + tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) + except json.JSONDecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") + else: + tools = None + + return input_messages, system, tools, images or None, videos or None, audios or None + + +def _create_stream_chat_completion_chunk( + completion_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional["Finish"] = None, +) -> str: + choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) + return jsonify(chunk) + + +async def create_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> "ChatCompletionResponse": + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + input_messages, system, tools, images, videos, audios = _process_request(request) + responses = await chat_model.achat( + input_messages, + system, + tools, + images, + videos, + audios, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + num_return_sequences=request.n, + repetition_penalty=request.presence_penalty, + stop=request.stop, + ) + + prompt_length, response_length = 0, 0 + choices = [] + for i, response in enumerate(responses): + if tools: + result = chat_model.engine.template.extract_tool(response.response_text) + else: + result = response.response_text + + if isinstance(result, list): + tool_calls = [] + for tool in result: + function = Function(name=tool.name, arguments=tool.arguments) + tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function)) + + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) + finish_reason = Finish.TOOL + else: + response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) + finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + + choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) + prompt_length = response.prompt_length + response_length += response.response_length + + usage = ChatCompletionResponseUsage( + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length + response_length, + ) + + return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage) + + +async def create_stream_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> AsyncGenerator[str, None]: + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + input_messages, system, tools, images, videos, audios = _process_request(request) + if tools: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") + + if request.n > 1: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") + ) + async for new_token in chat_model.astream_chat( + input_messages, + system, + tools, + images, + videos, + audios, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + repetition_penalty=request.presence_penalty, + stop=request.stop, + ): + if len(new_token) != 0: + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) + ) + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP + ) + yield "[DONE]" + + +async def create_score_evaluation_response( + request: "ScoreEvaluationRequest", chat_model: "ChatModel" +) -> "ScoreEvaluationResponse": + score_id = f"scoreval-{uuid.uuid4().hex}" + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + + scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) + return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores) diff --git a/llamafactory/api/common.py b/llamafactory/api/common.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4e9602de7ebc10b4f15c68ad9167cb9d80d8ef --- /dev/null +++ b/llamafactory/api/common.py @@ -0,0 +1,96 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ipaddress +import json +import os +import socket +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +from ..extras.misc import is_env_enabled +from ..extras.packages import is_fastapi_available + + +if is_fastapi_available(): + from fastapi import HTTPException, status + + +if TYPE_CHECKING: + from pydantic import BaseModel + + +SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media")) +ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1") + + +def dictify(data: "BaseModel") -> dict[str, Any]: + try: # pydantic v2 + return data.model_dump(exclude_unset=True) + except AttributeError: # pydantic v1 + return data.dict(exclude_unset=True) + + +def jsonify(data: "BaseModel") -> str: + try: # pydantic v2 + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except AttributeError: # pydantic v1 + return data.json(exclude_unset=True, ensure_ascii=False) + + +def check_lfi_path(path: str) -> None: + """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe.""" + if not ALLOW_LOCAL_FILES: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.") + + try: + os.makedirs(SAFE_MEDIA_PATH, exist_ok=True) + real_path = os.path.realpath(path) + safe_path = os.path.realpath(SAFE_MEDIA_PATH) + + if not real_path.startswith(safe_path): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory." + ) + except Exception: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.") + + +def check_ssrf_url(url: str) -> None: + """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe.""" + try: + parsed_url = urlparse(url) + if parsed_url.scheme not in ["http", "https"]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.") + + hostname = parsed_url.hostname + if not hostname: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.") + + ip_info = socket.getaddrinfo(hostname, parsed_url.port) + ip_address_str = ip_info[0][4][0] + ip = ipaddress.ip_address(ip_address_str) + + if not ip.is_global: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access to private or reserved IP addresses is not allowed.", + ) + + except socket.gaierror: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}" + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}") diff --git a/llamafactory/api/protocol.py b/llamafactory/api/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..889d938e0b727ef1fca63d95fa20926c31830c52 --- /dev/null +++ b/llamafactory/api/protocol.py @@ -0,0 +1,157 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from enum import Enum, unique +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +@unique +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" + TOOL = "tool_calls" + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: Literal["owner"] = "owner" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: list[ModelCard] = [] + + +class Function(BaseModel): + name: str + arguments: str + + +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: dict[str, Any] + + +class FunctionAvailable(BaseModel): + type: Literal["function", "code_interpreter"] = "function" + function: Optional[FunctionDefinition] = None + + +class FunctionCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: Function + + +class URL(BaseModel): + url: str + detail: Literal["auto", "low", "high"] = "auto" + + +class MultimodalInputItem(BaseModel): + type: Literal["text", "image_url", "video_url", "audio_url"] + text: Optional[str] = None + image_url: Optional[URL] = None + video_url: Optional[URL] = None + audio_url: Optional[URL] = None + + +class ChatMessage(BaseModel): + role: Role + content: Optional[Union[str, list[MultimodalInputItem]]] = None + tool_calls: Optional[list[FunctionCall]] = None + + +class ChatCompletionMessage(BaseModel): + role: Optional[Role] = None + content: Optional[str] = None + tool_calls: Optional[list[FunctionCall]] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: list[ChatMessage] + tools: Optional[list[FunctionAvailable]] = None + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + n: int = 1 + presence_penalty: Optional[float] = None + max_tokens: Optional[int] = None + stop: Optional[Union[str, list[str]]] = None + stream: bool = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatCompletionMessage + finish_reason: Finish + + +class ChatCompletionStreamResponseChoice(BaseModel): + index: int + delta: ChatCompletionMessage + finish_reason: Optional[Finish] = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionStreamResponseChoice] + + +class ScoreEvaluationRequest(BaseModel): + model: str + messages: list[str] + max_length: Optional[int] = None + + +class ScoreEvaluationResponse(BaseModel): + id: str + object: Literal["score.evaluation"] = "score.evaluation" + model: str + scores: list[float] diff --git a/llamafactory/chat/__init__.py b/llamafactory/chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15d8b9ba2d77d6f300d59300da5a49abd3ed4e57 --- /dev/null +++ b/llamafactory/chat/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_engine import BaseEngine +from .chat_model import ChatModel + + +__all__ = ["BaseEngine", "ChatModel"] diff --git a/llamafactory/chat/__pycache__/__init__.cpython-312.pyc b/llamafactory/chat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1435170474306d36918011ef6af44e36b8fb8e44 Binary files /dev/null and b/llamafactory/chat/__pycache__/__init__.cpython-312.pyc differ diff --git a/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc b/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..358bdcffd32512c3649dea4159ee05f027a37181 Binary files /dev/null and b/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc differ diff --git a/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc b/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5abba6446be7ea5a10650a5e22a3dd0f1534446 Binary files /dev/null and b/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc differ diff --git a/llamafactory/chat/base_engine.py b/llamafactory/chat/base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..6d497c1ae927f94f396c18833b18cdb894cbd59d --- /dev/null +++ b/llamafactory/chat/base_engine.py @@ -0,0 +1,98 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + from vllm import AsyncLLMEngine + + from ..data import Template + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..extras.constants import EngineName + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +@dataclass +class Response: + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] + + +class BaseEngine(ABC): + r"""Base class for inference engine of chat models. + + Must implements async methods: chat(), stream_chat() and get_scores(). + """ + + name: "EngineName" + model: Union["PreTrainedModel", "AsyncLLMEngine"] + tokenizer: "PreTrainedTokenizer" + can_generate: bool + template: "Template" + generating_args: dict[str, Any] + + @abstractmethod + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + r"""Initialize an inference engine.""" + ... + + @abstractmethod + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" + ... + + @abstractmethod + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + r"""Get the response token-by-token of the chat model.""" + ... + + @abstractmethod + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Get a list of scores of the reward model.""" + ... diff --git a/llamafactory/chat/chat_model.py b/llamafactory/chat/chat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cb612f88d468d76f06eefa45b96c1bfa0351fa7c --- /dev/null +++ b/llamafactory/chat/chat_model.py @@ -0,0 +1,210 @@ +# Copyright 2025 THUDM and the LlamaFactory team. +# +# This code is inspired by the THUDM's ChatGLM implementation. +# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from collections.abc import AsyncGenerator, Generator +from threading import Thread +from typing import TYPE_CHECKING, Any, Optional + +from ..extras.constants import EngineName +from ..extras.misc import torch_gc +from ..hparams import get_infer_args + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from .base_engine import BaseEngine, Response + + +def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ChatModel: + r"""General class for chat models. Backed by huggingface or vllm engines. + + Supports both sync and async methods. + Sync methods: chat(), stream_chat() and get_scores(). + Async methods: achat(), astream_chat() and aget_scores(). + """ + + def __init__(self, args: Optional[dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + + if model_args.infer_backend == EngineName.HF: + from .hf_engine import HuggingfaceEngine + + self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == EngineName.VLLM: + try: + from .vllm_engine import VllmEngine + + self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "vLLM not install, you may need to run `pip install vllm`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + elif model_args.infer_backend == EngineName.SGLANG: + try: + from .sglang_engine import SGLangEngine + + self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "SGLang not install, you may need to run `pip install sglang[all]`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + elif model_args.infer_backend == EngineName.KT: + try: + from .kt_engine import KTransformersEngine + + self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "KTransformers not install, you may need to run `pip install ktransformers`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e + else: + raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") + + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) + self._thread.start() + + def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Get a list of responses of the chat model.""" + task = asyncio.run_coroutine_threadsafe( + self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop + ) + return task.result() + + async def achat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + r"""Asynchronously get a list of responses of the chat model.""" + return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs) + + def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> Generator[str, None, None]: + r"""Get the response token-by-token of the chat model.""" + generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs) + while True: + try: + task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) + yield task.result() + except StopAsyncIteration: + break + + async def astream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + r"""Asynchronously get the response token-by-token of the chat model.""" + async for new_token in self.engine.stream_chat( + messages, system, tools, images, videos, audios, **input_kwargs + ): + yield new_token + + def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Get a list of scores of the reward model.""" + task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) + return task.result() + + async def aget_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + r"""Asynchronously get a list of scores of the reward model.""" + return await self.engine.get_scores(batch_input, **input_kwargs) + + +def run_chat() -> None: + if os.name != "nt": + try: + import readline # noqa: F401 + except ImportError: + print("Install `readline` for a better experience.") + + chat_model = ChatModel() + messages = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + messages = [] + torch_gc() + print("History has been removed.") + continue + + messages.append({"role": "user", "content": query}) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(messages): + print(new_text, end="", flush=True) + response += new_text + print() + messages.append({"role": "assistant", "content": response}) diff --git a/llamafactory/chat/hf_engine.py b/llamafactory/chat/hf_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..adaaaa872786446fde2eba4a2a8f32f7ec4cc462 --- /dev/null +++ b/llamafactory/chat/hf_engine.py @@ -0,0 +1,412 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from collections.abc import AsyncGenerator +from threading import Thread +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import torch +from transformers import GenerationConfig, TextIteratorStreamer +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + from trl import PreTrainedModelWrapper + + from ..data import Template + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = logging.get_logger(__name__) + + +class HuggingfaceEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.HF + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) # must after fixing tokenizer to resize vocab + self.generating_args = generating_args.to_dict() + try: + asyncio.get_event_loop() + except RuntimeError: + logger.warning_rank0_once("There is no current event loop, creating a new one.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @staticmethod + def _process_args( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> tuple[dict[str, Any], int]: + mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} + if images is not None: + mm_input_dict.update({"images": images, "imglens": [len(images)]}) + if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None: + mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) + if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None: + mm_input_dict.update({"audios": audios, "audlens": [len(audios)]}) + if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = template.mm_plugin.process_messages( + messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) + prompt_ids, _ = template.mm_plugin.process_token_ids( + prompt_ids, + None, + mm_input_dict["images"], + mm_input_dict["videos"], + mm_input_dict["audios"], + tokenizer, + processor, + ) + prompt_length = len(prompt_ids) + inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.long) + + do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if stop is not None: + logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") + + generating_args = generating_args.copy() + generating_args.update( + dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature if temperature is not None else generating_args["temperature"], + top_p=top_p if top_p is not None else generating_args["top_p"], + top_k=top_k if top_k is not None else generating_args["top_k"], + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty + if repetition_penalty is not None + else generating_args["repetition_penalty"], + length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else generating_args["skip_special_tokens"], + eos_token_id=template.get_stop_token_ids(tokenizer), + pad_token_id=tokenizer.pad_token_id, + ) + ) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0 + generating_args["do_sample"] = True + generating_args["temperature"] = generating_args["temperature"] or 1.0 + + if not generating_args["temperature"]: + generating_args["do_sample"] = False + + if not generating_args["do_sample"]: + generating_args.pop("temperature", None) + generating_args.pop("top_p", None) + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=inputs, + attention_mask=attention_mask, + generation_config=GenerationConfig(**generating_args), + ) + + mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) + for key, value in mm_inputs.items(): + if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs + value = torch.stack(value) # assume they have same sizes + elif ( + isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor) + ): # for minicpmv inputs + value = torch.stack([torch.stack(v) for v in value]) + elif not isinstance(value, torch.Tensor): + value = torch.tensor(value) + + if torch.is_floating_point(value): # cast data dtype for paligemma + value = value.to(model.dtype) + + if key == "second_per_grid_ts": # qwen2.5vl special case + gen_kwargs[key] = value.tolist() + else: + gen_kwargs[key] = value.to(model.device) + + if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: + gen_kwargs["input_ids"] = inputs + gen_kwargs["tokenizer"] = tokenizer + if "audio_feature_lens" in mm_inputs: + gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"] + + gen_kwargs.pop("image_sizes", None) + + return gen_kwargs, prompt_length + + @staticmethod + @torch.inference_mode() + def _chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list["Response"]: + gen_kwargs, prompt_length = HuggingfaceEngine._process_args( + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + generate_output = model.generate(**gen_kwargs) + if isinstance(generate_output, tuple): + generate_output = generate_output[1][0] # post-process the minicpm_o output + + response_ids = generate_output[:, prompt_length:] + response = tokenizer.batch_decode( + response_ids, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), + clean_up_tokenization_spaces=True, + ) + results = [] + for i in range(len(response)): + eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append( + Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length", + ) + ) + + return results + + @staticmethod + @torch.inference_mode() + def _stream_chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: dict[str, Any], + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> Callable[[], str]: + gen_kwargs, _ = HuggingfaceEngine._process_args( + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + streamer = TextIteratorStreamer( + tokenizer, + skip_prompt=True, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), + ) + gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: list[str], + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list[float]: + max_length: Optional[int] = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs: dict[str, torch.Tensor] = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=False, + ).to(device) + values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1] + scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return scores + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + async with self.semaphore: + return await asyncio.to_thread(self._chat, *input_args) + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + images, + videos, + audios, + input_kwargs, + ) + async with self.semaphore: + stream = self._stream_chat(*input_args) + while True: + try: + yield await asyncio.to_thread(stream) + except StopAsyncIteration: + break + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + + input_args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self.semaphore: + return await asyncio.to_thread(self._get_scores, *input_args) diff --git a/llamafactory/chat/kt_engine.py b/llamafactory/chat/kt_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf3f4bb2b685ee971d538d29f0b6afa16956f2c --- /dev/null +++ b/llamafactory/chat/kt_engine.py @@ -0,0 +1,284 @@ +# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import platform +from collections.abc import AsyncGenerator +from threading import Thread +from typing import TYPE_CHECKING, Any, Optional + +import torch +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import EngineName +from ..model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + from trl import PreTrainedModelWrapper + + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled +from ktransformers.server.config.config import Config +from ktransformers.util.utils import ( + get_compute_capability, + prefill_and_generate_capture, +) +from ktransformers.util.vendors import GPUVendor, device_manager + + +logger = logging.get_logger(__name__) + + +class KTransformersEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.KT + self.can_generate = finetuning_args.stage == "sft" + + tok_mod = load_tokenizer(model_args) + self.tokenizer = tok_mod["tokenizer"] + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) + + self.generating_args = generating_args.to_dict() + self.max_new_tokens = model_args.kt_maxlen + self.use_cuda_graph = model_args.kt_use_cuda_graph + self.mode = model_args.kt_mode + self.force_think = model_args.kt_force_think + self.chunk_size = model_args.chunk_size + + try: + asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: list[str], + input_kwargs: Optional[dict[str, Any]] = {}, + ) -> list[float]: + max_length: Optional[int] = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=False, + ).to(device) + values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1] + scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return scores + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + paired = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools) + prompt_len = len(prompt_ids) + + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + + if "max_new_tokens" in self.generating_args: + max_tokens = int(self.generating_args["max_new_tokens"]) + elif "max_length" in self.generating_args: + gl = int(self.generating_args["max_length"]) + max_tokens = gl - prompt_len if gl > prompt_len else 1 + else: + max_tokens = self.max_new_tokens or 256 + + if max_length is not None: + max_tokens = max(max_length - prompt_len, 1) + if max_new_tokens is not None: + max_tokens = int(max_new_tokens) + max_tokens = max(1, int(max_tokens)) + + if self.mode == "long_context": + max_len_cfg = Config().long_context_config["max_seq_len"] + need = prompt_len + max_tokens + assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml" + + device = next(self.model.parameters()).device + input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) + if self.force_think: + think = torch.tensor( + [self.tokenizer.encode("\n", add_special_tokens=False)], dtype=torch.long, device=device + ) + input_tensor = torch.cat([input_tensor, think], dim=1) + + use_flashinfer = ( + platform.system() != "Windows" + and getattr(self.model.config, "architectures", [""])[0] + in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"} + and flashinfer_enabled + and get_compute_capability() >= 8 + and device_manager.gpu_vendor == GPUVendor.NVIDIA + ) + + def make_gen(): + if use_flashinfer: + return prefill_and_generate_capture( + self.model, + self.tokenizer, + input_tensor, + max_tokens, + self.use_cuda_graph, + mode=self.mode, + force_think=self.force_think, + chunk_size=self.chunk_size, + use_flashinfer_mla=True, + num_heads=self.model.config.num_attention_heads, + head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0), + head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0), + q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + + getattr(self.model.config, "qk_nope_head_dim", 0), + echo_stream=False, + ) + else: + return prefill_and_generate_capture( + self.model, + self.tokenizer, + input_tensor, + max_tokens, + self.use_cuda_graph, + mode=self.mode, + force_think=self.force_think, + chunk_size=self.chunk_size, + echo_stream=False, + ) + + loop = asyncio.get_running_loop() + q: asyncio.Queue[Optional[str]] = asyncio.Queue() + + def producer(): + try: + gen = make_gen() + if hasattr(gen, "__aiter__"): + + async def drain_async(): + async for t in gen: + loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t)) + + asyncio.run(drain_async()) + elif hasattr(gen, "__iter__"): + for t in gen: + loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t)) + else: + loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen)) + finally: + loop.call_soon_threadsafe(q.put_nowait, None) + + Thread(target=producer, daemon=True).start() + + while True: + item = await q.get() + if item is None: + break + yield item + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + async with self.semaphore: + produced = "" + final_text = "" + async for t in self._generate(messages, system, tools, **input_kwargs): + delta = t + produced = produced + delta + if delta: + final_text += delta + + prompt_ids, _ = self.template.encode_oneturn( + self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools + ) + return [ + Response( + response_text=final_text, + response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)), + prompt_length=len(prompt_ids), + finish_reason="stop", + ) + ] + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + async with self.semaphore: + produced = "" + async for t in self._generate(messages, system, tools, **input_kwargs): + delta = t[len(produced) :] if t.startswith(produced) else t + produced = t + if delta: + yield delta + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self.semaphore: + return await asyncio.to_thread(self._get_scores, *args) diff --git a/llamafactory/chat/sglang_engine.py b/llamafactory/chat/sglang_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d2ead33823bc70d51cda59750d25580f972083 --- /dev/null +++ b/llamafactory/chat/sglang_engine.py @@ -0,0 +1,289 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import json +from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +import requests +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..extras.misc import get_device_count, torch_gc +from ..extras.packages import is_sglang_available +from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from .base_engine import BaseEngine, Response + + +if is_sglang_available(): + from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class SGLangEngine(BaseEngine): + """Inference engine for SGLang models. + + This class wraps the SGLang engine to provide a consistent interface for text generation + that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for + better interaction and performance. The engine launches a server process and communicates + with it via HTTP requests. + + For more details on the SGLang HTTP server approach, see: + https://docs.sglang.ai/backend/send_request.html + """ + + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.SGLANG + self.model_args = model_args + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for sglang generate + self.generating_args = generating_args.to_dict() + if model_args.adapter_name_or_path is not None: + self.lora_request = True + else: + self.lora_request = False + + launch_cmd = [ + "python3 -m sglang.launch_server", + f"--model-path {model_args.model_name_or_path}", + f"--dtype {model_args.infer_dtype}", + f"--context-length {model_args.sglang_maxlen}", + f"--mem-fraction-static {model_args.sglang_mem_fraction}", + f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}", + f"--download-dir {model_args.cache_dir}", + "--log-level error", + ] + if self.lora_request: + launch_cmd.extend( + [ + "--max-loras-per-batch 1", + f"--lora-backend {model_args.sglang_lora_backend}", + f"--lora-paths lora0={model_args.adapter_name_or_path[0]}", + "--disable-radix-cache", + ] + ) + launch_cmd = " ".join(launch_cmd) + logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") + try: + torch_gc() + self.server_process, port = launch_server_cmd(launch_cmd) + self.base_url = f"http://localhost:{port}" + atexit.register(self._cleanup_server) + + logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}") + wait_for_server(self.base_url, timeout=300) + logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}") + try: + response = requests.get(f"{self.base_url}/get_model_info", timeout=5) + if response.status_code == 200: + model_info = response.json() + logger.info(f"SGLang server model info: {model_info}") + except Exception as e: + logger.debug(f"Note: could not get model info: {str(e)}") + + except Exception as e: + logger.error(f"Failed to start SGLang server: {str(e)}") + self._cleanup_server() # make sure to clean up any started process + raise RuntimeError(f"SGLang server initialization failed: {str(e)}.") + + def _cleanup_server(self): + r"""Clean up the server process when the engine is destroyed.""" + if hasattr(self, "server_process") and self.server_process: + try: + logger.info("Terminating SGLang server process") + terminate_process(self.server_process) + logger.info("SGLang server process terminated") + except Exception as e: + logger.warning(f"Error terminating SGLang server: {str(e)}") + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncIterator[dict[str, Any]]: + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages( + messages, images or [], videos or [], audios or [], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + prompt_length = len(prompt_ids) + + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if num_return_sequences != 1: + raise NotImplementedError("SGLang only supports n=1.") + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = { + "temperature": temperature if temperature is not None else self.generating_args["temperature"], + "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 + "stop": stop, + "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer), + "max_new_tokens": max_tokens, + "repetition_penalty": ( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + "skip_special_tokens": skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], + } + + def stream_request(): + json_data = { + "input_ids": prompt_ids, + "sampling_params": sampling_params, + "stream": True, + } + if self.lora_request: + json_data["lora_request"] = ["lora0"] + response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) + if response.status_code != 200: + raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") + + for chunk in response.iter_lines(decode_unicode=False): + chunk = str(chunk.decode("utf-8")) + if chunk == "data: [DONE]": + break + + if chunk and chunk.startswith("data:"): + yield json.loads(chunk[5:].strip("\n")) + + return await asyncio.to_thread(stream_request) + + @override + async def chat( + self, + messages: Sequence[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, + audios: Optional[Sequence["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for request_output in generator: + final_output = request_output + + results = [ + Response( + response_text=final_output["text"], + response_length=final_output["meta_info"]["completion_tokens"], + prompt_length=final_output["meta_info"]["prompt_tokens"], + finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length", + ) + ] + return results + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for result in generator: + delta_text = result["text"][len(generated_text) :] + generated_text = result["text"] + yield delta_text + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + raise NotImplementedError("SGLang engine does not support `get_scores`.") + + def __del__(self): + r"""Ensure server is cleaned up when object is deleted.""" + self._cleanup_server() + try: + atexit.unregister(self._cleanup_server) + except Exception: + pass diff --git a/llamafactory/chat/vllm_engine.py b/llamafactory/chat/vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b33705279da13d398f8089b94c72b22d742f2c6f --- /dev/null +++ b/llamafactory/chat/vllm_engine.py @@ -0,0 +1,263 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from collections.abc import AsyncGenerator, AsyncIterator +from typing import TYPE_CHECKING, Any, Optional, Union + +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..extras.misc import get_device_count +from ..extras.packages import is_vllm_available +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM +from .base_engine import BaseEngine, Response + + +if is_vllm_available(): + from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + from vllm.lora.request import LoRARequest + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = logging.get_logger(__name__) + + +class VllmEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.VLLM + self.model_args = model_args + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for vllm generate + self.generating_args = generating_args.to_dict() + + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": model_args.trust_remote_code, + "download_dir": model_args.cache_dir, + "dtype": model_args.infer_dtype, + "max_model_len": model_args.vllm_maxlen, + "tensor_parallel_size": get_device_count() or 1, + "gpu_memory_utilization": model_args.vllm_gpu_util, + "disable_log_stats": True, + "disable_log_requests": True, + "enforce_eager": model_args.vllm_enforce_eager, + "enable_lora": model_args.adapter_name_or_path is not None, + "max_lora_rank": model_args.vllm_max_lora_rank, + } + if self.template.mm_plugin.__class__.__name__ != "BasePlugin": + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} + + if isinstance(model_args.vllm_config, dict): + engine_args.update(model_args.vllm_config) + + if getattr(config, "is_yi_vl_derived_model", None): + import vllm.model_executor.models.llava + + logger.info_rank0("Detected Yi-VL model, applying projector patch.") + vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM + + self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) + if model_args.adapter_name_or_path is not None: + self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + self.lora_request = None + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncIterator["RequestOutput"]: + request_id = f"chatcmpl-{uuid.uuid4().hex}" + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages( + messages, images or [], videos or [], audios or [], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + prompt_length = len(prompt_ids) + + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if length_penalty is not None: + logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = SamplingParams( + n=num_return_sequences, + repetition_penalty=( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + temperature=temperature if temperature is not None else self.generating_args["temperature"], + top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 + stop=stop, + stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), + max_tokens=max_tokens, + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], + ) + + if images is not None: # add image features + multi_modal_data = { + "image": self.template.mm_plugin._regularize_images( + images, + image_max_pixels=self.model_args.image_max_pixels, + image_min_pixels=self.model_args.image_min_pixels, + )["images"] + } + elif videos is not None: + multi_modal_data = { + "video": self.template.mm_plugin._regularize_videos( + videos, + image_max_pixels=self.model_args.video_max_pixels, + image_min_pixels=self.model_args.video_min_pixels, + video_fps=self.model_args.video_fps, + video_maxlen=self.model_args.video_maxlen, + )["videos"] + } + elif audios is not None: + audio_data = self.template.mm_plugin._regularize_audios( + audios, + sampling_rate=self.model_args.audio_sampling_rate, + ) + multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} + else: + multi_modal_data = None + + result_generator = self.model.generate( + {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, + sampling_params=sampling_params, + request_id=request_id, + lora_request=self.lora_request, + ) + return result_generator + + @override + async def chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + async for request_output in generator: + final_output = request_output + + results = [] + for output in final_output.outputs: + results.append( + Response( + response_text=output.text, + response_length=len(output.token_ids), + prompt_length=len(final_output.prompt_token_ids), + finish_reason=output.finish_reason, + ) + ) + + return results + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + async for result in generator: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + raise NotImplementedError("vLLM engine does not support `get_scores`.") diff --git a/llamafactory/cli.py b/llamafactory/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..d574bf1db543f5379f074e276898826234708037 --- /dev/null +++ b/llamafactory/cli.py @@ -0,0 +1,31 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def main(): + from .extras.misc import is_env_enabled + + if is_env_enabled("USE_V1"): + from .v1 import launcher + else: + from . import launcher + + launcher.launch() + + +if __name__ == "__main__": + from multiprocessing import freeze_support + + freeze_support() + main() diff --git a/llamafactory/data/__init__.py b/llamafactory/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11c8c9fcecd10e736e240196fde98f833c9df3dc --- /dev/null +++ b/llamafactory/data/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .collator import ( + KTODataCollatorWithPadding, + MultiModalDataCollatorForSeq2Seq, + PairwiseDataCollatorWithPadding, + SFTDataCollatorWith4DAttentionMask, +) +from .data_utils import Role, split_dataset +from .loader import get_dataset +from .template import TEMPLATES, Template, get_template_and_fix_tokenizer + + +__all__ = [ + "TEMPLATES", + "KTODataCollatorWithPadding", + "MultiModalDataCollatorForSeq2Seq", + "PairwiseDataCollatorWithPadding", + "Role", + "SFTDataCollatorWith4DAttentionMask", + "Template", + "get_dataset", + "get_template_and_fix_tokenizer", + "split_dataset", +] diff --git a/llamafactory/data/__pycache__/__init__.cpython-312.pyc b/llamafactory/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aeb4c34a1136f093a1b21a6f280408dd1286477 Binary files /dev/null and b/llamafactory/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/collator.cpython-312.pyc b/llamafactory/data/__pycache__/collator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..912142ec8dd6bc455b401e5935f10e8190a6bdf3 Binary files /dev/null and b/llamafactory/data/__pycache__/collator.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/converter.cpython-312.pyc b/llamafactory/data/__pycache__/converter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b29a1575867106edc4f03a5b02406ad72871b1c5 Binary files /dev/null and b/llamafactory/data/__pycache__/converter.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/data_utils.cpython-312.pyc b/llamafactory/data/__pycache__/data_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4276f26997c27740fbc9ec5faaf3e99b3442d7eb Binary files /dev/null and b/llamafactory/data/__pycache__/data_utils.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/formatter.cpython-312.pyc b/llamafactory/data/__pycache__/formatter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8f78d7f723f0fb3ea2ec501c517ee4dbdad4e49 Binary files /dev/null and b/llamafactory/data/__pycache__/formatter.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/loader.cpython-312.pyc b/llamafactory/data/__pycache__/loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8232883361bdf807fa8523fdbca05e1b992c535 Binary files /dev/null and b/llamafactory/data/__pycache__/loader.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc b/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eff12f0ba2ac7b9f274742c73de5198970487bd5 Binary files /dev/null and b/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/parser.cpython-312.pyc b/llamafactory/data/__pycache__/parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60110a599edbc6e1a3ceac15438706301ab9a77 Binary files /dev/null and b/llamafactory/data/__pycache__/parser.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/template.cpython-312.pyc b/llamafactory/data/__pycache__/template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df03dd6517d2b90e1d8c61583f02f7da39977c9d Binary files /dev/null and b/llamafactory/data/__pycache__/template.cpython-312.pyc differ diff --git a/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc b/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7989ecbc820bda2775d2bc6f82b90fcd446ea4c2 Binary files /dev/null and b/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc differ diff --git a/llamafactory/data/collator.py b/llamafactory/data/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..162f432c9e5bf195ed4c6a821eb36d279bf3bac4 --- /dev/null +++ b/llamafactory/data/collator.py @@ -0,0 +1,331 @@ +# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team. +# +# This code is inspired by the OpenAccess AI Collective's axolotl library. +# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from peft import PeftModel +from transformers import DataCollatorForSeq2Seq + +from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER +from ..extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + + from .template import Template + + +def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": + r"""Expand 2d attention mask to 4d attention mask. + + Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + handle packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ```python + # input + [[1, 1, 2, 2, 2, 0]] + # output + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ] + ] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + _, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + zero_tensor = torch.tensor(0, dtype=dtype) + + # Create a non-padding mask. + non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2) + # Create indices for comparison. + indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len] + indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1] + # Create a lower triangular mask. + tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool)) + attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype) + return attention_mask_4d + + +@dataclass +class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + r"""Data collator that supports VLMs. + + Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios. + """ + + template: Optional["Template"] = None + processor: Optional["ProcessorMixin"] = None + + def __post_init__(self): + if self.template is None: + raise ValueError("Template is required for MultiModalDataCollator.") + + if isinstance(self.model, PeftModel): + self.model = self.model.base_model.model + + if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope + self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni + elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"): + self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0 + else: + self.get_rope_func = None + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + batch_images, batch_videos, batch_audios = [], [], [] + batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] + for feature in features: + images = feature.pop("images", None) or [] + videos = feature.pop("videos", None) or [] + audios = feature.pop("audios", None) or [] + batch_images.extend(images) + batch_videos.extend(videos) + batch_audios.extend(audios) + batch_imglens.append(len(images)) + batch_vidlens.append(len(videos)) + batch_audlens.append(len(audios)) + batch_input_ids.append(feature["input_ids"]) + + fake_input_ids = [] + if ( + self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 + ): # avoid process hanging in zero3/fsdp case + fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] + fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] + fake_messages = self.template.mm_plugin.process_messages( + fake_messages, fake_images, [], [], self.processor + ) + _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) + _fake_input_ids, _ = self.template.mm_plugin.process_token_ids( + _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor + ) + fake_input_ids.extend(_fake_input_ids) + batch_images = fake_images + batch_imglens[0] = 1 + + if ( + self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 + ): # avoid process hanging in zero3/fsdp case + fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}] + fake_audios = [np.zeros(1600)] + fake_messages = self.template.mm_plugin.process_messages( + fake_messages, [], [], fake_audios, self.processor + ) + _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) + _fake_input_ids, _ = self.template.mm_plugin.process_token_ids( + _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor + ) + fake_input_ids.extend(_fake_input_ids) + batch_audios = fake_audios + batch_audlens[0] = 1 + + if len(fake_input_ids) != 0: + if self.tokenizer.padding_side == "right": + features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids + features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) + features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids) + else: + features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"] + features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"] + features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"] + + batch_input_ids[0] = features[0]["input_ids"] + + mm_inputs = self.template.mm_plugin.get_mm_inputs( + batch_images, + batch_videos, + batch_audios, + batch_imglens, + batch_vidlens, + batch_audlens, + batch_input_ids, + self.processor, + ) + if "token_type_ids" in mm_inputs: + token_type_ids = mm_inputs.pop("token_type_ids") + for i, feature in enumerate(features): + feature["token_type_ids"] = token_type_ids[i] + + features: dict[str, torch.Tensor] = super().__call__(features) + + if self.get_rope_func is not None: + rope_index_kwargs = { + "input_ids": features["input_ids"], + "image_grid_thw": mm_inputs.get("image_grid_thw"), + "video_grid_thw": mm_inputs.get("video_grid_thw"), + "attention_mask": (features["attention_mask"] >= 1).float(), + } + if "second_per_grid_ts" in mm_inputs: # for qwen2vl + rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") + elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni + rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") + + if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]: + rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) + feature_attention_mask = mm_inputs.get("feature_attention_mask", None) + if feature_attention_mask is not None: # FIXME: need to get video image lengths + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input + + features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) + features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( + dim=-1 + ).unsqueeze(-1) + else: # for qwen vl + features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) + + if ( + self.model is not None + and getattr(self.model.config, "model_type", None) + in [ + "glm4v", + "Keye", + "qwen2_vl", + "qwen2_5_vl", + "qwen2_5_omni_thinker", + "qwen3_omni_moe_thinker", + "qwen3_vl", + "qwen3_vl_moe", + ] + and ("position_ids" not in features or features["position_ids"].dim() != 3) + ): + raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.") + + if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled + cross_attention_mask = mm_inputs.pop("cross_attention_mask") + seq_len = features["input_ids"].size(1) + orig_len = cross_attention_mask.size(1) + mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len)) + + features.update(mm_inputs) + + if "image_bound" in features: # for minicpmv inputs + bsz, seq_length = features["input_ids"].shape + features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1) + return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]} + + return features + + +@dataclass +class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): + r"""Data collator for 4d attention mask.""" + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + + for key, value in features.items(): # cast data dtype for paligemma + if torch.is_tensor(value) and torch.is_floating_point(value): + features[key] = value.to(self.compute_dtype) + + return features + + +@dataclass +class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): + r"""Data collator for pairwise data.""" + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + r"""Pad batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + for key in ("chosen", "rejected"): + for feature in features: + target_feature = { + "input_ids": feature[f"{key}_input_ids"], + "attention_mask": feature[f"{key}_attention_mask"], + "labels": feature[f"{key}_labels"], + "images": feature["images"], + "videos": feature["videos"], + "audios": feature["audios"], + } + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) + + +@dataclass +class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): + r"""Data collator for KTO data.""" + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + target_features = [] + kl_features = [] + kto_tags = [] + for feature in features: + target_feature = { + "input_ids": feature["input_ids"], + "attention_mask": feature["attention_mask"], + "labels": feature["labels"], + "images": feature["images"], + "videos": feature["videos"], + "audios": feature["audios"], + } + kl_feature = { + "input_ids": feature["kl_input_ids"], + "attention_mask": feature["kl_attention_mask"], + "labels": feature["kl_labels"], + "images": feature["images"], + "videos": feature["videos"], + "audios": feature["audios"], + } + target_features.append(target_feature) + kl_features.append(kl_feature) + kto_tags.append(feature["kto_tags"]) + + batch = super().__call__(target_features) + kl_batch = super().__call__(kl_features) + batch["kl_input_ids"] = kl_batch["input_ids"] + batch["kl_attention_mask"] = kl_batch["attention_mask"] + batch["kl_labels"] = kl_batch["labels"] + if "cross_attention_mask" in kl_batch: # for mllama inputs + batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"] + + if "token_type_ids" in kl_batch: + batch["kl_token_type_ids"] = kl_batch["token_type_ids"] + + batch["kto_tags"] = torch.tensor(kto_tags) + return batch diff --git a/llamafactory/data/converter.py b/llamafactory/data/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3735e648eee7117154c5301e8adae746a566d7 --- /dev/null +++ b/llamafactory/data/converter.py @@ -0,0 +1,425 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +from ..extras import logging +from .data_utils import Role + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from ..hparams import DataArguments + from .mm_plugin import AudioInput, ImageInput, VideoInput + from .parser import DatasetAttr + + MediaType = Union[ImageInput, VideoInput, AudioInput] + + +logger = logging.get_logger(__name__) + + +@dataclass +class DatasetConverter: + dataset_attr: "DatasetAttr" + data_args: "DataArguments" + + def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]: + r"""Optionally concatenate media path to media dir when loading from local disk.""" + if medias is None: + return None + elif not isinstance(medias, list): + medias = [medias] + elif len(medias) == 0: + return None + else: + medias = medias[:] + + if self.dataset_attr.load_from in ["script", "file"]: + if isinstance(medias[0], str): + for i in range(len(medias)): + media_path = os.path.join(self.data_args.media_dir, medias[i]) + if os.path.isfile(media_path): + medias[i] = media_path + else: + logger.warning_rank0_once( + f"Media {medias[i]} does not exist in `media_dir`. Use original path." + ) + elif isinstance(medias[0], list): # for processed video frames + # medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]] + for i in range(len(medias)): + for j in range(len(medias[i])): + media_path = os.path.join(self.data_args.media_dir, medias[i][j]) + if os.path.isfile(media_path): + medias[i][j] = media_path + else: + logger.warning_rank0_once( + f"Media {medias[i][j]} does not exist in `media_dir`. Use original path." + ) + + return medias + + @abstractmethod + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + r"""Convert a single example in the dataset to the standard format.""" + ... + + +@dataclass +class AlpacaDatasetConverter(DatasetConverter): + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + prompt = [] + if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list): + for old_prompt, old_response in example[self.dataset_attr.history]: + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + + query = [] + if self.dataset_attr.prompt and example[self.dataset_attr.prompt]: + query.append(example[self.dataset_attr.prompt]) + + if self.dataset_attr.query and example[self.dataset_attr.query]: + query.append(example[self.dataset_attr.query]) + + prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery" + + if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], str) + and isinstance(example[self.dataset_attr.rejected], str) + ): # pairwise example + response = [ + {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]}, + {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]}, + ] + elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example + response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}] + else: # unsupervised + response = [] + + output = { + "_prompt": prompt, + "_response": response, + "_system": example[self.dataset_attr.system] if self.dataset_attr.system else "", + "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "", + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + +@dataclass +class SharegptDatasetConverter(DatasetConverter): + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + tag_mapping = { + self.dataset_attr.user_tag: Role.USER.value, + self.dataset_attr.assistant_tag: Role.ASSISTANT.value, + self.dataset_attr.observation_tag: Role.OBSERVATION.value, + self.dataset_attr.function_tag: Role.FUNCTION.value, + self.dataset_attr.system_tag: Role.SYSTEM.value, + } + odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag) + even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + messages = example[self.dataset_attr.messages] + if ( + self.dataset_attr.system_tag + and len(messages) != 0 + and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag + ): + system = messages[0][self.dataset_attr.content_tag] + messages = messages[1:] + else: + system = example[self.dataset_attr.system] if self.dataset_attr.system else "" + + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + logger.warning_rank0(f"Invalid role tag in {messages}.") + broken_data = True + break + + aligned_messages.append( + { + "role": tag_mapping[message[self.dataset_attr.role_tag]], + "content": message[self.dataset_attr.content_tag], + } + ) + + if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + self.dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning_rank0(f"Invalid message count in {messages}.") + broken_data = True + + if broken_data: + logger.warning_rank0("Skipping this abnormal example.") + prompt, response = [], [] + elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], dict) + and isinstance(example[self.dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[self.dataset_attr.chosen] + rejected = example[self.dataset_attr.rejected] + if ( + chosen[self.dataset_attr.role_tag] not in accept_tags[-1] + or rejected[self.dataset_attr.role_tag] not in accept_tags[-1] + ): + logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + { + "role": tag_mapping[chosen[self.dataset_attr.role_tag]], + "content": chosen[self.dataset_attr.content_tag], + }, + { + "role": tag_mapping[rejected[self.dataset_attr.role_tag]], + "content": rejected[self.dataset_attr.content_tag], + }, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "", + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + +@dataclass +class OpenAIDatasetConverter(DatasetConverter): + def __call__(self, example: dict[str, Any]) -> dict[str, Any]: + tag_mapping = { + self.dataset_attr.user_tag: Role.USER.value, + self.dataset_attr.assistant_tag: Role.ASSISTANT.value, + self.dataset_attr.observation_tag: Role.OBSERVATION.value, + self.dataset_attr.function_tag: Role.FUNCTION.value, + self.dataset_attr.system_tag: Role.SYSTEM.value, + } + + messages = example[self.dataset_attr.messages] + if ( + self.dataset_attr.system_tag + and len(messages) != 0 + and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag + ): + system = messages[0][self.dataset_attr.content_tag] + messages = messages[1:] + else: + system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else "" + + aligned_messages = [] + tool_responses = [] + broken_data = False + for turn_idx, message in enumerate(messages): + role = message[self.dataset_attr.role_tag] + content = message[self.dataset_attr.content_tag] + + if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]: + if "tool_calls" in message and len(message["tool_calls"]) > 0: + tool_calls_list = [tool["function"] for tool in message["tool_calls"]] + content = json.dumps(tool_calls_list, ensure_ascii=False) + role = self.dataset_attr.function_tag + + if role == self.dataset_attr.observation_tag: + tool_responses.append(content) + continue + elif len(tool_responses) > 0: + _content = "\n\n\n".join(tool_responses) + aligned_messages.append( + { + "role": Role.OBSERVATION.value, + "content": _content, + } + ) + tool_responses = [] + + aligned_messages.append( + { + "role": tag_mapping[role], + "content": content, + } + ) + + odd_tags = (Role.USER.value, Role.OBSERVATION.value) + even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value) + accept_tags = (odd_tags, even_tags) + for turn_idx, message in enumerate(aligned_messages): + if message["role"] not in accept_tags[turn_idx % 2]: + logger.warning_rank0(f"Invalid role tag in {messages}.") + broken_data = True + break + + if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + self.dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning_rank0(f"Invalid message count in {messages}.") + broken_data = True + + if broken_data: + logger.warning_rank0("Skipping this abnormal example.") + prompt, response = [], [] + elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if example[self.dataset_attr.kto_tag]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + self.dataset_attr.ranking + and isinstance(example[self.dataset_attr.chosen], dict) + and isinstance(example[self.dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[self.dataset_attr.chosen] + rejected = example[self.dataset_attr.rejected] + if ( + chosen[self.dataset_attr.role_tag] not in accept_tags[-1] + or rejected[self.dataset_attr.role_tag] not in accept_tags[-1] + ): + logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + { + "role": tag_mapping[chosen[self.dataset_attr.role_tag]], + "content": chosen[self.dataset_attr.content_tag], + }, + { + "role": tag_mapping[rejected[self.dataset_attr.role_tag]], + "content": rejected[self.dataset_attr.content_tag], + }, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else "" + if isinstance(tools, dict) or isinstance(tools, list): + tools = json.dumps(tools, ensure_ascii=False) + + short_system_prompt = "detailed thinking off" + if not system: + if not tools: + system = short_system_prompt + else: + pass + else: + if not tools: + if "detailed thinking on" in system or "detailed thinking off" in system: + pass + else: + system += "\n" + short_system_prompt + else: + system += "\n" + + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": tools, + "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, + "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, + "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + } + return output + + +DATASET_CONVERTERS = { + "alpaca": AlpacaDatasetConverter, + "sharegpt": SharegptDatasetConverter, + "openai": OpenAIDatasetConverter, +} + + +def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None: + r"""Register a new dataset converter.""" + if name in DATASET_CONVERTERS: + raise ValueError(f"Dataset converter {name} already exists.") + + DATASET_CONVERTERS[name] = dataset_converter + + +def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter": + r"""Get a dataset converter.""" + if name not in DATASET_CONVERTERS: + raise ValueError(f"Dataset converter {name} not found.") + + return DATASET_CONVERTERS[name](dataset_attr, data_args) + + +def align_dataset( + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + r"""Align the dataset to a specific format. + + Aligned dataset: + _prompt: [{"role": "user", "content": "..."}] * (2T - 1) + _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + _system: "..." + _tools: "..." + _images: [] + _videos: [] + _audios: [] + """ + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Converting format of dataset", + ) + + dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args) + return dataset.map( + dataset_converter, + batched=False, + remove_columns=column_names, + **kwargs, + ) diff --git a/llamafactory/data/data_utils.py b/llamafactory/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14e261290a032466c87772915eb9f472e08d14fa --- /dev/null +++ b/llamafactory/data/data_utils.py @@ -0,0 +1,190 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from enum import Enum, unique +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union + +import fsspec +from datasets import DatasetDict, concatenate_datasets, interleave_datasets + +from ..extras import logging + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + + from ..hparams import DataArguments + + +logger = logging.get_logger(__name__) + + +SLOTS = list[Union[str, set[str], dict[str, str]]] + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + OBSERVATION = "observation" + + +class DatasetModule(TypedDict): + train_dataset: Optional[Union["Dataset", "IterableDataset"]] + eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]] + + +def merge_dataset( + all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int +) -> Union["Dataset", "IterableDataset"]: + r"""Merge multiple datasets to a unified dataset.""" + if len(all_datasets) == 1: + return all_datasets[0] + + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") + + return concatenate_datasets(all_datasets) + + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") + + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", + ) + + else: + raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.") + + +def split_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]], + data_args: "DataArguments", + seed: int, +) -> "DatasetDict": + r"""Split the dataset and returns a dataset dict containing train set and validation set. + + Support both map dataset and iterable dataset. + """ + if eval_dataset is not None and data_args.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + + if data_args.val_size > 1e-6: + if data_args.streaming: + dataset_dict["validation"] = dataset.take(int(data_args.val_size)) + dataset_dict["train"] = dataset.skip(int(data_args.val_size)) + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed) + dataset = dataset.train_test_split(test_size=val_size, seed=seed) + dataset_dict = {"train": dataset["train"], "validation": dataset["test"]} + else: + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) + else: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + + dataset_dict["validation"] = eval_dataset + + return DatasetDict(dataset_dict) + + +def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": + r"""Convert dataset or dataset dict to dataset module.""" + dataset_module: DatasetModule = {} + if isinstance(dataset, DatasetDict): # dataset dict + if "train" in dataset: + dataset_module["train_dataset"] = dataset["train"] + + if "validation" in dataset: + dataset_module["eval_dataset"] = dataset["validation"] + else: + eval_dataset = {} + for key in dataset.keys(): + if key.startswith("validation_"): + eval_dataset[key[len("validation_") :]] = dataset[key] + + if len(eval_dataset): + dataset_module["eval_dataset"] = eval_dataset + + else: # single dataset + dataset_module["train_dataset"] = dataset + + return dataset_module + + +def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem": + r"""Set up a filesystem object based on the path protocol.""" + storage_options = {"anon": anon} if anon else {} + if path.startswith("s3://"): + fs = fsspec.filesystem("s3", **storage_options) + elif path.startswith(("gs://", "gcs://")): + fs = fsspec.filesystem("gcs", **storage_options) + else: + raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.") + + if not fs.exists(path): + raise ValueError(f"Path does not exist: {path}.") + + return fs + + +def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]: + r"""Helper function to read JSON/JSONL files using fsspec.""" + with fs.open(path, "r") as f: + if path.endswith(".jsonl"): + return [json.loads(line) for line in f if line.strip()] + else: + return json.load(f) + + +def read_cloud_json(cloud_path: str) -> list[Any]: + r"""Read a JSON/JSONL file from cloud storage (S3 or GCS). + + Args: + cloud_path: str + Cloud path in the format: + - 's3://bucket-name/file.json' for AWS S3 + - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage + """ + try: + fs = setup_fs(cloud_path, anon=True) # try with anonymous access first + except Exception: + fs = setup_fs(cloud_path) # try again with credentials + + # filter out non-JSON files + files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path] + files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files) + if not files: + raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.") + + return sum([_read_json_with_fs(fs, file) for file in files], []) diff --git a/llamafactory/data/formatter.py b/llamafactory/data/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9b527a7ef0093ef1f7d462639780d5330023c4e9 --- /dev/null +++ b/llamafactory/data/formatter.py @@ -0,0 +1,145 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional, Union + +from typing_extensions import override + +from .data_utils import SLOTS +from .tool_utils import FunctionCall, get_tool_utils + + +@dataclass +class Formatter(ABC): + slots: SLOTS = field(default_factory=list) + tool_format: Optional[str] = None + + @abstractmethod + def apply(self, **kwargs) -> SLOTS: + r"""Forms a list of slots according to the inputs to encode.""" + ... + + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract a list of tuples from the response message if using tools. + + Each tuple consists of function name and function arguments. + """ + raise NotImplementedError + + +@dataclass +class EmptyFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if has_placeholder: + raise ValueError("Empty formatter should not contain any placeholder.") + + @override + def apply(self, **kwargs) -> SLOTS: + return self.slots + + +@dataclass +class StringFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if not has_placeholder: + raise ValueError("A placeholder is required in the string formatter.") + + @override + def apply(self, **kwargs) -> SLOTS: + elements = [] + for slot in self.slots: + if isinstance(slot, str): + for name, value in kwargs.items(): + if not isinstance(value, str): + raise RuntimeError(f"Expected a string, got {value}") + + slot = slot.replace("{{" + name + "}}", value, 1) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.") + + return elements + + +@dataclass +class FunctionFormatter(StringFormatter): + def __post_init__(self): + super().__post_init__() + self.tool_utils = get_tool_utils(self.tool_format) + + @override + def apply(self, **kwargs) -> SLOTS: + content: str = kwargs.pop("content") + thought_words, thought = kwargs.pop("thought_words", None), None + if thought_words and len(thought_words) == 2: + regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL) + thought = re.search(regex, content) + + if thought: + content = content.replace(thought.group(0), "") + + functions: list[FunctionCall] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append( + FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) + ) + + except json.JSONDecodeError: + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string + + function_str = self.tool_utils.function_formatter(functions) + if thought: + function_str = thought.group(0) + function_str + + return super().apply(content=function_str) + + +@dataclass +class ToolFormatter(Formatter): + def __post_init__(self): + self.tool_utils = get_tool_utils(self.tool_format) + + @override + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string + + @override + def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + return self.tool_utils.tool_extractor(content) diff --git a/llamafactory/data/loader.py b/llamafactory/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb13455beb1172930afd26154db72bbe387df16 --- /dev/null +++ b/llamafactory/data/loader.py @@ -0,0 +1,334 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Literal, Optional, Union + +import numpy as np +from datasets import Dataset, load_dataset, load_from_disk + +from ..extras import logging +from ..extras.constants import FILEEXT2TYPE +from ..extras.misc import check_version, has_tokenized_data +from .converter import align_dataset +from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset +from .parser import get_dataset_list +from .processor import ( + FeedbackDatasetProcessor, + PackedSupervisedDatasetProcessor, + PairwiseDatasetProcessor, + PretrainDatasetProcessor, + SupervisedDatasetProcessor, + UnsupervisedDatasetProcessor, +) + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + + from ..hparams import DataArguments, ModelArguments + from .data_utils import DatasetModule + from .parser import DatasetAttr + from .processor import DatasetProcessor + from .template import Template + + +logger = logging.get_logger(__name__) + + +def _load_single_dataset( + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + r"""Load a single dataset and aligns it to the standard format.""" + logger.info_rank0(f"Loading dataset {dataset_attr}...") + data_path, data_name, data_dir, data_files = None, None, None, None + if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: + data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "cloud_file": + data_path = dataset_attr.dataset_name + + elif dataset_attr.load_from == "file": + data_files = [] + local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + if os.path.isdir(local_path): # is directory + for file_name in os.listdir(local_path): + data_files.append(os.path.join(local_path, file_name)) + elif os.path.isfile(local_path): # is file + data_files.append(local_path) + else: + raise ValueError(f"File {local_path} not found.") + + data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None) + if data_path is None: + raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) + + if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files): + raise ValueError("File types should be identical.") + else: + raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") + + if dataset_attr.load_from == "ms_hub": + check_version("modelscope>=1.14.0", mandatory=True) + from modelscope import MsDataset # type: ignore + from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore + + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=data_args.streaming, + ) + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() + + elif dataset_attr.load_from == "om_hub": + check_version("openmind>=0.8.0", mandatory=True) + from openmind import OmDataset # type: ignore + from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore + + cache_dir = model_args.cache_dir or OM_DATASETS_CACHE + dataset = OmDataset.load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.om_hub_token, + streaming=data_args.streaming, + ) + elif dataset_attr.load_from == "cloud_file": + dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split) + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + num_proc=data_args.preprocessing_num_workers, + streaming=data_args.streaming and dataset_attr.load_from != "file", + ) + if data_args.streaming and dataset_attr.load_from == "file": + dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers) + + if dataset_attr.num_samples is not None and not data_args.streaming: + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." + dataset = dataset.select(indexes) + logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.") + + if data_args.max_samples is not None: # truncate dataset + max_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(max_samples)) + + return align_dataset(dataset, dataset_attr, data_args, training_args) + + +def _get_merged_dataset( + dataset_names: Optional[list[str]], + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + return_dict: bool = False, +) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: + r"""Return the merged datasets in the standard format.""" + if dataset_names is None: + return None + + datasets = {} + for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + + datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) + + if return_dict: + return datasets + else: + return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed) + + +def _get_dataset_processor( + data_args: "DataArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + do_generate: bool = False, +) -> "DatasetProcessor": + r"""Return the corresponding dataset processor.""" + if stage == "pt": + dataset_processor_class = PretrainDatasetProcessor + elif stage == "sft" and not do_generate: + if data_args.packing: + if data_args.neat_packing: # hack datasets to have int32 attention mask + from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence + + def __init__(self, data, **kwargs): + return TypedSequence.__init__( + self, + data, + type=kwargs.pop("type", None), + try_type=kwargs.pop("try_type", None), + optimized_int_type=kwargs.pop("optimized_int_type", None), + ) + + OptimizedTypedSequence.__init__ = __init__ + dataset_processor_class = PackedSupervisedDatasetProcessor + else: + dataset_processor_class = SupervisedDatasetProcessor + + elif stage == "rm": + dataset_processor_class = PairwiseDatasetProcessor + elif stage == "kto": + dataset_processor_class = FeedbackDatasetProcessor + else: + dataset_processor_class = UnsupervisedDatasetProcessor + + return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args) + + +def _get_preprocessed_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, + is_eval: bool = False, +) -> Optional[Union["Dataset", "IterableDataset"]]: + r"""Preprocesses the dataset, including format checking and tokenization.""" + if dataset is None: + return None + + dataset_processor = _get_dataset_processor( + data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running tokenizer on dataset", + ) + + dataset = dataset.map( + dataset_processor.preprocess_dataset, + batched=True, + batch_size=data_args.preprocessing_batch_size, + remove_columns=column_names, + **kwargs, + ) + + if training_args.should_log: + try: + print("eval example:" if is_eval else "training example:") + dataset_processor.print_data_example(next(iter(dataset))) + except StopIteration: + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + + return dataset + + +def get_dataset( + template: "Template", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, +) -> "DatasetModule": + r"""Get the train dataset and optionally gets the evaluation dataset.""" + # Load tokenized dataset if path exists + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") + tokenized_data = load_from_disk(data_args.tokenized_path) + dataset_module = get_dataset_module(tokenized_data) + if data_args.streaming: + dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset() + + logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") + return dataset_module + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + # Load and preprocess dataset + with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)): + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset( + data_args.eval_dataset, + model_args, + data_args, + training_args, + stage, + return_dict=data_args.eval_on_each_dataset, + ) + + with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): + dataset = _get_preprocessed_dataset( + dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False + ) + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + eval_dataset[eval_name] = _get_preprocessed_dataset( + eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + else: + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed) + if data_args.tokenized_path is not None: # save tokenized dataset to disk + if training_args.should_save: + dataset_dict.save_to_disk(data_args.tokenized_path) + logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.") + logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.") + + return get_dataset_module(dataset_dict) diff --git a/llamafactory/data/mm_plugin.py b/llamafactory/data/mm_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..91b801dc1105c31fe9fe8a76807f4126198556a2 --- /dev/null +++ b/llamafactory/data/mm_plugin.py @@ -0,0 +1,2082 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from io import BytesIO +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array +from transformers.models.mllama.processing_mllama import ( + convert_sparse_cross_attention_mask_to_dense, + get_cross_attention_token_mask, +) +from typing_extensions import NotRequired, override + +from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.packages import ( + is_librosa_available, + is_pillow_available, + is_pyav_available, + is_transformers_version_greater_than, +) + + +if is_librosa_available(): + import librosa + + +if is_pillow_available(): + from PIL import Image + from PIL.Image import Image as ImageObject + + +if is_pyav_available(): + import av + + +if is_transformers_version_greater_than("4.52.0"): + from transformers.image_utils import make_flat_list_of_images + from transformers.video_utils import make_batched_videos +else: + from transformers.image_utils import make_batched_videos, make_flat_list_of_images + + +if TYPE_CHECKING: + from av.stream import Stream + from numpy.typing import NDArray + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor + from transformers.image_processing_utils import BaseImageProcessor + from transformers.video_processing_utils import BaseVideoProcessor + + class EncodedImage(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] + VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] + AudioInput = Union[str, BinaryIO, NDArray] + + class RegularizedImageOutput(TypedDict): + images: list[ImageObject] + + class RegularizedVideoOutput(TypedDict): + videos: list[list[ImageObject]] + durations: list[float] + fps_per_video: NotRequired[list[float]] + + class RegularizedAudioOutput(TypedDict): + audios: list[NDArray] + sampling_rates: list[float] + + class MMProcessor(ProcessorMixin): + patch_size: int + image_seq_length: int + num_additional_image_tokens: int + vision_feature_select_strategy: Literal["default", "full"] + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + pass + + +def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: + r"""Get paligemma token type ids for computing loss. + + It is slightly different with the original token type ids where the prompt part is 0. + + Returns: + batch_token_type_ids: shape (batch_size, seq_length) + + """ + batch_token_type_ids = [] + for imglen, seqlen in zip(imglens, seqlens): + image_seqlen = imglen * processor.image_seq_length + batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) + + return batch_token_type_ids + + +def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"): + r"""Get gemma3 token type ids for computing loss. + + Returns: + batch_token_type_ids: shape (batch_size, seq_length) + + """ + image_token_id: int = getattr(processor, "image_token_id") + batch_token_type_ids = [] + for token_ids in batch_ids: + token_ids = np.array(token_ids) + token_type_ids = np.zeros_like(token_ids) + token_type_ids[token_ids == image_token_id] = 1 + batch_token_type_ids.append(token_type_ids.tolist()) + + return batch_token_type_ids + + +def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: + r"""Make nested list of images.""" + batch_images = [] + for imglen in imglens: + batch_images.append(images[:imglen]) + images = images[imglen:] + + return batch_images + + +def _check_video_is_nested_images(video: "VideoInput") -> bool: + r"""Check if the video is nested images.""" + return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video) + + +@dataclass +class MMPluginMixin: + image_token: Optional[str] + video_token: Optional[str] + audio_token: Optional[str] + expand_mm_tokens: bool = True + + def _validate_input( + self, + processor: Optional["MMProcessor"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> None: + r"""Validate if this model accepts the input modalities.""" + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + if len(images) != 0 and self.image_token is None: + raise ValueError( + "This model does not support image input. Please check whether the correct `template` is used." + ) + + if len(videos) != 0 and self.video_token is None: + raise ValueError( + "This model does not support video input. Please check whether the correct `template` is used." + ) + + if len(audios) != 0 and self.audio_token is None: + raise ValueError( + "This model does not support audio input. Please check whether the correct `template` is used." + ) + + if self.image_token is not None and processor is None: + raise ValueError("Processor was not found, please check and update your model file.") + + if self.image_token is not None and image_processor is None: + raise ValueError("Image processor was not found, please check and update your model file.") + + if self.video_token is not None and video_processor is None: + raise ValueError("Video processor was not found, please check and update your model file.") + + if self.audio_token is not None and feature_extractor is None: + raise ValueError("Audio feature extractor was not found, please check and update your model file.") + + def _validate_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ): + r"""Validate if the number of images, videos and audios match the number of placeholders in messages.""" + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + for message in messages: + num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER) + num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER) + num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER) + + if len(images) != num_image_tokens: + raise ValueError( + f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}." + ) + + if len(videos) != num_video_tokens: + raise ValueError( + f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}." + ) + + if len(audios) != num_audio_tokens: + raise ValueError( + f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}." + ) + + def _preprocess_image( + self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs + ) -> "ImageObject": + r"""Pre-process a single image.""" + if (image.width * image.height) > image_max_pixels: + resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if (image.width * image.height) < image_min_pixels: + resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def _get_video_sample_indices( + self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs + ) -> list[int]: + r"""Compute video sample indices according to fps.""" + total_frames = video_stream.frames + if total_frames == 0: # infinite video + return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) + + sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)) + sample_frames = min(total_frames, video_maxlen, sample_frames) + return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput": + r"""Regularize images to avoid error. Including reading and pre-processing.""" + results = [] + for image in images: + if isinstance(image, (str, BinaryIO)): + image = Image.open(image) + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)) + elif isinstance(image, dict): + if image["bytes"] is not None: + image = Image.open(BytesIO(image["bytes"])) + else: + image = Image.open(image["path"]) + + if not isinstance(image, ImageObject): + raise ValueError(f"Expect input is a list of images, but got {type(image)}.") + + results.append(self._preprocess_image(image, **kwargs)) + + return {"images": results} + + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": + r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" + results = [] + durations = [] + for video in videos: + frames: list[ImageObject] = [] + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") + frames = video + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if video_stream.duration is None: + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + durations.append(float(video_stream.duration * video_stream.time_base)) + + frames = self._regularize_images(frames, **kwargs)["images"] + results.append(frames) + + return {"videos": results, "durations": durations} + + def _regularize_audios( + self, audios: list["AudioInput"], sampling_rate: float, **kwargs + ) -> "RegularizedAudioOutput": + r"""Regularizes audios to avoid error. Including reading and resampling.""" + results, sampling_rates = [], [] + for audio in audios: + if not isinstance(audio, np.ndarray): + audio, sampling_rate = librosa.load(audio, sr=sampling_rate) + + results.append(audio) + sampling_rates.append(sampling_rate) + + return {"audios": results, "sampling_rates": sampling_rates} + + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + imglens: Optional[list[int]] = None, + ) -> dict[str, "torch.Tensor"]: + r"""Process visual inputs. + + Returns: (llava and paligemma) + pixel_values: tensor with shape (B, C, H, W) + + Returns: (qwen2-vl) + pixel_values: tensor with shape (num_patches, patch_dim) + image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height + where num_patches == torch.prod(image_grid_thw) + + Returns: (mllama) + pixel_values: tensor with shape + (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) + For example, (2, 1, 4, 3, 560, 560). + aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). + aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). + num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). + + """ + mm_inputs = {} + if len(images) != 0: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + if imglens is not None: # if imglens are provided, make batched images + images = _make_batched_images(images, imglens) + + image_processor_kwargs = {} + if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor + image_processor_kwargs.update( + { + "do_pan_and_scan": True, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + } + ) + + mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) + + if len(videos) != 0: + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava + mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) + else: # for llava_next_video + mm_inputs.update(video_processor(videos, return_tensors="pt")) + + if len(audios) != 0: + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts + + return mm_inputs + + +@dataclass +class BasePlugin(MMPluginMixin): + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + r"""Pre-process input messages before tokenization for VLMs.""" + self._validate_input(processor, images, videos, audios) + return messages + + def process_token_ids( + self, + input_ids: list[int], + labels: Optional[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["MMProcessor"], + ) -> tuple[list[int], Optional[list[int]]]: + r"""Pre-process token ids after tokenization for VLMs.""" + self._validate_input(processor, images, videos, audios) + return input_ids, labels + + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + r"""Build batched multimodal inputs for VLMs. + + Arguments: + images: a list of image inputs, shape (num_images,) + videos: a list of video inputs, shape (num_videos,) + audios: a list of audio inputs, shape (num_audios,) + imglens: number of images in each sample, shape (batch_size,) + vidlens: number of videos in each sample, shape (batch_size,) + audlens: number of audios in each sample, shape (batch_size,) + batch_ids: token ids of input samples, shape (batch_size, seq_len) + processor: a processor for pre-processing images and videos + + """ + self._validate_input(processor, images, videos, audios) + return self._get_mm_inputs(images, videos, audios, processor) + + +@dataclass +class Gemma3Plugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token + + do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False) + if do_pan_and_scan: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if do_pan_and_scan: + image_placeholder_str = ( + "Here is the original image {{image}} and here are some crops to help you see better " + + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens]) + ) + else: + image_placeholder_str = "{{image}}" + + content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", image_str) + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("num_crops", None) + mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor) + return mm_inputs + + +class Gemma3nPlugin(Gemma3Plugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + boa_token: str = getattr(processor, "boa_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + full_audio_sequence: str = getattr(processor, "full_audio_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token + audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, image_str, 1) + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1) + + message["content"] = content + + return messages + + +@dataclass +class InternVLPlugin(BasePlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + image_processor_kwargs = {} + if getattr(processor, "crop_to_patches", False): + image_processor_kwargs.update( + { + "crop_to_patches": True, + "max_patches": 12, + "min_patches": 1, + } + ) + + mm_inputs = {} + image_video_patches = [] + + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + + if len(images) != 0: + images = make_flat_list_of_images(images) + image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs) + image_num_patches = image_inputs.pop("num_patches") + image_pixel_values = image_inputs.pop("pixel_values") + image_num_patches_indices = np.cumsum(image_num_patches) + + if len(videos) != 0: + videos = make_batched_videos(videos) + num_frames_per_video = [len(video) for video in videos] + patch_indices = np.cumsum(num_frames_per_video) + image_processor_kwargs["crop_to_patches"] = False + video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs) + video_num_patches = video_inputs.pop("num_patches") + video_pixel_values = video_inputs.pop("pixel_values") + video_num_patches_indices = np.cumsum(video_num_patches) + + # NOT SUPPORT IMAGE VIDEO INTERLEAVED + if len(images) != 0 and image_pixel_values is not None: + for i in range(len(images)): + start_index = image_num_patches_indices[i - 1] if i > 0 else 0 + end_index = image_num_patches_indices[i] + image_video_patches.append(image_pixel_values[start_index:end_index]) + + if len(videos) != 0 and video_pixel_values is not None: + patch_indices_with_prefix = [0] + list(patch_indices) + for i in range(len(videos)): + current_patch_index = patch_indices_with_prefix[i] + end_patch_index = patch_indices_with_prefix[i + 1] + start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0 + end_index = video_num_patches_indices[end_patch_index - 1] + image_video_patches.append(video_pixel_values[start_index:end_index]) + + if len(images) != 0 or len(videos) != 0: + mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0) + + if len(images) != 0: + mm_inputs.update({"image_num_patches": image_num_patches}) + + if len(videos) != 0: + mm_inputs.update({"video_patch_indices": patch_indices}) + mm_inputs.update({"video_num_patches": video_num_patches}) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["ProcessorMixin"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images + video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos + video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace( + IMAGE_PLACEHOLDER, + f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", + 1, + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 + end_patch_index = video_patch_indices[num_video_tokens] + num_patches = list(video_num_patches[current_patch_index:end_patch_index]) + video_replaced_prompt = "\n".join( + f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" + for i in range(len(num_patches)) + ) + content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("image_num_patches", None) + mm_inputs.pop("video_patch_indices", None) + mm_inputs.pop("video_num_patches", None) + return mm_inputs + + +class KimiVLPlugin(BasePlugin): + @override + def process_messages(self, messages, images, videos, audios, processor): + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_hws = mm_inputs.get("image_grid_hws", []) + else: + image_grid_hws = [None] * len(images) + + num_image_tokens = 0 + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + merge_length = math.prod(image_processor.merge_kernel_size) + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, + f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>", + 1, + ) + num_image_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class Llama4Plugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:] + num_patches_per_chunk = int( + (image_height // processor.patch_size) + * (image_width // processor.patch_size) + // processor.downsample_ratio + ) + aspect_ratios = mm_inputs.pop("aspect_ratios") + + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + if self.expand_mm_tokens: + placeholder_count = content.count(IMAGE_PLACEHOLDER) + prompt_splits = content.split(IMAGE_PLACEHOLDER) + new_content = [] + for local_image_index, split_part in enumerate(prompt_splits): + new_content.append(split_part) + if local_image_index < placeholder_count: + tokens_for_this_image = processor._prompt_split_image( + aspect_ratios[num_image_tokens], num_patches_per_chunk + ) + num_image_tokens += 1 + new_content.append(tokens_for_this_image) + + content = "".join(new_content) + else: + content = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("aspect_ratios", None) + return mm_inputs + + +@dataclass +class LlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + return messages + + +@dataclass +class LlavaNextPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", self.image_token) + + return messages + + +@dataclass +class LlavaNextVideoPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + if self.expand_mm_tokens: + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + else: + video_seqlen = 1 + + for message in messages: + content = message["content"] + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + + message["content"] = content.replace("{{video}}", self.video_token) + + return messages + + +@dataclass +class MiniCPMVPlugin(BasePlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + if "valid_image_nums_ls" in kwargs: + valid_image_nums_ls = kwargs["valid_image_nums_ls"] + new_images = [] + idx = 0 + for valid_image_nums in valid_image_nums_ls: + new_images.append(images[idx : idx + valid_image_nums]) + idx += valid_image_nums + + images = new_images + + image_inputs = image_processor( + images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" + ) + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + mm_inputs.update(video_inputs) + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + if "valid_audio_nums_ls" in kwargs: + valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] + audios_ls = [] + idx = 0 + for valid_audio_nums in valid_audio_nums_ls: + audios_ls.append(audios[idx : idx + valid_audio_nums]) + idx += valid_audio_nums + else: + audios_ls = [audios] + + audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( + audios_ls, + chunk_input=True, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + ) + audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] + mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) + if kwargs.get("ret_phs", False): + mm_inputs.update({"audio_phs": audio_phs}) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + mm_inputs, audio_inputs = {}, {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") + + if len(videos) != 0: + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, [], processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id + + for i, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) + num_video_tokens += 1 + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) + num_audio_tokens += 1 + + message["content"] = content.replace("{{image}}", "(./)").replace( + "{{audio}}", "()" + ) + + if len(images): + mm_inputs = self._get_mm_inputs(images, [], [], processor) + + if len(audios): + audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) + + if self.expand_mm_tokens and mm_inputs: + pattern = "(./)" + image_sizes = mm_inputs["image_sizes"] + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[0][idx], idx, max_slice_nums, use_image_id + ) + ) + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if self.expand_mm_tokens and audio_inputs: + pattern = "()" + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + audio_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(audio_tags)): + audio_placeholder = audio_inputs["audio_phs"][0][idx] + final_text = final_text + text_chunks[i] + audio_placeholder + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + # image bound + image_bounds_list = [] + valid_image_nums_ls = [] + for i, input_ids in enumerate(batch_ids): + input_ids_ = torch.tensor(input_ids) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( + input_ids_ == processor.tokenizer.slice_start_id + ) + end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums_ls.append(imglens[i]) + image_bounds = torch.hstack( + [ + image_start_tokens.unsqueeze(-1), + image_end_tokens.unsqueeze(-1), + ] + ) + image_bounds_list.append(image_bounds) + + mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls) + if "tgt_sizes" not in mm_inputs: + dummy_data = [torch.empty(0) for _ in range(len(batch_ids))] + mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data}) + + mm_inputs.update({"image_bound": image_bounds_list}) + + if len(audios) > 0: + # audio bound + audio_bounds_ls = [] + spk_bounds_ls = [] + valid_audio_nums_ls = [] + + for input_ids, audiolen in zip(batch_ids, audlens): + input_ids_ = torch.tensor(input_ids) + audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0] + audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0] + assert len(audio_start_idx) == len(audio_end_idx) + audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) + audio_bounds_ls.append(audio_bounds) + valid_audio_nums_ls.append(audiolen) + + spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0] + spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0] + assert len(spk_start_idx) == len(spk_end_idx) + spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) + spk_bounds_ls.append(spk_bounds) + + audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls) + mm_inputs.update(audio_inputs) + mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls}) + + return mm_inputs + + +@dataclass +class MllamaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + num_image_tokens += content.count(IMAGE_PLACEHOLDER) + message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) + if mm_inputs: + num_tiles = mm_inputs.pop("num_tiles") + image_token_id: int = getattr(processor, "image_token_id") + max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles") + cross_attention_token_mask = [ + get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids + ] + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) # shape: (batch_size, length, max_num_images, max_num_tiles) + + return mm_inputs + + +@dataclass +class PaliGemmaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "", 1) + num_image_tokens += 1 + + message["content"] = content + + return messages + + @override + def process_token_ids( + self, + input_ids: list[int], + labels: Optional[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["MMProcessor"], + ) -> tuple[list[int], Optional[list[int]]]: + self._validate_input(processor, images, videos, audios) + num_images = len(images) + image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token + image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + input_ids = [image_token_id] * num_images * image_seqlen + input_ids + if labels is not None: + labels = [IGNORE_INDEX] * num_images * image_seqlen + labels + + return input_ids, labels + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + seqlens = [len(input_ids) for input_ids in batch_ids] + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) + return mm_inputs + + +@dataclass +class PixtralPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + # BC for transformers < 4.49.0 + if isinstance(mm_inputs["image_sizes"], list): + image_sizes = iter(mm_inputs["image_sizes"][0]) + else: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + + image_break_token: str = getattr(processor, "image_break_token") + image_end_token: str = getattr(processor, "image_end_token") + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1) + height, width = next(image_sizes) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens + replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list + replace_tokens[-1] = image_end_token + replace_str = "".join(replace_tokens) + else: + replace_str = self.image_token + + content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + # ref to this commit https://github.com/huggingface/transformers/pull/35122 + # after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding. + # it can be passed into `LlavaConditionalGeneration` as a parameter. + if not is_transformers_version_greater_than("4.49.0"): + mm_inputs.pop("image_sizes", None) + return mm_inputs + + +@dataclass +class Qwen2AudioPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + bos_token: str = getattr(processor, "audio_bos_token") + eos_token: str = getattr(processor, "audio_eos_token") + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs([], [], audios, processor) + if "feature_attention_mask" in mm_inputs: + audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() + + for message in messages: + content = message["content"] + while AUDIO_PLACEHOLDER in content: + if self.expand_mm_tokens: + audio_length = audio_lengths.pop(0) + input_length = (audio_length - 1) // 2 + 1 + audio_seqlen = (input_length - 2) // 2 + 1 + else: + audio_seqlen = 1 + + content = content.replace( + AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 + ) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + return self._get_mm_inputs(images, videos, audios, processor) + + +@dataclass +class Qwen2VLPlugin(BasePlugin): + vision_bos_token: str = "<|vision_start|>" + vision_eos_token: str = "<|vision_end|>" + + @override + def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + image = super()._preprocess_image(image, **kwargs) + if min(image.width, image.height) < 28: + width, height = max(image.width, 28), max(image.height, 28) + image = image.resize((width, height)) + + if image.width / image.height > 200: + width, height = image.height * 180, image.height + image = image.resize((width, height)) + + if image.height / image.width > 200: + width, height = image.width, image.width * 180 + image = image.resize((width, height)) + + return image + + @override + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": + results, fps_per_video, durations = [], [], [] + for video in videos: + frames: list[ImageObject] = [] + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") + + frames = video + fps_per_video.append(kwargs.get("video_fps", 2.0)) + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if video_stream.duration is None: + fps_per_video.append(kwargs.get("video_fps", 2.0)) + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) + durations.append(float(video_stream.duration * video_stream.time_base)) + + if len(frames) % 2 != 0: + frames.append(frames[-1]) + + frames = self._regularize_images(frames, **kwargs)["images"] + results.append(frames) + + return {"videos": results, "fps_per_video": fps_per_video, "durations": durations} + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt")) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + if "second_per_grid_ts" in processor.model_input_names: + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]] + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, + f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}", + 1, + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + VIDEO_PLACEHOLDER, + f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}", + 1, + ) + num_video_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class Qwen3VLPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + video_metadata = [ + {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)} + for video, duration in zip(videos["videos"], videos["durations"]) + ] + mm_inputs.update( + video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) + ) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + if "second_per_grid_ts" in processor.model_input_names: + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]] + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + video_processor: BaseImageProcessor = getattr(processor, "video_processor") + + image_merge_length: int = getattr(image_processor, "merge_size") ** 2 + video_merge_length: int = getattr(video_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now + video_metadata = mm_inputs.get("video_metadata", {}) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + timestamps = [0] + + for idx, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = ( + image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1 + ) + content = content.replace( + IMAGE_PLACEHOLDER, + f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}", + 1, + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + if self.expand_mm_tokens: + metadata = video_metadata[idx] + timestamps = processor._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + video_processor.merge_size, + ) + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // video_merge_length + if self.expand_mm_tokens + else 1 + ) + timestamp_sec = timestamps[frame_index] + frame_structure = ( + f"<{timestamp_sec:.1f} seconds>" + f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" + ) + video_structure += frame_structure + else: + video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" + + content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class GLM4VPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + # prepare video metadata + video_metadata = [ + {"fps": 2, "duration": duration, "total_frames": len(video)} + for video, duration in zip(video_data["videos"], video_data["durations"]) + ] + mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now + timestamps = mm_inputs.get("timestamps", []) + + if hasattr(timestamps, "tolist"): + timestamps = timestamps.tolist() + + if not timestamps: + timestamps_list = [] + elif isinstance(timestamps[0], list): + timestamps_list = timestamps[0] + else: + timestamps_list = timestamps + + unique_timestamps = timestamps_list.copy() + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + selected_timestamps = [0] + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1 + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1 + ) + timestamp_sec = selected_timestamps[frame_index] + frame_structure = ( + f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" + ) + video_structure += frame_structure + + if not self.expand_mm_tokens: + video_structure = self.video_token + + content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("timestamps", None) + return mm_inputs + + +@dataclass +class Qwen2OmniPlugin(Qwen2VLPlugin): + audio_bos_token: str = "<|audio_start|>" + audio_eos_token: str = "<|audio_end|>" + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_dict = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt")) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + mm_inputs["video_second_per_grid"] = torch.tensor( + [temporal_patch_size / fps for fps in video_dict["fps_per_video"]] + ) + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + + merge_length = processor.image_processor.merge_size**2 + use_audio_in_video = getattr(processor, "use_audio_in_video", False) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + if "feature_attention_mask" in mm_inputs: + if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni + input_lengths = mm_inputs["feature_attention_mask"].sum(-1) + input_lengths_leave = input_lengths % 100 + feature_lengths = (input_lengths_leave - 1) // 2 + 1 + audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + else: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 + else: + mm_inputs = {} + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + audio_lengths = [None] * len(audios) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, + f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}", + 1, + ) + num_image_tokens += 1 + + if ( + use_audio_in_video and len(audios) and len(videos) + ): # if use the audio of video # deal video token and audio token togather + if len(videos) != len(audios): + raise ValueError( + f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video." + ) + + while VIDEO_PLACEHOLDER in content: + video_pos = content.find(VIDEO_PLACEHOLDER) + audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) + if audio_pos == -1 or audio_pos < video_pos: + raise ValueError( + f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." + ) + + audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) + video_t_index = ( + torch.arange(video_grid_thw[num_video_tokens][0]) + .view(-1, 1, 1) + .expand( + -1, + video_grid_thw[num_video_tokens][1] // image_processor.merge_size, + video_grid_thw[num_video_tokens][2] // image_processor.merge_size, + ) + .flatten() + * mm_inputs["video_second_per_grid"][num_video_tokens] + * 25 # FIXME hardcode of position_id_per_seconds=25 + ).long() + t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] + video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) + audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) + placeholder_string = "" + placeholder_string += self.vision_bos_token + self.audio_bos_token + for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): + video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None + audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None + if video_chunk_index is not None: + placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) + + if audio_chunk_index is not None: + placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) + + placeholder_string += self.audio_eos_token + self.vision_eos_token + content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) + content = content.replace(AUDIO_PLACEHOLDER, "", 1) + num_audio_tokens += 1 + num_video_tokens += 1 + else: + while AUDIO_PLACEHOLDER in content: + audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 + content = content.replace( + AUDIO_PLACEHOLDER, + f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}", + 1, + ) + num_audio_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = ( + video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + ) + content = content.replace( + VIDEO_PLACEHOLDER, + f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}", + 1, + ) + num_video_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class VideoLlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + num_frames = 0 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values_images" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0])) + num_frames = 1 + + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + + if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs: + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens + video_seqlen = image_seqlen * num_frames + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen, video_seqlen = 1, 1 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + num_video_tokens += 1 + + content = content.replace("{{image}}", self.image_token) + message["content"] = content.replace("{{video}}", self.video_token) + + return messages + + +PLUGINS = { + "base": BasePlugin, + "gemma3": Gemma3Plugin, + "glm4v": GLM4VPlugin, + "gemma3n": Gemma3nPlugin, + "intern_vl": InternVLPlugin, + "kimi_vl": KimiVLPlugin, + "llama4": Llama4Plugin, + "llava": LlavaPlugin, + "llava_next": LlavaNextPlugin, + "llava_next_video": LlavaNextVideoPlugin, + "minicpm_v": MiniCPMVPlugin, + "mllama": MllamaPlugin, + "paligemma": PaliGemmaPlugin, + "pixtral": PixtralPlugin, + "qwen2_audio": Qwen2AudioPlugin, + "qwen2_omni": Qwen2OmniPlugin, + "qwen2_vl": Qwen2VLPlugin, + "qwen3_vl": Qwen3VLPlugin, + "video_llava": VideoLlavaPlugin, +} + + +def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None: + r"""Register a multimodal plugin.""" + if name in PLUGINS: + raise ValueError(f"Multimodal plugin {name} already exists.") + + PLUGINS[name] = plugin_class + + +def get_mm_plugin( + name: str, + image_token: Optional[str] = None, + video_token: Optional[str] = None, + audio_token: Optional[str] = None, + **kwargs, +) -> "BasePlugin": + r"""Get plugin for multimodal inputs.""" + if name not in PLUGINS: + raise ValueError(f"Multimodal plugin `{name}` not found.") + + return PLUGINS[name](image_token, video_token, audio_token, **kwargs) diff --git a/llamafactory/data/parser.py b/llamafactory/data/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..3a865fd83dfe6221f312a2030cd08a6b38366cfe --- /dev/null +++ b/llamafactory/data/parser.py @@ -0,0 +1,149 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from typing import Any, Literal, Optional, Union + +from huggingface_hub import hf_hub_download + +from ..extras.constants import DATA_CONFIG +from ..extras.misc import use_modelscope, use_openmind + + +@dataclass +class DatasetAttr: + r"""Dataset attributes.""" + + # basic configs + load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"] + dataset_name: str + formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca" + ranking: bool = False + # extra configs + subset: Optional[str] = None + split: str = "train" + folder: Optional[str] = None + num_samples: Optional[int] = None + # common columns + system: Optional[str] = None + tools: Optional[str] = None + images: Optional[str] = None + videos: Optional[str] = None + audios: Optional[str] = None + # dpo columns + chosen: Optional[str] = None + rejected: Optional[str] = None + kto_tag: Optional[str] = None + # alpaca columns + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + # sharegpt columns + messages: Optional[str] = "conversations" + # sharegpt tags + role_tag: Optional[str] = "from" + content_tag: Optional[str] = "value" + user_tag: Optional[str] = "human" + assistant_tag: Optional[str] = "gpt" + observation_tag: Optional[str] = "observation" + function_tag: Optional[str] = "function_call" + system_tag: Optional[str] = "system" + + def __repr__(self) -> str: + return self.dataset_name + + def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: + setattr(self, key, obj.get(key, default)) + + def join(self, attr: dict[str, Any]) -> None: + self.set_attr("formatting", attr, default="alpaca") + self.set_attr("ranking", attr, default=False) + self.set_attr("subset", attr) + self.set_attr("split", attr, default="train") + self.set_attr("folder", attr) + self.set_attr("num_samples", attr) + + if "columns" in attr: + column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"] + column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"] + for column_name in column_names: + self.set_attr(column_name, attr["columns"]) + + if "tags" in attr: + tag_names = ["role_tag", "content_tag"] + tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"] + for tag in tag_names: + self.set_attr(tag, attr["tags"]) + + +def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]: + r"""Get the attributes of the datasets.""" + if dataset_names is None: + dataset_names = [] + + if isinstance(dataset_dir, dict): + dataset_info = dataset_dir + elif dataset_dir == "ONLINE": + dataset_info = None + else: + if dataset_dir.startswith("REMOTE:"): + config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + else: + config_path = os.path.join(dataset_dir, DATA_CONFIG) + + try: + with open(config_path) as f: + dataset_info = json.load(f) + except Exception as err: + if len(dataset_names) != 0: + raise ValueError(f"Cannot open {config_path} due to {str(err)}.") + + dataset_info = None + + dataset_list: list[DatasetAttr] = [] + for name in dataset_names: + if dataset_info is None: # dataset_dir is ONLINE + load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub" + dataset_attr = DatasetAttr(load_from, dataset_name=name) + dataset_list.append(dataset_attr) + continue + + if name not in dataset_info: + raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.") + + has_hf_url = "hf_hub_url" in dataset_info[name] + has_ms_url = "ms_hub_url" in dataset_info[name] + has_om_url = "om_hub_url" in dataset_info[name] + + if has_hf_url or has_ms_url or has_om_url: + if has_ms_url and (use_modelscope() or not has_hf_url): + dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) + elif has_om_url and (use_openmind() or not has_hf_url): + dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"]) + else: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + elif "cloud_file_name" in dataset_info[name]: + dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"]) + else: + dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + + dataset_attr.join(dataset_info[name]) + dataset_list.append(dataset_attr) + + return dataset_list diff --git a/llamafactory/data/processor/__init__.py b/llamafactory/data/processor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..357ab7899f9eecbd29344482d109b89af274ea2e --- /dev/null +++ b/llamafactory/data/processor/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .feedback import FeedbackDatasetProcessor +from .pairwise import PairwiseDatasetProcessor +from .pretrain import PretrainDatasetProcessor +from .processor_utils import DatasetProcessor +from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor +from .unsupervised import UnsupervisedDatasetProcessor + + +__all__ = [ + "DatasetProcessor", + "FeedbackDatasetProcessor", + "PackedSupervisedDatasetProcessor", + "PairwiseDatasetProcessor", + "PretrainDatasetProcessor", + "SupervisedDatasetProcessor", + "UnsupervisedDatasetProcessor", +] diff --git a/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc b/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e85ff951b87bd8d7cd85626bac8829ab0a727db Binary files /dev/null and b/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc b/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a84077c2946d6287b73a6f0decf5ab36fe133f Binary files /dev/null and b/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc b/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..448f8f1ced997cd41e862fe07863214e4be3f68d Binary files /dev/null and b/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc b/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85208325c8814f094b4466853fa293154e398576 Binary files /dev/null and b/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc b/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84248d8f528b3a457311a4d90c73173eef406822 Binary files /dev/null and b/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc b/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2619e0ba420f7ea44190d6d70bafb8b0245106d4 Binary files /dev/null and b/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc differ diff --git a/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc b/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..656d5b93ad6bf756fb5ba1cfd4a027954fa2e783 Binary files /dev/null and b/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc differ diff --git a/llamafactory/data/processor/feedback.py b/llamafactory/data/processor/feedback.py new file mode 100644 index 0000000000000000000000000000000000000000..871615b9266e501f25f68e84e4536c0d24617803 --- /dev/null +++ b/llamafactory/data/processor/feedback.py @@ -0,0 +1,129 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Optional + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class FeedbackDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: list[dict[str, str]], + response: list[dict[str, str]], + kl_response: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> tuple[list[int], list[int], list[int], list[int], bool]: + if response[0]["content"]: # desired example + kto_tag = True + messages = prompt + [response[0]] + else: # undesired example + kto_tag = False + messages = prompt + [response[1]] + + if kl_response[0]["content"]: + kl_messages = prompt + [kl_response[0]] + else: + kl_messages = prompt + [kl_response[1]] + + messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor) + kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor) + prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools) + kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools) + + if self.template.efficient_eos: + response_ids += [self.tokenizer.eos_token_id] + kl_response_ids += [self.tokenizer.eos_token_id] + + prompt_ids, _ = self.template.mm_plugin.process_token_ids( + prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids( + kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + + source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len) + prompt_ids = prompt_ids[:source_len] + response_ids = response_ids[:target_len] + kl_source_len, kl_target_len = infer_seqlen( + len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len + ) + kl_prompt_ids = kl_prompt_ids[:kl_source_len] + kl_response_ids = kl_response_ids[:kl_target_len] + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * source_len + response_ids + kl_input_ids = kl_prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids + return input_ids, labels, kl_input_ids, kl_labels, kto_tag + + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions. + kl_response = [examples["_response"][-1]] + examples["_response"][:-1] + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + kl_response=kl_response[i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["kto_tags"].append(kto_tag) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) + undesirable_num = len(model_inputs["kto_tags"]) - desirable_num + if desirable_num == 0 or undesirable_num == 0: + logger.warning_rank0("Your dataset only has one preference type.") + + return model_inputs + + def print_data_example(self, example: dict[str, list[int]]) -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}") diff --git a/llamafactory/data/processor/pairwise.py b/llamafactory/data/processor/pairwise.py new file mode 100644 index 0000000000000000000000000000000000000000..94101deb8e75af73c1851720604994a11f2eb87d --- /dev/null +++ b/llamafactory/data/processor/pairwise.py @@ -0,0 +1,118 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Optional + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class PairwiseDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: list[dict[str, str]], + response: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> tuple[list[int], list[int], list[int], list[int]]: + chosen_messages = self.template.mm_plugin.process_messages( + prompt + [response[0]], images, videos, audios, self.processor + ) + rejected_messages = self.template.mm_plugin.process_messages( + prompt + [response[1]], images, videos, audios, self.processor + ) + prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools) + _, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools) + + if self.template.efficient_eos: + chosen_ids += [self.tokenizer.eos_token_id] + rejected_ids += [self.tokenizer.eos_token_id] + + prompt_ids, _ = self.template.mm_plugin.process_token_ids( + prompt_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + # consider the response is more important + source_len, target_len = infer_seqlen( + len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len + ) + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: dict[str, list[int]]) -> None: + valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) + valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) + print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) + print( + "chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)) + ) + print("chosen_label_ids:\n{}".format(example["chosen_labels"])) + print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}") + print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) + print( + "rejected_inputs:\n{}".format( + self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False) + ) + ) + print("rejected_label_ids:\n{}".format(example["rejected_labels"])) + print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}") diff --git a/llamafactory/data/processor/pretrain.py b/llamafactory/data/processor/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa6b1ca58a8d59493cd4b43c51cb268080cc506 --- /dev/null +++ b/llamafactory/data/processor/pretrain.py @@ -0,0 +1,57 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from itertools import chain +from typing import Any + +from .processor_utils import DatasetProcessor + + +@dataclass +class PretrainDatasetProcessor(DatasetProcessor): + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # build grouped texts with format `X1 X2 X3 ...` if packing is enabled + eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token + text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] + + if not self.data_args.packing: + if getattr(self.tokenizer, "add_bos_token", False): + text_examples = [self.tokenizer.bos_token + example for example in text_examples] + + result = self.tokenizer( + text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len + ) + else: + tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = self.data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if getattr(self.tokenizer, "add_bos_token", False): + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = self.tokenizer.bos_token_id + + return result + + def print_data_example(self, example: dict[str, list[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/llamafactory/data/processor/processor_utils.py b/llamafactory/data/processor/processor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db44b19cf6fc84d6551fb7cce82283774ae72030 --- /dev/null +++ b/llamafactory/data/processor/processor_utils.py @@ -0,0 +1,88 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +@dataclass +class DatasetProcessor(ABC): + r"""A class for data processors.""" + + template: "Template" + tokenizer: "PreTrainedTokenizer" + processor: Optional["ProcessorMixin"] + data_args: "DataArguments" + + @abstractmethod + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + r"""Build model inputs from the examples.""" + ... + + @abstractmethod + def print_data_example(self, example: dict[str, list[int]]) -> None: + r"""Print a data example to stdout.""" + ... + + +def search_for_fit(numbers: list[int], capacity: int) -> int: + r"""Find the index of largest number that fits into the knapsack with the given capacity.""" + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]: + r"""Implement efficient greedy algorithm with binary search for the knapsack problem.""" + numbers.sort() # sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = search_for_fit(numbers, remaining_capacity) + if index == -1: + break # no more numbers fit in this knapsack + + remaining_capacity -= numbers[index] # update the remaining capacity + current_knapsack.append(numbers.pop(index)) # add the number to knapsack + + knapsacks.append(current_knapsack) + + return knapsacks + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]: + r"""Compute the real sequence length after truncation by the cutoff_len.""" + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + max_source_len = max(cutoff_len - new_target_len, 0) + new_source_len = min(max_source_len, source_len) + return new_source_len, new_target_len diff --git a/llamafactory/data/processor/supervised.py b/llamafactory/data/processor/supervised.py new file mode 100644 index 0000000000000000000000000000000000000000..b5aba11b6535078f46bdf6aca743c6ae262e1fc6 --- /dev/null +++ b/llamafactory/data/processor/supervised.py @@ -0,0 +1,203 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +@dataclass +class SupervisedDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: list[dict[str, str]], + response: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> tuple[list[int], list[int]]: + messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) + input_ids, labels = self.template.mm_plugin.process_token_ids( + [], [], images, videos, audios, self.tokenizer, self.processor + ) + encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools) + total_length = len(input_ids) + (1 if self.template.efficient_eos else 0) + if self.data_args.mask_history: + encoded_pairs = encoded_pairs[::-1] # high priority for last turns + + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= self.data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen( + len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length + ) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + + if self.data_args.train_on_prompt: + source_label = source_ids + elif self.template.efficient_eos and turn_idx != 0: + source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + else: + source_label = [IGNORE_INDEX] * source_len + + if self.data_args.mask_history and turn_idx != 0: # train on the last turn only + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + if self.data_args.mask_history: # reversed sequences + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels + else: + input_ids += source_ids + target_ids + labels += source_label + target_label + + if self.template.efficient_eos: + input_ids += [self.tokenizer.eos_token_id] + labels += [self.tokenizer.eos_token_id] + + return input_ids, labels + + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: dict[str, list[int]]) -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}") + + +@dataclass +class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # TODO: use `position_ids` to achieve packing + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + valid_num = 0 + batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] + lengths = [] + length2indexes = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + length = len(input_ids) + if length > self.data_args.cutoff_len: + logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.") + else: + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + batch_images.append(examples["_images"][i] or []) + batch_videos.append(examples["_videos"][i] or []) + batch_audios.append(examples["_audios"][i] or []) + valid_num += 1 + + model_inputs = defaultdict(list) + knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], [] + packed_images, packed_videos, packed_audios = [], [], [] + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this + packed_labels += batch_labels[index] + packed_images += batch_images[index] + packed_videos += batch_videos[index] + packed_audios += batch_audios[index] + if self.data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) + + if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask + pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1 + packed_input_ids += [self.tokenizer.pad_token_id] * pad_length + packed_position_ids += [0] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + if self.data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn + + if len(packed_input_ids) != self.data_args.cutoff_len + 1: + raise ValueError("The length of packed example should be identical to the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append(packed_attention_masks) + model_inputs["position_ids"].append(packed_position_ids) + model_inputs["labels"].append(packed_labels) + model_inputs["images"].append(packed_images or None) + model_inputs["videos"].append(packed_videos or None) + model_inputs["audios"].append(packed_audios or None) + + return model_inputs diff --git a/llamafactory/data/processor/unsupervised.py b/llamafactory/data/processor/unsupervised.py new file mode 100644 index 0000000000000000000000000000000000000000..256174b6dd38696b5b180501102af40ff395d0a9 --- /dev/null +++ b/llamafactory/data/processor/unsupervised.py @@ -0,0 +1,91 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Optional + +from ...extras import logging +from ..data_utils import Role +from .processor_utils import DatasetProcessor, infer_seqlen + + +if TYPE_CHECKING: + from ..mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class UnsupervisedDatasetProcessor(DatasetProcessor): + def _encode_data_example( + self, + prompt: list[dict[str, str]], + response: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> tuple[list[int], list[int]]: + if len(response) == 1: + messages = prompt + response + else: + messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] + + messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor) + input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools) + if self.template.efficient_eos: + labels += [self.tokenizer.eos_token_id] + + input_ids, _ = self.template.mm_plugin.process_token_ids( + input_ids, None, images, videos, audios, self.tokenizer, self.processor + ) + source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len) + input_ids = input_ids[:source_len] + labels = labels[:target_len] + return input_ids, labels + + def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1: + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = self._encode_data_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + audios=examples["_audios"][i] or [], + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + model_inputs["audios"].append(examples["_audios"][i]) + + return model_inputs + + def print_data_example(self, example: dict[str, list[int]]) -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False))) diff --git a/llamafactory/data/template.py b/llamafactory/data/template.py new file mode 100644 index 0000000000000000000000000000000000000000..56e32dd203cf753e6ffe486df291e2f71c29a11c --- /dev/null +++ b/llamafactory/data/template.py @@ -0,0 +1,2209 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +from typing_extensions import override + +from ..extras import logging +from .data_utils import Role +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .mm_plugin import get_mm_plugin + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from ..hparams import DataArguments + from .formatter import SLOTS, Formatter + from .mm_plugin import BasePlugin + from .tool_utils import FunctionCall + + +logger = logging.get_logger(__name__) + + +@dataclass +class Template: + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_prefix: "Formatter" + default_system: str + stop_words: list[str] + thought_words: tuple[str, str] + efficient_eos: bool + replace_eos: bool + replace_jinja_template: bool + enable_thinking: Optional[bool] + mm_plugin: "BasePlugin" + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> tuple[list[int], list[int]]: + r"""Return a single pair of token ids representing prompt and response respectively.""" + encoded_messages = self._encode(tokenizer, messages, system, tools) + prompt_ids = [] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + response_ids = encoded_messages[-1] + return prompt_ids, response_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> list[tuple[list[int], list[int]]]: + r"""Return multiple pairs of token ids representing prompts and responses respectively.""" + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: + r"""Extract tool message.""" + return self.format_tools.extract(content) + + def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Return stop token ids.""" + stop_token_ids = {tokenizer.eos_token_id} + for token in self.stop_words: + stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) + + return list(stop_token_ids) + + def add_thought(self, content: str = "") -> str: + r"""Add empty thought to assistant message.""" + return f"{self.thought_words[0]}{self.thought_words[1]}" + content + + def remove_thought(self, content: str) -> str: + r"""Remove thought from assistant message.""" + pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL) + return re.sub(pattern, "", content).lstrip("\n") + + def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: + r"""Get the token ids of thought words.""" + return tokenizer.encode(self.add_thought(), add_special_tokens=False) + + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: + r"""Convert elements to token ids.""" + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}") + + return token_ids + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str], + tools: Optional[str], + ) -> list[list[int]]: + r"""Encode formatted inputs to pairs of token ids. + + Turn 0: prefix + system + query resp + Turn t: query resp. + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + @staticmethod + def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: + r"""Add or replace eos token to the tokenizer.""" + if tokenizer.eos_token == eos_token: + return + + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.") + else: + logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.") + + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: + r"""Add eos token and pad token to the tokenizer.""" + stop_words = self.stop_words + if self.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + @staticmethod + def _jinja_escape(content: str) -> str: + r"""Escape single quotes in content.""" + return content.replace("'", r"\'") + + @staticmethod + def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + r"""Convert slots to jinja template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the jinja template.""" + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") + user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% if system_message is defined %}{{ " + system + " }}{% endif %}" + "{% for message in loop_messages %}" + "{% set content = message['content'] %}" + "{% if message['role'] == 'user' %}" + "{{ " + user + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: + r"""Replace the jinja template in the tokenizer.""" + if tokenizer.chat_template is None or self.replace_jinja_template: + try: + tokenizer.chat_template = self._get_jinja_template(tokenizer) + except ValueError as e: + logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") + + @staticmethod + def _convert_slots_to_ollama( + slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" + ) -> str: + r"""Convert slots to ollama template.""" + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append(slot_pieces[0]) + if len(slot_pieces) > 1: + slot_items.append("{{ " + placeholder + " }}") + if slot_pieces[1]: + slot_items.append(slot_pieces[1]) + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append(tokenizer.bos_token) + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append(tokenizer.eos_token) + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return "".join(slot_items) + + def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the ollama template.""" + prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") + user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") + assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content") + return ( + f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}" + f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}""" + f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}""" + ) + + def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: + r"""Return the ollama modelfile. + + TODO: support function calling. + """ + modelfile = "# ollama modelfile auto-generated by llamafactory\n\n" + modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n' + + if self.default_system: + modelfile += f'SYSTEM """{self.default_system}"""\n\n' + + for stop_token_id in self.get_stop_token_ids(tokenizer): + modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n' + + modelfile += "PARAMETER num_ctx 4096\n" + return modelfile + + +@dataclass +class Llama2Template(Template): + r"""A template that fuse the system message to first user message.""" + + @override + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: str, + tools: str, + ) -> list[list[int]]: + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + system_text = "" + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + + if message["role"] == Role.USER: + elements += self.format_user.apply(content=system_text + message["content"]) + elif message["role"] == Role.ASSISTANT: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system_message = self._convert_slots_to_jinja( + self.format_system.apply(), tokenizer, placeholder="system_message" + ) + user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% for message in loop_messages %}" + "{% if loop.index0 == 0 and system_message is defined %}" + "{% set content = " + system_message + " + message['content'] %}" + "{% else %}{% set content = message['content'] %}{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ " + user_message + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant_message + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + +@dataclass +class ReasoningTemplate(Template): + r"""A template that add thought to assistant message.""" + + @override + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> tuple[list[int], list[int]]: + messages = deepcopy(messages) + for i in range(1, len(messages) - 2, 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + if self.enable_thinking is False: # remove all cot + messages[-1]["content"] = self.remove_thought(messages[-1]["content"]) + + prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools) + if ( + self.thought_words[0].strip() not in messages[-1]["content"] + and self.thought_words[1].strip() not in messages[-1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + prompt_ids += self.get_thought_word_ids(tokenizer) + else: # do compute loss + response_ids = self.get_thought_word_ids(tokenizer) + response_ids + + return prompt_ids, response_ids + + @override + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> list[tuple[list[int], list[int]]]: + messages = deepcopy(messages) + if self.enable_thinking is False: # remove all cot + for i in range(1, len(messages), 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + + encoded_messages = self._encode(tokenizer, messages, system, tools) + for i in range(0, len(messages), 2): + if ( + self.thought_words[0].strip() not in messages[i + 1]["content"] + and self.thought_words[1].strip() not in messages[i + 1]["content"] + ): # add empty cot + if not self.enable_thinking: # do not compute loss + encoded_messages[i] += self.get_thought_word_ids(tokenizer) + else: # do compute loss + encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1] + + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + +TEMPLATES: dict[str, "Template"] = {} + + +def register_template( + name: str, + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, + default_system: str = "", + stop_words: Optional[list[str]] = None, + thought_words: Optional[tuple[str, str]] = None, + efficient_eos: bool = False, + replace_eos: bool = False, + replace_jinja_template: bool = False, + enable_thinking: Optional[bool] = True, + mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), + template_class: type["Template"] = Template, +) -> None: + r"""Register a chat template. + + To add the following chat template: + ``` + user prompt here + model response here + user prompt here + model response here + ``` + + The corresponding code should be: + ``` + register_template( + name="custom", + format_user=StringFormatter(slots=["{{content}}\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(""), + ) + ``` + """ + if name in TEMPLATES: + raise ValueError(f"Template {name} already exists.") + + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=default_slots) + if format_assistant is not None: + default_function_formatter = FunctionFormatter(slots=format_assistant.slots, tool_format="default") + else: + default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") + + default_tool_formatter = ToolFormatter(tool_format="default") + default_prefix_formatter = EmptyFormatter() + TEMPLATES[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_prefix=format_prefix or default_prefix_formatter, + default_system=default_system, + stop_words=stop_words or [], + thought_words=thought_words or ("\n", "\n\n\n"), + efficient_eos=efficient_eos, + replace_eos=replace_eos, + replace_jinja_template=replace_jinja_template, + enable_thinking=enable_thinking, + mm_plugin=mm_plugin, + ) + + +def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": + r"""Extract a chat template from the tokenizer.""" + + def find_diff(short_str: str, long_str: str) -> str: + i, j = 0, 0 + diff = "" + while i < len(short_str) and j < len(long_str): + if short_str[i] == long_str[j]: + i += 1 + j += 1 + else: + diff += long_str[j] + j += 1 + + return diff + + prefix = tokenizer.decode(tokenizer.encode("")) + + messages = [{"role": "system", "content": "{{content}}"}] + system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :] + + messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}] + user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot_empty_system = user_slot_empty_system[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}] + user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot = user_slot[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] + assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] + template_class = ReasoningTemplate if "" in assistant_slot else Template + assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags + + if len(user_slot) > len(user_slot_empty_system): + default_system = find_diff(user_slot_empty_system, user_slot) + sole_system = system_slot.replace("{{content}}", default_system, 1) + user_slot = user_slot[len(sole_system) :] + else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot + default_system = "" + + return template_class( + format_user=StringFormatter(slots=[user_slot]), + format_assistant=StringFormatter(slots=[assistant_slot]), + format_system=StringFormatter(slots=[system_slot]), + format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"), + format_observation=StringFormatter(slots=[user_slot]), + format_tools=ToolFormatter(tool_format="default"), + format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(), + default_system=default_system, + stop_words=[], + thought_words=("\n", "\n\n\n"), + efficient_eos=False, + replace_eos=False, + replace_jinja_template=False, + enable_thinking=True, + mm_plugin=get_mm_plugin(name="base"), + ) + + +def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": + r"""Get chat template and fixes the tokenizer.""" + if data_args.template is None: + if isinstance(tokenizer.chat_template, str): + logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") + template = parse_template(tokenizer) + else: + logger.warning_rank0("`template` was not specified, use `empty` template.") + template = TEMPLATES["empty"] # placeholder + else: + if data_args.template not in TEMPLATES: + raise ValueError(f"Template {data_args.template} does not exist.") + + template = TEMPLATES[data_args.template] + + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + if data_args.tool_format is not None: + logger.info_rank0(f"Using tool format: {data_args.tool_format}.") + default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}] + template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) + template.format_tools = ToolFormatter(tool_format=data_args.tool_format) + + if data_args.default_system is not None: + logger.info_rank0(f"Using default system message: {data_args.default_system}.") + template.default_system = data_args.default_system + + if isinstance(template, ReasoningTemplate): + logger.warning_rank0( + "You are using reasoning template, " + "please add `_nothink` suffix if the model is not a reasoning model. " + "e.g., qwen3_vl_nothink" + ) + template.enable_thinking = data_args.enable_thinking + + template.fix_special_tokens(tokenizer) + template.fix_jinja_template(tokenizer) + return template + + +register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), + default_system=( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + ), + replace_jinja_template=True, +) + + +register_template( + name="aquila", + format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), + format_assistant=StringFormatter(slots=["{{content}}###"]), + format_system=StringFormatter(slots=["System: {{content}}###"]), + default_system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + stop_words=[""], +) + + +register_template( + name="atom", + format_user=StringFormatter( + slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] + ), + format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), +) + + +register_template( + name="baichuan", + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + efficient_eos=True, +) + + +register_template( + name="baichuan2", + format_user=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +register_template( + name="bailing", + format_user=StringFormatter(slots=["HUMAN{{content}}ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}"]), + format_observation=StringFormatter(slots=["OBSERVATION{{content}}ASSISTANT"]), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="bailing_v2", + format_user=StringFormatter(slots=["HUMAN{{content}}<|role_end|>ASSISTANT"]), + format_system=StringFormatter(slots=["SYSTEM{{content}}<|role_end|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]), + format_observation=StringFormatter( + slots=[ + "OBSERVATION\n\n{{content}}\n<|role_end|>ASSISTANT" + ] + ), + format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"), + format_tools=ToolFormatter(tool_format="ling"), + stop_words=["<|endoftext|>"], + efficient_eos=True, +) + + +register_template( + name="belle", + format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +register_template( + name="bluelm", + format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), +) + + +register_template( + name="breeze", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +register_template( + name="chatglm2", + format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + efficient_eos=True, +) + + +register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +# copied from chatml template +register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="codegeex2", + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), +) + + +register_template( + name="codegeex4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + default_system=( + "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题," + "并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。" + ), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +register_template( + name="cohere", + format_user=StringFormatter( + slots=[ + ( + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) + ] + ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +register_template( + name="cpm", + format_user=StringFormatter(slots=["<用户>{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from chatml template +register_template( + name="cpm3", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], +) + + +# copied from chatml template +register_template( + name="cpm4", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], +) + + +# copied from chatml template +register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +register_template( + name="deepseek3", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +# copied from deepseek3 template +register_template( + name="deepseekr1", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + template_class=ReasoningTemplate, +) + + +register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI programming assistant, utilizing the DeepSeek Coder model, " + "developed by DeepSeek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer.\n" + ), +) + + +register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]), + replace_jinja_template=True, +) + + +register_template( + name="dots_ocr", + format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]), + format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]), + stop_words=["<|endofassistant|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_vl", + image_token="<|imgpad|>", + video_token="<|vidpad|>", + vision_bos_token="<|img|>", + vision_eos_token="<|endofimg|>", + ), +) + + +register_template( + name="empty", + format_assistant=StringFormatter(slots=["{{content}}"]), +) + + +# copied from chatml template +register_template( + name="ernie", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]), + default_system="\nthink_mode=True\n", + stop_words=["<|im_end|>"], +) + + +register_template( + name="ernie_nothink", + format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]), + stop_words=["<|end_of_sentence|>"], +) + + +register_template( + name="exaone", + format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), + format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]), +) + + +register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + efficient_eos=True, +) + + +# copied from chatml template +register_template( + name="falcon_h1", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>", "<|end_of_text|>"], +) + + +register_template( + name="fewshot", + format_assistant=StringFormatter(slots=["{{content}}\n\n"]), + efficient_eos=True, + replace_jinja_template=True, +) + + +register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma2", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["", ""], + efficient_eos=True, + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma3", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3", image_token=""), + template_class=Llama2Template, +) + + +register_template( + name="gemma3n", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3n", image_token="", audio_token=""), + template_class=Llama2Template, +) + + +register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +# copied from glm4 template +register_template( + name="glm4_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glmz1", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="gpt", + format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), + format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), + format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), + default_system="You are ChatGPT, a large language model trained by OpenAI.", + thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +register_template( + name="granite3", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), +) + + +register_template( + name="granite3_vision", + format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +register_template( + name="granite4", + format_user=StringFormatter( + slots=[ + "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]), + format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"), + format_observation=StringFormatter( + slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="default"), + stop_words=["<|end_of_text|>"], + default_system="You are Granite, developed by IBM. You are a helpful AI assistant.", +) + + +register_template( + name="index", + format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), + format_system=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +register_template( + name="hunyuan", + format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]), + format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]), + format_prefix=EmptyFormatter(slots=["<|startoftext|>"]), + stop_words=["<|eos|>"], +) + + +register_template( + name="intern", + format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), + stop_words=[""], +) + + +register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), + stop_words=["<|im_end|>"], +) + + +register_template( + name="intern_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + ), + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token="