diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..6aad062a9ead2d88a859379d5932cd3dc78a776b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+llamafactory/extras/__pycache__/constants.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index c4747101b23ead88704837d275618fe368198a4d..6cf72efb35b793371fac42a092678981ab5317a2 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,6 @@
---
-title: App
-emoji: 🏢
-colorFrom: blue
-colorTo: purple
+title: app
+app_file: webui.py
sdk: gradio
-sdk_version: 5.49.1
-app_file: app.py
-pinned: false
+sdk_version: 5.45.0
---
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/api.py b/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..61215459ed91c6fa529a719cb9dac57223754d2e
--- /dev/null
+++ b/api.py
@@ -0,0 +1,33 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import uvicorn
+
+from llamafactory.api.app import create_app
+from llamafactory.chat import ChatModel
+
+
+def main():
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ api_host = os.getenv("API_HOST", "0.0.0.0")
+ api_port = int(os.getenv("API_PORT", "8000"))
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
+ uvicorn.run(app, host=api_host, port=api_port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llamafactory.egg-info/PKG-INFO b/llamafactory.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..b0a26490c345a22dc3e831d6325c1d3de71902c4
--- /dev/null
+++ b/llamafactory.egg-info/PKG-INFO
@@ -0,0 +1,1124 @@
+Metadata-Version: 2.4
+Name: llamafactory
+Version: 0.9.4.dev0
+Summary: Unified Efficient Fine-Tuning of 100+ LLMs
+Home-page: https://github.com/hiyouga/LLaMA-Factory
+Author: hiyouga
+Author-email: hiyouga@buaa.edu.cn
+License: Apache 2.0 License
+Keywords: AI,LLM,GPT,ChatGPT,Llama,Transformer,DeepSeek,Pytorch
+Classifier: Development Status :: 4 - Beta
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: Education
+Classifier: Intended Audience :: Science/Research
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Programming Language :: Python :: 3.12
+Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
+Requires-Python: >=3.9.0
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Requires-Dist: transformers!=4.52.0,<=4.56.2,>=4.49.0; python_version < "3.10"
+Requires-Dist: transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0; python_version >= "3.10"
+Requires-Dist: datasets<=4.0.0,>=2.16.0
+Requires-Dist: accelerate<=1.11.0,>=1.3.0
+Requires-Dist: peft<=0.17.1,>=0.14.0
+Requires-Dist: trl<=0.9.6,>=0.8.6
+Requires-Dist: gradio<=5.45.0,>=4.38.0
+Requires-Dist: matplotlib>=3.7.0
+Requires-Dist: tyro<0.9.0
+Requires-Dist: einops
+Requires-Dist: numpy<2.0.0
+Requires-Dist: pandas>=2.0.0
+Requires-Dist: scipy
+Requires-Dist: sentencepiece
+Requires-Dist: tiktoken
+Requires-Dist: modelscope>=1.14.0
+Requires-Dist: hf-transfer
+Requires-Dist: safetensors<=0.5.3
+Requires-Dist: fire
+Requires-Dist: omegaconf
+Requires-Dist: packaging
+Requires-Dist: protobuf
+Requires-Dist: pyyaml
+Requires-Dist: pydantic<=2.10.6
+Requires-Dist: uvicorn
+Requires-Dist: fastapi
+Requires-Dist: sse-starlette
+Requires-Dist: av
+Requires-Dist: librosa
+Requires-Dist: propcache!=0.4.0
+Provides-Extra: torch
+Requires-Dist: torch>=2.0.0; extra == "torch"
+Requires-Dist: torchvision>=0.15.0; extra == "torch"
+Provides-Extra: torch-npu
+Requires-Dist: torch==2.7.1; extra == "torch-npu"
+Requires-Dist: torch-npu==2.7.1; extra == "torch-npu"
+Requires-Dist: torchvision==0.22.1; extra == "torch-npu"
+Requires-Dist: decorator; extra == "torch-npu"
+Provides-Extra: metrics
+Requires-Dist: nltk; extra == "metrics"
+Requires-Dist: jieba; extra == "metrics"
+Requires-Dist: rouge-chinese; extra == "metrics"
+Provides-Extra: deepspeed
+Requires-Dist: deepspeed<=0.16.9,>=0.10.0; extra == "deepspeed"
+Provides-Extra: liger-kernel
+Requires-Dist: liger-kernel>=0.5.5; extra == "liger-kernel"
+Provides-Extra: bitsandbytes
+Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes"
+Provides-Extra: hqq
+Requires-Dist: hqq; extra == "hqq"
+Provides-Extra: eetq
+Requires-Dist: eetq; extra == "eetq"
+Provides-Extra: gptq
+Requires-Dist: optimum>=1.24.0; extra == "gptq"
+Requires-Dist: gptqmodel>=2.0.0; extra == "gptq"
+Provides-Extra: aqlm
+Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm"
+Provides-Extra: vllm
+Requires-Dist: vllm<=0.11.0,>=0.4.3; extra == "vllm"
+Provides-Extra: sglang
+Requires-Dist: sglang[srt]>=0.4.5; extra == "sglang"
+Requires-Dist: transformers==4.51.1; extra == "sglang"
+Provides-Extra: galore
+Requires-Dist: galore-torch; extra == "galore"
+Provides-Extra: apollo
+Requires-Dist: apollo-torch; extra == "apollo"
+Provides-Extra: badam
+Requires-Dist: badam>=1.2.1; extra == "badam"
+Provides-Extra: adam-mini
+Requires-Dist: adam-mini; extra == "adam-mini"
+Provides-Extra: minicpm-v
+Requires-Dist: soundfile; extra == "minicpm-v"
+Requires-Dist: torchvision; extra == "minicpm-v"
+Requires-Dist: torchaudio; extra == "minicpm-v"
+Requires-Dist: vector_quantize_pytorch; extra == "minicpm-v"
+Requires-Dist: vocos; extra == "minicpm-v"
+Requires-Dist: msgpack; extra == "minicpm-v"
+Requires-Dist: referencing; extra == "minicpm-v"
+Requires-Dist: jsonschema_specifications; extra == "minicpm-v"
+Provides-Extra: openmind
+Requires-Dist: openmind; extra == "openmind"
+Provides-Extra: swanlab
+Requires-Dist: swanlab; extra == "swanlab"
+Provides-Extra: fp8
+Requires-Dist: torchao>=0.8.0; extra == "fp8"
+Requires-Dist: accelerate>=1.10.0; extra == "fp8"
+Provides-Extra: fp8-te
+Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-te"
+Requires-Dist: accelerate>=1.10.0; extra == "fp8-te"
+Provides-Extra: fp8-all
+Requires-Dist: torchao>=0.8.0; extra == "fp8-all"
+Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-all"
+Requires-Dist: accelerate>=1.10.0; extra == "fp8-all"
+Provides-Extra: dev
+Requires-Dist: pre-commit; extra == "dev"
+Requires-Dist: ruff; extra == "dev"
+Requires-Dist: pytest; extra == "dev"
+Requires-Dist: build; extra == "dev"
+Dynamic: author
+Dynamic: author-email
+Dynamic: classifier
+Dynamic: description
+Dynamic: description-content-type
+Dynamic: home-page
+Dynamic: keywords
+Dynamic: license
+Dynamic: license-file
+Dynamic: provides-extra
+Dynamic: requires-dist
+Dynamic: requires-python
+Dynamic: summary
+
+
+
+[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
+[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
+[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
+[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
+[](https://pypi.org/project/llamafactory/)
+[](https://scholar.google.com/scholar?cites=12620864006390196564)
+[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
+
+[](https://twitter.com/llamafactory_ai)
+[](https://discord.gg/rKfvV9r9FK)
+[](https://github.com/hiyouga/llamafactory-community)
+[](https://blog.llamafactory.net/en/)
+
+[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
+[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
+[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
+[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
+[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
+[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
+[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
+
+### Used by [Amazon](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/), [NVIDIA](https://developer.nvidia.com/rtx/ai-toolkit), [Aliyun](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory), etc.
+
+
+
+### Supporters ❤️
+
+|
Warp, the agentic terminal for developers Available for MacOS, Linux, & Windows | |
+| ---- | ---- |
+
+----
+
+### Easily fine-tune 100+ large language models with zero-code [CLI](#quickstart) and [Web UI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
+
+
+
+
+
+👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
+
+\[ English | [中文](README_zh.md) \]
+
+**Fine-tuning a large language model can be easy as...**
+
+https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
+
+Start local training:
+- Please refer to [usage](#getting-started)
+
+Start cloud training:
+- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
+- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
+- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
+- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
+
+Read technical notes:
+- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
+- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
+- **Official Blog**: https://blog.llamafactory.net/en/
+- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
+
+> [!NOTE]
+> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
+
+## Table of Contents
+
+- [Features](#features)
+- [Blogs](#blogs)
+- [Changelog](#changelog)
+- [Supported Models](#supported-models)
+- [Supported Training Approaches](#supported-training-approaches)
+- [Provided Datasets](#provided-datasets)
+- [Requirement](#requirement)
+- [Getting Started](#getting-started)
+ - [Installation](#installation)
+ - [Data Preparation](#data-preparation)
+ - [Quickstart](#quickstart)
+ - [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
+ - [LLaMA Factory Online](#llama-factory-online)
+ - [Build Docker](#build-docker)
+ - [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
+ - [Download from ModelScope Hub](#download-from-modelscope-hub)
+ - [Download from Modelers Hub](#download-from-modelers-hub)
+ - [Use W&B Logger](#use-wb-logger)
+ - [Use SwanLab Logger](#use-swanlab-logger)
+- [Projects using LLaMA Factory](#projects-using-llama-factory)
+- [License](#license)
+- [Citation](#citation)
+- [Acknowledgement](#acknowledgement)
+
+## Features
+
+- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
+- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
+- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
+- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
+- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
+- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
+- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
+- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
+
+### Day-N Support for Fine-Tuning Cutting-Edge Models
+
+| Support Date | Model Name |
+| ------------ | -------------------------------------------------------------------- |
+| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 |
+| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
+
+## Blogs
+
+> [!TIP]
+> Now we have a dedicated blog for LLaMA Factory!
+>
+> Website: https://blog.llamafactory.net/en/
+
+- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
+- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
+- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
+- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
+- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
+
+All Blogs
+
+- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
+- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
+- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
+- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
+- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
+- [LLaMA Factory: Fine-tuning Llama3 for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
+
+
+
+## Changelog
+
+[25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started.
+
+[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
+
+[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
+
+[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started.
+
+Full Changelog
+
+[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model.
+
+[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
+
+[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
+
+[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
+
+[25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models.
+
+[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
+
+[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
+
+[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
+
+[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
+
+[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
+
+[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
+
+[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
+
+[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** models.
+
+[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
+
+[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
+
+[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
+
+[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
+
+[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
+
+[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
+
+[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
+
+[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
+
+[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
+
+[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
+
+[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
+
+[24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
+
+[24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
+
+[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
+
+[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
+
+[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
+
+[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
+
+[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
+
+[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
+
+[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
+
+[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
+
+[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
+
+[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
+
+[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
+
+[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
+
+[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
+
+[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
+
+[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
+
+[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
+
+[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
+
+[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
+
+[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
+
+[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
+
+[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
+
+[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
+
+[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
+
+[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
+
+[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
+
+[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
+
+[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
+
+[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
+
+[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
+
+[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
+
+[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
+
+[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
+
+[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
+
+[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
+
+[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
+
+[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
+
+
+
+> [!TIP]
+> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
+
+## Supported Models
+
+| Model | Model size | Template |
+| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
+| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
+| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
+| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
+| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
+| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
+| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
+| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
+| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
+| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
+| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
+| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
+| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
+| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
+| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
+| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
+| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
+| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
+| [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 |
+| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
+| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
+| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
+| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
+| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
+| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
+| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
+| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
+| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
+| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
+| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
+| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
+| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
+| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
+| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
+| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
+| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
+| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
+| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+
+> [!NOTE]
+> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
+>
+> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
+>
+> Remember to use the **SAME** template in training and inference.
+>
+> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
+>
+> \*\*: You need to install a specific version of `transformers` to use the corresponding model.
+
+Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
+
+You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
+
+## Supported Training Approaches
+
+| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
+| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+
+> [!TIP]
+> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
+
+## Provided Datasets
+
+Pre-training datasets
+
+- [Wiki Demo (en)](data/wiki_demo.txt)
+- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
+- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
+- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
+- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
+- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
+- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
+- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
+- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
+- [CCI3-HQ (zh)](https://huggingface.co/datasets/BAAI/CCI3-HQ)
+- [CCI3-Data (zh)](https://huggingface.co/datasets/BAAI/CCI3-Data)
+- [CCI4.0-M2-Base-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Base-v1)
+- [CCI4.0-M2-CoT-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-CoT-v1)
+- [CCI4.0-M2-Extra-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Extra-v1)
+- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
+- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
+
+
+
+Supervised fine-tuning datasets
+
+- [Identity (en&zh)](data/identity.json)
+- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
+- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
+- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
+- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
+- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
+- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
+- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
+- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+- [UltraChat (en)](https://github.com/thunlp/UltraChat)
+- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
+- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
+- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
+- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
+- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
+- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
+- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
+- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
+- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
+- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
+- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
+- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
+- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
+- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
+- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
+- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
+- [Infinity Instruct (zh)](https://huggingface.co/datasets/BAAI/Infinity-Instruct)
+- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
+- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
+- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
+- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
+- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
+- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
+- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
+- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
+- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
+- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
+- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
+- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
+- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
+- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
+- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
+- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
+- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
+- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
+- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
+- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
+
+
+
+Preference datasets
+
+- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
+- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
+- [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
+- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
+- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
+- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
+- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
+- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
+- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
+- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
+
+
+
+Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
+
+```bash
+pip install "huggingface_hub<1.0.0"
+huggingface-cli login
+```
+
+## Requirement
+
+| Mandatory | Minimum | Recommend |
+| ------------ | ------- | --------- |
+| python | 3.9 | 3.10 |
+| torch | 2.0.0 | 2.6.0 |
+| torchvision | 0.15.0 | 0.21.0 |
+| transformers | 4.49.0 | 4.50.0 |
+| datasets | 2.16.0 | 3.2.0 |
+| accelerate | 0.34.0 | 1.2.1 |
+| peft | 0.14.0 | 0.15.1 |
+| trl | 0.8.6 | 0.9.6 |
+
+| Optional | Minimum | Recommend |
+| ------------ | ------- | --------- |
+| CUDA | 11.6 | 12.2 |
+| deepspeed | 0.10.0 | 0.16.4 |
+| bitsandbytes | 0.39.0 | 0.43.1 |
+| vllm | 0.4.3 | 0.8.2 |
+| flash-attn | 2.5.6 | 2.7.2 |
+
+### Hardware Requirement
+
+\* *estimated*
+
+| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
+| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
+| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
+| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
+| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
+| QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
+| QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
+| QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
+
+## Getting Started
+
+### Installation
+
+> [!IMPORTANT]
+> Installation is mandatory.
+
+#### Install from Source
+
+```bash
+git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
+cd LLaMA-Factory
+pip install -e ".[torch,metrics]" --no-build-isolation
+```
+
+Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
+
+#### Install from Docker Image
+
+```bash
+docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
+```
+
+This image is built on Ubuntu 22.04 (x86\_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4.
+
+Find the pre-built images: https://hub.docker.com/r/hiyouga/llamafactory/tags
+
+Please refer to [build docker](#build-docker) to build the image yourself.
+
+Setting up a virtual environment with uv
+
+Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
+
+```bash
+uv sync --extra torch --extra metrics --prerelease=allow
+```
+
+Run LLaMA-Factory in the isolated environment:
+
+```bash
+uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
+```
+
+
+
+For Windows users
+
+#### Install PyTorch
+
+You need to manually install the GPU version of PyTorch on the Windows platform. Please refer to the [official website](https://pytorch.org/get-started/locally/) and the following command to install PyTorch with CUDA support:
+
+```bash
+pip uninstall torch torchvision torchaudio
+pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
+python -c "import torch; print(torch.cuda.is_available())"
+```
+
+If you see `True` then you have successfully installed PyTorch with CUDA support.
+
+Try `dataloader_num_workers: 0` if you encounter `Can't pickle local object` error.
+
+#### Install BitsAndBytes
+
+If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
+
+```bash
+pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
+```
+
+#### Install Flash Attention-2
+
+To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself.
+
+
+
+For Ascend NPU users
+
+To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
+
+```bash
+# replace the url according to your CANN version and devices
+# install CANN Toolkit
+wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
+bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
+
+# install CANN Kernels
+wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
+bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
+
+# set env variables
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+```
+
+| Requirement | Minimum | Recommend |
+| ------------ | ------- | -------------- |
+| CANN | 8.0.RC1 | 8.0.0.alpha002 |
+| torch | 2.1.0 | 2.4.0 |
+| torch-npu | 2.1.0 | 2.4.0.post2 |
+| deepspeed | 0.13.2 | 0.13.2 |
+| vllm-ascend | - | 0.7.3 |
+
+Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
+
+If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
+
+Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
+#### Install BitsAndBytes
+
+To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
+
+1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
+
+```bash
+# Install bitsandbytes from source
+# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
+git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
+cd bitsandbytes/
+
+# Install dependencies
+pip install -r requirements-dev.txt
+
+# Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference
+apt-get install -y build-essential cmake
+
+# Compile & install
+cmake -DCOMPUTE_BACKEND=npu -S .
+make
+pip install .
+```
+
+2. Install transformers from the main branch.
+
+```bash
+git clone -b main https://github.com/huggingface/transformers.git
+cd transformers
+pip install .
+```
+
+3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
+
+
+
+### Data Preparation
+
+Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace / ModelScope / Modelers hub, load the dataset in local disk, or specify a path to s3/gcs cloud storage.
+
+> [!NOTE]
+> Please update `data/dataset_info.json` to use your custom dataset.
+
+You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, **[DataFlow](https://github.com/OpenDCAI/DataFlow)** and **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
+
+### Quickstart
+
+Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
+
+```bash
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+```
+
+See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
+
+> [!TIP]
+> Use `llamafactory-cli help` to show help information.
+>
+> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
+
+### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
+
+```bash
+llamafactory-cli webui
+```
+
+### LLaMA Factory Online
+
+Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
+
+### Build Docker
+
+For CUDA users:
+
+```bash
+cd docker/docker-cuda/
+docker compose up -d
+docker compose exec llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+cd docker/docker-npu/
+docker compose up -d
+docker compose exec llamafactory bash
+```
+
+For AMD ROCm users:
+
+```bash
+cd docker/docker-rocm/
+docker compose up -d
+docker compose exec llamafactory bash
+```
+
+Build without Docker Compose
+
+For CUDA users:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ --build-arg EXTRAS=metrics \
+ -t llamafactory:latest .
+
+docker run -dit --ipc=host --gpus=all \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ --build-arg EXTRAS=torch-npu,metrics \
+ -t llamafactory:latest .
+
+docker run -dit --ipc=host \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
+```
+
+For AMD ROCm users:
+
+```bash
+docker build -f ./docker/docker-rocm/Dockerfile \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ --build-arg EXTRAS=metrics \
+ -t llamafactory:latest .
+
+docker run -dit --ipc=host \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/kfd \
+ --device /dev/dri \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
+```
+
+
+
+Use Docker volumes
+
+You can uncomment `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` in the Dockerfile to use data volumes.
+
+When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` argument to mount the local directory to the container. The following data volumes are available.
+
+- `hf_cache`: Utilize Hugging Face cache on the host machine.
+- `shared_data`: The directionary to store datasets on the host machine.
+- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
+
+
+
+### Deploy with OpenAI-style API and vLLM
+
+```bash
+API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
+```
+
+> [!TIP]
+> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
+>
+> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
+
+### Download from ModelScope Hub
+
+If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
+
+```bash
+export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
+```
+
+Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
+
+### Download from Modelers Hub
+
+You can also use Modelers Hub to download models and datasets.
+
+```bash
+export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
+```
+
+Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
+
+### Use W&B Logger
+
+To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
+
+```yaml
+report_to: wandb
+run_name: test_run # optional
+```
+
+Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
+
+### Use SwanLab Logger
+
+To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files.
+
+```yaml
+use_swanlab: true
+swanlab_run_name: test_run # optional
+```
+
+When launching training tasks, you can log in to SwanLab in three ways:
+
+1. Add `swanlab_api_key=` to the yaml file, and set it to your [API key](https://swanlab.cn/settings).
+2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings).
+3. Use the `swanlab login` command to complete the login.
+
+## Projects using LLaMA Factory
+
+If you have a project that should be incorporated, please contact via email or create a pull request.
+
+Click to show
+
+1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
+1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
+1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
+1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
+1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
+1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
+1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
+1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
+1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
+1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
+1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
+1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
+1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
+1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
+1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
+1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
+1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
+1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
+1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
+1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
+1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
+1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
+1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
+1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
+1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
+1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
+1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
+1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
+1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
+1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
+1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
+1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
+1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
+1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
+1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
+1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
+1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
+1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
+1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
+1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
+1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
+1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
+1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
+1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
+1. Zhang et al. CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling. ACL 2024. [[paper]](https://aclanthology.org/2024.findings-acl.830.pdf)
+1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
+1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
+1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
+1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
+1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
+1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
+1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
+1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
+1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
+1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
+1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
+1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
+1. **[EmoLLM](https://github.com/SmartFlowAI/EmoLLM)**: A project about large language models (LLMs) and mental health.
+
+
+## License
+
+This repository is licensed under the [Apache-2.0 License](LICENSE).
+
+Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+
+## Citation
+
+If this work is helpful, please kindly cite as:
+
+```bibtex
+@inproceedings{zheng2024llamafactory,
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
+ year={2024},
+ url={http://arxiv.org/abs/2403.13372}
+}
+```
+
+## Acknowledgement
+
+This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
+
+## Star History
+
+
diff --git a/llamafactory.egg-info/SOURCES.txt b/llamafactory.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb4a46244940cfc2d52cfdfa74064bcb07bf2aa8
--- /dev/null
+++ b/llamafactory.egg-info/SOURCES.txt
@@ -0,0 +1,178 @@
+LICENSE
+MANIFEST.in
+README.md
+pyproject.toml
+requirements.txt
+setup.py
+src/llamafactory/__init__.py
+src/llamafactory/cli.py
+src/llamafactory/launcher.py
+src/llamafactory.egg-info/PKG-INFO
+src/llamafactory.egg-info/SOURCES.txt
+src/llamafactory.egg-info/dependency_links.txt
+src/llamafactory.egg-info/entry_points.txt
+src/llamafactory.egg-info/requires.txt
+src/llamafactory.egg-info/top_level.txt
+src/llamafactory/api/__init__.py
+src/llamafactory/api/app.py
+src/llamafactory/api/chat.py
+src/llamafactory/api/common.py
+src/llamafactory/api/protocol.py
+src/llamafactory/chat/__init__.py
+src/llamafactory/chat/base_engine.py
+src/llamafactory/chat/chat_model.py
+src/llamafactory/chat/hf_engine.py
+src/llamafactory/chat/kt_engine.py
+src/llamafactory/chat/sglang_engine.py
+src/llamafactory/chat/vllm_engine.py
+src/llamafactory/data/__init__.py
+src/llamafactory/data/collator.py
+src/llamafactory/data/converter.py
+src/llamafactory/data/data_utils.py
+src/llamafactory/data/formatter.py
+src/llamafactory/data/loader.py
+src/llamafactory/data/mm_plugin.py
+src/llamafactory/data/parser.py
+src/llamafactory/data/template.py
+src/llamafactory/data/tool_utils.py
+src/llamafactory/data/processor/__init__.py
+src/llamafactory/data/processor/feedback.py
+src/llamafactory/data/processor/pairwise.py
+src/llamafactory/data/processor/pretrain.py
+src/llamafactory/data/processor/processor_utils.py
+src/llamafactory/data/processor/supervised.py
+src/llamafactory/data/processor/unsupervised.py
+src/llamafactory/eval/__init__.py
+src/llamafactory/eval/evaluator.py
+src/llamafactory/eval/template.py
+src/llamafactory/extras/__init__.py
+src/llamafactory/extras/constants.py
+src/llamafactory/extras/env.py
+src/llamafactory/extras/logging.py
+src/llamafactory/extras/misc.py
+src/llamafactory/extras/packages.py
+src/llamafactory/extras/ploting.py
+src/llamafactory/hparams/__init__.py
+src/llamafactory/hparams/data_args.py
+src/llamafactory/hparams/evaluation_args.py
+src/llamafactory/hparams/finetuning_args.py
+src/llamafactory/hparams/generating_args.py
+src/llamafactory/hparams/model_args.py
+src/llamafactory/hparams/parser.py
+src/llamafactory/hparams/training_args.py
+src/llamafactory/model/__init__.py
+src/llamafactory/model/adapter.py
+src/llamafactory/model/loader.py
+src/llamafactory/model/patcher.py
+src/llamafactory/model/model_utils/__init__.py
+src/llamafactory/model/model_utils/attention.py
+src/llamafactory/model/model_utils/checkpointing.py
+src/llamafactory/model/model_utils/embedding.py
+src/llamafactory/model/model_utils/ktransformers.py
+src/llamafactory/model/model_utils/kv_cache.py
+src/llamafactory/model/model_utils/liger_kernel.py
+src/llamafactory/model/model_utils/longlora.py
+src/llamafactory/model/model_utils/misc.py
+src/llamafactory/model/model_utils/mod.py
+src/llamafactory/model/model_utils/moe.py
+src/llamafactory/model/model_utils/packing.py
+src/llamafactory/model/model_utils/quantization.py
+src/llamafactory/model/model_utils/rope.py
+src/llamafactory/model/model_utils/unsloth.py
+src/llamafactory/model/model_utils/valuehead.py
+src/llamafactory/model/model_utils/visual.py
+src/llamafactory/third_party/__init__.py
+src/llamafactory/third_party/muon/__init__.py
+src/llamafactory/third_party/muon/muon.py
+src/llamafactory/train/__init__.py
+src/llamafactory/train/callbacks.py
+src/llamafactory/train/fp8_utils.py
+src/llamafactory/train/test_utils.py
+src/llamafactory/train/trainer_utils.py
+src/llamafactory/train/tuner.py
+src/llamafactory/train/dpo/__init__.py
+src/llamafactory/train/dpo/trainer.py
+src/llamafactory/train/dpo/workflow.py
+src/llamafactory/train/ksft/__init__.py
+src/llamafactory/train/ksft/workflow.py
+src/llamafactory/train/kto/__init__.py
+src/llamafactory/train/kto/trainer.py
+src/llamafactory/train/kto/workflow.py
+src/llamafactory/train/mca/__init__.py
+src/llamafactory/train/mca/trainer.py
+src/llamafactory/train/mca/workflow.py
+src/llamafactory/train/ppo/__init__.py
+src/llamafactory/train/ppo/ppo_utils.py
+src/llamafactory/train/ppo/trainer.py
+src/llamafactory/train/ppo/workflow.py
+src/llamafactory/train/pt/__init__.py
+src/llamafactory/train/pt/trainer.py
+src/llamafactory/train/pt/workflow.py
+src/llamafactory/train/rm/__init__.py
+src/llamafactory/train/rm/metric.py
+src/llamafactory/train/rm/trainer.py
+src/llamafactory/train/rm/workflow.py
+src/llamafactory/train/sft/__init__.py
+src/llamafactory/train/sft/metric.py
+src/llamafactory/train/sft/trainer.py
+src/llamafactory/train/sft/workflow.py
+src/llamafactory/v1/__init__.py
+src/llamafactory/v1/launcher.py
+src/llamafactory/v1/config/__init__.py
+src/llamafactory/v1/config/data_args.py
+src/llamafactory/v1/config/model_args.py
+src/llamafactory/v1/config/parser.py
+src/llamafactory/v1/config/sample_args.py
+src/llamafactory/v1/config/training_args.py
+src/llamafactory/v1/core/__init__.py
+src/llamafactory/v1/core/base_trainer.py
+src/llamafactory/v1/core/chat_sampler.py
+src/llamafactory/v1/core/data_engine.py
+src/llamafactory/v1/core/model_engine.py
+src/llamafactory/v1/plugins/__init__.py
+src/llamafactory/v1/plugins/data_plugins/__init__.py
+src/llamafactory/v1/plugins/data_plugins/converter.py
+src/llamafactory/v1/plugins/data_plugins/loader.py
+src/llamafactory/v1/plugins/data_plugins/template.py
+src/llamafactory/v1/plugins/model_plugins/__init__.py
+src/llamafactory/v1/plugins/model_plugins/added_token.py
+src/llamafactory/v1/plugins/model_plugins/peft.py
+src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py
+src/llamafactory/v1/plugins/model_plugins/kernels/constants.py
+src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
+src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py
+src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py
+src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py
+src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py
+src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py
+src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py
+src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py
+src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py
+src/llamafactory/v1/plugins/sampler_plugins/__init__.py
+src/llamafactory/v1/plugins/sampler_plugins/vllm.py
+src/llamafactory/v1/plugins/trainer_plugins/__init__.py
+src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py
+src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py
+src/llamafactory/v1/trainers/__init__.py
+src/llamafactory/v1/trainers/dpo_trainer.py
+src/llamafactory/v1/trainers/rm_trainer.py
+src/llamafactory/v1/trainers/sft_trainer.py
+src/llamafactory/webui/__init__.py
+src/llamafactory/webui/chatter.py
+src/llamafactory/webui/common.py
+src/llamafactory/webui/control.py
+src/llamafactory/webui/css.py
+src/llamafactory/webui/engine.py
+src/llamafactory/webui/interface.py
+src/llamafactory/webui/locales.py
+src/llamafactory/webui/manager.py
+src/llamafactory/webui/runner.py
+src/llamafactory/webui/components/__init__.py
+src/llamafactory/webui/components/chatbot.py
+src/llamafactory/webui/components/data.py
+src/llamafactory/webui/components/eval.py
+src/llamafactory/webui/components/export.py
+src/llamafactory/webui/components/footer.py
+src/llamafactory/webui/components/infer.py
+src/llamafactory/webui/components/top.py
+src/llamafactory/webui/components/train.py
\ No newline at end of file
diff --git a/llamafactory.egg-info/dependency_links.txt b/llamafactory.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/llamafactory.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/llamafactory.egg-info/entry_points.txt b/llamafactory.egg-info/entry_points.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bcafd0771096e3e65f55808b2ead2909c6f6691b
--- /dev/null
+++ b/llamafactory.egg-info/entry_points.txt
@@ -0,0 +1,3 @@
+[console_scripts]
+llamafactory-cli = llamafactory.cli:main
+lmf = llamafactory.cli:main
diff --git a/llamafactory.egg-info/requires.txt b/llamafactory.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2fb93071d9a6df49004f3c31adc7c4e77bee4cd8
--- /dev/null
+++ b/llamafactory.egg-info/requires.txt
@@ -0,0 +1,125 @@
+datasets<=4.0.0,>=2.16.0
+accelerate<=1.11.0,>=1.3.0
+peft<=0.17.1,>=0.14.0
+trl<=0.9.6,>=0.8.6
+gradio<=5.45.0,>=4.38.0
+matplotlib>=3.7.0
+tyro<0.9.0
+einops
+numpy<2.0.0
+pandas>=2.0.0
+scipy
+sentencepiece
+tiktoken
+modelscope>=1.14.0
+hf-transfer
+safetensors<=0.5.3
+fire
+omegaconf
+packaging
+protobuf
+pyyaml
+pydantic<=2.10.6
+uvicorn
+fastapi
+sse-starlette
+av
+librosa
+propcache!=0.4.0
+
+[:python_version < "3.10"]
+transformers!=4.52.0,<=4.56.2,>=4.49.0
+
+[:python_version >= "3.10"]
+transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0
+
+[adam-mini]
+adam-mini
+
+[apollo]
+apollo-torch
+
+[aqlm]
+aqlm[gpu]>=1.1.0
+
+[badam]
+badam>=1.2.1
+
+[bitsandbytes]
+bitsandbytes>=0.39.0
+
+[deepspeed]
+deepspeed<=0.16.9,>=0.10.0
+
+[dev]
+pre-commit
+ruff
+pytest
+build
+
+[eetq]
+eetq
+
+[fp8]
+torchao>=0.8.0
+accelerate>=1.10.0
+
+[fp8-all]
+torchao>=0.8.0
+transformer_engine[pytorch]>=2.0.0
+accelerate>=1.10.0
+
+[fp8-te]
+transformer_engine[pytorch]>=2.0.0
+accelerate>=1.10.0
+
+[galore]
+galore-torch
+
+[gptq]
+optimum>=1.24.0
+gptqmodel>=2.0.0
+
+[hqq]
+hqq
+
+[liger-kernel]
+liger-kernel>=0.5.5
+
+[metrics]
+nltk
+jieba
+rouge-chinese
+
+[minicpm_v]
+soundfile
+torchvision
+torchaudio
+vector_quantize_pytorch
+vocos
+msgpack
+referencing
+jsonschema_specifications
+
+[openmind]
+openmind
+
+[sglang]
+sglang[srt]>=0.4.5
+transformers==4.51.1
+
+[swanlab]
+swanlab
+
+[torch]
+torch>=2.0.0
+torchvision>=0.15.0
+
+[torch-npu]
+torch==2.7.1
+torch-npu==2.7.1
+torchvision==0.22.1
+decorator
+
+[vllm]
+vllm<=0.11.0,>=0.4.3
diff --git a/llamafactory.egg-info/top_level.txt b/llamafactory.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6670a28d28bf8b6497d5cca0e63d1e13c1aee55
--- /dev/null
+++ b/llamafactory.egg-info/top_level.txt
@@ -0,0 +1 @@
+llamafactory
diff --git a/llamafactory/__init__.py b/llamafactory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1567ef572714881cc464db25d3da3d08a460963
--- /dev/null
+++ b/llamafactory/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Efficient fine-tuning of large language models.
+
+Level:
+ api, webui > chat, eval, train > data, model > hparams > extras
+
+Disable version checking: DISABLE_VERSION_CHECK=1
+Enable VRAM recording: RECORD_VRAM=1
+Force using torchrun: FORCE_TORCHRUN=1
+Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
+Use modelscope: USE_MODELSCOPE_HUB=1
+Use openmind: USE_OPENMIND_HUB=1
+"""
+
+from .extras.env import VERSION
+
+
+__version__ = VERSION
diff --git a/llamafactory/__pycache__/__init__.cpython-312.pyc b/llamafactory/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae97776ebf1e10b36d5f470d1c0de9c4f5b17797
Binary files /dev/null and b/llamafactory/__pycache__/__init__.cpython-312.pyc differ
diff --git a/llamafactory/__pycache__/cli.cpython-312.pyc b/llamafactory/__pycache__/cli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95db2f247b7b2fe6595a5f7cc774ce0f6f472c11
Binary files /dev/null and b/llamafactory/__pycache__/cli.cpython-312.pyc differ
diff --git a/llamafactory/__pycache__/launcher.cpython-312.pyc b/llamafactory/__pycache__/launcher.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28aab819e043e9a2500c50c92f460fa7d87533cd
Binary files /dev/null and b/llamafactory/__pycache__/launcher.cpython-312.pyc differ
diff --git a/llamafactory/api/__init__.py b/llamafactory/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llamafactory/api/app.py b/llamafactory/api/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0621d80b064f00970e8fb58909ec8656ba0fb6b
--- /dev/null
+++ b/llamafactory/api/app.py
@@ -0,0 +1,133 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+from contextlib import asynccontextmanager
+from functools import partial
+from typing import Annotated, Optional
+
+from ..chat import ChatModel
+from ..extras.constants import EngineName
+from ..extras.misc import torch_gc
+from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
+from .chat import (
+ create_chat_completion_response,
+ create_score_evaluation_response,
+ create_stream_chat_completion_response,
+)
+from .protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ModelCard,
+ ModelList,
+ ScoreEvaluationRequest,
+ ScoreEvaluationResponse,
+)
+
+
+if is_fastapi_available():
+ from fastapi import Depends, FastAPI, HTTPException, status
+ from fastapi.middleware.cors import CORSMiddleware
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
+
+
+if is_starlette_available():
+ from sse_starlette import EventSourceResponse
+
+
+if is_uvicorn_available():
+ import uvicorn
+
+
+async def sweeper() -> None:
+ while True:
+ torch_gc()
+ await asyncio.sleep(300)
+
+
+@asynccontextmanager
+async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
+ if chat_model.engine.name == EngineName.HF:
+ asyncio.create_task(sweeper())
+
+ yield
+ torch_gc()
+
+
+def create_app(chat_model: "ChatModel") -> "FastAPI":
+ root_path = os.getenv("FASTAPI_ROOT_PATH", "")
+ app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ api_key = os.getenv("API_KEY")
+ security = HTTPBearer(auto_error=False)
+
+ async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
+ if api_key and (auth is None or auth.credentials != api_key):
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
+
+ @app.get(
+ "/v1/models",
+ response_model=ModelList,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def list_models():
+ model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
+ return ModelList(data=[model_card])
+
+ @app.post(
+ "/v1/chat/completions",
+ response_model=ChatCompletionResponse,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def create_chat_completion(request: ChatCompletionRequest):
+ if not chat_model.engine.can_generate:
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
+
+ if request.stream:
+ generate = create_stream_chat_completion_response(request, chat_model)
+ return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
+ else:
+ return await create_chat_completion_response(request, chat_model)
+
+ @app.post(
+ "/v1/score/evaluation",
+ response_model=ScoreEvaluationResponse,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
+ if chat_model.engine.can_generate:
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
+
+ return await create_score_evaluation_response(request, chat_model)
+
+ return app
+
+
+def run_api() -> None:
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ api_host = os.getenv("API_HOST", "0.0.0.0")
+ api_port = int(os.getenv("API_PORT", "8000"))
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
+ uvicorn.run(app, host=api_host, port=api_port)
diff --git a/llamafactory/api/chat.py b/llamafactory/api/chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..93236c5ca865492f0c45e1f5ab56a389875350ea
--- /dev/null
+++ b/llamafactory/api/chat.py
@@ -0,0 +1,291 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import io
+import json
+import os
+import re
+import uuid
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING, Optional
+
+from ..data import Role as DataRole
+from ..extras import logging
+from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
+from ..extras.misc import is_env_enabled
+from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
+from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
+from .protocol import (
+ ChatCompletionMessage,
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseUsage,
+ ChatCompletionStreamResponse,
+ ChatCompletionStreamResponseChoice,
+ Finish,
+ Function,
+ FunctionCall,
+ Role,
+ ScoreEvaluationResponse,
+)
+
+
+if is_fastapi_available():
+ from fastapi import HTTPException, status
+
+
+if is_pillow_available():
+ from PIL import Image
+
+
+if is_requests_available():
+ import requests
+
+
+if TYPE_CHECKING:
+ from ..chat import ChatModel
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
+
+
+logger = logging.get_logger(__name__)
+ROLE_MAPPING = {
+ Role.USER: DataRole.USER.value,
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
+ Role.SYSTEM: DataRole.SYSTEM.value,
+ Role.FUNCTION: DataRole.FUNCTION.value,
+ Role.TOOL: DataRole.OBSERVATION.value,
+}
+
+
+def _process_request(
+ request: "ChatCompletionRequest",
+) -> tuple[
+ list[dict[str, str]],
+ Optional[str],
+ Optional[str],
+ Optional[list["ImageInput"]],
+ Optional[list["VideoInput"]],
+ Optional[list["AudioInput"]],
+]:
+ if is_env_enabled("API_VERBOSE", "1"):
+ logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
+
+ if len(request.messages) == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
+
+ if request.messages[0].role == Role.SYSTEM:
+ content = request.messages.pop(0).content
+ system = content[0].text if isinstance(content, list) else content
+ else:
+ system = None
+
+ if len(request.messages) % 2 == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
+
+ input_messages = []
+ images, videos, audios = [], [], []
+ for i, message in enumerate(request.messages):
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
+
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
+ tool_calls = [
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
+ for tool_call in message.tool_calls
+ ]
+ content = json.dumps(tool_calls, ensure_ascii=False)
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
+ elif isinstance(message.content, list):
+ text_content = ""
+ for input_item in message.content:
+ if input_item.type == "text":
+ text_content += input_item.text
+ elif input_item.type == "image_url":
+ text_content += IMAGE_PLACEHOLDER
+ image_url = input_item.image_url.url
+ if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
+ image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
+ elif os.path.isfile(image_url): # local file
+ check_lfi_path(image_url)
+ image_stream = open(image_url, "rb")
+ else: # web uri
+ check_ssrf_url(image_url)
+ image_stream = requests.get(image_url, stream=True).raw
+
+ images.append(Image.open(image_stream).convert("RGB"))
+ elif input_item.type == "video_url":
+ text_content += VIDEO_PLACEHOLDER
+ video_url = input_item.video_url.url
+ if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
+ video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
+ elif os.path.isfile(video_url): # local file
+ check_lfi_path(video_url)
+ video_stream = video_url
+ else: # web uri
+ check_ssrf_url(video_url)
+ video_stream = requests.get(video_url, stream=True).raw
+
+ videos.append(video_stream)
+ elif input_item.type == "audio_url":
+ text_content += AUDIO_PLACEHOLDER
+ audio_url = input_item.audio_url.url
+ if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
+ audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
+ elif os.path.isfile(audio_url): # local file
+ check_lfi_path(audio_url)
+ audio_stream = audio_url
+ else: # web uri
+ check_ssrf_url(audio_url)
+ audio_stream = requests.get(audio_url, stream=True).raw
+
+ audios.append(audio_stream)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
+ )
+
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
+ else:
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
+
+ tool_list = request.tools
+ if isinstance(tool_list, list) and len(tool_list):
+ try:
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
+ except json.JSONDecodeError:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
+ else:
+ tools = None
+
+ return input_messages, system, tools, images or None, videos or None, audios or None
+
+
+def _create_stream_chat_completion_chunk(
+ completion_id: str,
+ model: str,
+ delta: "ChatCompletionMessage",
+ index: Optional[int] = 0,
+ finish_reason: Optional["Finish"] = None,
+) -> str:
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
+ return jsonify(chunk)
+
+
+async def create_chat_completion_response(
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
+) -> "ChatCompletionResponse":
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
+ input_messages, system, tools, images, videos, audios = _process_request(request)
+ responses = await chat_model.achat(
+ input_messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens,
+ num_return_sequences=request.n,
+ repetition_penalty=request.presence_penalty,
+ stop=request.stop,
+ )
+
+ prompt_length, response_length = 0, 0
+ choices = []
+ for i, response in enumerate(responses):
+ if tools:
+ result = chat_model.engine.template.extract_tool(response.response_text)
+ else:
+ result = response.response_text
+
+ if isinstance(result, list):
+ tool_calls = []
+ for tool in result:
+ function = Function(name=tool.name, arguments=tool.arguments)
+ tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
+
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
+ finish_reason = Finish.TOOL
+ else:
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
+
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
+ prompt_length = response.prompt_length
+ response_length += response.response_length
+
+ usage = ChatCompletionResponseUsage(
+ prompt_tokens=prompt_length,
+ completion_tokens=response_length,
+ total_tokens=prompt_length + response_length,
+ )
+
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
+
+
+async def create_stream_chat_completion_response(
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
+) -> AsyncGenerator[str, None]:
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
+ input_messages, system, tools, images, videos, audios = _process_request(request)
+ if tools:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
+
+ if request.n > 1:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
+
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
+ )
+ async for new_token in chat_model.astream_chat(
+ input_messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens,
+ repetition_penalty=request.presence_penalty,
+ stop=request.stop,
+ ):
+ if len(new_token) != 0:
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
+ )
+
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
+ )
+ yield "[DONE]"
+
+
+async def create_score_evaluation_response(
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
+) -> "ScoreEvaluationResponse":
+ score_id = f"scoreval-{uuid.uuid4().hex}"
+ if len(request.messages) == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
+
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
+ return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
diff --git a/llamafactory/api/common.py b/llamafactory/api/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b4e9602de7ebc10b4f15c68ad9167cb9d80d8ef
--- /dev/null
+++ b/llamafactory/api/common.py
@@ -0,0 +1,96 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ipaddress
+import json
+import os
+import socket
+from typing import TYPE_CHECKING, Any
+from urllib.parse import urlparse
+
+from ..extras.misc import is_env_enabled
+from ..extras.packages import is_fastapi_available
+
+
+if is_fastapi_available():
+ from fastapi import HTTPException, status
+
+
+if TYPE_CHECKING:
+ from pydantic import BaseModel
+
+
+SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
+ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
+
+
+def dictify(data: "BaseModel") -> dict[str, Any]:
+ try: # pydantic v2
+ return data.model_dump(exclude_unset=True)
+ except AttributeError: # pydantic v1
+ return data.dict(exclude_unset=True)
+
+
+def jsonify(data: "BaseModel") -> str:
+ try: # pydantic v2
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
+ except AttributeError: # pydantic v1
+ return data.json(exclude_unset=True, ensure_ascii=False)
+
+
+def check_lfi_path(path: str) -> None:
+ """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
+ if not ALLOW_LOCAL_FILES:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
+
+ try:
+ os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
+ real_path = os.path.realpath(path)
+ safe_path = os.path.realpath(SAFE_MEDIA_PATH)
+
+ if not real_path.startswith(safe_path):
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
+ )
+ except Exception:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
+
+
+def check_ssrf_url(url: str) -> None:
+ """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
+ try:
+ parsed_url = urlparse(url)
+ if parsed_url.scheme not in ["http", "https"]:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
+
+ hostname = parsed_url.hostname
+ if not hostname:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
+
+ ip_info = socket.getaddrinfo(hostname, parsed_url.port)
+ ip_address_str = ip_info[0][4][0]
+ ip = ipaddress.ip_address(ip_address_str)
+
+ if not ip.is_global:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Access to private or reserved IP addresses is not allowed.",
+ )
+
+ except socket.gaierror:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
+ )
+ except Exception as e:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
diff --git a/llamafactory/api/protocol.py b/llamafactory/api/protocol.py
new file mode 100644
index 0000000000000000000000000000000000000000..889d938e0b727ef1fca63d95fa20926c31830c52
--- /dev/null
+++ b/llamafactory/api/protocol.py
@@ -0,0 +1,157 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from enum import Enum, unique
+from typing import Any, Optional, Union
+
+from pydantic import BaseModel, Field
+from typing_extensions import Literal
+
+
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+ TOOL = "tool"
+
+
+@unique
+class Finish(str, Enum):
+ STOP = "stop"
+ LENGTH = "length"
+ TOOL = "tool_calls"
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: Literal["model"] = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: Literal["owner"] = "owner"
+
+
+class ModelList(BaseModel):
+ object: Literal["list"] = "list"
+ data: list[ModelCard] = []
+
+
+class Function(BaseModel):
+ name: str
+ arguments: str
+
+
+class FunctionDefinition(BaseModel):
+ name: str
+ description: str
+ parameters: dict[str, Any]
+
+
+class FunctionAvailable(BaseModel):
+ type: Literal["function", "code_interpreter"] = "function"
+ function: Optional[FunctionDefinition] = None
+
+
+class FunctionCall(BaseModel):
+ id: str
+ type: Literal["function"] = "function"
+ function: Function
+
+
+class URL(BaseModel):
+ url: str
+ detail: Literal["auto", "low", "high"] = "auto"
+
+
+class MultimodalInputItem(BaseModel):
+ type: Literal["text", "image_url", "video_url", "audio_url"]
+ text: Optional[str] = None
+ image_url: Optional[URL] = None
+ video_url: Optional[URL] = None
+ audio_url: Optional[URL] = None
+
+
+class ChatMessage(BaseModel):
+ role: Role
+ content: Optional[Union[str, list[MultimodalInputItem]]] = None
+ tool_calls: Optional[list[FunctionCall]] = None
+
+
+class ChatCompletionMessage(BaseModel):
+ role: Optional[Role] = None
+ content: Optional[str] = None
+ tool_calls: Optional[list[FunctionCall]] = None
+
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ messages: list[ChatMessage]
+ tools: Optional[list[FunctionAvailable]] = None
+ do_sample: Optional[bool] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ n: int = 1
+ presence_penalty: Optional[float] = None
+ max_tokens: Optional[int] = None
+ stop: Optional[Union[str, list[str]]] = None
+ stream: bool = False
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatCompletionMessage
+ finish_reason: Finish
+
+
+class ChatCompletionStreamResponseChoice(BaseModel):
+ index: int
+ delta: ChatCompletionMessage
+ finish_reason: Optional[Finish] = None
+
+
+class ChatCompletionResponseUsage(BaseModel):
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+
+class ChatCompletionResponse(BaseModel):
+ id: str
+ object: Literal["chat.completion"] = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: list[ChatCompletionResponseChoice]
+ usage: ChatCompletionResponseUsage
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: str
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: list[ChatCompletionStreamResponseChoice]
+
+
+class ScoreEvaluationRequest(BaseModel):
+ model: str
+ messages: list[str]
+ max_length: Optional[int] = None
+
+
+class ScoreEvaluationResponse(BaseModel):
+ id: str
+ object: Literal["score.evaluation"] = "score.evaluation"
+ model: str
+ scores: list[float]
diff --git a/llamafactory/chat/__init__.py b/llamafactory/chat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..15d8b9ba2d77d6f300d59300da5a49abd3ed4e57
--- /dev/null
+++ b/llamafactory/chat/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_engine import BaseEngine
+from .chat_model import ChatModel
+
+
+__all__ = ["BaseEngine", "ChatModel"]
diff --git a/llamafactory/chat/__pycache__/__init__.cpython-312.pyc b/llamafactory/chat/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1435170474306d36918011ef6af44e36b8fb8e44
Binary files /dev/null and b/llamafactory/chat/__pycache__/__init__.cpython-312.pyc differ
diff --git a/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc b/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..358bdcffd32512c3649dea4159ee05f027a37181
Binary files /dev/null and b/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc differ
diff --git a/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc b/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5abba6446be7ea5a10650a5e22a3dd0f1534446
Binary files /dev/null and b/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc differ
diff --git a/llamafactory/chat/base_engine.py b/llamafactory/chat/base_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d497c1ae927f94f396c18833b18cdb894cbd59d
--- /dev/null
+++ b/llamafactory/chat/base_engine.py
@@ -0,0 +1,98 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from collections.abc import AsyncGenerator
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+ from vllm import AsyncLLMEngine
+
+ from ..data import Template
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from ..extras.constants import EngineName
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+@dataclass
+class Response:
+ response_text: str
+ response_length: int
+ prompt_length: int
+ finish_reason: Literal["stop", "length"]
+
+
+class BaseEngine(ABC):
+ r"""Base class for inference engine of chat models.
+
+ Must implements async methods: chat(), stream_chat() and get_scores().
+ """
+
+ name: "EngineName"
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
+ tokenizer: "PreTrainedTokenizer"
+ can_generate: bool
+ template: "Template"
+ generating_args: dict[str, Any]
+
+ @abstractmethod
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ r"""Initialize an inference engine."""
+ ...
+
+ @abstractmethod
+ async def chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ r"""Get a list of responses of the chat model."""
+ ...
+
+ @abstractmethod
+ async def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ r"""Get the response token-by-token of the chat model."""
+ ...
+
+ @abstractmethod
+ async def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ r"""Get a list of scores of the reward model."""
+ ...
diff --git a/llamafactory/chat/chat_model.py b/llamafactory/chat/chat_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb612f88d468d76f06eefa45b96c1bfa0351fa7c
--- /dev/null
+++ b/llamafactory/chat/chat_model.py
@@ -0,0 +1,210 @@
+# Copyright 2025 THUDM and the LlamaFactory team.
+#
+# This code is inspired by the THUDM's ChatGLM implementation.
+# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+from collections.abc import AsyncGenerator, Generator
+from threading import Thread
+from typing import TYPE_CHECKING, Any, Optional
+
+from ..extras.constants import EngineName
+from ..extras.misc import torch_gc
+from ..hparams import get_infer_args
+
+
+if TYPE_CHECKING:
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from .base_engine import BaseEngine, Response
+
+
+def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+
+class ChatModel:
+ r"""General class for chat models. Backed by huggingface or vllm engines.
+
+ Supports both sync and async methods.
+ Sync methods: chat(), stream_chat() and get_scores().
+ Async methods: achat(), astream_chat() and aget_scores().
+ """
+
+ def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
+
+ if model_args.infer_backend == EngineName.HF:
+ from .hf_engine import HuggingfaceEngine
+
+ self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
+ elif model_args.infer_backend == EngineName.VLLM:
+ try:
+ from .vllm_engine import VllmEngine
+
+ self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
+ except ImportError as e:
+ raise ImportError(
+ "vLLM not install, you may need to run `pip install vllm`\n"
+ "or try to use HuggingFace backend: --infer_backend huggingface"
+ ) from e
+ elif model_args.infer_backend == EngineName.SGLANG:
+ try:
+ from .sglang_engine import SGLangEngine
+
+ self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
+ except ImportError as e:
+ raise ImportError(
+ "SGLang not install, you may need to run `pip install sglang[all]`\n"
+ "or try to use HuggingFace backend: --infer_backend huggingface"
+ ) from e
+ elif model_args.infer_backend == EngineName.KT:
+ try:
+ from .kt_engine import KTransformersEngine
+
+ self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
+ except ImportError as e:
+ raise ImportError(
+ "KTransformers not install, you may need to run `pip install ktransformers`\n"
+ "or try to use HuggingFace backend: --infer_backend huggingface"
+ ) from e
+ else:
+ raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
+
+ self._loop = asyncio.new_event_loop()
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
+ self._thread.start()
+
+ def chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ r"""Get a list of responses of the chat model."""
+ task = asyncio.run_coroutine_threadsafe(
+ self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
+ )
+ return task.result()
+
+ async def achat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ r"""Asynchronously get a list of responses of the chat model."""
+ return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
+
+ def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> Generator[str, None, None]:
+ r"""Get the response token-by-token of the chat model."""
+ generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
+ while True:
+ try:
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
+ yield task.result()
+ except StopAsyncIteration:
+ break
+
+ async def astream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ r"""Asynchronously get the response token-by-token of the chat model."""
+ async for new_token in self.engine.stream_chat(
+ messages, system, tools, images, videos, audios, **input_kwargs
+ ):
+ yield new_token
+
+ def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ r"""Get a list of scores of the reward model."""
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
+ return task.result()
+
+ async def aget_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ r"""Asynchronously get a list of scores of the reward model."""
+ return await self.engine.get_scores(batch_input, **input_kwargs)
+
+
+def run_chat() -> None:
+ if os.name != "nt":
+ try:
+ import readline # noqa: F401
+ except ImportError:
+ print("Install `readline` for a better experience.")
+
+ chat_model = ChatModel()
+ messages = []
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
+
+ while True:
+ try:
+ query = input("\nUser: ")
+ except UnicodeDecodeError:
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
+ continue
+ except Exception:
+ raise
+
+ if query.strip() == "exit":
+ break
+
+ if query.strip() == "clear":
+ messages = []
+ torch_gc()
+ print("History has been removed.")
+ continue
+
+ messages.append({"role": "user", "content": query})
+ print("Assistant: ", end="", flush=True)
+
+ response = ""
+ for new_text in chat_model.stream_chat(messages):
+ print(new_text, end="", flush=True)
+ response += new_text
+ print()
+ messages.append({"role": "assistant", "content": response})
diff --git a/llamafactory/chat/hf_engine.py b/llamafactory/chat/hf_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..adaaaa872786446fde2eba4a2a8f32f7ec4cc462
--- /dev/null
+++ b/llamafactory/chat/hf_engine.py
@@ -0,0 +1,412 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+from collections.abc import AsyncGenerator
+from threading import Thread
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+import torch
+from transformers import GenerationConfig, TextIteratorStreamer
+from typing_extensions import override
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras import logging
+from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
+from ..model import load_model, load_tokenizer
+from .base_engine import BaseEngine, Response
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
+ from trl import PreTrainedModelWrapper
+
+ from ..data import Template
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = logging.get_logger(__name__)
+
+
+class HuggingfaceEngine(BaseEngine):
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ self.name = EngineName.HF
+ self.can_generate = finetuning_args.stage == "sft"
+ tokenizer_module = load_tokenizer(model_args)
+ self.tokenizer = tokenizer_module["tokenizer"]
+ self.processor = tokenizer_module["processor"]
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
+ self.model = load_model(
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
+ ) # must after fixing tokenizer to resize vocab
+ self.generating_args = generating_args.to_dict()
+ try:
+ asyncio.get_event_loop()
+ except RuntimeError:
+ logger.warning_rank0_once("There is no current event loop, creating a new one.")
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
+
+ @staticmethod
+ def _process_args(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: dict[str, Any],
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ input_kwargs: Optional[dict[str, Any]] = {},
+ ) -> tuple[dict[str, Any], int]:
+ mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
+ if images is not None:
+ mm_input_dict.update({"images": images, "imglens": [len(images)]})
+ if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
+
+ if videos is not None:
+ mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
+ if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
+
+ if audios is not None:
+ mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
+ if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
+
+ messages = template.mm_plugin.process_messages(
+ messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
+ )
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
+ prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
+ prompt_ids, _ = template.mm_plugin.process_token_ids(
+ prompt_ids,
+ None,
+ mm_input_dict["images"],
+ mm_input_dict["videos"],
+ mm_input_dict["audios"],
+ tokenizer,
+ processor,
+ )
+ prompt_length = len(prompt_ids)
+ inputs = torch.tensor([prompt_ids], device=model.device)
+ attention_mask = torch.ones_like(inputs, dtype=torch.long)
+
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
+
+ if stop is not None:
+ logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
+
+ generating_args = generating_args.copy()
+ generating_args.update(
+ dict(
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
+ temperature=temperature if temperature is not None else generating_args["temperature"],
+ top_p=top_p if top_p is not None else generating_args["top_p"],
+ top_k=top_k if top_k is not None else generating_args["top_k"],
+ num_return_sequences=num_return_sequences,
+ repetition_penalty=repetition_penalty
+ if repetition_penalty is not None
+ else generating_args["repetition_penalty"],
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
+ skip_special_tokens=skip_special_tokens
+ if skip_special_tokens is not None
+ else generating_args["skip_special_tokens"],
+ eos_token_id=template.get_stop_token_ids(tokenizer),
+ pad_token_id=tokenizer.pad_token_id,
+ )
+ )
+
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
+ generating_args["do_sample"] = True
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
+
+ if not generating_args["temperature"]:
+ generating_args["do_sample"] = False
+
+ if not generating_args["do_sample"]:
+ generating_args.pop("temperature", None)
+ generating_args.pop("top_p", None)
+
+ if max_length:
+ generating_args.pop("max_new_tokens", None)
+ generating_args["max_length"] = max_length
+
+ if max_new_tokens:
+ generating_args.pop("max_length", None)
+ generating_args["max_new_tokens"] = max_new_tokens
+
+ gen_kwargs = dict(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ generation_config=GenerationConfig(**generating_args),
+ )
+
+ mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
+ for key, value in mm_inputs.items():
+ if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
+ value = torch.stack(value) # assume they have same sizes
+ elif (
+ isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
+ ): # for minicpmv inputs
+ value = torch.stack([torch.stack(v) for v in value])
+ elif not isinstance(value, torch.Tensor):
+ value = torch.tensor(value)
+
+ if torch.is_floating_point(value): # cast data dtype for paligemma
+ value = value.to(model.dtype)
+
+ if key == "second_per_grid_ts": # qwen2.5vl special case
+ gen_kwargs[key] = value.tolist()
+ else:
+ gen_kwargs[key] = value.to(model.device)
+
+ if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
+ gen_kwargs["input_ids"] = inputs
+ gen_kwargs["tokenizer"] = tokenizer
+ if "audio_feature_lens" in mm_inputs:
+ gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
+
+ gen_kwargs.pop("image_sizes", None)
+
+ return gen_kwargs, prompt_length
+
+ @staticmethod
+ @torch.inference_mode()
+ def _chat(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: dict[str, Any],
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ input_kwargs: Optional[dict[str, Any]] = {},
+ ) -> list["Response"]:
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
+ model,
+ tokenizer,
+ processor,
+ template,
+ generating_args,
+ messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ input_kwargs,
+ )
+ generate_output = model.generate(**gen_kwargs)
+ if isinstance(generate_output, tuple):
+ generate_output = generate_output[1][0] # post-process the minicpm_o output
+
+ response_ids = generate_output[:, prompt_length:]
+ response = tokenizer.batch_decode(
+ response_ids,
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
+ clean_up_tokenization_spaces=True,
+ )
+ results = []
+ for i in range(len(response)):
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
+ results.append(
+ Response(
+ response_text=response[i],
+ response_length=response_length,
+ prompt_length=prompt_length,
+ finish_reason="stop" if len(eos_index) else "length",
+ )
+ )
+
+ return results
+
+ @staticmethod
+ @torch.inference_mode()
+ def _stream_chat(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: dict[str, Any],
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ input_kwargs: Optional[dict[str, Any]] = {},
+ ) -> Callable[[], str]:
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
+ model,
+ tokenizer,
+ processor,
+ template,
+ generating_args,
+ messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ input_kwargs,
+ )
+ streamer = TextIteratorStreamer(
+ tokenizer,
+ skip_prompt=True,
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
+ )
+ gen_kwargs["streamer"] = streamer
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
+ thread.start()
+
+ def stream():
+ try:
+ return streamer.__next__()
+ except StopIteration:
+ raise StopAsyncIteration()
+
+ return stream
+
+ @staticmethod
+ @torch.inference_mode()
+ def _get_scores(
+ model: "PreTrainedModelWrapper",
+ tokenizer: "PreTrainedTokenizer",
+ batch_input: list[str],
+ input_kwargs: Optional[dict[str, Any]] = {},
+ ) -> list[float]:
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ device = getattr(model.pretrained_model, "device", "cuda")
+ inputs: dict[str, torch.Tensor] = tokenizer(
+ batch_input,
+ padding=True,
+ truncation=True,
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
+ return_tensors="pt",
+ add_special_tokens=False,
+ ).to(device)
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
+ return scores
+
+ @override
+ async def chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `chat`.")
+
+ input_args = (
+ self.model,
+ self.tokenizer,
+ self.processor,
+ self.template,
+ self.generating_args,
+ messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ input_kwargs,
+ )
+ async with self.semaphore:
+ return await asyncio.to_thread(self._chat, *input_args)
+
+ @override
+ async def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `stream_chat`.")
+
+ input_args = (
+ self.model,
+ self.tokenizer,
+ self.processor,
+ self.template,
+ self.generating_args,
+ messages,
+ system,
+ tools,
+ images,
+ videos,
+ audios,
+ input_kwargs,
+ )
+ async with self.semaphore:
+ stream = self._stream_chat(*input_args)
+ while True:
+ try:
+ yield await asyncio.to_thread(stream)
+ except StopAsyncIteration:
+ break
+
+ @override
+ async def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ if self.can_generate:
+ raise ValueError("Cannot get scores using an auto-regressive model.")
+
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
+ async with self.semaphore:
+ return await asyncio.to_thread(self._get_scores, *input_args)
diff --git a/llamafactory/chat/kt_engine.py b/llamafactory/chat/kt_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf3f4bb2b685ee971d538d29f0b6afa16956f2c
--- /dev/null
+++ b/llamafactory/chat/kt_engine.py
@@ -0,0 +1,284 @@
+# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+import platform
+from collections.abc import AsyncGenerator
+from threading import Thread
+from typing import TYPE_CHECKING, Any, Optional
+
+import torch
+from typing_extensions import override
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras import logging
+from ..extras.constants import EngineName
+from ..model import load_model, load_tokenizer
+from .base_engine import BaseEngine, Response
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+ from trl import PreTrainedModelWrapper
+
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
+from ktransformers.server.config.config import Config
+from ktransformers.util.utils import (
+ get_compute_capability,
+ prefill_and_generate_capture,
+)
+from ktransformers.util.vendors import GPUVendor, device_manager
+
+
+logger = logging.get_logger(__name__)
+
+
+class KTransformersEngine(BaseEngine):
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ self.name = EngineName.KT
+ self.can_generate = finetuning_args.stage == "sft"
+
+ tok_mod = load_tokenizer(model_args)
+ self.tokenizer = tok_mod["tokenizer"]
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
+
+ self.model = load_model(
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
+ )
+
+ self.generating_args = generating_args.to_dict()
+ self.max_new_tokens = model_args.kt_maxlen
+ self.use_cuda_graph = model_args.kt_use_cuda_graph
+ self.mode = model_args.kt_mode
+ self.force_think = model_args.kt_force_think
+ self.chunk_size = model_args.chunk_size
+
+ try:
+ asyncio.get_event_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
+
+ @staticmethod
+ @torch.inference_mode()
+ def _get_scores(
+ model: "PreTrainedModelWrapper",
+ tokenizer: "PreTrainedTokenizer",
+ batch_input: list[str],
+ input_kwargs: Optional[dict[str, Any]] = {},
+ ) -> list[float]:
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ device = getattr(model.pretrained_model, "device", "cuda")
+ inputs = tokenizer(
+ batch_input,
+ padding=True,
+ truncation=True,
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
+ return_tensors="pt",
+ add_special_tokens=False,
+ ).to(device)
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
+ return scores
+
+ async def _generate(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ paired = messages + [{"role": "assistant", "content": ""}]
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
+ prompt_len = len(prompt_ids)
+
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+
+ if "max_new_tokens" in self.generating_args:
+ max_tokens = int(self.generating_args["max_new_tokens"])
+ elif "max_length" in self.generating_args:
+ gl = int(self.generating_args["max_length"])
+ max_tokens = gl - prompt_len if gl > prompt_len else 1
+ else:
+ max_tokens = self.max_new_tokens or 256
+
+ if max_length is not None:
+ max_tokens = max(max_length - prompt_len, 1)
+ if max_new_tokens is not None:
+ max_tokens = int(max_new_tokens)
+ max_tokens = max(1, int(max_tokens))
+
+ if self.mode == "long_context":
+ max_len_cfg = Config().long_context_config["max_seq_len"]
+ need = prompt_len + max_tokens
+ assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
+
+ device = next(self.model.parameters()).device
+ input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
+ if self.force_think:
+ think = torch.tensor(
+ [self.tokenizer.encode("\n", add_special_tokens=False)], dtype=torch.long, device=device
+ )
+ input_tensor = torch.cat([input_tensor, think], dim=1)
+
+ use_flashinfer = (
+ platform.system() != "Windows"
+ and getattr(self.model.config, "architectures", [""])[0]
+ in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
+ and flashinfer_enabled
+ and get_compute_capability() >= 8
+ and device_manager.gpu_vendor == GPUVendor.NVIDIA
+ )
+
+ def make_gen():
+ if use_flashinfer:
+ return prefill_and_generate_capture(
+ self.model,
+ self.tokenizer,
+ input_tensor,
+ max_tokens,
+ self.use_cuda_graph,
+ mode=self.mode,
+ force_think=self.force_think,
+ chunk_size=self.chunk_size,
+ use_flashinfer_mla=True,
+ num_heads=self.model.config.num_attention_heads,
+ head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
+ head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
+ q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ + getattr(self.model.config, "qk_nope_head_dim", 0),
+ echo_stream=False,
+ )
+ else:
+ return prefill_and_generate_capture(
+ self.model,
+ self.tokenizer,
+ input_tensor,
+ max_tokens,
+ self.use_cuda_graph,
+ mode=self.mode,
+ force_think=self.force_think,
+ chunk_size=self.chunk_size,
+ echo_stream=False,
+ )
+
+ loop = asyncio.get_running_loop()
+ q: asyncio.Queue[Optional[str]] = asyncio.Queue()
+
+ def producer():
+ try:
+ gen = make_gen()
+ if hasattr(gen, "__aiter__"):
+
+ async def drain_async():
+ async for t in gen:
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
+
+ asyncio.run(drain_async())
+ elif hasattr(gen, "__iter__"):
+ for t in gen:
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
+ else:
+ loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
+ finally:
+ loop.call_soon_threadsafe(q.put_nowait, None)
+
+ Thread(target=producer, daemon=True).start()
+
+ while True:
+ item = await q.get()
+ if item is None:
+ break
+ yield item
+
+ @override
+ async def chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `chat`.")
+ async with self.semaphore:
+ produced = ""
+ final_text = ""
+ async for t in self._generate(messages, system, tools, **input_kwargs):
+ delta = t
+ produced = produced + delta
+ if delta:
+ final_text += delta
+
+ prompt_ids, _ = self.template.encode_oneturn(
+ self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
+ )
+ return [
+ Response(
+ response_text=final_text,
+ response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
+ prompt_length=len(prompt_ids),
+ finish_reason="stop",
+ )
+ ]
+
+ @override
+ async def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `stream_chat`.")
+ async with self.semaphore:
+ produced = ""
+ async for t in self._generate(messages, system, tools, **input_kwargs):
+ delta = t[len(produced) :] if t.startswith(produced) else t
+ produced = t
+ if delta:
+ yield delta
+
+ @override
+ async def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ if self.can_generate:
+ raise ValueError("Cannot get scores using an auto-regressive model.")
+ args = (self.model, self.tokenizer, batch_input, input_kwargs)
+ async with self.semaphore:
+ return await asyncio.to_thread(self._get_scores, *args)
diff --git a/llamafactory/chat/sglang_engine.py b/llamafactory/chat/sglang_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d2ead33823bc70d51cda59750d25580f972083
--- /dev/null
+++ b/llamafactory/chat/sglang_engine.py
@@ -0,0 +1,289 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import atexit
+import json
+from collections.abc import AsyncGenerator, AsyncIterator, Sequence
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import requests
+from typing_extensions import override
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras import logging
+from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
+from ..extras.misc import get_device_count, torch_gc
+from ..extras.packages import is_sglang_available
+from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+from ..model import load_config, load_tokenizer
+from ..model.model_utils.quantization import QuantizationMethod
+from .base_engine import BaseEngine, Response
+
+
+if is_sglang_available():
+ from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
+
+
+if TYPE_CHECKING:
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class SGLangEngine(BaseEngine):
+ """Inference engine for SGLang models.
+
+ This class wraps the SGLang engine to provide a consistent interface for text generation
+ that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
+ better interaction and performance. The engine launches a server process and communicates
+ with it via HTTP requests.
+
+ For more details on the SGLang HTTP server approach, see:
+ https://docs.sglang.ai/backend/send_request.html
+ """
+
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ self.name = EngineName.SGLANG
+ self.model_args = model_args
+ config = load_config(model_args) # may download model from ms hub
+ if getattr(config, "quantization_config", None): # gptq models should use float16
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
+ quant_method = quantization_config.get("quant_method", "")
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
+ model_args.infer_dtype = "float16"
+
+ self.can_generate = finetuning_args.stage == "sft"
+ tokenizer_module = load_tokenizer(model_args)
+ self.tokenizer = tokenizer_module["tokenizer"]
+ self.processor = tokenizer_module["processor"]
+ self.tokenizer.padding_side = "left"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
+ self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
+ self.generating_args = generating_args.to_dict()
+ if model_args.adapter_name_or_path is not None:
+ self.lora_request = True
+ else:
+ self.lora_request = False
+
+ launch_cmd = [
+ "python3 -m sglang.launch_server",
+ f"--model-path {model_args.model_name_or_path}",
+ f"--dtype {model_args.infer_dtype}",
+ f"--context-length {model_args.sglang_maxlen}",
+ f"--mem-fraction-static {model_args.sglang_mem_fraction}",
+ f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
+ f"--download-dir {model_args.cache_dir}",
+ "--log-level error",
+ ]
+ if self.lora_request:
+ launch_cmd.extend(
+ [
+ "--max-loras-per-batch 1",
+ f"--lora-backend {model_args.sglang_lora_backend}",
+ f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
+ "--disable-radix-cache",
+ ]
+ )
+ launch_cmd = " ".join(launch_cmd)
+ logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
+ try:
+ torch_gc()
+ self.server_process, port = launch_server_cmd(launch_cmd)
+ self.base_url = f"http://localhost:{port}"
+ atexit.register(self._cleanup_server)
+
+ logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
+ wait_for_server(self.base_url, timeout=300)
+ logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
+ try:
+ response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
+ if response.status_code == 200:
+ model_info = response.json()
+ logger.info(f"SGLang server model info: {model_info}")
+ except Exception as e:
+ logger.debug(f"Note: could not get model info: {str(e)}")
+
+ except Exception as e:
+ logger.error(f"Failed to start SGLang server: {str(e)}")
+ self._cleanup_server() # make sure to clean up any started process
+ raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
+
+ def _cleanup_server(self):
+ r"""Clean up the server process when the engine is destroyed."""
+ if hasattr(self, "server_process") and self.server_process:
+ try:
+ logger.info("Terminating SGLang server process")
+ terminate_process(self.server_process)
+ logger.info("SGLang server process terminated")
+ except Exception as e:
+ logger.warning(f"Error terminating SGLang server: {str(e)}")
+
+ async def _generate(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncIterator[dict[str, Any]]:
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
+
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
+
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
+
+ messages = self.template.mm_plugin.process_messages(
+ messages, images or [], videos or [], audios or [], self.processor
+ )
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
+ prompt_length = len(prompt_ids)
+
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
+
+ if num_return_sequences != 1:
+ raise NotImplementedError("SGLang only supports n=1.")
+
+ if "max_new_tokens" in self.generating_args:
+ max_tokens = self.generating_args["max_new_tokens"]
+ elif "max_length" in self.generating_args:
+ if self.generating_args["max_length"] > prompt_length:
+ max_tokens = self.generating_args["max_length"] - prompt_length
+ else:
+ max_tokens = 1
+
+ if max_length:
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
+
+ if max_new_tokens:
+ max_tokens = max_new_tokens
+
+ sampling_params = {
+ "temperature": temperature if temperature is not None else self.generating_args["temperature"],
+ "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
+ "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
+ "stop": stop,
+ "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
+ "max_new_tokens": max_tokens,
+ "repetition_penalty": (
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
+ )
+ or 1.0, # repetition_penalty must > 0
+ "skip_special_tokens": skip_special_tokens
+ if skip_special_tokens is not None
+ else self.generating_args["skip_special_tokens"],
+ }
+
+ def stream_request():
+ json_data = {
+ "input_ids": prompt_ids,
+ "sampling_params": sampling_params,
+ "stream": True,
+ }
+ if self.lora_request:
+ json_data["lora_request"] = ["lora0"]
+ response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
+ if response.status_code != 200:
+ raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
+
+ for chunk in response.iter_lines(decode_unicode=False):
+ chunk = str(chunk.decode("utf-8"))
+ if chunk == "data: [DONE]":
+ break
+
+ if chunk and chunk.startswith("data:"):
+ yield json.loads(chunk[5:].strip("\n"))
+
+ return await asyncio.to_thread(stream_request)
+
+ @override
+ async def chat(
+ self,
+ messages: Sequence[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[Sequence["ImageInput"]] = None,
+ videos: Optional[Sequence["VideoInput"]] = None,
+ audios: Optional[Sequence["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ final_output = None
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
+ for request_output in generator:
+ final_output = request_output
+
+ results = [
+ Response(
+ response_text=final_output["text"],
+ response_length=final_output["meta_info"]["completion_tokens"],
+ prompt_length=final_output["meta_info"]["prompt_tokens"],
+ finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
+ )
+ ]
+ return results
+
+ @override
+ async def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ generated_text = ""
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
+ for result in generator:
+ delta_text = result["text"][len(generated_text) :]
+ generated_text = result["text"]
+ yield delta_text
+
+ @override
+ async def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ raise NotImplementedError("SGLang engine does not support `get_scores`.")
+
+ def __del__(self):
+ r"""Ensure server is cleaned up when object is deleted."""
+ self._cleanup_server()
+ try:
+ atexit.unregister(self._cleanup_server)
+ except Exception:
+ pass
diff --git a/llamafactory/chat/vllm_engine.py b/llamafactory/chat/vllm_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..b33705279da13d398f8089b94c72b22d742f2c6f
--- /dev/null
+++ b/llamafactory/chat/vllm_engine.py
@@ -0,0 +1,263 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import uuid
+from collections.abc import AsyncGenerator, AsyncIterator
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from typing_extensions import override
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras import logging
+from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
+from ..extras.misc import get_device_count
+from ..extras.packages import is_vllm_available
+from ..model import load_config, load_tokenizer
+from ..model.model_utils.quantization import QuantizationMethod
+from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
+from .base_engine import BaseEngine, Response
+
+
+if is_vllm_available():
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
+ from vllm.lora.request import LoRARequest
+
+
+if TYPE_CHECKING:
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = logging.get_logger(__name__)
+
+
+class VllmEngine(BaseEngine):
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ self.name = EngineName.VLLM
+ self.model_args = model_args
+ config = load_config(model_args) # may download model from ms hub
+ if getattr(config, "quantization_config", None): # gptq models should use float16
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
+ quant_method = quantization_config.get("quant_method", "")
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
+ model_args.infer_dtype = "float16"
+
+ self.can_generate = finetuning_args.stage == "sft"
+ tokenizer_module = load_tokenizer(model_args)
+ self.tokenizer = tokenizer_module["tokenizer"]
+ self.processor = tokenizer_module["processor"]
+ self.tokenizer.padding_side = "left"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
+ self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
+ self.generating_args = generating_args.to_dict()
+
+ engine_args = {
+ "model": model_args.model_name_or_path,
+ "trust_remote_code": model_args.trust_remote_code,
+ "download_dir": model_args.cache_dir,
+ "dtype": model_args.infer_dtype,
+ "max_model_len": model_args.vllm_maxlen,
+ "tensor_parallel_size": get_device_count() or 1,
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
+ "disable_log_stats": True,
+ "disable_log_requests": True,
+ "enforce_eager": model_args.vllm_enforce_eager,
+ "enable_lora": model_args.adapter_name_or_path is not None,
+ "max_lora_rank": model_args.vllm_max_lora_rank,
+ }
+ if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
+ engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
+
+ if isinstance(model_args.vllm_config, dict):
+ engine_args.update(model_args.vllm_config)
+
+ if getattr(config, "is_yi_vl_derived_model", None):
+ import vllm.model_executor.models.llava
+
+ logger.info_rank0("Detected Yi-VL model, applying projector patch.")
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
+
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
+ if model_args.adapter_name_or_path is not None:
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
+ else:
+ self.lora_request = None
+
+ async def _generate(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncIterator["RequestOutput"]:
+ request_id = f"chatcmpl-{uuid.uuid4().hex}"
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
+
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
+
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
+
+ messages = self.template.mm_plugin.process_messages(
+ messages, images or [], videos or [], audios or [], self.processor
+ )
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
+ prompt_length = len(prompt_ids)
+
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
+
+ if length_penalty is not None:
+ logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
+
+ if "max_new_tokens" in self.generating_args:
+ max_tokens = self.generating_args["max_new_tokens"]
+ elif "max_length" in self.generating_args:
+ if self.generating_args["max_length"] > prompt_length:
+ max_tokens = self.generating_args["max_length"] - prompt_length
+ else:
+ max_tokens = 1
+
+ if max_length:
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
+
+ if max_new_tokens:
+ max_tokens = max_new_tokens
+
+ sampling_params = SamplingParams(
+ n=num_return_sequences,
+ repetition_penalty=(
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
+ )
+ or 1.0, # repetition_penalty must > 0
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
+ top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
+ stop=stop,
+ stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
+ max_tokens=max_tokens,
+ skip_special_tokens=skip_special_tokens
+ if skip_special_tokens is not None
+ else self.generating_args["skip_special_tokens"],
+ )
+
+ if images is not None: # add image features
+ multi_modal_data = {
+ "image": self.template.mm_plugin._regularize_images(
+ images,
+ image_max_pixels=self.model_args.image_max_pixels,
+ image_min_pixels=self.model_args.image_min_pixels,
+ )["images"]
+ }
+ elif videos is not None:
+ multi_modal_data = {
+ "video": self.template.mm_plugin._regularize_videos(
+ videos,
+ image_max_pixels=self.model_args.video_max_pixels,
+ image_min_pixels=self.model_args.video_min_pixels,
+ video_fps=self.model_args.video_fps,
+ video_maxlen=self.model_args.video_maxlen,
+ )["videos"]
+ }
+ elif audios is not None:
+ audio_data = self.template.mm_plugin._regularize_audios(
+ audios,
+ sampling_rate=self.model_args.audio_sampling_rate,
+ )
+ multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
+ else:
+ multi_modal_data = None
+
+ result_generator = self.model.generate(
+ {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
+ sampling_params=sampling_params,
+ request_id=request_id,
+ lora_request=self.lora_request,
+ )
+ return result_generator
+
+ @override
+ async def chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> list["Response"]:
+ final_output = None
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
+ async for request_output in generator:
+ final_output = request_output
+
+ results = []
+ for output in final_output.outputs:
+ results.append(
+ Response(
+ response_text=output.text,
+ response_length=len(output.token_ids),
+ prompt_length=len(final_output.prompt_token_ids),
+ finish_reason=output.finish_reason,
+ )
+ )
+
+ return results
+
+ @override
+ async def stream_chat(
+ self,
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ images: Optional[list["ImageInput"]] = None,
+ videos: Optional[list["VideoInput"]] = None,
+ audios: Optional[list["AudioInput"]] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ generated_text = ""
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
+ async for result in generator:
+ delta_text = result.outputs[0].text[len(generated_text) :]
+ generated_text = result.outputs[0].text
+ yield delta_text
+
+ @override
+ async def get_scores(
+ self,
+ batch_input: list[str],
+ **input_kwargs,
+ ) -> list[float]:
+ raise NotImplementedError("vLLM engine does not support `get_scores`.")
diff --git a/llamafactory/cli.py b/llamafactory/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..d574bf1db543f5379f074e276898826234708037
--- /dev/null
+++ b/llamafactory/cli.py
@@ -0,0 +1,31 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def main():
+ from .extras.misc import is_env_enabled
+
+ if is_env_enabled("USE_V1"):
+ from .v1 import launcher
+ else:
+ from . import launcher
+
+ launcher.launch()
+
+
+if __name__ == "__main__":
+ from multiprocessing import freeze_support
+
+ freeze_support()
+ main()
diff --git a/llamafactory/data/__init__.py b/llamafactory/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..11c8c9fcecd10e736e240196fde98f833c9df3dc
--- /dev/null
+++ b/llamafactory/data/__init__.py
@@ -0,0 +1,37 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .collator import (
+ KTODataCollatorWithPadding,
+ MultiModalDataCollatorForSeq2Seq,
+ PairwiseDataCollatorWithPadding,
+ SFTDataCollatorWith4DAttentionMask,
+)
+from .data_utils import Role, split_dataset
+from .loader import get_dataset
+from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
+
+
+__all__ = [
+ "TEMPLATES",
+ "KTODataCollatorWithPadding",
+ "MultiModalDataCollatorForSeq2Seq",
+ "PairwiseDataCollatorWithPadding",
+ "Role",
+ "SFTDataCollatorWith4DAttentionMask",
+ "Template",
+ "get_dataset",
+ "get_template_and_fix_tokenizer",
+ "split_dataset",
+]
diff --git a/llamafactory/data/__pycache__/__init__.cpython-312.pyc b/llamafactory/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aeb4c34a1136f093a1b21a6f280408dd1286477
Binary files /dev/null and b/llamafactory/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/collator.cpython-312.pyc b/llamafactory/data/__pycache__/collator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..912142ec8dd6bc455b401e5935f10e8190a6bdf3
Binary files /dev/null and b/llamafactory/data/__pycache__/collator.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/converter.cpython-312.pyc b/llamafactory/data/__pycache__/converter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b29a1575867106edc4f03a5b02406ad72871b1c5
Binary files /dev/null and b/llamafactory/data/__pycache__/converter.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/data_utils.cpython-312.pyc b/llamafactory/data/__pycache__/data_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4276f26997c27740fbc9ec5faaf3e99b3442d7eb
Binary files /dev/null and b/llamafactory/data/__pycache__/data_utils.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/formatter.cpython-312.pyc b/llamafactory/data/__pycache__/formatter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8f78d7f723f0fb3ea2ec501c517ee4dbdad4e49
Binary files /dev/null and b/llamafactory/data/__pycache__/formatter.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/loader.cpython-312.pyc b/llamafactory/data/__pycache__/loader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8232883361bdf807fa8523fdbca05e1b992c535
Binary files /dev/null and b/llamafactory/data/__pycache__/loader.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc b/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eff12f0ba2ac7b9f274742c73de5198970487bd5
Binary files /dev/null and b/llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/parser.cpython-312.pyc b/llamafactory/data/__pycache__/parser.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d60110a599edbc6e1a3ceac15438706301ab9a77
Binary files /dev/null and b/llamafactory/data/__pycache__/parser.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/template.cpython-312.pyc b/llamafactory/data/__pycache__/template.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df03dd6517d2b90e1d8c61583f02f7da39977c9d
Binary files /dev/null and b/llamafactory/data/__pycache__/template.cpython-312.pyc differ
diff --git a/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc b/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7989ecbc820bda2775d2bc6f82b90fcd446ea4c2
Binary files /dev/null and b/llamafactory/data/__pycache__/tool_utils.cpython-312.pyc differ
diff --git a/llamafactory/data/collator.py b/llamafactory/data/collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..162f432c9e5bf195ed4c6a821eb36d279bf3bac4
--- /dev/null
+++ b/llamafactory/data/collator.py
@@ -0,0 +1,331 @@
+# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
+#
+# This code is inspired by the OpenAccess AI Collective's axolotl library.
+# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Literal, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from peft import PeftModel
+from transformers import DataCollatorForSeq2Seq
+
+from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
+from ..extras.packages import is_pillow_available
+
+
+if is_pillow_available():
+ from PIL import Image
+
+
+if TYPE_CHECKING:
+ from transformers import ProcessorMixin
+
+ from .template import Template
+
+
+def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
+ r"""Expand 2d attention mask to 4d attention mask.
+
+ Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
+ handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
+
+ e.g.
+ ```python
+ # input
+ [[1, 1, 2, 2, 2, 0]]
+ # output
+ [
+ [
+ [
+ [o, x, x, x, x, x],
+ [o, o, x, x, x, x],
+ [x, x, o, x, x, x],
+ [x, x, o, o, x, x],
+ [x, x, o, o, o, x],
+ [x, x, x, x, x, x],
+ ]
+ ]
+ ]
+ ```
+ where `o` equals to `0.0`, `x` equals to `min_dtype`.
+ """
+ _, seq_len = attention_mask_with_indices.size()
+ min_dtype = torch.finfo(dtype).min
+ zero_tensor = torch.tensor(0, dtype=dtype)
+
+ # Create a non-padding mask.
+ non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
+ # Create indices for comparison.
+ indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
+ indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
+ # Create a lower triangular mask.
+ tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
+ attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
+ # Invert the attention mask.
+ attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
+ return attention_mask_4d
+
+
+@dataclass
+class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
+ r"""Data collator that supports VLMs.
+
+ Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
+ """
+
+ template: Optional["Template"] = None
+ processor: Optional["ProcessorMixin"] = None
+
+ def __post_init__(self):
+ if self.template is None:
+ raise ValueError("Template is required for MultiModalDataCollator.")
+
+ if isinstance(self.model, PeftModel):
+ self.model = self.model.base_model.model
+
+ if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
+ self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
+ elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
+ self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
+ else:
+ self.get_rope_func = None
+
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
+ batch_images, batch_videos, batch_audios = [], [], []
+ batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
+ for feature in features:
+ images = feature.pop("images", None) or []
+ videos = feature.pop("videos", None) or []
+ audios = feature.pop("audios", None) or []
+ batch_images.extend(images)
+ batch_videos.extend(videos)
+ batch_audios.extend(audios)
+ batch_imglens.append(len(images))
+ batch_vidlens.append(len(videos))
+ batch_audlens.append(len(audios))
+ batch_input_ids.append(feature["input_ids"])
+
+ fake_input_ids = []
+ if (
+ self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
+ ): # avoid process hanging in zero3/fsdp case
+ fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
+ fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
+ fake_messages = self.template.mm_plugin.process_messages(
+ fake_messages, fake_images, [], [], self.processor
+ )
+ _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
+ _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
+ _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
+ )
+ fake_input_ids.extend(_fake_input_ids)
+ batch_images = fake_images
+ batch_imglens[0] = 1
+
+ if (
+ self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
+ ): # avoid process hanging in zero3/fsdp case
+ fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
+ fake_audios = [np.zeros(1600)]
+ fake_messages = self.template.mm_plugin.process_messages(
+ fake_messages, [], [], fake_audios, self.processor
+ )
+ _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
+ _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
+ _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
+ )
+ fake_input_ids.extend(_fake_input_ids)
+ batch_audios = fake_audios
+ batch_audlens[0] = 1
+
+ if len(fake_input_ids) != 0:
+ if self.tokenizer.padding_side == "right":
+ features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
+ features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
+ features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
+ else:
+ features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
+ features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
+ features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
+
+ batch_input_ids[0] = features[0]["input_ids"]
+
+ mm_inputs = self.template.mm_plugin.get_mm_inputs(
+ batch_images,
+ batch_videos,
+ batch_audios,
+ batch_imglens,
+ batch_vidlens,
+ batch_audlens,
+ batch_input_ids,
+ self.processor,
+ )
+ if "token_type_ids" in mm_inputs:
+ token_type_ids = mm_inputs.pop("token_type_ids")
+ for i, feature in enumerate(features):
+ feature["token_type_ids"] = token_type_ids[i]
+
+ features: dict[str, torch.Tensor] = super().__call__(features)
+
+ if self.get_rope_func is not None:
+ rope_index_kwargs = {
+ "input_ids": features["input_ids"],
+ "image_grid_thw": mm_inputs.get("image_grid_thw"),
+ "video_grid_thw": mm_inputs.get("video_grid_thw"),
+ "attention_mask": (features["attention_mask"] >= 1).float(),
+ }
+ if "second_per_grid_ts" in mm_inputs: # for qwen2vl
+ rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
+ elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
+ rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
+
+ if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
+ rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
+ feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
+ if feature_attention_mask is not None: # FIXME: need to get video image lengths
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
+ rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
+
+ features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
+ features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
+ dim=-1
+ ).unsqueeze(-1)
+ else: # for qwen vl
+ features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
+
+ if (
+ self.model is not None
+ and getattr(self.model.config, "model_type", None)
+ in [
+ "glm4v",
+ "Keye",
+ "qwen2_vl",
+ "qwen2_5_vl",
+ "qwen2_5_omni_thinker",
+ "qwen3_omni_moe_thinker",
+ "qwen3_vl",
+ "qwen3_vl_moe",
+ ]
+ and ("position_ids" not in features or features["position_ids"].dim() != 3)
+ ):
+ raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
+
+ if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
+ cross_attention_mask = mm_inputs.pop("cross_attention_mask")
+ seq_len = features["input_ids"].size(1)
+ orig_len = cross_attention_mask.size(1)
+ mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
+
+ features.update(mm_inputs)
+
+ if "image_bound" in features: # for minicpmv inputs
+ bsz, seq_length = features["input_ids"].shape
+ features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
+ return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
+
+ return features
+
+
+@dataclass
+class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
+ r"""Data collator for 4d attention mask."""
+
+ block_diag_attn: bool = False
+ attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
+ compute_dtype: "torch.dtype" = torch.float32
+
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
+ features = super().__call__(features)
+ if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
+ features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
+
+ for key, value in features.items(): # cast data dtype for paligemma
+ if torch.is_tensor(value) and torch.is_floating_point(value):
+ features[key] = value.to(self.compute_dtype)
+
+ return features
+
+
+@dataclass
+class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
+ r"""Data collator for pairwise data."""
+
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
+ r"""Pad batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ concatenated_features = []
+ for key in ("chosen", "rejected"):
+ for feature in features:
+ target_feature = {
+ "input_ids": feature[f"{key}_input_ids"],
+ "attention_mask": feature[f"{key}_attention_mask"],
+ "labels": feature[f"{key}_labels"],
+ "images": feature["images"],
+ "videos": feature["videos"],
+ "audios": feature["audios"],
+ }
+ concatenated_features.append(target_feature)
+
+ return super().__call__(concatenated_features)
+
+
+@dataclass
+class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
+ r"""Data collator for KTO data."""
+
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
+ target_features = []
+ kl_features = []
+ kto_tags = []
+ for feature in features:
+ target_feature = {
+ "input_ids": feature["input_ids"],
+ "attention_mask": feature["attention_mask"],
+ "labels": feature["labels"],
+ "images": feature["images"],
+ "videos": feature["videos"],
+ "audios": feature["audios"],
+ }
+ kl_feature = {
+ "input_ids": feature["kl_input_ids"],
+ "attention_mask": feature["kl_attention_mask"],
+ "labels": feature["kl_labels"],
+ "images": feature["images"],
+ "videos": feature["videos"],
+ "audios": feature["audios"],
+ }
+ target_features.append(target_feature)
+ kl_features.append(kl_feature)
+ kto_tags.append(feature["kto_tags"])
+
+ batch = super().__call__(target_features)
+ kl_batch = super().__call__(kl_features)
+ batch["kl_input_ids"] = kl_batch["input_ids"]
+ batch["kl_attention_mask"] = kl_batch["attention_mask"]
+ batch["kl_labels"] = kl_batch["labels"]
+ if "cross_attention_mask" in kl_batch: # for mllama inputs
+ batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
+
+ if "token_type_ids" in kl_batch:
+ batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
+
+ batch["kto_tags"] = torch.tensor(kto_tags)
+ return batch
diff --git a/llamafactory/data/converter.py b/llamafactory/data/converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3735e648eee7117154c5301e8adae746a566d7
--- /dev/null
+++ b/llamafactory/data/converter.py
@@ -0,0 +1,425 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from ..extras import logging
+from .data_utils import Role
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments
+
+ from ..hparams import DataArguments
+ from .mm_plugin import AudioInput, ImageInput, VideoInput
+ from .parser import DatasetAttr
+
+ MediaType = Union[ImageInput, VideoInput, AudioInput]
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class DatasetConverter:
+ dataset_attr: "DatasetAttr"
+ data_args: "DataArguments"
+
+ def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
+ r"""Optionally concatenate media path to media dir when loading from local disk."""
+ if medias is None:
+ return None
+ elif not isinstance(medias, list):
+ medias = [medias]
+ elif len(medias) == 0:
+ return None
+ else:
+ medias = medias[:]
+
+ if self.dataset_attr.load_from in ["script", "file"]:
+ if isinstance(medias[0], str):
+ for i in range(len(medias)):
+ media_path = os.path.join(self.data_args.media_dir, medias[i])
+ if os.path.isfile(media_path):
+ medias[i] = media_path
+ else:
+ logger.warning_rank0_once(
+ f"Media {medias[i]} does not exist in `media_dir`. Use original path."
+ )
+ elif isinstance(medias[0], list): # for processed video frames
+ # medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
+ for i in range(len(medias)):
+ for j in range(len(medias[i])):
+ media_path = os.path.join(self.data_args.media_dir, medias[i][j])
+ if os.path.isfile(media_path):
+ medias[i][j] = media_path
+ else:
+ logger.warning_rank0_once(
+ f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
+ )
+
+ return medias
+
+ @abstractmethod
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
+ r"""Convert a single example in the dataset to the standard format."""
+ ...
+
+
+@dataclass
+class AlpacaDatasetConverter(DatasetConverter):
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
+ prompt = []
+ if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
+ for old_prompt, old_response in example[self.dataset_attr.history]:
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
+
+ query = []
+ if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
+ query.append(example[self.dataset_attr.prompt])
+
+ if self.dataset_attr.query and example[self.dataset_attr.query]:
+ query.append(example[self.dataset_attr.query])
+
+ prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
+
+ if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
+ response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
+ if example[self.dataset_attr.kto_tag]:
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
+ else:
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
+ elif (
+ self.dataset_attr.ranking
+ and isinstance(example[self.dataset_attr.chosen], str)
+ and isinstance(example[self.dataset_attr.rejected], str)
+ ): # pairwise example
+ response = [
+ {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
+ {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
+ ]
+ elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
+ response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
+ else: # unsupervised
+ response = []
+
+ output = {
+ "_prompt": prompt,
+ "_response": response,
+ "_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
+ "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
+ }
+ return output
+
+
+@dataclass
+class SharegptDatasetConverter(DatasetConverter):
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
+ tag_mapping = {
+ self.dataset_attr.user_tag: Role.USER.value,
+ self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
+ self.dataset_attr.observation_tag: Role.OBSERVATION.value,
+ self.dataset_attr.function_tag: Role.FUNCTION.value,
+ self.dataset_attr.system_tag: Role.SYSTEM.value,
+ }
+ odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
+ even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
+ accept_tags = (odd_tags, even_tags)
+ messages = example[self.dataset_attr.messages]
+ if (
+ self.dataset_attr.system_tag
+ and len(messages) != 0
+ and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
+ ):
+ system = messages[0][self.dataset_attr.content_tag]
+ messages = messages[1:]
+ else:
+ system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
+
+ aligned_messages = []
+ broken_data = False
+ for turn_idx, message in enumerate(messages):
+ if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
+ logger.warning_rank0(f"Invalid role tag in {messages}.")
+ broken_data = True
+ break
+
+ aligned_messages.append(
+ {
+ "role": tag_mapping[message[self.dataset_attr.role_tag]],
+ "content": message[self.dataset_attr.content_tag],
+ }
+ )
+
+ if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
+ self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
+ ):
+ logger.warning_rank0(f"Invalid message count in {messages}.")
+ broken_data = True
+
+ if broken_data:
+ logger.warning_rank0("Skipping this abnormal example.")
+ prompt, response = [], []
+ elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+ if example[self.dataset_attr.kto_tag]:
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
+ else:
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
+ elif (
+ self.dataset_attr.ranking
+ and isinstance(example[self.dataset_attr.chosen], dict)
+ and isinstance(example[self.dataset_attr.rejected], dict)
+ ): # pairwise example
+ chosen = example[self.dataset_attr.chosen]
+ rejected = example[self.dataset_attr.rejected]
+ if (
+ chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
+ or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
+ ):
+ logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
+ broken_data = True
+
+ prompt = aligned_messages
+ response = [
+ {
+ "role": tag_mapping[chosen[self.dataset_attr.role_tag]],
+ "content": chosen[self.dataset_attr.content_tag],
+ },
+ {
+ "role": tag_mapping[rejected[self.dataset_attr.role_tag]],
+ "content": rejected[self.dataset_attr.content_tag],
+ },
+ ]
+ else: # normal example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+
+ output = {
+ "_prompt": prompt,
+ "_response": response,
+ "_system": system,
+ "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
+ }
+ return output
+
+
+@dataclass
+class OpenAIDatasetConverter(DatasetConverter):
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
+ tag_mapping = {
+ self.dataset_attr.user_tag: Role.USER.value,
+ self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
+ self.dataset_attr.observation_tag: Role.OBSERVATION.value,
+ self.dataset_attr.function_tag: Role.FUNCTION.value,
+ self.dataset_attr.system_tag: Role.SYSTEM.value,
+ }
+
+ messages = example[self.dataset_attr.messages]
+ if (
+ self.dataset_attr.system_tag
+ and len(messages) != 0
+ and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
+ ):
+ system = messages[0][self.dataset_attr.content_tag]
+ messages = messages[1:]
+ else:
+ system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
+
+ aligned_messages = []
+ tool_responses = []
+ broken_data = False
+ for turn_idx, message in enumerate(messages):
+ role = message[self.dataset_attr.role_tag]
+ content = message[self.dataset_attr.content_tag]
+
+ if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
+ if "tool_calls" in message and len(message["tool_calls"]) > 0:
+ tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
+ content = json.dumps(tool_calls_list, ensure_ascii=False)
+ role = self.dataset_attr.function_tag
+
+ if role == self.dataset_attr.observation_tag:
+ tool_responses.append(content)
+ continue
+ elif len(tool_responses) > 0:
+ _content = "\n\n\n".join(tool_responses)
+ aligned_messages.append(
+ {
+ "role": Role.OBSERVATION.value,
+ "content": _content,
+ }
+ )
+ tool_responses = []
+
+ aligned_messages.append(
+ {
+ "role": tag_mapping[role],
+ "content": content,
+ }
+ )
+
+ odd_tags = (Role.USER.value, Role.OBSERVATION.value)
+ even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
+ accept_tags = (odd_tags, even_tags)
+ for turn_idx, message in enumerate(aligned_messages):
+ if message["role"] not in accept_tags[turn_idx % 2]:
+ logger.warning_rank0(f"Invalid role tag in {messages}.")
+ broken_data = True
+ break
+
+ if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
+ self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
+ ):
+ logger.warning_rank0(f"Invalid message count in {messages}.")
+ broken_data = True
+
+ if broken_data:
+ logger.warning_rank0("Skipping this abnormal example.")
+ prompt, response = [], []
+ elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+ if example[self.dataset_attr.kto_tag]:
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
+ else:
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
+ elif (
+ self.dataset_attr.ranking
+ and isinstance(example[self.dataset_attr.chosen], dict)
+ and isinstance(example[self.dataset_attr.rejected], dict)
+ ): # pairwise example
+ chosen = example[self.dataset_attr.chosen]
+ rejected = example[self.dataset_attr.rejected]
+ if (
+ chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
+ or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
+ ):
+ logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
+ broken_data = True
+
+ prompt = aligned_messages
+ response = [
+ {
+ "role": tag_mapping[chosen[self.dataset_attr.role_tag]],
+ "content": chosen[self.dataset_attr.content_tag],
+ },
+ {
+ "role": tag_mapping[rejected[self.dataset_attr.role_tag]],
+ "content": rejected[self.dataset_attr.content_tag],
+ },
+ ]
+ else: # normal example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+
+ tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
+ if isinstance(tools, dict) or isinstance(tools, list):
+ tools = json.dumps(tools, ensure_ascii=False)
+
+ short_system_prompt = "detailed thinking off"
+ if not system:
+ if not tools:
+ system = short_system_prompt
+ else:
+ pass
+ else:
+ if not tools:
+ if "detailed thinking on" in system or "detailed thinking off" in system:
+ pass
+ else:
+ system += "\n" + short_system_prompt
+ else:
+ system += "\n"
+
+ output = {
+ "_prompt": prompt,
+ "_response": response,
+ "_system": system,
+ "_tools": tools,
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
+ }
+ return output
+
+
+DATASET_CONVERTERS = {
+ "alpaca": AlpacaDatasetConverter,
+ "sharegpt": SharegptDatasetConverter,
+ "openai": OpenAIDatasetConverter,
+}
+
+
+def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
+ r"""Register a new dataset converter."""
+ if name in DATASET_CONVERTERS:
+ raise ValueError(f"Dataset converter {name} already exists.")
+
+ DATASET_CONVERTERS[name] = dataset_converter
+
+
+def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
+ r"""Get a dataset converter."""
+ if name not in DATASET_CONVERTERS:
+ raise ValueError(f"Dataset converter {name} not found.")
+
+ return DATASET_CONVERTERS[name](dataset_attr, data_args)
+
+
+def align_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ dataset_attr: "DatasetAttr",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+) -> Union["Dataset", "IterableDataset"]:
+ r"""Align the dataset to a specific format.
+
+ Aligned dataset:
+ _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
+ _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
+ _system: "..."
+ _tools: "..."
+ _images: []
+ _videos: []
+ _audios: []
+ """
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
+ desc="Converting format of dataset",
+ )
+
+ dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
+ return dataset.map(
+ dataset_converter,
+ batched=False,
+ remove_columns=column_names,
+ **kwargs,
+ )
diff --git a/llamafactory/data/data_utils.py b/llamafactory/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e261290a032466c87772915eb9f472e08d14fa
--- /dev/null
+++ b/llamafactory/data/data_utils.py
@@ -0,0 +1,190 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from enum import Enum, unique
+from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
+
+import fsspec
+from datasets import DatasetDict, concatenate_datasets, interleave_datasets
+
+from ..extras import logging
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+
+ from ..hparams import DataArguments
+
+
+logger = logging.get_logger(__name__)
+
+
+SLOTS = list[Union[str, set[str], dict[str, str]]]
+
+
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+ OBSERVATION = "observation"
+
+
+class DatasetModule(TypedDict):
+ train_dataset: Optional[Union["Dataset", "IterableDataset"]]
+ eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
+
+
+def merge_dataset(
+ all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
+) -> Union["Dataset", "IterableDataset"]:
+ r"""Merge multiple datasets to a unified dataset."""
+ if len(all_datasets) == 1:
+ return all_datasets[0]
+
+ elif data_args.mix_strategy == "concat":
+ if data_args.streaming:
+ logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
+
+ return concatenate_datasets(all_datasets)
+
+ elif data_args.mix_strategy.startswith("interleave"):
+ if not data_args.streaming:
+ logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
+
+ return interleave_datasets(
+ datasets=all_datasets,
+ probabilities=data_args.interleave_probs,
+ seed=seed,
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
+ )
+
+ else:
+ raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
+
+
+def split_dataset(
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
+ eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
+ data_args: "DataArguments",
+ seed: int,
+) -> "DatasetDict":
+ r"""Split the dataset and returns a dataset dict containing train set and validation set.
+
+ Support both map dataset and iterable dataset.
+ """
+ if eval_dataset is not None and data_args.val_size > 1e-6:
+ raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
+
+ dataset_dict = {}
+ if dataset is not None:
+ if data_args.streaming:
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
+
+ if data_args.val_size > 1e-6:
+ if data_args.streaming:
+ dataset_dict["validation"] = dataset.take(int(data_args.val_size))
+ dataset_dict["train"] = dataset.skip(int(data_args.val_size))
+ else:
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
+ dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
+ dataset = dataset.train_test_split(test_size=val_size, seed=seed)
+ dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
+ else:
+ dataset_dict["train"] = dataset
+
+ if eval_dataset is not None:
+ if isinstance(eval_dataset, dict):
+ dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
+ else:
+ if data_args.streaming:
+ eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
+
+ dataset_dict["validation"] = eval_dataset
+
+ return DatasetDict(dataset_dict)
+
+
+def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
+ r"""Convert dataset or dataset dict to dataset module."""
+ dataset_module: DatasetModule = {}
+ if isinstance(dataset, DatasetDict): # dataset dict
+ if "train" in dataset:
+ dataset_module["train_dataset"] = dataset["train"]
+
+ if "validation" in dataset:
+ dataset_module["eval_dataset"] = dataset["validation"]
+ else:
+ eval_dataset = {}
+ for key in dataset.keys():
+ if key.startswith("validation_"):
+ eval_dataset[key[len("validation_") :]] = dataset[key]
+
+ if len(eval_dataset):
+ dataset_module["eval_dataset"] = eval_dataset
+
+ else: # single dataset
+ dataset_module["train_dataset"] = dataset
+
+ return dataset_module
+
+
+def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
+ r"""Set up a filesystem object based on the path protocol."""
+ storage_options = {"anon": anon} if anon else {}
+ if path.startswith("s3://"):
+ fs = fsspec.filesystem("s3", **storage_options)
+ elif path.startswith(("gs://", "gcs://")):
+ fs = fsspec.filesystem("gcs", **storage_options)
+ else:
+ raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
+
+ if not fs.exists(path):
+ raise ValueError(f"Path does not exist: {path}.")
+
+ return fs
+
+
+def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
+ r"""Helper function to read JSON/JSONL files using fsspec."""
+ with fs.open(path, "r") as f:
+ if path.endswith(".jsonl"):
+ return [json.loads(line) for line in f if line.strip()]
+ else:
+ return json.load(f)
+
+
+def read_cloud_json(cloud_path: str) -> list[Any]:
+ r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
+
+ Args:
+ cloud_path: str
+ Cloud path in the format:
+ - 's3://bucket-name/file.json' for AWS S3
+ - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
+ """
+ try:
+ fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
+ except Exception:
+ fs = setup_fs(cloud_path) # try again with credentials
+
+ # filter out non-JSON files
+ files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
+ files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
+ if not files:
+ raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
+
+ return sum([_read_json_with_fs(fs, file) for file in files], [])
diff --git a/llamafactory/data/formatter.py b/llamafactory/data/formatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b527a7ef0093ef1f7d462639780d5330023c4e9
--- /dev/null
+++ b/llamafactory/data/formatter.py
@@ -0,0 +1,145 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Optional, Union
+
+from typing_extensions import override
+
+from .data_utils import SLOTS
+from .tool_utils import FunctionCall, get_tool_utils
+
+
+@dataclass
+class Formatter(ABC):
+ slots: SLOTS = field(default_factory=list)
+ tool_format: Optional[str] = None
+
+ @abstractmethod
+ def apply(self, **kwargs) -> SLOTS:
+ r"""Forms a list of slots according to the inputs to encode."""
+ ...
+
+ def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
+ r"""Extract a list of tuples from the response message if using tools.
+
+ Each tuple consists of function name and function arguments.
+ """
+ raise NotImplementedError
+
+
+@dataclass
+class EmptyFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if has_placeholder:
+ raise ValueError("Empty formatter should not contain any placeholder.")
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ return self.slots
+
+
+@dataclass
+class StringFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if not has_placeholder:
+ raise ValueError("A placeholder is required in the string formatter.")
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ elements = []
+ for slot in self.slots:
+ if isinstance(slot, str):
+ for name, value in kwargs.items():
+ if not isinstance(value, str):
+ raise RuntimeError(f"Expected a string, got {value}")
+
+ slot = slot.replace("{{" + name + "}}", value, 1)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
+
+ return elements
+
+
+@dataclass
+class FunctionFormatter(StringFormatter):
+ def __post_init__(self):
+ super().__post_init__()
+ self.tool_utils = get_tool_utils(self.tool_format)
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ content: str = kwargs.pop("content")
+ thought_words, thought = kwargs.pop("thought_words", None), None
+ if thought_words and len(thought_words) == 2:
+ regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
+ thought = re.search(regex, content)
+
+ if thought:
+ content = content.replace(thought.group(0), "")
+
+ functions: list[FunctionCall] = []
+ try:
+ tool_calls = json.loads(content)
+ if not isinstance(tool_calls, list): # parallel function call
+ tool_calls = [tool_calls]
+
+ for tool_call in tool_calls:
+ functions.append(
+ FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
+ )
+
+ except json.JSONDecodeError:
+ raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
+
+ function_str = self.tool_utils.function_formatter(functions)
+ if thought:
+ function_str = thought.group(0) + function_str
+
+ return super().apply(content=function_str)
+
+
+@dataclass
+class ToolFormatter(Formatter):
+ def __post_init__(self):
+ self.tool_utils = get_tool_utils(self.tool_format)
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ content = kwargs.pop("content")
+ try:
+ tools = json.loads(content)
+ return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
+ except json.JSONDecodeError:
+ raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
+
+ @override
+ def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
+ return self.tool_utils.tool_extractor(content)
diff --git a/llamafactory/data/loader.py b/llamafactory/data/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbb13455beb1172930afd26154db72bbe387df16
--- /dev/null
+++ b/llamafactory/data/loader.py
@@ -0,0 +1,334 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import TYPE_CHECKING, Literal, Optional, Union
+
+import numpy as np
+from datasets import Dataset, load_dataset, load_from_disk
+
+from ..extras import logging
+from ..extras.constants import FILEEXT2TYPE
+from ..extras.misc import check_version, has_tokenized_data
+from .converter import align_dataset
+from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
+from .parser import get_dataset_list
+from .processor import (
+ FeedbackDatasetProcessor,
+ PackedSupervisedDatasetProcessor,
+ PairwiseDatasetProcessor,
+ PretrainDatasetProcessor,
+ SupervisedDatasetProcessor,
+ UnsupervisedDatasetProcessor,
+)
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
+
+ from ..hparams import DataArguments, ModelArguments
+ from .data_utils import DatasetModule
+ from .parser import DatasetAttr
+ from .processor import DatasetProcessor
+ from .template import Template
+
+
+logger = logging.get_logger(__name__)
+
+
+def _load_single_dataset(
+ dataset_attr: "DatasetAttr",
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+) -> Union["Dataset", "IterableDataset"]:
+ r"""Load a single dataset and aligns it to the standard format."""
+ logger.info_rank0(f"Loading dataset {dataset_attr}...")
+ data_path, data_name, data_dir, data_files = None, None, None, None
+ if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
+ data_path = dataset_attr.dataset_name
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
+
+ elif dataset_attr.load_from == "script":
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
+
+ elif dataset_attr.load_from == "cloud_file":
+ data_path = dataset_attr.dataset_name
+
+ elif dataset_attr.load_from == "file":
+ data_files = []
+ local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ if os.path.isdir(local_path): # is directory
+ for file_name in os.listdir(local_path):
+ data_files.append(os.path.join(local_path, file_name))
+ elif os.path.isfile(local_path): # is file
+ data_files.append(local_path)
+ else:
+ raise ValueError(f"File {local_path} not found.")
+
+ data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
+ if data_path is None:
+ raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
+
+ if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
+ raise ValueError("File types should be identical.")
+ else:
+ raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
+
+ if dataset_attr.load_from == "ms_hub":
+ check_version("modelscope>=1.14.0", mandatory=True)
+ from modelscope import MsDataset # type: ignore
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
+
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
+ dataset = MsDataset.load(
+ dataset_name=data_path,
+ subset_name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=dataset_attr.split,
+ cache_dir=cache_dir,
+ token=model_args.ms_hub_token,
+ use_streaming=data_args.streaming,
+ )
+ if isinstance(dataset, MsDataset):
+ dataset = dataset.to_hf_dataset()
+
+ elif dataset_attr.load_from == "om_hub":
+ check_version("openmind>=0.8.0", mandatory=True)
+ from openmind import OmDataset # type: ignore
+ from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
+
+ cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
+ dataset = OmDataset.load_dataset(
+ path=data_path,
+ name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=dataset_attr.split,
+ cache_dir=cache_dir,
+ token=model_args.om_hub_token,
+ streaming=data_args.streaming,
+ )
+ elif dataset_attr.load_from == "cloud_file":
+ dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
+ else:
+ dataset = load_dataset(
+ path=data_path,
+ name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=dataset_attr.split,
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ num_proc=data_args.preprocessing_num_workers,
+ streaming=data_args.streaming and dataset_attr.load_from != "file",
+ )
+ if data_args.streaming and dataset_attr.load_from == "file":
+ dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
+
+ if dataset_attr.num_samples is not None and not data_args.streaming:
+ target_num = dataset_attr.num_samples
+ indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
+ target_num -= len(indexes)
+ if target_num > 0:
+ expand_indexes = np.random.choice(len(dataset), target_num)
+ indexes = np.concatenate((indexes, expand_indexes), axis=0)
+
+ assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
+ dataset = dataset.select(indexes)
+ logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
+
+ if data_args.max_samples is not None: # truncate dataset
+ max_samples = min(data_args.max_samples, len(dataset))
+ dataset = dataset.select(range(max_samples))
+
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
+
+
+def _get_merged_dataset(
+ dataset_names: Optional[list[str]],
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ return_dict: bool = False,
+) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
+ r"""Return the merged datasets in the standard format."""
+ if dataset_names is None:
+ return None
+
+ datasets = {}
+ for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
+ if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
+ raise ValueError("The dataset is not applicable in the current training stage.")
+
+ datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
+
+ if return_dict:
+ return datasets
+ else:
+ return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
+
+
+def _get_dataset_processor(
+ data_args: "DataArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ do_generate: bool = False,
+) -> "DatasetProcessor":
+ r"""Return the corresponding dataset processor."""
+ if stage == "pt":
+ dataset_processor_class = PretrainDatasetProcessor
+ elif stage == "sft" and not do_generate:
+ if data_args.packing:
+ if data_args.neat_packing: # hack datasets to have int32 attention mask
+ from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
+
+ def __init__(self, data, **kwargs):
+ return TypedSequence.__init__(
+ self,
+ data,
+ type=kwargs.pop("type", None),
+ try_type=kwargs.pop("try_type", None),
+ optimized_int_type=kwargs.pop("optimized_int_type", None),
+ )
+
+ OptimizedTypedSequence.__init__ = __init__
+ dataset_processor_class = PackedSupervisedDatasetProcessor
+ else:
+ dataset_processor_class = SupervisedDatasetProcessor
+
+ elif stage == "rm":
+ dataset_processor_class = PairwiseDatasetProcessor
+ elif stage == "kto":
+ dataset_processor_class = FeedbackDatasetProcessor
+ else:
+ dataset_processor_class = UnsupervisedDatasetProcessor
+
+ return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
+
+
+def _get_preprocessed_dataset(
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"] = None,
+ is_eval: bool = False,
+) -> Optional[Union["Dataset", "IterableDataset"]]:
+ r"""Preprocesses the dataset, including format checking and tokenization."""
+ if dataset is None:
+ return None
+
+ dataset_processor = _get_dataset_processor(
+ data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
+ )
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
+ desc="Running tokenizer on dataset",
+ )
+
+ dataset = dataset.map(
+ dataset_processor.preprocess_dataset,
+ batched=True,
+ batch_size=data_args.preprocessing_batch_size,
+ remove_columns=column_names,
+ **kwargs,
+ )
+
+ if training_args.should_log:
+ try:
+ print("eval example:" if is_eval else "training example:")
+ dataset_processor.print_data_example(next(iter(dataset)))
+ except StopIteration:
+ if stage == "pt":
+ raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
+ else:
+ raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
+
+ return dataset
+
+
+def get_dataset(
+ template: "Template",
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"] = None,
+) -> "DatasetModule":
+ r"""Get the train dataset and optionally gets the evaluation dataset."""
+ # Load tokenized dataset if path exists
+ if data_args.tokenized_path is not None:
+ if has_tokenized_data(data_args.tokenized_path):
+ logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
+ tokenized_data = load_from_disk(data_args.tokenized_path)
+ dataset_module = get_dataset_module(tokenized_data)
+ if data_args.streaming:
+ dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset()
+
+ logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
+ return dataset_module
+
+ if data_args.streaming:
+ raise ValueError("Turn off `streaming` when saving dataset to disk.")
+
+ # Load and preprocess dataset
+ with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
+ dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
+ eval_dataset = _get_merged_dataset(
+ data_args.eval_dataset,
+ model_args,
+ data_args,
+ training_args,
+ stage,
+ return_dict=data_args.eval_on_each_dataset,
+ )
+
+ with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
+ dataset = _get_preprocessed_dataset(
+ dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
+ )
+ if isinstance(eval_dataset, dict):
+ for eval_name, eval_data in eval_dataset.items():
+ eval_dataset[eval_name] = _get_preprocessed_dataset(
+ eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
+ )
+ else:
+ eval_dataset = _get_preprocessed_dataset(
+ eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
+ )
+
+ dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
+ if data_args.tokenized_path is not None: # save tokenized dataset to disk
+ if training_args.should_save:
+ dataset_dict.save_to_disk(data_args.tokenized_path)
+ logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
+ logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")
+
+ return get_dataset_module(dataset_dict)
diff --git a/llamafactory/data/mm_plugin.py b/llamafactory/data/mm_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b801dc1105c31fe9fe8a76807f4126198556a2
--- /dev/null
+++ b/llamafactory/data/mm_plugin.py
@@ -0,0 +1,2082 @@
+# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+import os
+import re
+from copy import deepcopy
+from dataclasses import dataclass
+from io import BytesIO
+from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
+from transformers.models.mllama.processing_mllama import (
+ convert_sparse_cross_attention_mask_to_dense,
+ get_cross_attention_token_mask,
+)
+from typing_extensions import NotRequired, override
+
+from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
+from ..extras.packages import (
+ is_librosa_available,
+ is_pillow_available,
+ is_pyav_available,
+ is_transformers_version_greater_than,
+)
+
+
+if is_librosa_available():
+ import librosa
+
+
+if is_pillow_available():
+ from PIL import Image
+ from PIL.Image import Image as ImageObject
+
+
+if is_pyav_available():
+ import av
+
+
+if is_transformers_version_greater_than("4.52.0"):
+ from transformers.image_utils import make_flat_list_of_images
+ from transformers.video_utils import make_batched_videos
+else:
+ from transformers.image_utils import make_batched_videos, make_flat_list_of_images
+
+
+if TYPE_CHECKING:
+ from av.stream import Stream
+ from numpy.typing import NDArray
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
+ from transformers.image_processing_utils import BaseImageProcessor
+ from transformers.video_processing_utils import BaseVideoProcessor
+
+ class EncodedImage(TypedDict):
+ path: Optional[str]
+ bytes: Optional[bytes]
+
+ ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
+ VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
+ AudioInput = Union[str, BinaryIO, NDArray]
+
+ class RegularizedImageOutput(TypedDict):
+ images: list[ImageObject]
+
+ class RegularizedVideoOutput(TypedDict):
+ videos: list[list[ImageObject]]
+ durations: list[float]
+ fps_per_video: NotRequired[list[float]]
+
+ class RegularizedAudioOutput(TypedDict):
+ audios: list[NDArray]
+ sampling_rates: list[float]
+
+ class MMProcessor(ProcessorMixin):
+ patch_size: int
+ image_seq_length: int
+ num_additional_image_tokens: int
+ vision_feature_select_strategy: Literal["default", "full"]
+
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
+ pass
+
+
+def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
+ r"""Get paligemma token type ids for computing loss.
+
+ It is slightly different with the original token type ids where the prompt part is 0.
+
+ Returns:
+ batch_token_type_ids: shape (batch_size, seq_length)
+
+ """
+ batch_token_type_ids = []
+ for imglen, seqlen in zip(imglens, seqlens):
+ image_seqlen = imglen * processor.image_seq_length
+ batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
+
+ return batch_token_type_ids
+
+
+def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
+ r"""Get gemma3 token type ids for computing loss.
+
+ Returns:
+ batch_token_type_ids: shape (batch_size, seq_length)
+
+ """
+ image_token_id: int = getattr(processor, "image_token_id")
+ batch_token_type_ids = []
+ for token_ids in batch_ids:
+ token_ids = np.array(token_ids)
+ token_type_ids = np.zeros_like(token_ids)
+ token_type_ids[token_ids == image_token_id] = 1
+ batch_token_type_ids.append(token_type_ids.tolist())
+
+ return batch_token_type_ids
+
+
+def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
+ r"""Make nested list of images."""
+ batch_images = []
+ for imglen in imglens:
+ batch_images.append(images[:imglen])
+ images = images[imglen:]
+
+ return batch_images
+
+
+def _check_video_is_nested_images(video: "VideoInput") -> bool:
+ r"""Check if the video is nested images."""
+ return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
+
+
+@dataclass
+class MMPluginMixin:
+ image_token: Optional[str]
+ video_token: Optional[str]
+ audio_token: Optional[str]
+ expand_mm_tokens: bool = True
+
+ def _validate_input(
+ self,
+ processor: Optional["MMProcessor"],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ) -> None:
+ r"""Validate if this model accepts the input modalities."""
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ video_processor: BaseImageProcessor = getattr(
+ processor, "video_processor", getattr(processor, "image_processor", None)
+ )
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
+ if len(images) != 0 and self.image_token is None:
+ raise ValueError(
+ "This model does not support image input. Please check whether the correct `template` is used."
+ )
+
+ if len(videos) != 0 and self.video_token is None:
+ raise ValueError(
+ "This model does not support video input. Please check whether the correct `template` is used."
+ )
+
+ if len(audios) != 0 and self.audio_token is None:
+ raise ValueError(
+ "This model does not support audio input. Please check whether the correct `template` is used."
+ )
+
+ if self.image_token is not None and processor is None:
+ raise ValueError("Processor was not found, please check and update your model file.")
+
+ if self.image_token is not None and image_processor is None:
+ raise ValueError("Image processor was not found, please check and update your model file.")
+
+ if self.video_token is not None and video_processor is None:
+ raise ValueError("Video processor was not found, please check and update your model file.")
+
+ if self.audio_token is not None and feature_extractor is None:
+ raise ValueError("Audio feature extractor was not found, please check and update your model file.")
+
+ def _validate_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ):
+ r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
+ for message in messages:
+ num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
+ num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
+ num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
+
+ if len(images) != num_image_tokens:
+ raise ValueError(
+ f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
+ )
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(
+ f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
+ )
+
+ if len(audios) != num_audio_tokens:
+ raise ValueError(
+ f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
+ )
+
+ def _preprocess_image(
+ self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
+ ) -> "ImageObject":
+ r"""Pre-process a single image."""
+ if (image.width * image.height) > image_max_pixels:
+ resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
+ image = image.resize((width, height))
+
+ if (image.width * image.height) < image_min_pixels:
+ resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
+ image = image.resize((width, height))
+
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
+ return image
+
+ def _get_video_sample_indices(
+ self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
+ ) -> list[int]:
+ r"""Compute video sample indices according to fps."""
+ total_frames = video_stream.frames
+ if total_frames == 0: # infinite video
+ return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
+
+ sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps))
+ sample_frames = min(total_frames, video_maxlen, sample_frames)
+ return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
+
+ def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
+ r"""Regularize images to avoid error. Including reading and pre-processing."""
+ results = []
+ for image in images:
+ if isinstance(image, (str, BinaryIO)):
+ image = Image.open(image)
+ elif isinstance(image, bytes):
+ image = Image.open(BytesIO(image))
+ elif isinstance(image, dict):
+ if image["bytes"] is not None:
+ image = Image.open(BytesIO(image["bytes"]))
+ else:
+ image = Image.open(image["path"])
+
+ if not isinstance(image, ImageObject):
+ raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
+
+ results.append(self._preprocess_image(image, **kwargs))
+
+ return {"images": results}
+
+ def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
+ r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
+ results = []
+ durations = []
+ for video in videos:
+ frames: list[ImageObject] = []
+ if _check_video_is_nested_images(video):
+ for frame in video:
+ if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
+ raise ValueError("Invalid image found in video frames.")
+ frames = video
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
+ else:
+ container = av.open(video, "r")
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
+ sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
+ container.seek(0)
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
+ if frame_idx in sample_indices:
+ frames.append(frame.to_image())
+
+ if video_stream.duration is None:
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
+ else:
+ durations.append(float(video_stream.duration * video_stream.time_base))
+
+ frames = self._regularize_images(frames, **kwargs)["images"]
+ results.append(frames)
+
+ return {"videos": results, "durations": durations}
+
+ def _regularize_audios(
+ self, audios: list["AudioInput"], sampling_rate: float, **kwargs
+ ) -> "RegularizedAudioOutput":
+ r"""Regularizes audios to avoid error. Including reading and resampling."""
+ results, sampling_rates = [], []
+ for audio in audios:
+ if not isinstance(audio, np.ndarray):
+ audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
+
+ results.append(audio)
+ sampling_rates.append(sampling_rate)
+
+ return {"audios": results, "sampling_rates": sampling_rates}
+
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ imglens: Optional[list[int]] = None,
+ ) -> dict[str, "torch.Tensor"]:
+ r"""Process visual inputs.
+
+ Returns: (llava and paligemma)
+ pixel_values: tensor with shape (B, C, H, W)
+
+ Returns: (qwen2-vl)
+ pixel_values: tensor with shape (num_patches, patch_dim)
+ image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
+ where num_patches == torch.prod(image_grid_thw)
+
+ Returns: (mllama)
+ pixel_values: tensor with shape
+ (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
+ For example, (2, 1, 4, 3, 560, 560).
+ aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
+ aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
+ num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
+
+ """
+ mm_inputs = {}
+ if len(images) != 0:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ if imglens is not None: # if imglens are provided, make batched images
+ images = _make_batched_images(images, imglens)
+
+ image_processor_kwargs = {}
+ if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
+ image_processor_kwargs.update(
+ {
+ "do_pan_and_scan": True,
+ "pan_and_scan_min_crop_size": 256,
+ "pan_and_scan_max_num_crops": 4,
+ "pan_and_scan_min_ratio_to_activate": 1.2,
+ }
+ )
+
+ mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
+
+ if len(videos) != 0:
+ video_processor: BaseImageProcessor = getattr(
+ processor, "video_processor", getattr(processor, "image_processor", None)
+ )
+ videos = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )["videos"]
+ if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
+ mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
+ else: # for llava_next_video
+ mm_inputs.update(video_processor(videos, return_tensors="pt"))
+
+ if len(audios) != 0:
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
+ audios = self._regularize_audios(
+ audios,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ )["audios"]
+ mm_inputs.update(
+ feature_extractor(
+ audios,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ return_attention_mask=True,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ )
+ mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
+
+ return mm_inputs
+
+
+@dataclass
+class BasePlugin(MMPluginMixin):
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ r"""Pre-process input messages before tokenization for VLMs."""
+ self._validate_input(processor, images, videos, audios)
+ return messages
+
+ def process_token_ids(
+ self,
+ input_ids: list[int],
+ labels: Optional[list[int]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["MMProcessor"],
+ ) -> tuple[list[int], Optional[list[int]]]:
+ r"""Pre-process token ids after tokenization for VLMs."""
+ self._validate_input(processor, images, videos, audios)
+ return input_ids, labels
+
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ r"""Build batched multimodal inputs for VLMs.
+
+ Arguments:
+ images: a list of image inputs, shape (num_images,)
+ videos: a list of video inputs, shape (num_videos,)
+ audios: a list of audio inputs, shape (num_audios,)
+ imglens: number of images in each sample, shape (batch_size,)
+ vidlens: number of videos in each sample, shape (batch_size,)
+ audlens: number of audios in each sample, shape (batch_size,)
+ batch_ids: token ids of input samples, shape (batch_size, seq_len)
+ processor: a processor for pre-processing images and videos
+
+ """
+ self._validate_input(processor, images, videos, audios)
+ return self._get_mm_inputs(images, videos, audios, processor)
+
+
+@dataclass
+class Gemma3Plugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ boi_token: str = getattr(processor, "boi_token")
+ full_image_sequence: str = getattr(processor, "full_image_sequence")
+ image_str = full_image_sequence if self.expand_mm_tokens else boi_token
+
+ do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
+ if do_pan_and_scan:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if do_pan_and_scan:
+ image_placeholder_str = (
+ "Here is the original image {{image}} and here are some crops to help you see better "
+ + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
+ )
+ else:
+ image_placeholder_str = "{{image}}"
+
+ content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
+ num_image_tokens += 1
+
+ message["content"] = content.replace("{{image}}", image_str)
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("num_crops", None)
+ mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
+ return mm_inputs
+
+
+class Gemma3nPlugin(Gemma3Plugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ messages = deepcopy(messages)
+ boi_token: str = getattr(processor, "boi_token")
+ boa_token: str = getattr(processor, "boa_token")
+ full_image_sequence: str = getattr(processor, "full_image_sequence")
+ full_audio_sequence: str = getattr(processor, "full_audio_sequence")
+ image_str = full_image_sequence if self.expand_mm_tokens else boi_token
+ audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
+
+ while AUDIO_PLACEHOLDER in content:
+ content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
+
+ message["content"] = content
+
+ return messages
+
+
+@dataclass
+class InternVLPlugin(BasePlugin):
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "ProcessorMixin",
+ **kwargs,
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ image_processor_kwargs = {}
+ if getattr(processor, "crop_to_patches", False):
+ image_processor_kwargs.update(
+ {
+ "crop_to_patches": True,
+ "max_patches": 12,
+ "min_patches": 1,
+ }
+ )
+
+ mm_inputs = {}
+ image_video_patches = []
+
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )["videos"]
+
+ if len(images) != 0:
+ images = make_flat_list_of_images(images)
+ image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
+ image_num_patches = image_inputs.pop("num_patches")
+ image_pixel_values = image_inputs.pop("pixel_values")
+ image_num_patches_indices = np.cumsum(image_num_patches)
+
+ if len(videos) != 0:
+ videos = make_batched_videos(videos)
+ num_frames_per_video = [len(video) for video in videos]
+ patch_indices = np.cumsum(num_frames_per_video)
+ image_processor_kwargs["crop_to_patches"] = False
+ video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
+ video_num_patches = video_inputs.pop("num_patches")
+ video_pixel_values = video_inputs.pop("pixel_values")
+ video_num_patches_indices = np.cumsum(video_num_patches)
+
+ # NOT SUPPORT IMAGE VIDEO INTERLEAVED
+ if len(images) != 0 and image_pixel_values is not None:
+ for i in range(len(images)):
+ start_index = image_num_patches_indices[i - 1] if i > 0 else 0
+ end_index = image_num_patches_indices[i]
+ image_video_patches.append(image_pixel_values[start_index:end_index])
+
+ if len(videos) != 0 and video_pixel_values is not None:
+ patch_indices_with_prefix = [0] + list(patch_indices)
+ for i in range(len(videos)):
+ current_patch_index = patch_indices_with_prefix[i]
+ end_patch_index = patch_indices_with_prefix[i + 1]
+ start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
+ end_index = video_num_patches_indices[end_patch_index - 1]
+ image_video_patches.append(video_pixel_values[start_index:end_index])
+
+ if len(images) != 0 or len(videos) != 0:
+ mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
+
+ if len(images) != 0:
+ mm_inputs.update({"image_num_patches": image_num_patches})
+
+ if len(videos) != 0:
+ mm_inputs.update({"video_patch_indices": patch_indices})
+ mm_inputs.update({"video_num_patches": video_num_patches})
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
+ messages = deepcopy(messages)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+
+ image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
+ video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
+ video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}",
+ 1,
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
+ end_patch_index = video_patch_indices[num_video_tokens]
+ num_patches = list(video_num_patches[current_patch_index:end_patch_index])
+ video_replaced_prompt = "\n".join(
+ f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}"
+ for i in range(len(num_patches))
+ )
+ content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("image_num_patches", None)
+ mm_inputs.pop("video_patch_indices", None)
+ mm_inputs.pop("video_num_patches", None)
+ return mm_inputs
+
+
+class KimiVLPlugin(BasePlugin):
+ @override
+ def process_messages(self, messages, images, videos, audios, processor):
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_hws = mm_inputs.get("image_grid_hws", [])
+ else:
+ image_grid_hws = [None] * len(images)
+
+ num_image_tokens = 0
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ merge_length = math.prod(image_processor.merge_kernel_size)
+ messages = deepcopy(messages)
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
+ 1,
+ )
+ num_image_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+
+@dataclass
+class Llama4Plugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values" in mm_inputs:
+ image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
+ num_patches_per_chunk = int(
+ (image_height // processor.patch_size)
+ * (image_width // processor.patch_size)
+ // processor.downsample_ratio
+ )
+ aspect_ratios = mm_inputs.pop("aspect_ratios")
+
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ for message in messages:
+ content = message["content"]
+ if self.expand_mm_tokens:
+ placeholder_count = content.count(IMAGE_PLACEHOLDER)
+ prompt_splits = content.split(IMAGE_PLACEHOLDER)
+ new_content = []
+ for local_image_index, split_part in enumerate(prompt_splits):
+ new_content.append(split_part)
+ if local_image_index < placeholder_count:
+ tokens_for_this_image = processor._prompt_split_image(
+ aspect_ratios[num_image_tokens], num_patches_per_chunk
+ )
+ num_image_tokens += 1
+ new_content.append(tokens_for_this_image)
+
+ content = "".join(new_content)
+ else:
+ content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("aspect_ratios", None)
+ return mm_inputs
+
+
+@dataclass
+class LlavaPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ messages = deepcopy(messages)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values" in mm_inputs:
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
+ image_seqlen = (height // processor.patch_size) * (
+ width // processor.patch_size
+ ) + processor.num_additional_image_tokens
+ if processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ else:
+ image_seqlen = 1
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
+
+ message["content"] = content.replace("{{image}}", self.image_token)
+
+ return messages
+
+
+@dataclass
+class LlavaNextPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values" in mm_inputs:
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if self.expand_mm_tokens:
+ orig_height, orig_width = next(image_sizes)
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
+ if processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ else:
+ image_seqlen = 1
+
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
+ num_image_tokens += 1
+
+ message["content"] = content.replace("{{image}}", self.image_token)
+
+ return messages
+
+
+@dataclass
+class LlavaNextVideoPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ messages = deepcopy(messages)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values" in mm_inputs:
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if self.expand_mm_tokens:
+ orig_height, orig_width = next(image_sizes)
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
+ if processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ else:
+ image_seqlen = 1
+
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
+
+ message["content"] = content.replace("{{image}}", self.image_token)
+
+ if self.expand_mm_tokens:
+ if "pixel_values_videos" in mm_inputs:
+ one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
+ height, width = get_image_size(one_video[0])
+ num_frames = one_video.shape[0] # frame dim is always after batch dim
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
+ video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
+ else:
+ video_seqlen = 1
+
+ for message in messages:
+ content = message["content"]
+ while VIDEO_PLACEHOLDER in content:
+ content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
+
+ message["content"] = content.replace("{{video}}", self.video_token)
+
+ return messages
+
+
+@dataclass
+class MiniCPMVPlugin(BasePlugin):
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ **kwargs,
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ if "valid_image_nums_ls" in kwargs:
+ valid_image_nums_ls = kwargs["valid_image_nums_ls"]
+ new_images = []
+ idx = 0
+ for valid_image_nums in valid_image_nums_ls:
+ new_images.append(images[idx : idx + valid_image_nums])
+ idx += valid_image_nums
+
+ images = new_images
+
+ image_inputs = image_processor(
+ images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
+ )
+ mm_inputs.update(image_inputs)
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )["videos"]
+ video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
+ mm_inputs.update(video_inputs)
+
+ if len(audios) != 0:
+ audios = self._regularize_audios(
+ audios,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ )["audios"]
+ if "valid_audio_nums_ls" in kwargs:
+ valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
+ audios_ls = []
+ idx = 0
+ for valid_audio_nums in valid_audio_nums_ls:
+ audios_ls.append(audios[idx : idx + valid_audio_nums])
+ idx += valid_audio_nums
+ else:
+ audios_ls = [audios]
+
+ audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
+ audios_ls,
+ chunk_input=True,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ )
+ audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
+ mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
+ if kwargs.get("ret_phs", False):
+ mm_inputs.update({"audio_phs": audio_phs})
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ mm_inputs, audio_inputs = {}, {}
+ if len(images) != 0 and len(videos) != 0:
+ raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
+
+ if len(videos) != 0:
+ max_slice_nums = 2
+ use_image_id = False
+ mm_inputs = self._get_mm_inputs([], videos, [], processor)
+ else:
+ max_slice_nums = image_processor.max_slice_nums
+ use_image_id = image_processor.use_image_id
+
+ for i, message in enumerate(messages):
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
+ content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
+ num_video_tokens += 1
+
+ while AUDIO_PLACEHOLDER in content:
+ content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
+ num_audio_tokens += 1
+
+ message["content"] = content.replace("{{image}}", "(./)").replace(
+ "{{audio}}", "()"
+ )
+
+ if len(images):
+ mm_inputs = self._get_mm_inputs(images, [], [], processor)
+
+ if len(audios):
+ audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
+
+ if self.expand_mm_tokens and mm_inputs:
+ pattern = "(./)"
+ image_sizes = mm_inputs["image_sizes"]
+ idx = 0
+ for index, message in enumerate(messages):
+ text = message["content"]
+ image_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[0][idx], idx, max_slice_nums, use_image_id
+ )
+ )
+ idx += 1
+
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ if self.expand_mm_tokens and audio_inputs:
+ pattern = "()"
+ idx = 0
+ for index, message in enumerate(messages):
+ text = message["content"]
+ audio_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(audio_tags)):
+ audio_placeholder = audio_inputs["audio_phs"][0][idx]
+ final_text = final_text + text_chunks[i] + audio_placeholder
+ idx += 1
+
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ # image bound
+ image_bounds_list = []
+ valid_image_nums_ls = []
+ for i, input_ids in enumerate(batch_ids):
+ input_ids_ = torch.tensor(input_ids)
+ start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
+ input_ids_ == processor.tokenizer.slice_start_id
+ )
+ end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
+ image_start_tokens = torch.where(start_cond)[0]
+ image_start_tokens += 1
+ image_end_tokens = torch.where(end_cond)[0]
+ valid_image_nums_ls.append(imglens[i])
+ image_bounds = torch.hstack(
+ [
+ image_start_tokens.unsqueeze(-1),
+ image_end_tokens.unsqueeze(-1),
+ ]
+ )
+ image_bounds_list.append(image_bounds)
+
+ mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
+ if "tgt_sizes" not in mm_inputs:
+ dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
+ mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
+
+ mm_inputs.update({"image_bound": image_bounds_list})
+
+ if len(audios) > 0:
+ # audio bound
+ audio_bounds_ls = []
+ spk_bounds_ls = []
+ valid_audio_nums_ls = []
+
+ for input_ids, audiolen in zip(batch_ids, audlens):
+ input_ids_ = torch.tensor(input_ids)
+ audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
+ audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
+ assert len(audio_start_idx) == len(audio_end_idx)
+ audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
+ audio_bounds_ls.append(audio_bounds)
+ valid_audio_nums_ls.append(audiolen)
+
+ spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
+ spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
+ assert len(spk_start_idx) == len(spk_end_idx)
+ spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
+ spk_bounds_ls.append(spk_bounds)
+
+ audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
+ mm_inputs.update(audio_inputs)
+ mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
+
+ return mm_inputs
+
+
+@dataclass
+class MllamaPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ for message in messages:
+ content = message["content"]
+ num_image_tokens += content.count(IMAGE_PLACEHOLDER)
+ message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
+ if mm_inputs:
+ num_tiles = mm_inputs.pop("num_tiles")
+ image_token_id: int = getattr(processor, "image_token_id")
+ max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
+ cross_attention_token_mask = [
+ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
+ ]
+ mm_inputs["cross_attention_mask"] = torch.from_numpy(
+ convert_sparse_cross_attention_mask_to_dense(
+ cross_attention_token_mask,
+ num_tiles=num_tiles,
+ max_num_tiles=max_image_tiles,
+ length=max(len(input_ids) for input_ids in batch_ids),
+ )
+ ) # shape: (batch_size, length, max_num_images, max_num_tiles)
+
+ return mm_inputs
+
+
+@dataclass
+class PaliGemmaPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, "", 1)
+ num_image_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def process_token_ids(
+ self,
+ input_ids: list[int],
+ labels: Optional[list[int]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["MMProcessor"],
+ ) -> tuple[list[int], Optional[list[int]]]:
+ self._validate_input(processor, images, videos, audios)
+ num_images = len(images)
+ image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
+ image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ input_ids = [image_token_id] * num_images * image_seqlen + input_ids
+ if labels is not None:
+ labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
+
+ return input_ids, labels
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ seqlens = [len(input_ids) for input_ids in batch_ids]
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
+ return mm_inputs
+
+
+@dataclass
+class PixtralPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ messages = deepcopy(messages)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values" in mm_inputs:
+ # BC for transformers < 4.49.0
+ if isinstance(mm_inputs["image_sizes"], list):
+ image_sizes = iter(mm_inputs["image_sizes"][0])
+ else:
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
+
+ image_break_token: str = getattr(processor, "image_break_token")
+ image_end_token: str = getattr(processor, "image_end_token")
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if self.expand_mm_tokens:
+ patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1)
+ height, width = next(image_sizes)
+ num_height_tokens = height // patch_size
+ num_width_tokens = width // patch_size
+ replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
+ replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
+ replace_tokens[-1] = image_end_token
+ replace_str = "".join(replace_tokens)
+ else:
+ replace_str = self.image_token
+
+ content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ # ref to this commit https://github.com/huggingface/transformers/pull/35122
+ # after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
+ # it can be passed into `LlavaConditionalGeneration` as a parameter.
+ if not is_transformers_version_greater_than("4.49.0"):
+ mm_inputs.pop("image_sizes", None)
+ return mm_inputs
+
+
+@dataclass
+class Qwen2AudioPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ bos_token: str = getattr(processor, "audio_bos_token")
+ eos_token: str = getattr(processor, "audio_eos_token")
+ messages = deepcopy(messages)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs([], [], audios, processor)
+ if "feature_attention_mask" in mm_inputs:
+ audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
+
+ for message in messages:
+ content = message["content"]
+ while AUDIO_PLACEHOLDER in content:
+ if self.expand_mm_tokens:
+ audio_length = audio_lengths.pop(0)
+ input_length = (audio_length - 1) // 2 + 1
+ audio_seqlen = (input_length - 2) // 2 + 1
+ else:
+ audio_seqlen = 1
+
+ content = content.replace(
+ AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
+ )
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["MMProcessor"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ return self._get_mm_inputs(images, videos, audios, processor)
+
+
+@dataclass
+class Qwen2VLPlugin(BasePlugin):
+ vision_bos_token: str = "<|vision_start|>"
+ vision_eos_token: str = "<|vision_end|>"
+
+ @override
+ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
+ image = super()._preprocess_image(image, **kwargs)
+ if min(image.width, image.height) < 28:
+ width, height = max(image.width, 28), max(image.height, 28)
+ image = image.resize((width, height))
+
+ if image.width / image.height > 200:
+ width, height = image.height * 180, image.height
+ image = image.resize((width, height))
+
+ if image.height / image.width > 200:
+ width, height = image.width, image.width * 180
+ image = image.resize((width, height))
+
+ return image
+
+ @override
+ def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
+ results, fps_per_video, durations = [], [], []
+ for video in videos:
+ frames: list[ImageObject] = []
+ if _check_video_is_nested_images(video):
+ for frame in video:
+ if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
+ raise ValueError("Invalid image found in video frames.")
+
+ frames = video
+ fps_per_video.append(kwargs.get("video_fps", 2.0))
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
+ else:
+ container = av.open(video, "r")
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
+ sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
+ container.seek(0)
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
+ if frame_idx in sample_indices:
+ frames.append(frame.to_image())
+
+ if video_stream.duration is None:
+ fps_per_video.append(kwargs.get("video_fps", 2.0))
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
+ else:
+ fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
+ durations.append(float(video_stream.duration * video_stream.time_base))
+
+ if len(frames) % 2 != 0:
+ frames.append(frames[-1])
+
+ frames = self._regularize_images(frames, **kwargs)["images"]
+ results.append(frames)
+
+ return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
+
+ if len(videos) != 0:
+ video_data = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )
+ mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
+ if "second_per_grid_ts" in processor.model_input_names:
+ mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+
+ merge_length: int = getattr(image_processor, "merge_size") ** 2
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
+ else:
+ image_grid_thw = [None] * len(images)
+ video_grid_thw = [None] * len(videos)
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
+ 1,
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ VIDEO_PLACEHOLDER,
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
+ 1,
+ )
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+
+@dataclass
+class Qwen3VLPlugin(Qwen2VLPlugin):
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )
+ video_metadata = [
+ {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
+ for video, duration in zip(videos["videos"], videos["durations"])
+ ]
+ mm_inputs.update(
+ video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
+ )
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
+ if "second_per_grid_ts" in processor.model_input_names:
+ mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor")
+
+ image_merge_length: int = getattr(image_processor, "merge_size") ** 2
+ video_merge_length: int = getattr(video_processor, "merge_size") ** 2
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
+ num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
+ video_metadata = mm_inputs.get("video_metadata", {})
+
+ else:
+ image_grid_thw = [None] * len(images)
+ video_grid_thw = [None] * len(videos)
+ num_frames = 0
+ timestamps = [0]
+
+ for idx, message in enumerate(messages):
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = (
+ image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
+ )
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
+ 1,
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ if self.expand_mm_tokens:
+ metadata = video_metadata[idx]
+ timestamps = processor._calculate_timestamps(
+ metadata.frames_indices,
+ metadata.fps,
+ video_processor.merge_size,
+ )
+ video_structure = ""
+ for frame_index in range(num_frames):
+ video_seqlen = (
+ video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
+ if self.expand_mm_tokens
+ else 1
+ )
+ timestamp_sec = timestamps[frame_index]
+ frame_structure = (
+ f"<{timestamp_sec:.1f} seconds>"
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
+ )
+ video_structure += frame_structure
+ else:
+ video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
+
+ content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+
+@dataclass
+class GLM4VPlugin(Qwen2VLPlugin):
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
+
+ if len(videos) != 0:
+ video_data = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )
+ # prepare video metadata
+ video_metadata = [
+ {"fps": 2, "duration": duration, "total_frames": len(video)}
+ for video, duration in zip(video_data["videos"], video_data["durations"])
+ ]
+ mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+
+ merge_length: int = getattr(image_processor, "merge_size") ** 2
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
+ num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
+ timestamps = mm_inputs.get("timestamps", [])
+
+ if hasattr(timestamps, "tolist"):
+ timestamps = timestamps.tolist()
+
+ if not timestamps:
+ timestamps_list = []
+ elif isinstance(timestamps[0], list):
+ timestamps_list = timestamps[0]
+ else:
+ timestamps_list = timestamps
+
+ unique_timestamps = timestamps_list.copy()
+ selected_timestamps = unique_timestamps[:num_frames]
+ while len(selected_timestamps) < num_frames:
+ selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
+
+ else:
+ image_grid_thw = [None] * len(images)
+ video_grid_thw = [None] * len(videos)
+ num_frames = 0
+ selected_timestamps = [0]
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_structure = ""
+ for frame_index in range(num_frames):
+ video_seqlen = (
+ video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
+ )
+ timestamp_sec = selected_timestamps[frame_index]
+ frame_structure = (
+ f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
+ )
+ video_structure += frame_structure
+
+ if not self.expand_mm_tokens:
+ video_structure = self.video_token
+
+ content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("timestamps", None)
+ return mm_inputs
+
+
+@dataclass
+class Qwen2OmniPlugin(Qwen2VLPlugin):
+ audio_bos_token: str = "<|audio_start|>"
+ audio_eos_token: str = "<|audio_end|>"
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "MMProcessor",
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+ video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
+
+ if len(videos) != 0:
+ video_dict = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )
+ mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
+ mm_inputs["video_second_per_grid"] = torch.tensor(
+ [temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
+ )
+
+ if len(audios) != 0:
+ audios = self._regularize_audios(
+ audios,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ )["audios"]
+ mm_inputs.update(
+ feature_extractor(
+ audios,
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
+ return_attention_mask=True,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ )
+ mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
+
+ merge_length = processor.image_processor.merge_size**2
+ use_audio_in_video = getattr(processor, "use_audio_in_video", False)
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
+ if "feature_attention_mask" in mm_inputs:
+ if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
+ input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
+ input_lengths_leave = input_lengths % 100
+ feature_lengths = (input_lengths_leave - 1) // 2 + 1
+ audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
+ else:
+ input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
+ audio_lengths = (input_lengths - 2) // 2 + 1
+ else:
+ mm_inputs = {}
+ image_grid_thw = [None] * len(images)
+ video_grid_thw = [None] * len(videos)
+ audio_lengths = [None] * len(audios)
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
+ 1,
+ )
+ num_image_tokens += 1
+
+ if (
+ use_audio_in_video and len(audios) and len(videos)
+ ): # if use the audio of video # deal video token and audio token togather
+ if len(videos) != len(audios):
+ raise ValueError(
+ f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video."
+ )
+
+ while VIDEO_PLACEHOLDER in content:
+ video_pos = content.find(VIDEO_PLACEHOLDER)
+ audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
+ if audio_pos == -1 or audio_pos < video_pos:
+ raise ValueError(
+ f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
+ )
+
+ audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
+ video_t_index = (
+ torch.arange(video_grid_thw[num_video_tokens][0])
+ .view(-1, 1, 1)
+ .expand(
+ -1,
+ video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
+ video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
+ )
+ .flatten()
+ * mm_inputs["video_second_per_grid"][num_video_tokens]
+ * 25 # FIXME hardcode of position_id_per_seconds=25
+ ).long()
+ t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
+ video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
+ audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
+ placeholder_string = ""
+ placeholder_string += self.vision_bos_token + self.audio_bos_token
+ for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
+ video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
+ audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
+ if video_chunk_index is not None:
+ placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
+
+ if audio_chunk_index is not None:
+ placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
+
+ placeholder_string += self.audio_eos_token + self.vision_eos_token
+ content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
+ content = content.replace(AUDIO_PLACEHOLDER, "", 1)
+ num_audio_tokens += 1
+ num_video_tokens += 1
+ else:
+ while AUDIO_PLACEHOLDER in content:
+ audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
+ content = content.replace(
+ AUDIO_PLACEHOLDER,
+ f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
+ 1,
+ )
+ num_audio_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_seqlen = (
+ video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ )
+ content = content.replace(
+ VIDEO_PLACEHOLDER,
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
+ 1,
+ )
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+
+@dataclass
+class VideoLlavaPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ messages = deepcopy(messages)
+ num_frames = 0
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ if "pixel_values_images" in mm_inputs:
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
+ num_frames = 1
+
+ if "pixel_values_videos" in mm_inputs:
+ one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
+ height, width = get_image_size(one_video[0])
+ num_frames = one_video.shape[0] # frame dim is always after batch dim
+
+ if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
+ image_seqlen = (height // processor.patch_size) * (
+ width // processor.patch_size
+ ) + processor.num_additional_image_tokens
+ video_seqlen = image_seqlen * num_frames
+ if processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ else:
+ image_seqlen, video_seqlen = 1, 1
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
+ num_video_tokens += 1
+
+ content = content.replace("{{image}}", self.image_token)
+ message["content"] = content.replace("{{video}}", self.video_token)
+
+ return messages
+
+
+PLUGINS = {
+ "base": BasePlugin,
+ "gemma3": Gemma3Plugin,
+ "glm4v": GLM4VPlugin,
+ "gemma3n": Gemma3nPlugin,
+ "intern_vl": InternVLPlugin,
+ "kimi_vl": KimiVLPlugin,
+ "llama4": Llama4Plugin,
+ "llava": LlavaPlugin,
+ "llava_next": LlavaNextPlugin,
+ "llava_next_video": LlavaNextVideoPlugin,
+ "minicpm_v": MiniCPMVPlugin,
+ "mllama": MllamaPlugin,
+ "paligemma": PaliGemmaPlugin,
+ "pixtral": PixtralPlugin,
+ "qwen2_audio": Qwen2AudioPlugin,
+ "qwen2_omni": Qwen2OmniPlugin,
+ "qwen2_vl": Qwen2VLPlugin,
+ "qwen3_vl": Qwen3VLPlugin,
+ "video_llava": VideoLlavaPlugin,
+}
+
+
+def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
+ r"""Register a multimodal plugin."""
+ if name in PLUGINS:
+ raise ValueError(f"Multimodal plugin {name} already exists.")
+
+ PLUGINS[name] = plugin_class
+
+
+def get_mm_plugin(
+ name: str,
+ image_token: Optional[str] = None,
+ video_token: Optional[str] = None,
+ audio_token: Optional[str] = None,
+ **kwargs,
+) -> "BasePlugin":
+ r"""Get plugin for multimodal inputs."""
+ if name not in PLUGINS:
+ raise ValueError(f"Multimodal plugin `{name}` not found.")
+
+ return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
diff --git a/llamafactory/data/parser.py b/llamafactory/data/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a865fd83dfe6221f312a2030cd08a6b38366cfe
--- /dev/null
+++ b/llamafactory/data/parser.py
@@ -0,0 +1,149 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, Literal, Optional, Union
+
+from huggingface_hub import hf_hub_download
+
+from ..extras.constants import DATA_CONFIG
+from ..extras.misc import use_modelscope, use_openmind
+
+
+@dataclass
+class DatasetAttr:
+ r"""Dataset attributes."""
+
+ # basic configs
+ load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
+ dataset_name: str
+ formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
+ ranking: bool = False
+ # extra configs
+ subset: Optional[str] = None
+ split: str = "train"
+ folder: Optional[str] = None
+ num_samples: Optional[int] = None
+ # common columns
+ system: Optional[str] = None
+ tools: Optional[str] = None
+ images: Optional[str] = None
+ videos: Optional[str] = None
+ audios: Optional[str] = None
+ # dpo columns
+ chosen: Optional[str] = None
+ rejected: Optional[str] = None
+ kto_tag: Optional[str] = None
+ # alpaca columns
+ prompt: Optional[str] = "instruction"
+ query: Optional[str] = "input"
+ response: Optional[str] = "output"
+ history: Optional[str] = None
+ # sharegpt columns
+ messages: Optional[str] = "conversations"
+ # sharegpt tags
+ role_tag: Optional[str] = "from"
+ content_tag: Optional[str] = "value"
+ user_tag: Optional[str] = "human"
+ assistant_tag: Optional[str] = "gpt"
+ observation_tag: Optional[str] = "observation"
+ function_tag: Optional[str] = "function_call"
+ system_tag: Optional[str] = "system"
+
+ def __repr__(self) -> str:
+ return self.dataset_name
+
+ def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
+ setattr(self, key, obj.get(key, default))
+
+ def join(self, attr: dict[str, Any]) -> None:
+ self.set_attr("formatting", attr, default="alpaca")
+ self.set_attr("ranking", attr, default=False)
+ self.set_attr("subset", attr)
+ self.set_attr("split", attr, default="train")
+ self.set_attr("folder", attr)
+ self.set_attr("num_samples", attr)
+
+ if "columns" in attr:
+ column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
+ column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
+ for column_name in column_names:
+ self.set_attr(column_name, attr["columns"])
+
+ if "tags" in attr:
+ tag_names = ["role_tag", "content_tag"]
+ tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
+ for tag in tag_names:
+ self.set_attr(tag, attr["tags"])
+
+
+def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
+ r"""Get the attributes of the datasets."""
+ if dataset_names is None:
+ dataset_names = []
+
+ if isinstance(dataset_dir, dict):
+ dataset_info = dataset_dir
+ elif dataset_dir == "ONLINE":
+ dataset_info = None
+ else:
+ if dataset_dir.startswith("REMOTE:"):
+ config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
+ else:
+ config_path = os.path.join(dataset_dir, DATA_CONFIG)
+
+ try:
+ with open(config_path) as f:
+ dataset_info = json.load(f)
+ except Exception as err:
+ if len(dataset_names) != 0:
+ raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
+
+ dataset_info = None
+
+ dataset_list: list[DatasetAttr] = []
+ for name in dataset_names:
+ if dataset_info is None: # dataset_dir is ONLINE
+ load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
+ dataset_attr = DatasetAttr(load_from, dataset_name=name)
+ dataset_list.append(dataset_attr)
+ continue
+
+ if name not in dataset_info:
+ raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
+
+ has_hf_url = "hf_hub_url" in dataset_info[name]
+ has_ms_url = "ms_hub_url" in dataset_info[name]
+ has_om_url = "om_hub_url" in dataset_info[name]
+
+ if has_hf_url or has_ms_url or has_om_url:
+ if has_ms_url and (use_modelscope() or not has_hf_url):
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
+ elif has_om_url and (use_openmind() or not has_hf_url):
+ dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
+ else:
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
+ elif "script_url" in dataset_info[name]:
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
+ elif "cloud_file_name" in dataset_info[name]:
+ dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
+ else:
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
+
+ dataset_attr.join(dataset_info[name])
+ dataset_list.append(dataset_attr)
+
+ return dataset_list
diff --git a/llamafactory/data/processor/__init__.py b/llamafactory/data/processor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..357ab7899f9eecbd29344482d109b89af274ea2e
--- /dev/null
+++ b/llamafactory/data/processor/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .feedback import FeedbackDatasetProcessor
+from .pairwise import PairwiseDatasetProcessor
+from .pretrain import PretrainDatasetProcessor
+from .processor_utils import DatasetProcessor
+from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
+from .unsupervised import UnsupervisedDatasetProcessor
+
+
+__all__ = [
+ "DatasetProcessor",
+ "FeedbackDatasetProcessor",
+ "PackedSupervisedDatasetProcessor",
+ "PairwiseDatasetProcessor",
+ "PretrainDatasetProcessor",
+ "SupervisedDatasetProcessor",
+ "UnsupervisedDatasetProcessor",
+]
diff --git a/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc b/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e85ff951b87bd8d7cd85626bac8829ab0a727db
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc b/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01a84077c2946d6287b73a6f0decf5ab36fe133f
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc b/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..448f8f1ced997cd41e862fe07863214e4be3f68d
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/pairwise.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc b/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85208325c8814f094b4466853fa293154e398576
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/pretrain.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc b/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84248d8f528b3a457311a4d90c73173eef406822
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/processor_utils.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc b/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2619e0ba420f7ea44190d6d70bafb8b0245106d4
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/supervised.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc b/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..656d5b93ad6bf756fb5ba1cfd4a027954fa2e783
Binary files /dev/null and b/llamafactory/data/processor/__pycache__/unsupervised.cpython-312.pyc differ
diff --git a/llamafactory/data/processor/feedback.py b/llamafactory/data/processor/feedback.py
new file mode 100644
index 0000000000000000000000000000000000000000..871615b9266e501f25f68e84e4536c0d24617803
--- /dev/null
+++ b/llamafactory/data/processor/feedback.py
@@ -0,0 +1,129 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...extras import logging
+from ...extras.constants import IGNORE_INDEX
+from .processor_utils import DatasetProcessor, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from ..mm_plugin import AudioInput, ImageInput, VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class FeedbackDatasetProcessor(DatasetProcessor):
+ def _encode_data_example(
+ self,
+ prompt: list[dict[str, str]],
+ response: list[dict[str, str]],
+ kl_response: list[dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ) -> tuple[list[int], list[int], list[int], list[int], bool]:
+ if response[0]["content"]: # desired example
+ kto_tag = True
+ messages = prompt + [response[0]]
+ else: # undesired example
+ kto_tag = False
+ messages = prompt + [response[1]]
+
+ if kl_response[0]["content"]:
+ kl_messages = prompt + [kl_response[0]]
+ else:
+ kl_messages = prompt + [kl_response[1]]
+
+ messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
+ kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
+ prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
+ kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
+
+ if self.template.efficient_eos:
+ response_ids += [self.tokenizer.eos_token_id]
+ kl_response_ids += [self.tokenizer.eos_token_id]
+
+ prompt_ids, _ = self.template.mm_plugin.process_token_ids(
+ prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
+ )
+ kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
+ kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
+ )
+
+ source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
+ prompt_ids = prompt_ids[:source_len]
+ response_ids = response_ids[:target_len]
+ kl_source_len, kl_target_len = infer_seqlen(
+ len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
+ )
+ kl_prompt_ids = kl_prompt_ids[:kl_source_len]
+ kl_response_ids = kl_response_ids[:kl_target_len]
+
+ input_ids = prompt_ids + response_ids
+ labels = [IGNORE_INDEX] * source_len + response_ids
+ kl_input_ids = kl_prompt_ids + kl_response_ids
+ kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
+ return input_ids, labels, kl_input_ids, kl_labels, kto_tag
+
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
+ kl_response = [examples["_response"][-1]] + examples["_response"][:-1]
+ model_inputs = defaultdict(list)
+ for i in range(len(examples["_prompt"])):
+ if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
+ logger.warning_rank0(
+ "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
+ )
+ continue
+
+ input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
+ prompt=examples["_prompt"][i],
+ response=examples["_response"][i],
+ kl_response=kl_response[i],
+ system=examples["_system"][i],
+ tools=examples["_tools"][i],
+ images=examples["_images"][i] or [],
+ videos=examples["_videos"][i] or [],
+ audios=examples["_audios"][i] or [],
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ model_inputs["kl_input_ids"].append(kl_input_ids)
+ model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
+ model_inputs["kl_labels"].append(kl_labels)
+ model_inputs["kto_tags"].append(kto_tag)
+ model_inputs["images"].append(examples["_images"][i])
+ model_inputs["videos"].append(examples["_videos"][i])
+ model_inputs["audios"].append(examples["_audios"][i])
+
+ desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
+ undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
+ if desirable_num == 0 or undesirable_num == 0:
+ logger.warning_rank0("Your dataset only has one preference type.")
+
+ return model_inputs
+
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
diff --git a/llamafactory/data/processor/pairwise.py b/llamafactory/data/processor/pairwise.py
new file mode 100644
index 0000000000000000000000000000000000000000..94101deb8e75af73c1851720604994a11f2eb87d
--- /dev/null
+++ b/llamafactory/data/processor/pairwise.py
@@ -0,0 +1,118 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...extras import logging
+from ...extras.constants import IGNORE_INDEX
+from .processor_utils import DatasetProcessor, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from ..mm_plugin import AudioInput, ImageInput, VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class PairwiseDatasetProcessor(DatasetProcessor):
+ def _encode_data_example(
+ self,
+ prompt: list[dict[str, str]],
+ response: list[dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ) -> tuple[list[int], list[int], list[int], list[int]]:
+ chosen_messages = self.template.mm_plugin.process_messages(
+ prompt + [response[0]], images, videos, audios, self.processor
+ )
+ rejected_messages = self.template.mm_plugin.process_messages(
+ prompt + [response[1]], images, videos, audios, self.processor
+ )
+ prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
+ _, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
+
+ if self.template.efficient_eos:
+ chosen_ids += [self.tokenizer.eos_token_id]
+ rejected_ids += [self.tokenizer.eos_token_id]
+
+ prompt_ids, _ = self.template.mm_plugin.process_token_ids(
+ prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
+ )
+ # consider the response is more important
+ source_len, target_len = infer_seqlen(
+ len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
+ )
+ prompt_ids = prompt_ids[:source_len]
+ chosen_ids = chosen_ids[:target_len]
+ rejected_ids = rejected_ids[:target_len]
+
+ chosen_input_ids = prompt_ids + chosen_ids
+ chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
+ rejected_input_ids = prompt_ids + rejected_ids
+ rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
+ return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
+
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # build input pairs with format ` X`, `Y1 ` and `Y2 `
+ model_inputs = defaultdict(list)
+ for i in range(len(examples["_prompt"])):
+ if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
+ logger.warning_rank0(
+ "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
+ )
+ continue
+
+ chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
+ prompt=examples["_prompt"][i],
+ response=examples["_response"][i],
+ system=examples["_system"][i],
+ tools=examples["_tools"][i],
+ images=examples["_images"][i] or [],
+ videos=examples["_videos"][i] or [],
+ audios=examples["_audios"][i] or [],
+ )
+ model_inputs["chosen_input_ids"].append(chosen_input_ids)
+ model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
+ model_inputs["chosen_labels"].append(chosen_labels)
+ model_inputs["rejected_input_ids"].append(rejected_input_ids)
+ model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
+ model_inputs["rejected_labels"].append(rejected_labels)
+ model_inputs["images"].append(examples["_images"][i])
+ model_inputs["videos"].append(examples["_videos"][i])
+ model_inputs["audios"].append(examples["_audios"][i])
+
+ return model_inputs
+
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
+ valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
+ print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
+ print(
+ "chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
+ )
+ print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
+ print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
+ print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
+ print(
+ "rejected_inputs:\n{}".format(
+ self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
+ )
+ )
+ print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
+ print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
diff --git a/llamafactory/data/processor/pretrain.py b/llamafactory/data/processor/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa6b1ca58a8d59493cd4b43c51cb268080cc506
--- /dev/null
+++ b/llamafactory/data/processor/pretrain.py
@@ -0,0 +1,57 @@
+# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from itertools import chain
+from typing import Any
+
+from .processor_utils import DatasetProcessor
+
+
+@dataclass
+class PretrainDatasetProcessor(DatasetProcessor):
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
+ eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
+ text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
+
+ if not self.data_args.packing:
+ if getattr(self.tokenizer, "add_bos_token", False):
+ text_examples = [self.tokenizer.bos_token + example for example in text_examples]
+
+ result = self.tokenizer(
+ text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
+ )
+ else:
+ tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
+ block_size = self.data_args.cutoff_len
+ total_length = (total_length // block_size) * block_size
+ result = {
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ if getattr(self.tokenizer, "add_bos_token", False):
+ for i in range(len(result["input_ids"])):
+ result["input_ids"][i][0] = self.tokenizer.bos_token_id
+
+ return result
+
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
diff --git a/llamafactory/data/processor/processor_utils.py b/llamafactory/data/processor/processor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db44b19cf6fc84d6551fb7cce82283774ae72030
--- /dev/null
+++ b/llamafactory/data/processor/processor_utils.py
@@ -0,0 +1,88 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import bisect
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Optional
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ...hparams import DataArguments
+ from ..template import Template
+
+
+@dataclass
+class DatasetProcessor(ABC):
+ r"""A class for data processors."""
+
+ template: "Template"
+ tokenizer: "PreTrainedTokenizer"
+ processor: Optional["ProcessorMixin"]
+ data_args: "DataArguments"
+
+ @abstractmethod
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ r"""Build model inputs from the examples."""
+ ...
+
+ @abstractmethod
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ r"""Print a data example to stdout."""
+ ...
+
+
+def search_for_fit(numbers: list[int], capacity: int) -> int:
+ r"""Find the index of largest number that fits into the knapsack with the given capacity."""
+ index = bisect.bisect(numbers, capacity)
+ return -1 if index == 0 else (index - 1)
+
+
+def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
+ r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
+ numbers.sort() # sort numbers in ascending order for binary search
+ knapsacks = []
+
+ while numbers:
+ current_knapsack = []
+ remaining_capacity = capacity
+
+ while True:
+ index = search_for_fit(numbers, remaining_capacity)
+ if index == -1:
+ break # no more numbers fit in this knapsack
+
+ remaining_capacity -= numbers[index] # update the remaining capacity
+ current_knapsack.append(numbers.pop(index)) # add the number to knapsack
+
+ knapsacks.append(current_knapsack)
+
+ return knapsacks
+
+
+def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
+ r"""Compute the real sequence length after truncation by the cutoff_len."""
+ if target_len * 2 < cutoff_len: # truncate source
+ max_target_len = cutoff_len
+ elif source_len * 2 < cutoff_len: # truncate target
+ max_target_len = cutoff_len - source_len
+ else: # truncate both
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
+
+ new_target_len = min(max_target_len, target_len)
+ max_source_len = max(cutoff_len - new_target_len, 0)
+ new_source_len = min(max_source_len, source_len)
+ return new_source_len, new_target_len
diff --git a/llamafactory/data/processor/supervised.py b/llamafactory/data/processor/supervised.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5aba11b6535078f46bdf6aca743c6ae262e1fc6
--- /dev/null
+++ b/llamafactory/data/processor/supervised.py
@@ -0,0 +1,203 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...extras import logging
+from ...extras.constants import IGNORE_INDEX
+from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from ..mm_plugin import AudioInput, ImageInput, VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class SupervisedDatasetProcessor(DatasetProcessor):
+ def _encode_data_example(
+ self,
+ prompt: list[dict[str, str]],
+ response: list[dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ) -> tuple[list[int], list[int]]:
+ messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
+ input_ids, labels = self.template.mm_plugin.process_token_ids(
+ [], [], images, videos, audios, self.tokenizer, self.processor
+ )
+ encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
+ total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
+ if self.data_args.mask_history:
+ encoded_pairs = encoded_pairs[::-1] # high priority for last turns
+
+ for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
+ if total_length >= self.data_args.cutoff_len:
+ break
+
+ source_len, target_len = infer_seqlen(
+ len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
+ )
+ source_ids = source_ids[:source_len]
+ target_ids = target_ids[:target_len]
+ total_length += source_len + target_len
+
+ if self.data_args.train_on_prompt:
+ source_label = source_ids
+ elif self.template.efficient_eos and turn_idx != 0:
+ source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
+ else:
+ source_label = [IGNORE_INDEX] * source_len
+
+ if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
+ target_label = [IGNORE_INDEX] * target_len
+ else:
+ target_label = target_ids
+
+ if self.data_args.mask_history: # reversed sequences
+ input_ids = source_ids + target_ids + input_ids
+ labels = source_label + target_label + labels
+ else:
+ input_ids += source_ids + target_ids
+ labels += source_label + target_label
+
+ if self.template.efficient_eos:
+ input_ids += [self.tokenizer.eos_token_id]
+ labels += [self.tokenizer.eos_token_id]
+
+ return input_ids, labels
+
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # build inputs with format ` X Y ` and labels with format ` ... Y `
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+ model_inputs = defaultdict(list)
+ for i in range(len(examples["_prompt"])):
+ if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
+ logger.warning_rank0(
+ "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
+ )
+ continue
+
+ input_ids, labels = self._encode_data_example(
+ prompt=examples["_prompt"][i],
+ response=examples["_response"][i],
+ system=examples["_system"][i],
+ tools=examples["_tools"][i],
+ images=examples["_images"][i] or [],
+ videos=examples["_videos"][i] or [],
+ audios=examples["_audios"][i] or [],
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ model_inputs["images"].append(examples["_images"][i])
+ model_inputs["videos"].append(examples["_videos"][i])
+ model_inputs["audios"].append(examples["_audios"][i])
+
+ return model_inputs
+
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
+
+
+@dataclass
+class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # TODO: use `position_ids` to achieve packing
+ # build inputs with format ` X1 Y1 X2 Y2 `
+ # and labels with format ` ... Y1 ... Y2 `
+ valid_num = 0
+ batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
+ lengths = []
+ length2indexes = defaultdict(list)
+ for i in range(len(examples["_prompt"])):
+ if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
+ logger.warning_rank0(
+ "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
+ )
+ continue
+
+ input_ids, labels = self._encode_data_example(
+ prompt=examples["_prompt"][i],
+ response=examples["_response"][i],
+ system=examples["_system"][i],
+ tools=examples["_tools"][i],
+ images=examples["_images"][i] or [],
+ videos=examples["_videos"][i] or [],
+ audios=examples["_audios"][i] or [],
+ )
+ length = len(input_ids)
+ if length > self.data_args.cutoff_len:
+ logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
+ else:
+ lengths.append(length)
+ length2indexes[length].append(valid_num)
+ batch_input_ids.append(input_ids)
+ batch_labels.append(labels)
+ batch_images.append(examples["_images"][i] or [])
+ batch_videos.append(examples["_videos"][i] or [])
+ batch_audios.append(examples["_audios"][i] or [])
+ valid_num += 1
+
+ model_inputs = defaultdict(list)
+ knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
+ for knapsack in knapsacks:
+ packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
+ packed_images, packed_videos, packed_audios = [], [], []
+ for i, length in enumerate(knapsack):
+ index = length2indexes[length].pop()
+ packed_input_ids += batch_input_ids[index]
+ packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
+ packed_labels += batch_labels[index]
+ packed_images += batch_images[index]
+ packed_videos += batch_videos[index]
+ packed_audios += batch_audios[index]
+ if self.data_args.neat_packing:
+ packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
+ else:
+ packed_attention_masks += [1] * len(batch_input_ids[index])
+
+ if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
+ pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
+ packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
+ packed_position_ids += [0] * pad_length
+ packed_labels += [IGNORE_INDEX] * pad_length
+ if self.data_args.neat_packing:
+ packed_attention_masks += [0] * pad_length
+ else:
+ packed_attention_masks += [1] * pad_length # more efficient flash_attn
+
+ if len(packed_input_ids) != self.data_args.cutoff_len + 1:
+ raise ValueError("The length of packed example should be identical to the cutoff length.")
+
+ model_inputs["input_ids"].append(packed_input_ids)
+ model_inputs["attention_mask"].append(packed_attention_masks)
+ model_inputs["position_ids"].append(packed_position_ids)
+ model_inputs["labels"].append(packed_labels)
+ model_inputs["images"].append(packed_images or None)
+ model_inputs["videos"].append(packed_videos or None)
+ model_inputs["audios"].append(packed_audios or None)
+
+ return model_inputs
diff --git a/llamafactory/data/processor/unsupervised.py b/llamafactory/data/processor/unsupervised.py
new file mode 100644
index 0000000000000000000000000000000000000000..256174b6dd38696b5b180501102af40ff395d0a9
--- /dev/null
+++ b/llamafactory/data/processor/unsupervised.py
@@ -0,0 +1,91 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...extras import logging
+from ..data_utils import Role
+from .processor_utils import DatasetProcessor, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from ..mm_plugin import AudioInput, ImageInput, VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class UnsupervisedDatasetProcessor(DatasetProcessor):
+ def _encode_data_example(
+ self,
+ prompt: list[dict[str, str]],
+ response: list[dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ) -> tuple[list[int], list[int]]:
+ if len(response) == 1:
+ messages = prompt + response
+ else:
+ messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
+
+ messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
+ input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
+ if self.template.efficient_eos:
+ labels += [self.tokenizer.eos_token_id]
+
+ input_ids, _ = self.template.mm_plugin.process_token_ids(
+ input_ids, None, images, videos, audios, self.tokenizer, self.processor
+ )
+ source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
+ input_ids = input_ids[:source_len]
+ labels = labels[:target_len]
+ return input_ids, labels
+
+ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
+ # build inputs with format ` X` and labels with format `Y `
+ model_inputs = defaultdict(list)
+ for i in range(len(examples["_prompt"])):
+ if len(examples["_prompt"][i]) % 2 != 1:
+ logger.warning_rank0(
+ "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
+ )
+ continue
+
+ input_ids, labels = self._encode_data_example(
+ prompt=examples["_prompt"][i],
+ response=examples["_response"][i],
+ system=examples["_system"][i],
+ tools=examples["_tools"][i],
+ images=examples["_images"][i] or [],
+ videos=examples["_videos"][i] or [],
+ audios=examples["_audios"][i] or [],
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ model_inputs["images"].append(examples["_images"][i])
+ model_inputs["videos"].append(examples["_videos"][i])
+ model_inputs["audios"].append(examples["_audios"][i])
+
+ return model_inputs
+
+ def print_data_example(self, example: dict[str, list[int]]) -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))
diff --git a/llamafactory/data/template.py b/llamafactory/data/template.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e32dd203cf753e6ffe486df291e2f71c29a11c
--- /dev/null
+++ b/llamafactory/data/template.py
@@ -0,0 +1,2209 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional, Union
+
+from typing_extensions import override
+
+from ..extras import logging
+from .data_utils import Role
+from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
+from .mm_plugin import get_mm_plugin
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+ from ..hparams import DataArguments
+ from .formatter import SLOTS, Formatter
+ from .mm_plugin import BasePlugin
+ from .tool_utils import FunctionCall
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class Template:
+ format_user: "Formatter"
+ format_assistant: "Formatter"
+ format_system: "Formatter"
+ format_function: "Formatter"
+ format_observation: "Formatter"
+ format_tools: "Formatter"
+ format_prefix: "Formatter"
+ default_system: str
+ stop_words: list[str]
+ thought_words: tuple[str, str]
+ efficient_eos: bool
+ replace_eos: bool
+ replace_jinja_template: bool
+ enable_thinking: Optional[bool]
+ mm_plugin: "BasePlugin"
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> tuple[list[int], list[int]]:
+ r"""Return a single pair of token ids representing prompt and response respectively."""
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ prompt_ids = []
+ for encoded_ids in encoded_messages[:-1]:
+ prompt_ids += encoded_ids
+
+ response_ids = encoded_messages[-1]
+ return prompt_ids, response_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> list[tuple[list[int], list[int]]]:
+ r"""Return multiple pairs of token ids representing prompts and responses respectively."""
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+ def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
+ r"""Extract tool message."""
+ return self.format_tools.extract(content)
+
+ def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
+ r"""Return stop token ids."""
+ stop_token_ids = {tokenizer.eos_token_id}
+ for token in self.stop_words:
+ stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
+
+ return list(stop_token_ids)
+
+ def add_thought(self, content: str = "") -> str:
+ r"""Add empty thought to assistant message."""
+ return f"{self.thought_words[0]}{self.thought_words[1]}" + content
+
+ def remove_thought(self, content: str) -> str:
+ r"""Remove thought from assistant message."""
+ pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
+ return re.sub(pattern, "", content).lstrip("\n")
+
+ def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
+ r"""Get the token ids of thought words."""
+ return tokenizer.encode(self.add_thought(), add_special_tokens=False)
+
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
+ r"""Convert elements to token ids."""
+ token_ids = []
+ for elem in elements:
+ if isinstance(elem, str):
+ if len(elem) != 0:
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
+ elif isinstance(elem, dict):
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ elif isinstance(elem, set):
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
+ token_ids += [tokenizer.bos_token_id]
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
+ token_ids += [tokenizer.eos_token_id]
+ else:
+ raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
+
+ return token_ids
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ ) -> list[list[int]]:
+ r"""Encode formatted inputs to pairs of token ids.
+
+ Turn 0: prefix + system + query resp
+ Turn t: query resp.
+ """
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ elements += self.format_system.apply(content=(system + tool_text))
+
+ if message["role"] == Role.USER:
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT:
+ elements += self.format_assistant.apply(content=message["content"])
+ elif message["role"] == Role.OBSERVATION:
+ elements += self.format_observation.apply(content=message["content"])
+ elif message["role"] == Role.FUNCTION:
+ elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words)
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+ @staticmethod
+ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
+ r"""Add or replace eos token to the tokenizer."""
+ if tokenizer.eos_token == eos_token:
+ return
+
+ is_added = tokenizer.eos_token_id is None
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
+
+ if is_added:
+ logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
+ else:
+ logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
+
+ if num_added_tokens > 0:
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
+
+ def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
+ r"""Add eos token and pad token to the tokenizer."""
+ stop_words = self.stop_words
+ if self.replace_eos:
+ if not stop_words:
+ raise ValueError("Stop words are required to replace the EOS token.")
+
+ self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
+ stop_words = stop_words[1:]
+
+ if tokenizer.eos_token_id is None:
+ self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
+
+ if stop_words:
+ num_added_tokens = tokenizer.add_special_tokens(
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
+ )
+ logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
+ if num_added_tokens > 0:
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
+
+ @staticmethod
+ def _jinja_escape(content: str) -> str:
+ r"""Escape single quotes in content."""
+ return content.replace("'", r"\'")
+
+ @staticmethod
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
+ r"""Convert slots to jinja template."""
+ slot_items = []
+ for slot in slots:
+ if isinstance(slot, str):
+ slot_pieces = slot.split("{{content}}")
+ if slot_pieces[0]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
+ if len(slot_pieces) > 1:
+ slot_items.append(placeholder)
+ if slot_pieces[1]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
+ slot_items.append("'" + tokenizer.bos_token + "'")
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
+ slot_items.append("'" + tokenizer.eos_token + "'")
+ elif isinstance(slot, dict):
+ raise ValueError("Dict is not supported.")
+
+ return " + ".join(slot_items)
+
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
+ r"""Return the jinja template."""
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
+ system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
+ user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
+ assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
+ jinja_template = ""
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if self.default_system:
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
+
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
+ "{% if system_message is defined %}{{ " + system + " }}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% set content = message['content'] %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ " + user + " }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ " + assistant + " }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ return jinja_template
+
+ def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
+ r"""Replace the jinja template in the tokenizer."""
+ if tokenizer.chat_template is None or self.replace_jinja_template:
+ try:
+ tokenizer.chat_template = self._get_jinja_template(tokenizer)
+ except ValueError as e:
+ logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
+
+ @staticmethod
+ def _convert_slots_to_ollama(
+ slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
+ ) -> str:
+ r"""Convert slots to ollama template."""
+ slot_items = []
+ for slot in slots:
+ if isinstance(slot, str):
+ slot_pieces = slot.split("{{content}}")
+ if slot_pieces[0]:
+ slot_items.append(slot_pieces[0])
+ if len(slot_pieces) > 1:
+ slot_items.append("{{ " + placeholder + " }}")
+ if slot_pieces[1]:
+ slot_items.append(slot_pieces[1])
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
+ slot_items.append(tokenizer.bos_token)
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
+ slot_items.append(tokenizer.eos_token)
+ elif isinstance(slot, dict):
+ raise ValueError("Dict is not supported.")
+
+ return "".join(slot_items)
+
+ def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
+ r"""Return the ollama template."""
+ prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
+ system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
+ user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
+ assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
+ return (
+ f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
+ f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
+ f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
+ )
+
+ def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
+ r"""Return the ollama modelfile.
+
+ TODO: support function calling.
+ """
+ modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
+ modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
+
+ if self.default_system:
+ modelfile += f'SYSTEM """{self.default_system}"""\n\n'
+
+ for stop_token_id in self.get_stop_token_ids(tokenizer):
+ modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
+
+ modelfile += "PARAMETER num_ctx 4096\n"
+ return modelfile
+
+
+@dataclass
+class Llama2Template(Template):
+ r"""A template that fuse the system message to first user message."""
+
+ @override
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: str,
+ tools: str,
+ ) -> list[list[int]]:
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+
+ system_text = ""
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
+
+ if message["role"] == Role.USER:
+ elements += self.format_user.apply(content=system_text + message["content"])
+ elif message["role"] == Role.ASSISTANT:
+ elements += self.format_assistant.apply(content=message["content"])
+ elif message["role"] == Role.OBSERVATION:
+ elements += self.format_observation.apply(content=message["content"])
+ elif message["role"] == Role.FUNCTION:
+ elements += self.format_function.apply(content=message["content"])
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
+ system_message = self._convert_slots_to_jinja(
+ self.format_system.apply(), tokenizer, placeholder="system_message"
+ )
+ user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
+ assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
+ jinja_template = ""
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if self.default_system:
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
+
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% if loop.index0 == 0 and system_message is defined %}"
+ "{% set content = " + system_message + " + message['content'] %}"
+ "{% else %}{% set content = message['content'] %}{% endif %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ " + user_message + " }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ " + assistant_message + " }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ return jinja_template
+
+
+@dataclass
+class ReasoningTemplate(Template):
+ r"""A template that add thought to assistant message."""
+
+ @override
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> tuple[list[int], list[int]]:
+ messages = deepcopy(messages)
+ for i in range(1, len(messages) - 2, 2):
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
+
+ if self.enable_thinking is False: # remove all cot
+ messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
+
+ prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
+ if (
+ self.thought_words[0].strip() not in messages[-1]["content"]
+ and self.thought_words[1].strip() not in messages[-1]["content"]
+ ): # add empty cot
+ if not self.enable_thinking: # do not compute loss
+ prompt_ids += self.get_thought_word_ids(tokenizer)
+ else: # do compute loss
+ response_ids = self.get_thought_word_ids(tokenizer) + response_ids
+
+ return prompt_ids, response_ids
+
+ @override
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> list[tuple[list[int], list[int]]]:
+ messages = deepcopy(messages)
+ if self.enable_thinking is False: # remove all cot
+ for i in range(1, len(messages), 2):
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
+
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ for i in range(0, len(messages), 2):
+ if (
+ self.thought_words[0].strip() not in messages[i + 1]["content"]
+ and self.thought_words[1].strip() not in messages[i + 1]["content"]
+ ): # add empty cot
+ if not self.enable_thinking: # do not compute loss
+ encoded_messages[i] += self.get_thought_word_ids(tokenizer)
+ else: # do compute loss
+ encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
+
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+
+TEMPLATES: dict[str, "Template"] = {}
+
+
+def register_template(
+ name: str,
+ format_user: Optional["Formatter"] = None,
+ format_assistant: Optional["Formatter"] = None,
+ format_system: Optional["Formatter"] = None,
+ format_function: Optional["Formatter"] = None,
+ format_observation: Optional["Formatter"] = None,
+ format_tools: Optional["Formatter"] = None,
+ format_prefix: Optional["Formatter"] = None,
+ default_system: str = "",
+ stop_words: Optional[list[str]] = None,
+ thought_words: Optional[tuple[str, str]] = None,
+ efficient_eos: bool = False,
+ replace_eos: bool = False,
+ replace_jinja_template: bool = False,
+ enable_thinking: Optional[bool] = True,
+ mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
+ template_class: type["Template"] = Template,
+) -> None:
+ r"""Register a chat template.
+
+ To add the following chat template:
+ ```
+ user prompt here
+ model response here
+ user prompt here
+ model response here
+ ```
+
+ The corresponding code should be:
+ ```
+ register_template(
+ name="custom",
+ format_user=StringFormatter(slots=["{{content}}\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_prefix=EmptyFormatter(""),
+ )
+ ```
+ """
+ if name in TEMPLATES:
+ raise ValueError(f"Template {name} already exists.")
+
+ default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
+ default_assistant_formatter = StringFormatter(slots=default_slots)
+ if format_assistant is not None:
+ default_function_formatter = FunctionFormatter(slots=format_assistant.slots, tool_format="default")
+ else:
+ default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
+
+ default_tool_formatter = ToolFormatter(tool_format="default")
+ default_prefix_formatter = EmptyFormatter()
+ TEMPLATES[name] = template_class(
+ format_user=format_user or default_user_formatter,
+ format_assistant=format_assistant or default_assistant_formatter,
+ format_system=format_system or default_user_formatter,
+ format_function=format_function or default_function_formatter,
+ format_observation=format_observation or format_user or default_user_formatter,
+ format_tools=format_tools or default_tool_formatter,
+ format_prefix=format_prefix or default_prefix_formatter,
+ default_system=default_system,
+ stop_words=stop_words or [],
+ thought_words=thought_words or ("\n", "\n\n\n"),
+ efficient_eos=efficient_eos,
+ replace_eos=replace_eos,
+ replace_jinja_template=replace_jinja_template,
+ enable_thinking=enable_thinking,
+ mm_plugin=mm_plugin,
+ )
+
+
+def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
+ r"""Extract a chat template from the tokenizer."""
+
+ def find_diff(short_str: str, long_str: str) -> str:
+ i, j = 0, 0
+ diff = ""
+ while i < len(short_str) and j < len(long_str):
+ if short_str[i] == long_str[j]:
+ i += 1
+ j += 1
+ else:
+ diff += long_str[j]
+ j += 1
+
+ return diff
+
+ prefix = tokenizer.decode(tokenizer.encode(""))
+
+ messages = [{"role": "system", "content": "{{content}}"}]
+ system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
+
+ messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
+ user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ user_slot_empty_system = user_slot_empty_system[len(prefix) :]
+
+ messages = [{"role": "user", "content": "{{content}}"}]
+ user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ user_slot = user_slot[len(prefix) :]
+
+ messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
+ assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
+ assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
+ template_class = ReasoningTemplate if "" in assistant_slot else Template
+ assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags
+
+ if len(user_slot) > len(user_slot_empty_system):
+ default_system = find_diff(user_slot_empty_system, user_slot)
+ sole_system = system_slot.replace("{{content}}", default_system, 1)
+ user_slot = user_slot[len(sole_system) :]
+ else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
+ default_system = ""
+
+ return template_class(
+ format_user=StringFormatter(slots=[user_slot]),
+ format_assistant=StringFormatter(slots=[assistant_slot]),
+ format_system=StringFormatter(slots=[system_slot]),
+ format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
+ format_observation=StringFormatter(slots=[user_slot]),
+ format_tools=ToolFormatter(tool_format="default"),
+ format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
+ default_system=default_system,
+ stop_words=[],
+ thought_words=("\n", "\n\n\n"),
+ efficient_eos=False,
+ replace_eos=False,
+ replace_jinja_template=False,
+ enable_thinking=True,
+ mm_plugin=get_mm_plugin(name="base"),
+ )
+
+
+def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
+ r"""Get chat template and fixes the tokenizer."""
+ if data_args.template is None:
+ if isinstance(tokenizer.chat_template, str):
+ logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
+ template = parse_template(tokenizer)
+ else:
+ logger.warning_rank0("`template` was not specified, use `empty` template.")
+ template = TEMPLATES["empty"] # placeholder
+ else:
+ if data_args.template not in TEMPLATES:
+ raise ValueError(f"Template {data_args.template} does not exist.")
+
+ template = TEMPLATES[data_args.template]
+
+ if data_args.train_on_prompt and template.efficient_eos:
+ raise ValueError("Current template does not support `train_on_prompt`.")
+
+ if data_args.tool_format is not None:
+ logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
+ default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
+ template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
+ template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
+
+ if data_args.default_system is not None:
+ logger.info_rank0(f"Using default system message: {data_args.default_system}.")
+ template.default_system = data_args.default_system
+
+ if isinstance(template, ReasoningTemplate):
+ logger.warning_rank0(
+ "You are using reasoning template, "
+ "please add `_nothink` suffix if the model is not a reasoning model. "
+ "e.g., qwen3_vl_nothink"
+ )
+ template.enable_thinking = data_args.enable_thinking
+
+ template.fix_special_tokens(tokenizer)
+ template.fix_jinja_template(tokenizer)
+ return template
+
+
+register_template(
+ name="alpaca",
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
+ default_system=(
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
+ ),
+ replace_jinja_template=True,
+)
+
+
+register_template(
+ name="aquila",
+ format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
+ format_assistant=StringFormatter(slots=["{{content}}###"]),
+ format_system=StringFormatter(slots=["System: {{content}}###"]),
+ default_system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
+ stop_words=[""],
+)
+
+
+register_template(
+ name="atom",
+ format_user=StringFormatter(
+ slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
+)
+
+
+register_template(
+ name="baichuan",
+ format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="baichuan2",
+ format_user=StringFormatter(slots=["{{content}}"]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="bailing",
+ format_user=StringFormatter(slots=["HUMAN{{content}}ASSISTANT"]),
+ format_system=StringFormatter(slots=["SYSTEM{{content}}"]),
+ format_observation=StringFormatter(slots=["OBSERVATION{{content}}ASSISTANT"]),
+ stop_words=["<|endoftext|>"],
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="bailing_v2",
+ format_user=StringFormatter(slots=["HUMAN{{content}}<|role_end|>ASSISTANT"]),
+ format_system=StringFormatter(slots=["SYSTEM{{content}}<|role_end|>"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]),
+ format_observation=StringFormatter(
+ slots=[
+ "OBSERVATION\n\n{{content}}\n<|role_end|>ASSISTANT"
+ ]
+ ),
+ format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"),
+ format_tools=ToolFormatter(tool_format="ling"),
+ stop_words=["<|endoftext|>"],
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="belle",
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+register_template(
+ name="bluelm",
+ format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
+)
+
+
+register_template(
+ name="breeze",
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="chatglm2",
+ format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="chatglm3",
+ format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
+ format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(
+ slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
+ ),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="chatml",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ stop_words=["<|im_end|>", "<|im_start|>"],
+ replace_eos=True,
+ replace_jinja_template=True,
+)
+
+
+# copied from chatml template
+register_template(
+ name="chatml_de",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
+ stop_words=["<|im_end|>", "<|im_start|>"],
+ replace_eos=True,
+ replace_jinja_template=True,
+)
+
+
+register_template(
+ name="codegeex2",
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+)
+
+
+register_template(
+ name="codegeex4",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ default_system=(
+ "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
+ "并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
+ ),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="cohere",
+ format_user=StringFormatter(
+ slots=[
+ (
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ )
+ ]
+ ),
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+register_template(
+ name="cpm",
+ format_user=StringFormatter(slots=["<用户>{{content}}"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+# copied from chatml template
+register_template(
+ name="cpm3",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|im_end|>"],
+)
+
+
+# copied from chatml template
+register_template(
+ name="cpm4",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|im_end|>"],
+)
+
+
+# copied from chatml template
+register_template(
+ name="dbrx",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ default_system=(
+ "You are DBRX, created by Databricks. You were last updated in December 2023. "
+ "You answer questions based on information available up to that point.\n"
+ "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
+ "responses to more complex and open-ended questions.\nYou assist with various tasks, "
+ "from writing to coding (using markdown for code blocks — remember to use ``` with "
+ "code, JSON, and tables).\n(You do not have real-time data access or code execution "
+ "capabilities. You avoid stereotyping and provide balanced perspectives on "
+ "controversial topics. You do not provide song lyrics, poems, or news articles and "
+ "do not divulge details of your training data.)\nThis is your system prompt, "
+ "guiding your responses. Do not reference it, just respond to the user. If you find "
+ "yourself talking about this message, stop. You should be responding appropriately "
+ "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
+ "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
+ ),
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+)
+
+
+register_template(
+ name="deepseek",
+ format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+register_template(
+ name="deepseek3",
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+# copied from deepseek3 template
+register_template(
+ name="deepseekr1",
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ template_class=ReasoningTemplate,
+)
+
+
+register_template(
+ name="deepseekcoder",
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "You are an AI programming assistant, utilizing the DeepSeek Coder model, "
+ "developed by DeepSeek Company, and you only answer questions related to computer science. "
+ "For politically sensitive questions, security and privacy issues, "
+ "and other non-computer science questions, you will refuse to answer.\n"
+ ),
+)
+
+
+register_template(
+ name="default",
+ format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
+ format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
+ replace_jinja_template=True,
+)
+
+
+register_template(
+ name="dots_ocr",
+ format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]),
+ format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]),
+ stop_words=["<|endofassistant|>"],
+ efficient_eos=True,
+ mm_plugin=get_mm_plugin(
+ name="qwen2_vl",
+ image_token="<|imgpad|>",
+ video_token="<|vidpad|>",
+ vision_bos_token="<|img|>",
+ vision_eos_token="<|endofimg|>",
+ ),
+)
+
+
+register_template(
+ name="empty",
+ format_assistant=StringFormatter(slots=["{{content}}"]),
+)
+
+
+# copied from chatml template
+register_template(
+ name="ernie",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
+ default_system="\nthink_mode=True\n",
+ stop_words=["<|im_end|>"],
+)
+
+
+register_template(
+ name="ernie_nothink",
+ format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]),
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]),
+ format_system=StringFormatter(slots=["{{content}}\n"]),
+ format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]),
+ stop_words=["<|end_of_sentence|>"],
+)
+
+
+register_template(
+ name="exaone",
+ format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
+ format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
+)
+
+
+register_template(
+ name="falcon",
+ format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ efficient_eos=True,
+)
+
+
+# copied from chatml template
+register_template(
+ name="falcon_h1",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|im_end|>", "<|end_of_text|>"],
+)
+
+
+register_template(
+ name="fewshot",
+ format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
+ efficient_eos=True,
+ replace_jinja_template=True,
+)
+
+
+register_template(
+ name="gemma",
+ format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
+ format_observation=StringFormatter(
+ slots=["tool\n{{content}}\nmodel\n"]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=[""],
+ replace_eos=True,
+ template_class=Llama2Template,
+)
+
+
+# copied from gemma template
+register_template(
+ name="gemma2",
+ format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
+ format_observation=StringFormatter(
+ slots=["tool\n{{content}}\nmodel\n"]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["", ""],
+ efficient_eos=True,
+ template_class=Llama2Template,
+)
+
+
+# copied from gemma template
+register_template(
+ name="gemma3",
+ format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
+ format_observation=StringFormatter(
+ slots=["tool\n{{content}}\nmodel\n"]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=[""],
+ replace_eos=True,
+ mm_plugin=get_mm_plugin("gemma3", image_token=""),
+ template_class=Llama2Template,
+)
+
+
+register_template(
+ name="gemma3n",
+ format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
+ format_observation=StringFormatter(
+ slots=["tool\n{{content}}\nmodel\n"]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=[""],
+ replace_eos=True,
+ mm_plugin=get_mm_plugin("gemma3n", image_token="", audio_token=""),
+ template_class=Llama2Template,
+)
+
+
+register_template(
+ name="glm4",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4_moe",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+ template_class=ReasoningTemplate,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4v",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>", ""],
+ efficient_eos=True,
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
+ template_class=ReasoningTemplate,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4v_moe",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>", ""],
+ efficient_eos=True,
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
+ template_class=ReasoningTemplate,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glmz1",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+ template_class=ReasoningTemplate,
+)
+
+
+register_template(
+ name="gpt",
+ format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
+ format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
+ default_system="You are ChatGPT, a large language model trained by OpenAI.",
+ thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
+ efficient_eos=True,
+ template_class=ReasoningTemplate,
+)
+
+
+register_template(
+ name="granite3",
+ format_user=StringFormatter(
+ slots=[
+ "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
+ ]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
+ format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
+)
+
+
+register_template(
+ name="granite3_vision",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
+ default_system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ mm_plugin=get_mm_plugin(name="llava_next", image_token=""),
+)
+
+
+register_template(
+ name="granite4",
+ format_user=StringFormatter(
+ slots=[
+ "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
+ ]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
+ format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"),
+ format_observation=StringFormatter(
+ slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="default"),
+ stop_words=["<|end_of_text|>"],
+ default_system="You are Granite, developed by IBM. You are a helpful AI assistant.",
+)
+
+
+register_template(
+ name="index",
+ format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
+ format_system=StringFormatter(slots=["{{content}}"]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="hunyuan",
+ format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]),
+ format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
+ format_prefix=EmptyFormatter(slots=["<|startoftext|>"]),
+ stop_words=["<|eos|>"],
+)
+
+
+register_template(
+ name="intern",
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "You are an AI assistant whose name is InternLM (书生·浦语).\n"
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
+ "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language "
+ "chosen by the user such as English and 中文."
+ ),
+ stop_words=[""],
+)
+
+
+register_template(
+ name="intern2",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "You are an AI assistant whose name is InternLM (书生·浦语).\n"
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
+ "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language "
+ "chosen by the user such as English and 中文."
+ ),
+ stop_words=["<|im_end|>"],
+)
+
+
+register_template(
+ name="intern_vl",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
+ ),
+ stop_words=["<|im_end|>"],
+ mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token="