lincarl commited on
Commit
b67858f
·
verified ·
1 Parent(s): f57eada

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +3 -9
  3. api.py +33 -0
  4. llamafactory.egg-info/PKG-INFO +1124 -0
  5. llamafactory.egg-info/SOURCES.txt +178 -0
  6. llamafactory.egg-info/dependency_links.txt +1 -0
  7. llamafactory.egg-info/entry_points.txt +3 -0
  8. llamafactory.egg-info/requires.txt +125 -0
  9. llamafactory.egg-info/top_level.txt +1 -0
  10. llamafactory/__init__.py +31 -0
  11. llamafactory/__pycache__/__init__.cpython-312.pyc +0 -0
  12. llamafactory/__pycache__/cli.cpython-312.pyc +0 -0
  13. llamafactory/__pycache__/launcher.cpython-312.pyc +0 -0
  14. llamafactory/api/__init__.py +0 -0
  15. llamafactory/api/app.py +133 -0
  16. llamafactory/api/chat.py +291 -0
  17. llamafactory/api/common.py +96 -0
  18. llamafactory/api/protocol.py +157 -0
  19. llamafactory/chat/__init__.py +19 -0
  20. llamafactory/chat/__pycache__/__init__.cpython-312.pyc +0 -0
  21. llamafactory/chat/__pycache__/base_engine.cpython-312.pyc +0 -0
  22. llamafactory/chat/__pycache__/chat_model.cpython-312.pyc +0 -0
  23. llamafactory/chat/base_engine.py +98 -0
  24. llamafactory/chat/chat_model.py +210 -0
  25. llamafactory/chat/hf_engine.py +412 -0
  26. llamafactory/chat/kt_engine.py +284 -0
  27. llamafactory/chat/sglang_engine.py +289 -0
  28. llamafactory/chat/vllm_engine.py +263 -0
  29. llamafactory/cli.py +31 -0
  30. llamafactory/data/__init__.py +37 -0
  31. llamafactory/data/__pycache__/__init__.cpython-312.pyc +0 -0
  32. llamafactory/data/__pycache__/collator.cpython-312.pyc +0 -0
  33. llamafactory/data/__pycache__/converter.cpython-312.pyc +0 -0
  34. llamafactory/data/__pycache__/data_utils.cpython-312.pyc +0 -0
  35. llamafactory/data/__pycache__/formatter.cpython-312.pyc +0 -0
  36. llamafactory/data/__pycache__/loader.cpython-312.pyc +0 -0
  37. llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc +0 -0
  38. llamafactory/data/__pycache__/parser.cpython-312.pyc +0 -0
  39. llamafactory/data/__pycache__/template.cpython-312.pyc +0 -0
  40. llamafactory/data/__pycache__/tool_utils.cpython-312.pyc +0 -0
  41. llamafactory/data/collator.py +331 -0
  42. llamafactory/data/converter.py +425 -0
  43. llamafactory/data/data_utils.py +190 -0
  44. llamafactory/data/formatter.py +145 -0
  45. llamafactory/data/loader.py +334 -0
  46. llamafactory/data/mm_plugin.py +2082 -0
  47. llamafactory/data/parser.py +149 -0
  48. llamafactory/data/processor/__init__.py +31 -0
  49. llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc +0 -0
  50. llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ llamafactory/extras/__pycache__/constants.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: App
3
- emoji: 🏢
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: app
3
+ app_file: webui.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.45.0
 
 
6
  ---
 
 
api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import uvicorn
18
+
19
+ from llamafactory.api.app import create_app
20
+ from llamafactory.chat import ChatModel
21
+
22
+
23
+ def main():
24
+ chat_model = ChatModel()
25
+ app = create_app(chat_model)
26
+ api_host = os.getenv("API_HOST", "0.0.0.0")
27
+ api_port = int(os.getenv("API_PORT", "8000"))
28
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
29
+ uvicorn.run(app, host=api_host, port=api_port)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
llamafactory.egg-info/PKG-INFO ADDED
@@ -0,0 +1,1124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: llamafactory
3
+ Version: 0.9.4.dev0
4
+ Summary: Unified Efficient Fine-Tuning of 100+ LLMs
5
+ Home-page: https://github.com/hiyouga/LLaMA-Factory
6
+ Author: hiyouga
7
+ Author-email: [email protected]
8
+ License: Apache 2.0 License
9
+ Keywords: AI,LLM,GPT,ChatGPT,Llama,Transformer,DeepSeek,Pytorch
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.9.0
23
+ Description-Content-Type: text/markdown
24
+ License-File: LICENSE
25
+ Requires-Dist: transformers!=4.52.0,<=4.56.2,>=4.49.0; python_version < "3.10"
26
+ Requires-Dist: transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0; python_version >= "3.10"
27
+ Requires-Dist: datasets<=4.0.0,>=2.16.0
28
+ Requires-Dist: accelerate<=1.11.0,>=1.3.0
29
+ Requires-Dist: peft<=0.17.1,>=0.14.0
30
+ Requires-Dist: trl<=0.9.6,>=0.8.6
31
+ Requires-Dist: gradio<=5.45.0,>=4.38.0
32
+ Requires-Dist: matplotlib>=3.7.0
33
+ Requires-Dist: tyro<0.9.0
34
+ Requires-Dist: einops
35
+ Requires-Dist: numpy<2.0.0
36
+ Requires-Dist: pandas>=2.0.0
37
+ Requires-Dist: scipy
38
+ Requires-Dist: sentencepiece
39
+ Requires-Dist: tiktoken
40
+ Requires-Dist: modelscope>=1.14.0
41
+ Requires-Dist: hf-transfer
42
+ Requires-Dist: safetensors<=0.5.3
43
+ Requires-Dist: fire
44
+ Requires-Dist: omegaconf
45
+ Requires-Dist: packaging
46
+ Requires-Dist: protobuf
47
+ Requires-Dist: pyyaml
48
+ Requires-Dist: pydantic<=2.10.6
49
+ Requires-Dist: uvicorn
50
+ Requires-Dist: fastapi
51
+ Requires-Dist: sse-starlette
52
+ Requires-Dist: av
53
+ Requires-Dist: librosa
54
+ Requires-Dist: propcache!=0.4.0
55
+ Provides-Extra: torch
56
+ Requires-Dist: torch>=2.0.0; extra == "torch"
57
+ Requires-Dist: torchvision>=0.15.0; extra == "torch"
58
+ Provides-Extra: torch-npu
59
+ Requires-Dist: torch==2.7.1; extra == "torch-npu"
60
+ Requires-Dist: torch-npu==2.7.1; extra == "torch-npu"
61
+ Requires-Dist: torchvision==0.22.1; extra == "torch-npu"
62
+ Requires-Dist: decorator; extra == "torch-npu"
63
+ Provides-Extra: metrics
64
+ Requires-Dist: nltk; extra == "metrics"
65
+ Requires-Dist: jieba; extra == "metrics"
66
+ Requires-Dist: rouge-chinese; extra == "metrics"
67
+ Provides-Extra: deepspeed
68
+ Requires-Dist: deepspeed<=0.16.9,>=0.10.0; extra == "deepspeed"
69
+ Provides-Extra: liger-kernel
70
+ Requires-Dist: liger-kernel>=0.5.5; extra == "liger-kernel"
71
+ Provides-Extra: bitsandbytes
72
+ Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes"
73
+ Provides-Extra: hqq
74
+ Requires-Dist: hqq; extra == "hqq"
75
+ Provides-Extra: eetq
76
+ Requires-Dist: eetq; extra == "eetq"
77
+ Provides-Extra: gptq
78
+ Requires-Dist: optimum>=1.24.0; extra == "gptq"
79
+ Requires-Dist: gptqmodel>=2.0.0; extra == "gptq"
80
+ Provides-Extra: aqlm
81
+ Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm"
82
+ Provides-Extra: vllm
83
+ Requires-Dist: vllm<=0.11.0,>=0.4.3; extra == "vllm"
84
+ Provides-Extra: sglang
85
+ Requires-Dist: sglang[srt]>=0.4.5; extra == "sglang"
86
+ Requires-Dist: transformers==4.51.1; extra == "sglang"
87
+ Provides-Extra: galore
88
+ Requires-Dist: galore-torch; extra == "galore"
89
+ Provides-Extra: apollo
90
+ Requires-Dist: apollo-torch; extra == "apollo"
91
+ Provides-Extra: badam
92
+ Requires-Dist: badam>=1.2.1; extra == "badam"
93
+ Provides-Extra: adam-mini
94
+ Requires-Dist: adam-mini; extra == "adam-mini"
95
+ Provides-Extra: minicpm-v
96
+ Requires-Dist: soundfile; extra == "minicpm-v"
97
+ Requires-Dist: torchvision; extra == "minicpm-v"
98
+ Requires-Dist: torchaudio; extra == "minicpm-v"
99
+ Requires-Dist: vector_quantize_pytorch; extra == "minicpm-v"
100
+ Requires-Dist: vocos; extra == "minicpm-v"
101
+ Requires-Dist: msgpack; extra == "minicpm-v"
102
+ Requires-Dist: referencing; extra == "minicpm-v"
103
+ Requires-Dist: jsonschema_specifications; extra == "minicpm-v"
104
+ Provides-Extra: openmind
105
+ Requires-Dist: openmind; extra == "openmind"
106
+ Provides-Extra: swanlab
107
+ Requires-Dist: swanlab; extra == "swanlab"
108
+ Provides-Extra: fp8
109
+ Requires-Dist: torchao>=0.8.0; extra == "fp8"
110
+ Requires-Dist: accelerate>=1.10.0; extra == "fp8"
111
+ Provides-Extra: fp8-te
112
+ Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-te"
113
+ Requires-Dist: accelerate>=1.10.0; extra == "fp8-te"
114
+ Provides-Extra: fp8-all
115
+ Requires-Dist: torchao>=0.8.0; extra == "fp8-all"
116
+ Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-all"
117
+ Requires-Dist: accelerate>=1.10.0; extra == "fp8-all"
118
+ Provides-Extra: dev
119
+ Requires-Dist: pre-commit; extra == "dev"
120
+ Requires-Dist: ruff; extra == "dev"
121
+ Requires-Dist: pytest; extra == "dev"
122
+ Requires-Dist: build; extra == "dev"
123
+ Dynamic: author
124
+ Dynamic: author-email
125
+ Dynamic: classifier
126
+ Dynamic: description
127
+ Dynamic: description-content-type
128
+ Dynamic: home-page
129
+ Dynamic: keywords
130
+ Dynamic: license
131
+ Dynamic: license-file
132
+ Dynamic: provides-extra
133
+ Dynamic: requires-dist
134
+ Dynamic: requires-python
135
+ Dynamic: summary
136
+
137
+ ![# LLaMA Factory](assets/logo.png)
138
+
139
+ [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
140
+ [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
141
+ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
142
+ [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
143
+ [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
144
+ [![Citation](https://img.shields.io/badge/citation-1000+-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
145
+ [![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags)
146
+
147
+ [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
148
+ [![Discord](assets/thirdparty/discord.svg)](https://discord.gg/rKfvV9r9FK)
149
+ [![WeChat](https://img.shields.io/badge/WeChat-User%20Group-blue?logo=wechat)](https://github.com/hiyouga/llamafactory-community)
150
+ [![Blog](https://img.shields.io/badge/Hugo-Official%20Blog-blue?logo=hugo)](https://blog.llamafactory.net/en/)
151
+
152
+ [![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
153
+ [![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
154
+ [![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
155
+ [![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
156
+ [![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
157
+ [![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
158
+ [![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
159
+
160
+ ### Used by [Amazon](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/), [NVIDIA](https://developer.nvidia.com/rtx/ai-toolkit), [Aliyun](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory), etc.
161
+
162
+ <div align="center" markdown="1">
163
+
164
+ ### Supporters ❤️
165
+
166
+ | <div style="text-align: center;"><a href="https://warp.dev/llama-factory"><img alt="Warp sponsorship" width="400" src="assets/sponsors/warp.jpg"></a><br><a href="https://warp.dev/llama-factory" style="font-size:larger;">Warp, the agentic terminal for developers</a><br><a href="https://warp.dev/llama-factory">Available for MacOS, Linux, & Windows</a> | <a href="https://serpapi.com"><img alt="SerpAPI sponsorship" width="250" src="assets/sponsors/serpapi.svg"> </a> |
167
+ | ---- | ---- |
168
+
169
+ ----
170
+
171
+ ### Easily fine-tune 100+ large language models with zero-code [CLI](#quickstart) and [Web UI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
172
+
173
+ ![GitHub Trend](https://trendshift.io/api/badge/repositories/4535)
174
+
175
+ </div>
176
+
177
+ 👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
178
+
179
+ \[ English | [中文](README_zh.md) \]
180
+
181
+ **Fine-tuning a large language model can be easy as...**
182
+
183
+ https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
184
+
185
+ Start local training:
186
+ - Please refer to [usage](#getting-started)
187
+
188
+ Start cloud training:
189
+ - **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
190
+ - **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
191
+ - **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
192
+ - **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
193
+
194
+ Read technical notes:
195
+ - **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
196
+ - **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
197
+ - **Official Blog**: https://blog.llamafactory.net/en/
198
+ - **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
199
+
200
+ > [!NOTE]
201
+ > Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
202
+
203
+ ## Table of Contents
204
+
205
+ - [Features](#features)
206
+ - [Blogs](#blogs)
207
+ - [Changelog](#changelog)
208
+ - [Supported Models](#supported-models)
209
+ - [Supported Training Approaches](#supported-training-approaches)
210
+ - [Provided Datasets](#provided-datasets)
211
+ - [Requirement](#requirement)
212
+ - [Getting Started](#getting-started)
213
+ - [Installation](#installation)
214
+ - [Data Preparation](#data-preparation)
215
+ - [Quickstart](#quickstart)
216
+ - [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
217
+ - [LLaMA Factory Online](#llama-factory-online)
218
+ - [Build Docker](#build-docker)
219
+ - [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
220
+ - [Download from ModelScope Hub](#download-from-modelscope-hub)
221
+ - [Download from Modelers Hub](#download-from-modelers-hub)
222
+ - [Use W&B Logger](#use-wb-logger)
223
+ - [Use SwanLab Logger](#use-swanlab-logger)
224
+ - [Projects using LLaMA Factory](#projects-using-llama-factory)
225
+ - [License](#license)
226
+ - [Citation](#citation)
227
+ - [Acknowledgement](#acknowledgement)
228
+
229
+ ## Features
230
+
231
+ - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
232
+ - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
233
+ - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
234
+ - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
235
+ - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
236
+ - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
237
+ - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
238
+ - **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
239
+
240
+ ### Day-N Support for Fine-Tuning Cutting-Edge Models
241
+
242
+ | Support Date | Model Name |
243
+ | ------------ | -------------------------------------------------------------------- |
244
+ | Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 |
245
+ | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
246
+
247
+ ## Blogs
248
+
249
+ > [!TIP]
250
+ > Now we have a dedicated blog for LLaMA Factory!
251
+ >
252
+ > Website: https://blog.llamafactory.net/en/
253
+
254
+ - 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
255
+ - [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
256
+ - [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
257
+ - [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
258
+ - [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
259
+
260
+ <details><summary>All Blogs</summary>
261
+
262
+ - [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
263
+ - [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
264
+ - [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
265
+ - [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
266
+ - [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
267
+ - [LLaMA Factory: Fine-tuning Llama3 for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
268
+
269
+ </details>
270
+
271
+ ## Changelog
272
+
273
+ [25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started.
274
+
275
+ [25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
276
+
277
+ [25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
278
+
279
+ [25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started.
280
+
281
+ <details><summary>Full Changelog</summary>
282
+
283
+ [25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model.
284
+
285
+ [25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
286
+
287
+ [25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
288
+
289
+ [25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
290
+
291
+ [25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models.
292
+
293
+ [25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
294
+
295
+ [25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
296
+
297
+ [25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
298
+
299
+ [25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
300
+
301
+ [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
302
+
303
+ [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
304
+
305
+ [25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
306
+
307
+ [25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** models.
308
+
309
+ [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
310
+
311
+ [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
312
+
313
+ [25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
314
+
315
+ [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
316
+
317
+ [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
318
+
319
+ [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
320
+
321
+ [24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
322
+
323
+ [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
324
+
325
+ [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
326
+
327
+ [24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
328
+
329
+ [24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
330
+
331
+ [24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
332
+
333
+ [24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
334
+
335
+ [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
336
+
337
+ [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
338
+
339
+ [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
340
+
341
+ [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
342
+
343
+ [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
344
+
345
+ [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
346
+
347
+ [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
348
+
349
+ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
350
+
351
+ [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
352
+
353
+ [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
354
+
355
+ [24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
356
+
357
+ [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
358
+
359
+ [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
360
+
361
+ [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
362
+
363
+ [24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
364
+
365
+ [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
366
+
367
+ [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
368
+
369
+ [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
370
+
371
+ [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
372
+
373
+ [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
374
+
375
+ [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
376
+
377
+ [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
378
+
379
+ [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
380
+
381
+ [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
382
+
383
+ [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
384
+
385
+ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
386
+
387
+ [23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
388
+
389
+ [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
390
+
391
+ [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
392
+
393
+ [23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
394
+
395
+ [23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
396
+
397
+ [23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
398
+
399
+ [23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
400
+
401
+ [23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
402
+
403
+ [23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
404
+
405
+ [23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
406
+
407
+ </details>
408
+
409
+ > [!TIP]
410
+ > If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
411
+
412
+ ## Supported Models
413
+
414
+ | Model | Model size | Template |
415
+ | ----------------------------------------------------------------- | -------------------------------- | -------------------- |
416
+ | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
417
+ | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
418
+ | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
419
+ | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
420
+ | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
421
+ | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
422
+ | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
423
+ | [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
424
+ | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
425
+ | [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
426
+ | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
427
+ | [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
428
+ | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
429
+ | [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
430
+ | [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
431
+ | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
432
+ | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
433
+ | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
434
+ | [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
435
+ | [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
436
+ | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
437
+ | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
438
+ | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
439
+ | [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
440
+ | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
441
+ | [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
442
+ | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
443
+ | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
444
+ | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
445
+ | [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 |
446
+ | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
447
+ | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
448
+ | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
449
+ | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
450
+ | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
451
+ | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
452
+ | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
453
+ | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
454
+ | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
455
+ | [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
456
+ | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
457
+ | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
458
+ | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
459
+ | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
460
+ | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
461
+ | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
462
+ | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
463
+ | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
464
+ | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
465
+ | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
466
+ | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
467
+ | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
468
+ | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
469
+ | [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
470
+ | [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
471
+ | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
472
+ | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
473
+ | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
474
+ | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
475
+ | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
476
+ | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
477
+ | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
478
+
479
+ > [!NOTE]
480
+ > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
481
+ >
482
+ > If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
483
+ >
484
+ > Remember to use the **SAME** template in training and inference.
485
+ >
486
+ > \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
487
+ >
488
+ > \*\*: You need to install a specific version of `transformers` to use the corresponding model.
489
+
490
+ Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
491
+
492
+ You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
493
+
494
+ ## Supported Training Approaches
495
+
496
+ | Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT |
497
+ | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
498
+ | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
499
+ | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
500
+ | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
501
+ | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
502
+ | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
503
+ | KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
504
+ | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
505
+ | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
506
+
507
+ > [!TIP]
508
+ > The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
509
+
510
+ ## Provided Datasets
511
+
512
+ <details><summary>Pre-training datasets</summary>
513
+
514
+ - [Wiki Demo (en)](data/wiki_demo.txt)
515
+ - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
516
+ - [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
517
+ - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
518
+ - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
519
+ - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
520
+ - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
521
+ - [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
522
+ - [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
523
+ - [CCI3-HQ (zh)](https://huggingface.co/datasets/BAAI/CCI3-HQ)
524
+ - [CCI3-Data (zh)](https://huggingface.co/datasets/BAAI/CCI3-Data)
525
+ - [CCI4.0-M2-Base-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Base-v1)
526
+ - [CCI4.0-M2-CoT-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-CoT-v1)
527
+ - [CCI4.0-M2-Extra-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Extra-v1)
528
+ - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
529
+ - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
530
+
531
+ </details>
532
+
533
+ <details><summary>Supervised fine-tuning datasets</summary>
534
+
535
+ - [Identity (en&zh)](data/identity.json)
536
+ - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
537
+ - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
538
+ - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
539
+ - [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
540
+ - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
541
+ - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
542
+ - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
543
+ - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
544
+ - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
545
+ - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
546
+ - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
547
+ - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
548
+ - [UltraChat (en)](https://github.com/thunlp/UltraChat)
549
+ - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
550
+ - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
551
+ - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
552
+ - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
553
+ - [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
554
+ - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
555
+ - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
556
+ - [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
557
+ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
558
+ - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
559
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
560
+ - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
561
+ - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
562
+ - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
563
+ - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
564
+ - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
565
+ - [Infinity Instruct (zh)](https://huggingface.co/datasets/BAAI/Infinity-Instruct)
566
+ - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
567
+ - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
568
+ - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
569
+ - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
570
+ - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
571
+ - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
572
+ - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
573
+ - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
574
+ - [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
575
+ - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
576
+ - [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
577
+ - [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
578
+ - [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
579
+ - [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
580
+ - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
581
+ - [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
582
+ - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
583
+ - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
584
+ - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
585
+ - [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
586
+ - [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
587
+ - [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
588
+ - [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
589
+ - [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
590
+ - [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
591
+
592
+ </details>
593
+
594
+ <details><summary>Preference datasets</summary>
595
+
596
+ - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
597
+ - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
598
+ - [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
599
+ - [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
600
+ - [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
601
+ - [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
602
+ - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
603
+ - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
604
+ - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
605
+ - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
606
+ - [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
607
+
608
+ </details>
609
+
610
+ Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
611
+
612
+ ```bash
613
+ pip install "huggingface_hub<1.0.0"
614
+ huggingface-cli login
615
+ ```
616
+
617
+ ## Requirement
618
+
619
+ | Mandatory | Minimum | Recommend |
620
+ | ------------ | ------- | --------- |
621
+ | python | 3.9 | 3.10 |
622
+ | torch | 2.0.0 | 2.6.0 |
623
+ | torchvision | 0.15.0 | 0.21.0 |
624
+ | transformers | 4.49.0 | 4.50.0 |
625
+ | datasets | 2.16.0 | 3.2.0 |
626
+ | accelerate | 0.34.0 | 1.2.1 |
627
+ | peft | 0.14.0 | 0.15.1 |
628
+ | trl | 0.8.6 | 0.9.6 |
629
+
630
+ | Optional | Minimum | Recommend |
631
+ | ------------ | ------- | --------- |
632
+ | CUDA | 11.6 | 12.2 |
633
+ | deepspeed | 0.10.0 | 0.16.4 |
634
+ | bitsandbytes | 0.39.0 | 0.43.1 |
635
+ | vllm | 0.4.3 | 0.8.2 |
636
+ | flash-attn | 2.5.6 | 2.7.2 |
637
+
638
+ ### Hardware Requirement
639
+
640
+ \* *estimated*
641
+
642
+ | Method | Bits | 7B | 14B | 30B | 70B | `x`B |
643
+ | ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
644
+ | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
645
+ | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
646
+ | Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
647
+ | QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
648
+ | QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
649
+ | QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
650
+
651
+ ## Getting Started
652
+
653
+ ### Installation
654
+
655
+ > [!IMPORTANT]
656
+ > Installation is mandatory.
657
+
658
+ #### Install from Source
659
+
660
+ ```bash
661
+ git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
662
+ cd LLaMA-Factory
663
+ pip install -e ".[torch,metrics]" --no-build-isolation
664
+ ```
665
+
666
+ Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
667
+
668
+ #### Install from Docker Image
669
+
670
+ ```bash
671
+ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
672
+ ```
673
+
674
+ This image is built on Ubuntu 22.04 (x86\_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4.
675
+
676
+ Find the pre-built images: https://hub.docker.com/r/hiyouga/llamafactory/tags
677
+
678
+ Please refer to [build docker](#build-docker) to build the image yourself.
679
+
680
+ <details><summary>Setting up a virtual environment with <b>uv</b></summary>
681
+
682
+ Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
683
+
684
+ ```bash
685
+ uv sync --extra torch --extra metrics --prerelease=allow
686
+ ```
687
+
688
+ Run LLaMA-Factory in the isolated environment:
689
+
690
+ ```bash
691
+ uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
692
+ ```
693
+
694
+ </details>
695
+
696
+ <details><summary>For Windows users</summary>
697
+
698
+ #### Install PyTorch
699
+
700
+ You need to manually install the GPU version of PyTorch on the Windows platform. Please refer to the [official website](https://pytorch.org/get-started/locally/) and the following command to install PyTorch with CUDA support:
701
+
702
+ ```bash
703
+ pip uninstall torch torchvision torchaudio
704
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
705
+ python -c "import torch; print(torch.cuda.is_available())"
706
+ ```
707
+
708
+ If you see `True` then you have successfully installed PyTorch with CUDA support.
709
+
710
+ Try `dataloader_num_workers: 0` if you encounter `Can't pickle local object` error.
711
+
712
+ #### Install BitsAndBytes
713
+
714
+ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
715
+
716
+ ```bash
717
+ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
718
+ ```
719
+
720
+ #### Install Flash Attention-2
721
+
722
+ To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself.
723
+
724
+ </details>
725
+
726
+ <details><summary>For Ascend NPU users</summary>
727
+
728
+ To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
729
+
730
+ ```bash
731
+ # replace the url according to your CANN version and devices
732
+ # install CANN Toolkit
733
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
734
+ bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
735
+
736
+ # install CANN Kernels
737
+ wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
738
+ bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
739
+
740
+ # set env variables
741
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
742
+ ```
743
+
744
+ | Requirement | Minimum | Recommend |
745
+ | ------------ | ------- | -------------- |
746
+ | CANN | 8.0.RC1 | 8.0.0.alpha002 |
747
+ | torch | 2.1.0 | 2.4.0 |
748
+ | torch-npu | 2.1.0 | 2.4.0.post2 |
749
+ | deepspeed | 0.13.2 | 0.13.2 |
750
+ | vllm-ascend | - | 0.7.3 |
751
+
752
+ Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
753
+
754
+ If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
755
+
756
+ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
757
+
758
+ #### Install BitsAndBytes
759
+
760
+ To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
761
+
762
+ 1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
763
+
764
+ ```bash
765
+ # Install bitsandbytes from source
766
+ # Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
767
+ git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
768
+ cd bitsandbytes/
769
+
770
+ # Install dependencies
771
+ pip install -r requirements-dev.txt
772
+
773
+ # Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference
774
+ apt-get install -y build-essential cmake
775
+
776
+ # Compile & install
777
+ cmake -DCOMPUTE_BACKEND=npu -S .
778
+ make
779
+ pip install .
780
+ ```
781
+
782
+ 2. Install transformers from the main branch.
783
+
784
+ ```bash
785
+ git clone -b main https://github.com/huggingface/transformers.git
786
+ cd transformers
787
+ pip install .
788
+ ```
789
+
790
+ 3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
791
+
792
+ </details>
793
+
794
+ ### Data Preparation
795
+
796
+ Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace / ModelScope / Modelers hub, load the dataset in local disk, or specify a path to s3/gcs cloud storage.
797
+
798
+ > [!NOTE]
799
+ > Please update `data/dataset_info.json` to use your custom dataset.
800
+
801
+ You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, **[DataFlow](https://github.com/OpenDCAI/DataFlow)** and **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
802
+
803
+ ### Quickstart
804
+
805
+ Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
806
+
807
+ ```bash
808
+ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
809
+ llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
810
+ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
811
+ ```
812
+
813
+ See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
814
+
815
+ > [!TIP]
816
+ > Use `llamafactory-cli help` to show help information.
817
+ >
818
+ > Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
819
+
820
+ ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
821
+
822
+ ```bash
823
+ llamafactory-cli webui
824
+ ```
825
+
826
+ ### LLaMA Factory Online
827
+
828
+ Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
829
+
830
+ ### Build Docker
831
+
832
+ For CUDA users:
833
+
834
+ ```bash
835
+ cd docker/docker-cuda/
836
+ docker compose up -d
837
+ docker compose exec llamafactory bash
838
+ ```
839
+
840
+ For Ascend NPU users:
841
+
842
+ ```bash
843
+ cd docker/docker-npu/
844
+ docker compose up -d
845
+ docker compose exec llamafactory bash
846
+ ```
847
+
848
+ For AMD ROCm users:
849
+
850
+ ```bash
851
+ cd docker/docker-rocm/
852
+ docker compose up -d
853
+ docker compose exec llamafactory bash
854
+ ```
855
+
856
+ <details><summary>Build without Docker Compose</summary>
857
+
858
+ For CUDA users:
859
+
860
+ ```bash
861
+ docker build -f ./docker/docker-cuda/Dockerfile \
862
+ --build-arg PIP_INDEX=https://pypi.org/simple \
863
+ --build-arg EXTRAS=metrics \
864
+ -t llamafactory:latest .
865
+
866
+ docker run -dit --ipc=host --gpus=all \
867
+ -p 7860:7860 \
868
+ -p 8000:8000 \
869
+ --name llamafactory \
870
+ llamafactory:latest
871
+
872
+ docker exec -it llamafactory bash
873
+ ```
874
+
875
+ For Ascend NPU users:
876
+
877
+ ```bash
878
+ docker build -f ./docker/docker-npu/Dockerfile \
879
+ --build-arg PIP_INDEX=https://pypi.org/simple \
880
+ --build-arg EXTRAS=torch-npu,metrics \
881
+ -t llamafactory:latest .
882
+
883
+ docker run -dit --ipc=host \
884
+ -v /usr/local/dcmi:/usr/local/dcmi \
885
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
886
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
887
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
888
+ -p 7860:7860 \
889
+ -p 8000:8000 \
890
+ --device /dev/davinci0 \
891
+ --device /dev/davinci_manager \
892
+ --device /dev/devmm_svm \
893
+ --device /dev/hisi_hdc \
894
+ --name llamafactory \
895
+ llamafactory:latest
896
+
897
+ docker exec -it llamafactory bash
898
+ ```
899
+
900
+ For AMD ROCm users:
901
+
902
+ ```bash
903
+ docker build -f ./docker/docker-rocm/Dockerfile \
904
+ --build-arg PIP_INDEX=https://pypi.org/simple \
905
+ --build-arg EXTRAS=metrics \
906
+ -t llamafactory:latest .
907
+
908
+ docker run -dit --ipc=host \
909
+ -p 7860:7860 \
910
+ -p 8000:8000 \
911
+ --device /dev/kfd \
912
+ --device /dev/dri \
913
+ --name llamafactory \
914
+ llamafactory:latest
915
+
916
+ docker exec -it llamafactory bash
917
+ ```
918
+
919
+ </details>
920
+
921
+ <details><summary>Use Docker volumes</summary>
922
+
923
+ You can uncomment `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` in the Dockerfile to use data volumes.
924
+
925
+ When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` argument to mount the local directory to the container. The following data volumes are available.
926
+
927
+ - `hf_cache`: Utilize Hugging Face cache on the host machine.
928
+ - `shared_data`: The directionary to store datasets on the host machine.
929
+ - `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
930
+
931
+ </details>
932
+
933
+ ### Deploy with OpenAI-style API and vLLM
934
+
935
+ ```bash
936
+ API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
937
+ ```
938
+
939
+ > [!TIP]
940
+ > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
941
+ >
942
+ > Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
943
+
944
+ ### Download from ModelScope Hub
945
+
946
+ If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
947
+
948
+ ```bash
949
+ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
950
+ ```
951
+
952
+ Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
953
+
954
+ ### Download from Modelers Hub
955
+
956
+ You can also use Modelers Hub to download models and datasets.
957
+
958
+ ```bash
959
+ export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
960
+ ```
961
+
962
+ Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
963
+
964
+ ### Use W&B Logger
965
+
966
+ To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
967
+
968
+ ```yaml
969
+ report_to: wandb
970
+ run_name: test_run # optional
971
+ ```
972
+
973
+ Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
974
+
975
+ ### Use SwanLab Logger
976
+
977
+ To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files.
978
+
979
+ ```yaml
980
+ use_swanlab: true
981
+ swanlab_run_name: test_run # optional
982
+ ```
983
+
984
+ When launching training tasks, you can log in to SwanLab in three ways:
985
+
986
+ 1. Add `swanlab_api_key=<your_api_key>` to the yaml file, and set it to your [API key](https://swanlab.cn/settings).
987
+ 2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings).
988
+ 3. Use the `swanlab login` command to complete the login.
989
+
990
+ ## Projects using LLaMA Factory
991
+
992
+ If you have a project that should be incorporated, please contact via email or create a pull request.
993
+
994
+ <details><summary>Click to show</summary>
995
+
996
+ 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
997
+ 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
998
+ 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
999
+ 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
1000
+ 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
1001
+ 1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
1002
+ 1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1003
+ 1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1004
+ 1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1005
+ 1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1006
+ 1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1007
+ 1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1008
+ 1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1009
+ 1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
1010
+ 1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1011
+ 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1012
+ 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
1013
+ 1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
1014
+ 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1015
+ 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1016
+ 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1017
+ 1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1018
+ 1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1019
+ 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1020
+ 1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1021
+ 1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1022
+ 1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1023
+ 1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1024
+ 1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1025
+ 1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1026
+ 1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1027
+ 1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1028
+ 1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1029
+ 1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1030
+ 1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
1031
+ 1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
1032
+ 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1033
+ 1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
1034
+ 1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
1035
+ 1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
1036
+ 1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
1037
+ 1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
1038
+ 1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
1039
+ 1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
1040
+ 1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
1041
+ 1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
1042
+ 1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
1043
+ 1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
1044
+ 1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
1045
+ 1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
1046
+ 1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
1047
+ 1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
1048
+ 1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
1049
+ 1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
1050
+ 1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
1051
+ 1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
1052
+ 1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
1053
+ 1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
1054
+ 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1055
+ 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1056
+ 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1057
+ 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1058
+ 1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1059
+ 1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1060
+ 1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1061
+ 1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1062
+ 1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1063
+ 1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1064
+ 1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1065
+ 1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1066
+ 1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1067
+ 1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1068
+ 1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1069
+ 1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1070
+ 1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1071
+ 1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1072
+ 1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1073
+ 1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1074
+ 1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1075
+ 1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1076
+ 1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1077
+ 1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
1078
+ 1. Zhang et al. CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling. ACL 2024. [[paper]](https://aclanthology.org/2024.findings-acl.830.pdf)
1079
+ 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1080
+ 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1081
+ 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1082
+ 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1083
+ 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1084
+ 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1085
+ 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
1086
+ 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
1087
+ 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1088
+ 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1089
+ 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
1090
+ 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
1091
+ 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
1092
+ 1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
1093
+ 1. **[EmoLLM](https://github.com/SmartFlowAI/EmoLLM)**: A project about large language models (LLMs) and mental health.
1094
+ </details>
1095
+
1096
+ ## License
1097
+
1098
+ This repository is licensed under the [Apache-2.0 License](LICENSE).
1099
+
1100
+ Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
1101
+
1102
+ ## Citation
1103
+
1104
+ If this work is helpful, please kindly cite as:
1105
+
1106
+ ```bibtex
1107
+ @inproceedings{zheng2024llamafactory,
1108
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
1109
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
1110
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
1111
+ address={Bangkok, Thailand},
1112
+ publisher={Association for Computational Linguistics},
1113
+ year={2024},
1114
+ url={http://arxiv.org/abs/2403.13372}
1115
+ }
1116
+ ```
1117
+
1118
+ ## Acknowledgement
1119
+
1120
+ This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
1121
+
1122
+ ## Star History
1123
+
1124
+ ![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date)
llamafactory.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ requirements.txt
6
+ setup.py
7
+ src/llamafactory/__init__.py
8
+ src/llamafactory/cli.py
9
+ src/llamafactory/launcher.py
10
+ src/llamafactory.egg-info/PKG-INFO
11
+ src/llamafactory.egg-info/SOURCES.txt
12
+ src/llamafactory.egg-info/dependency_links.txt
13
+ src/llamafactory.egg-info/entry_points.txt
14
+ src/llamafactory.egg-info/requires.txt
15
+ src/llamafactory.egg-info/top_level.txt
16
+ src/llamafactory/api/__init__.py
17
+ src/llamafactory/api/app.py
18
+ src/llamafactory/api/chat.py
19
+ src/llamafactory/api/common.py
20
+ src/llamafactory/api/protocol.py
21
+ src/llamafactory/chat/__init__.py
22
+ src/llamafactory/chat/base_engine.py
23
+ src/llamafactory/chat/chat_model.py
24
+ src/llamafactory/chat/hf_engine.py
25
+ src/llamafactory/chat/kt_engine.py
26
+ src/llamafactory/chat/sglang_engine.py
27
+ src/llamafactory/chat/vllm_engine.py
28
+ src/llamafactory/data/__init__.py
29
+ src/llamafactory/data/collator.py
30
+ src/llamafactory/data/converter.py
31
+ src/llamafactory/data/data_utils.py
32
+ src/llamafactory/data/formatter.py
33
+ src/llamafactory/data/loader.py
34
+ src/llamafactory/data/mm_plugin.py
35
+ src/llamafactory/data/parser.py
36
+ src/llamafactory/data/template.py
37
+ src/llamafactory/data/tool_utils.py
38
+ src/llamafactory/data/processor/__init__.py
39
+ src/llamafactory/data/processor/feedback.py
40
+ src/llamafactory/data/processor/pairwise.py
41
+ src/llamafactory/data/processor/pretrain.py
42
+ src/llamafactory/data/processor/processor_utils.py
43
+ src/llamafactory/data/processor/supervised.py
44
+ src/llamafactory/data/processor/unsupervised.py
45
+ src/llamafactory/eval/__init__.py
46
+ src/llamafactory/eval/evaluator.py
47
+ src/llamafactory/eval/template.py
48
+ src/llamafactory/extras/__init__.py
49
+ src/llamafactory/extras/constants.py
50
+ src/llamafactory/extras/env.py
51
+ src/llamafactory/extras/logging.py
52
+ src/llamafactory/extras/misc.py
53
+ src/llamafactory/extras/packages.py
54
+ src/llamafactory/extras/ploting.py
55
+ src/llamafactory/hparams/__init__.py
56
+ src/llamafactory/hparams/data_args.py
57
+ src/llamafactory/hparams/evaluation_args.py
58
+ src/llamafactory/hparams/finetuning_args.py
59
+ src/llamafactory/hparams/generating_args.py
60
+ src/llamafactory/hparams/model_args.py
61
+ src/llamafactory/hparams/parser.py
62
+ src/llamafactory/hparams/training_args.py
63
+ src/llamafactory/model/__init__.py
64
+ src/llamafactory/model/adapter.py
65
+ src/llamafactory/model/loader.py
66
+ src/llamafactory/model/patcher.py
67
+ src/llamafactory/model/model_utils/__init__.py
68
+ src/llamafactory/model/model_utils/attention.py
69
+ src/llamafactory/model/model_utils/checkpointing.py
70
+ src/llamafactory/model/model_utils/embedding.py
71
+ src/llamafactory/model/model_utils/ktransformers.py
72
+ src/llamafactory/model/model_utils/kv_cache.py
73
+ src/llamafactory/model/model_utils/liger_kernel.py
74
+ src/llamafactory/model/model_utils/longlora.py
75
+ src/llamafactory/model/model_utils/misc.py
76
+ src/llamafactory/model/model_utils/mod.py
77
+ src/llamafactory/model/model_utils/moe.py
78
+ src/llamafactory/model/model_utils/packing.py
79
+ src/llamafactory/model/model_utils/quantization.py
80
+ src/llamafactory/model/model_utils/rope.py
81
+ src/llamafactory/model/model_utils/unsloth.py
82
+ src/llamafactory/model/model_utils/valuehead.py
83
+ src/llamafactory/model/model_utils/visual.py
84
+ src/llamafactory/third_party/__init__.py
85
+ src/llamafactory/third_party/muon/__init__.py
86
+ src/llamafactory/third_party/muon/muon.py
87
+ src/llamafactory/train/__init__.py
88
+ src/llamafactory/train/callbacks.py
89
+ src/llamafactory/train/fp8_utils.py
90
+ src/llamafactory/train/test_utils.py
91
+ src/llamafactory/train/trainer_utils.py
92
+ src/llamafactory/train/tuner.py
93
+ src/llamafactory/train/dpo/__init__.py
94
+ src/llamafactory/train/dpo/trainer.py
95
+ src/llamafactory/train/dpo/workflow.py
96
+ src/llamafactory/train/ksft/__init__.py
97
+ src/llamafactory/train/ksft/workflow.py
98
+ src/llamafactory/train/kto/__init__.py
99
+ src/llamafactory/train/kto/trainer.py
100
+ src/llamafactory/train/kto/workflow.py
101
+ src/llamafactory/train/mca/__init__.py
102
+ src/llamafactory/train/mca/trainer.py
103
+ src/llamafactory/train/mca/workflow.py
104
+ src/llamafactory/train/ppo/__init__.py
105
+ src/llamafactory/train/ppo/ppo_utils.py
106
+ src/llamafactory/train/ppo/trainer.py
107
+ src/llamafactory/train/ppo/workflow.py
108
+ src/llamafactory/train/pt/__init__.py
109
+ src/llamafactory/train/pt/trainer.py
110
+ src/llamafactory/train/pt/workflow.py
111
+ src/llamafactory/train/rm/__init__.py
112
+ src/llamafactory/train/rm/metric.py
113
+ src/llamafactory/train/rm/trainer.py
114
+ src/llamafactory/train/rm/workflow.py
115
+ src/llamafactory/train/sft/__init__.py
116
+ src/llamafactory/train/sft/metric.py
117
+ src/llamafactory/train/sft/trainer.py
118
+ src/llamafactory/train/sft/workflow.py
119
+ src/llamafactory/v1/__init__.py
120
+ src/llamafactory/v1/launcher.py
121
+ src/llamafactory/v1/config/__init__.py
122
+ src/llamafactory/v1/config/data_args.py
123
+ src/llamafactory/v1/config/model_args.py
124
+ src/llamafactory/v1/config/parser.py
125
+ src/llamafactory/v1/config/sample_args.py
126
+ src/llamafactory/v1/config/training_args.py
127
+ src/llamafactory/v1/core/__init__.py
128
+ src/llamafactory/v1/core/base_trainer.py
129
+ src/llamafactory/v1/core/chat_sampler.py
130
+ src/llamafactory/v1/core/data_engine.py
131
+ src/llamafactory/v1/core/model_engine.py
132
+ src/llamafactory/v1/plugins/__init__.py
133
+ src/llamafactory/v1/plugins/data_plugins/__init__.py
134
+ src/llamafactory/v1/plugins/data_plugins/converter.py
135
+ src/llamafactory/v1/plugins/data_plugins/loader.py
136
+ src/llamafactory/v1/plugins/data_plugins/template.py
137
+ src/llamafactory/v1/plugins/model_plugins/__init__.py
138
+ src/llamafactory/v1/plugins/model_plugins/added_token.py
139
+ src/llamafactory/v1/plugins/model_plugins/peft.py
140
+ src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py
141
+ src/llamafactory/v1/plugins/model_plugins/kernels/constants.py
142
+ src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
143
+ src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py
144
+ src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py
145
+ src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py
146
+ src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py
147
+ src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py
148
+ src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py
149
+ src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py
150
+ src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py
151
+ src/llamafactory/v1/plugins/sampler_plugins/__init__.py
152
+ src/llamafactory/v1/plugins/sampler_plugins/vllm.py
153
+ src/llamafactory/v1/plugins/trainer_plugins/__init__.py
154
+ src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py
155
+ src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py
156
+ src/llamafactory/v1/trainers/__init__.py
157
+ src/llamafactory/v1/trainers/dpo_trainer.py
158
+ src/llamafactory/v1/trainers/rm_trainer.py
159
+ src/llamafactory/v1/trainers/sft_trainer.py
160
+ src/llamafactory/webui/__init__.py
161
+ src/llamafactory/webui/chatter.py
162
+ src/llamafactory/webui/common.py
163
+ src/llamafactory/webui/control.py
164
+ src/llamafactory/webui/css.py
165
+ src/llamafactory/webui/engine.py
166
+ src/llamafactory/webui/interface.py
167
+ src/llamafactory/webui/locales.py
168
+ src/llamafactory/webui/manager.py
169
+ src/llamafactory/webui/runner.py
170
+ src/llamafactory/webui/components/__init__.py
171
+ src/llamafactory/webui/components/chatbot.py
172
+ src/llamafactory/webui/components/data.py
173
+ src/llamafactory/webui/components/eval.py
174
+ src/llamafactory/webui/components/export.py
175
+ src/llamafactory/webui/components/footer.py
176
+ src/llamafactory/webui/components/infer.py
177
+ src/llamafactory/webui/components/top.py
178
+ src/llamafactory/webui/components/train.py
llamafactory.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
llamafactory.egg-info/entry_points.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [console_scripts]
2
+ llamafactory-cli = llamafactory.cli:main
3
+ lmf = llamafactory.cli:main
llamafactory.egg-info/requires.txt ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets<=4.0.0,>=2.16.0
2
+ accelerate<=1.11.0,>=1.3.0
3
+ peft<=0.17.1,>=0.14.0
4
+ trl<=0.9.6,>=0.8.6
5
+ gradio<=5.45.0,>=4.38.0
6
+ matplotlib>=3.7.0
7
+ tyro<0.9.0
8
+ einops
9
+ numpy<2.0.0
10
+ pandas>=2.0.0
11
+ scipy
12
+ sentencepiece
13
+ tiktoken
14
+ modelscope>=1.14.0
15
+ hf-transfer
16
+ safetensors<=0.5.3
17
+ fire
18
+ omegaconf
19
+ packaging
20
+ protobuf
21
+ pyyaml
22
+ pydantic<=2.10.6
23
+ uvicorn
24
+ fastapi
25
+ sse-starlette
26
+ av
27
+ librosa
28
+ propcache!=0.4.0
29
+
30
+ [:python_version < "3.10"]
31
+ transformers!=4.52.0,<=4.56.2,>=4.49.0
32
+
33
+ [:python_version >= "3.10"]
34
+ transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0
35
+
36
+ [adam-mini]
37
+ adam-mini
38
+
39
+ [apollo]
40
+ apollo-torch
41
+
42
+ [aqlm]
43
+ aqlm[gpu]>=1.1.0
44
+
45
+ [badam]
46
+ badam>=1.2.1
47
+
48
+ [bitsandbytes]
49
+ bitsandbytes>=0.39.0
50
+
51
+ [deepspeed]
52
+ deepspeed<=0.16.9,>=0.10.0
53
+
54
+ [dev]
55
+ pre-commit
56
+ ruff
57
+ pytest
58
+ build
59
+
60
+ [eetq]
61
+ eetq
62
+
63
+ [fp8]
64
+ torchao>=0.8.0
65
+ accelerate>=1.10.0
66
+
67
+ [fp8-all]
68
+ torchao>=0.8.0
69
+ transformer_engine[pytorch]>=2.0.0
70
+ accelerate>=1.10.0
71
+
72
+ [fp8-te]
73
+ transformer_engine[pytorch]>=2.0.0
74
+ accelerate>=1.10.0
75
+
76
+ [galore]
77
+ galore-torch
78
+
79
+ [gptq]
80
+ optimum>=1.24.0
81
+ gptqmodel>=2.0.0
82
+
83
+ [hqq]
84
+ hqq
85
+
86
+ [liger-kernel]
87
+ liger-kernel>=0.5.5
88
+
89
+ [metrics]
90
+ nltk
91
+ jieba
92
+ rouge-chinese
93
+
94
+ [minicpm_v]
95
+ soundfile
96
+ torchvision
97
+ torchaudio
98
+ vector_quantize_pytorch
99
+ vocos
100
+ msgpack
101
+ referencing
102
+ jsonschema_specifications
103
+
104
+ [openmind]
105
+ openmind
106
+
107
+ [sglang]
108
+ sglang[srt]>=0.4.5
109
+ transformers==4.51.1
110
+
111
+ [swanlab]
112
+ swanlab
113
+
114
+ [torch]
115
+ torch>=2.0.0
116
+ torchvision>=0.15.0
117
+
118
+ [torch-npu]
119
+ torch==2.7.1
120
+ torch-npu==2.7.1
121
+ torchvision==0.22.1
122
+ decorator
123
+
124
+ [vllm]
125
+ vllm<=0.11.0,>=0.4.3
llamafactory.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ llamafactory
llamafactory/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Efficient fine-tuning of large language models.
16
+
17
+ Level:
18
+ api, webui > chat, eval, train > data, model > hparams > extras
19
+
20
+ Disable version checking: DISABLE_VERSION_CHECK=1
21
+ Enable VRAM recording: RECORD_VRAM=1
22
+ Force using torchrun: FORCE_TORCHRUN=1
23
+ Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
24
+ Use modelscope: USE_MODELSCOPE_HUB=1
25
+ Use openmind: USE_OPENMIND_HUB=1
26
+ """
27
+
28
+ from .extras.env import VERSION
29
+
30
+
31
+ __version__ = VERSION
llamafactory/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (612 Bytes). View file
 
llamafactory/__pycache__/cli.cpython-312.pyc ADDED
Binary file (581 Bytes). View file
 
llamafactory/__pycache__/launcher.cpython-312.pyc ADDED
Binary file (6.29 kB). View file
 
llamafactory/api/__init__.py ADDED
File without changes
llamafactory/api/app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ from contextlib import asynccontextmanager
18
+ from functools import partial
19
+ from typing import Annotated, Optional
20
+
21
+ from ..chat import ChatModel
22
+ from ..extras.constants import EngineName
23
+ from ..extras.misc import torch_gc
24
+ from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
25
+ from .chat import (
26
+ create_chat_completion_response,
27
+ create_score_evaluation_response,
28
+ create_stream_chat_completion_response,
29
+ )
30
+ from .protocol import (
31
+ ChatCompletionRequest,
32
+ ChatCompletionResponse,
33
+ ModelCard,
34
+ ModelList,
35
+ ScoreEvaluationRequest,
36
+ ScoreEvaluationResponse,
37
+ )
38
+
39
+
40
+ if is_fastapi_available():
41
+ from fastapi import Depends, FastAPI, HTTPException, status
42
+ from fastapi.middleware.cors import CORSMiddleware
43
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
44
+
45
+
46
+ if is_starlette_available():
47
+ from sse_starlette import EventSourceResponse
48
+
49
+
50
+ if is_uvicorn_available():
51
+ import uvicorn
52
+
53
+
54
+ async def sweeper() -> None:
55
+ while True:
56
+ torch_gc()
57
+ await asyncio.sleep(300)
58
+
59
+
60
+ @asynccontextmanager
61
+ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
62
+ if chat_model.engine.name == EngineName.HF:
63
+ asyncio.create_task(sweeper())
64
+
65
+ yield
66
+ torch_gc()
67
+
68
+
69
+ def create_app(chat_model: "ChatModel") -> "FastAPI":
70
+ root_path = os.getenv("FASTAPI_ROOT_PATH", "")
71
+ app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
72
+ app.add_middleware(
73
+ CORSMiddleware,
74
+ allow_origins=["*"],
75
+ allow_credentials=True,
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
+ )
79
+ api_key = os.getenv("API_KEY")
80
+ security = HTTPBearer(auto_error=False)
81
+
82
+ async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
83
+ if api_key and (auth is None or auth.credentials != api_key):
84
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
85
+
86
+ @app.get(
87
+ "/v1/models",
88
+ response_model=ModelList,
89
+ status_code=status.HTTP_200_OK,
90
+ dependencies=[Depends(verify_api_key)],
91
+ )
92
+ async def list_models():
93
+ model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
94
+ return ModelList(data=[model_card])
95
+
96
+ @app.post(
97
+ "/v1/chat/completions",
98
+ response_model=ChatCompletionResponse,
99
+ status_code=status.HTTP_200_OK,
100
+ dependencies=[Depends(verify_api_key)],
101
+ )
102
+ async def create_chat_completion(request: ChatCompletionRequest):
103
+ if not chat_model.engine.can_generate:
104
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
105
+
106
+ if request.stream:
107
+ generate = create_stream_chat_completion_response(request, chat_model)
108
+ return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
109
+ else:
110
+ return await create_chat_completion_response(request, chat_model)
111
+
112
+ @app.post(
113
+ "/v1/score/evaluation",
114
+ response_model=ScoreEvaluationResponse,
115
+ status_code=status.HTTP_200_OK,
116
+ dependencies=[Depends(verify_api_key)],
117
+ )
118
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
119
+ if chat_model.engine.can_generate:
120
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
121
+
122
+ return await create_score_evaluation_response(request, chat_model)
123
+
124
+ return app
125
+
126
+
127
+ def run_api() -> None:
128
+ chat_model = ChatModel()
129
+ app = create_app(chat_model)
130
+ api_host = os.getenv("API_HOST", "0.0.0.0")
131
+ api_port = int(os.getenv("API_PORT", "8000"))
132
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
133
+ uvicorn.run(app, host=api_host, port=api_port)
llamafactory/api/chat.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+ import json
18
+ import os
19
+ import re
20
+ import uuid
21
+ from collections.abc import AsyncGenerator
22
+ from typing import TYPE_CHECKING, Optional
23
+
24
+ from ..data import Role as DataRole
25
+ from ..extras import logging
26
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
27
+ from ..extras.misc import is_env_enabled
28
+ from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
29
+ from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
30
+ from .protocol import (
31
+ ChatCompletionMessage,
32
+ ChatCompletionResponse,
33
+ ChatCompletionResponseChoice,
34
+ ChatCompletionResponseUsage,
35
+ ChatCompletionStreamResponse,
36
+ ChatCompletionStreamResponseChoice,
37
+ Finish,
38
+ Function,
39
+ FunctionCall,
40
+ Role,
41
+ ScoreEvaluationResponse,
42
+ )
43
+
44
+
45
+ if is_fastapi_available():
46
+ from fastapi import HTTPException, status
47
+
48
+
49
+ if is_pillow_available():
50
+ from PIL import Image
51
+
52
+
53
+ if is_requests_available():
54
+ import requests
55
+
56
+
57
+ if TYPE_CHECKING:
58
+ from ..chat import ChatModel
59
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
60
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+ ROLE_MAPPING = {
65
+ Role.USER: DataRole.USER.value,
66
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
67
+ Role.SYSTEM: DataRole.SYSTEM.value,
68
+ Role.FUNCTION: DataRole.FUNCTION.value,
69
+ Role.TOOL: DataRole.OBSERVATION.value,
70
+ }
71
+
72
+
73
+ def _process_request(
74
+ request: "ChatCompletionRequest",
75
+ ) -> tuple[
76
+ list[dict[str, str]],
77
+ Optional[str],
78
+ Optional[str],
79
+ Optional[list["ImageInput"]],
80
+ Optional[list["VideoInput"]],
81
+ Optional[list["AudioInput"]],
82
+ ]:
83
+ if is_env_enabled("API_VERBOSE", "1"):
84
+ logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
85
+
86
+ if len(request.messages) == 0:
87
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
88
+
89
+ if request.messages[0].role == Role.SYSTEM:
90
+ content = request.messages.pop(0).content
91
+ system = content[0].text if isinstance(content, list) else content
92
+ else:
93
+ system = None
94
+
95
+ if len(request.messages) % 2 == 0:
96
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
97
+
98
+ input_messages = []
99
+ images, videos, audios = [], [], []
100
+ for i, message in enumerate(request.messages):
101
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
102
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
103
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
104
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
105
+
106
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
107
+ tool_calls = [
108
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
109
+ for tool_call in message.tool_calls
110
+ ]
111
+ content = json.dumps(tool_calls, ensure_ascii=False)
112
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
113
+ elif isinstance(message.content, list):
114
+ text_content = ""
115
+ for input_item in message.content:
116
+ if input_item.type == "text":
117
+ text_content += input_item.text
118
+ elif input_item.type == "image_url":
119
+ text_content += IMAGE_PLACEHOLDER
120
+ image_url = input_item.image_url.url
121
+ if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
122
+ image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
123
+ elif os.path.isfile(image_url): # local file
124
+ check_lfi_path(image_url)
125
+ image_stream = open(image_url, "rb")
126
+ else: # web uri
127
+ check_ssrf_url(image_url)
128
+ image_stream = requests.get(image_url, stream=True).raw
129
+
130
+ images.append(Image.open(image_stream).convert("RGB"))
131
+ elif input_item.type == "video_url":
132
+ text_content += VIDEO_PLACEHOLDER
133
+ video_url = input_item.video_url.url
134
+ if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
135
+ video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
136
+ elif os.path.isfile(video_url): # local file
137
+ check_lfi_path(video_url)
138
+ video_stream = video_url
139
+ else: # web uri
140
+ check_ssrf_url(video_url)
141
+ video_stream = requests.get(video_url, stream=True).raw
142
+
143
+ videos.append(video_stream)
144
+ elif input_item.type == "audio_url":
145
+ text_content += AUDIO_PLACEHOLDER
146
+ audio_url = input_item.audio_url.url
147
+ if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
148
+ audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
149
+ elif os.path.isfile(audio_url): # local file
150
+ check_lfi_path(audio_url)
151
+ audio_stream = audio_url
152
+ else: # web uri
153
+ check_ssrf_url(audio_url)
154
+ audio_stream = requests.get(audio_url, stream=True).raw
155
+
156
+ audios.append(audio_stream)
157
+ else:
158
+ raise HTTPException(
159
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
160
+ )
161
+
162
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
163
+ else:
164
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
165
+
166
+ tool_list = request.tools
167
+ if isinstance(tool_list, list) and len(tool_list):
168
+ try:
169
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
170
+ except json.JSONDecodeError:
171
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
172
+ else:
173
+ tools = None
174
+
175
+ return input_messages, system, tools, images or None, videos or None, audios or None
176
+
177
+
178
+ def _create_stream_chat_completion_chunk(
179
+ completion_id: str,
180
+ model: str,
181
+ delta: "ChatCompletionMessage",
182
+ index: Optional[int] = 0,
183
+ finish_reason: Optional["Finish"] = None,
184
+ ) -> str:
185
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
186
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
187
+ return jsonify(chunk)
188
+
189
+
190
+ async def create_chat_completion_response(
191
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
192
+ ) -> "ChatCompletionResponse":
193
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
194
+ input_messages, system, tools, images, videos, audios = _process_request(request)
195
+ responses = await chat_model.achat(
196
+ input_messages,
197
+ system,
198
+ tools,
199
+ images,
200
+ videos,
201
+ audios,
202
+ do_sample=request.do_sample,
203
+ temperature=request.temperature,
204
+ top_p=request.top_p,
205
+ max_new_tokens=request.max_tokens,
206
+ num_return_sequences=request.n,
207
+ repetition_penalty=request.presence_penalty,
208
+ stop=request.stop,
209
+ )
210
+
211
+ prompt_length, response_length = 0, 0
212
+ choices = []
213
+ for i, response in enumerate(responses):
214
+ if tools:
215
+ result = chat_model.engine.template.extract_tool(response.response_text)
216
+ else:
217
+ result = response.response_text
218
+
219
+ if isinstance(result, list):
220
+ tool_calls = []
221
+ for tool in result:
222
+ function = Function(name=tool.name, arguments=tool.arguments)
223
+ tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
224
+
225
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
226
+ finish_reason = Finish.TOOL
227
+ else:
228
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
229
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
230
+
231
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
232
+ prompt_length = response.prompt_length
233
+ response_length += response.response_length
234
+
235
+ usage = ChatCompletionResponseUsage(
236
+ prompt_tokens=prompt_length,
237
+ completion_tokens=response_length,
238
+ total_tokens=prompt_length + response_length,
239
+ )
240
+
241
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
242
+
243
+
244
+ async def create_stream_chat_completion_response(
245
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
246
+ ) -> AsyncGenerator[str, None]:
247
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
248
+ input_messages, system, tools, images, videos, audios = _process_request(request)
249
+ if tools:
250
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
251
+
252
+ if request.n > 1:
253
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
254
+
255
+ yield _create_stream_chat_completion_chunk(
256
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
257
+ )
258
+ async for new_token in chat_model.astream_chat(
259
+ input_messages,
260
+ system,
261
+ tools,
262
+ images,
263
+ videos,
264
+ audios,
265
+ do_sample=request.do_sample,
266
+ temperature=request.temperature,
267
+ top_p=request.top_p,
268
+ max_new_tokens=request.max_tokens,
269
+ repetition_penalty=request.presence_penalty,
270
+ stop=request.stop,
271
+ ):
272
+ if len(new_token) != 0:
273
+ yield _create_stream_chat_completion_chunk(
274
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
275
+ )
276
+
277
+ yield _create_stream_chat_completion_chunk(
278
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
279
+ )
280
+ yield "[DONE]"
281
+
282
+
283
+ async def create_score_evaluation_response(
284
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
285
+ ) -> "ScoreEvaluationResponse":
286
+ score_id = f"scoreval-{uuid.uuid4().hex}"
287
+ if len(request.messages) == 0:
288
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
289
+
290
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
291
+ return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
llamafactory/api/common.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ipaddress
16
+ import json
17
+ import os
18
+ import socket
19
+ from typing import TYPE_CHECKING, Any
20
+ from urllib.parse import urlparse
21
+
22
+ from ..extras.misc import is_env_enabled
23
+ from ..extras.packages import is_fastapi_available
24
+
25
+
26
+ if is_fastapi_available():
27
+ from fastapi import HTTPException, status
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from pydantic import BaseModel
32
+
33
+
34
+ SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
35
+ ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
36
+
37
+
38
+ def dictify(data: "BaseModel") -> dict[str, Any]:
39
+ try: # pydantic v2
40
+ return data.model_dump(exclude_unset=True)
41
+ except AttributeError: # pydantic v1
42
+ return data.dict(exclude_unset=True)
43
+
44
+
45
+ def jsonify(data: "BaseModel") -> str:
46
+ try: # pydantic v2
47
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
48
+ except AttributeError: # pydantic v1
49
+ return data.json(exclude_unset=True, ensure_ascii=False)
50
+
51
+
52
+ def check_lfi_path(path: str) -> None:
53
+ """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
54
+ if not ALLOW_LOCAL_FILES:
55
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
56
+
57
+ try:
58
+ os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
59
+ real_path = os.path.realpath(path)
60
+ safe_path = os.path.realpath(SAFE_MEDIA_PATH)
61
+
62
+ if not real_path.startswith(safe_path):
63
+ raise HTTPException(
64
+ status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
65
+ )
66
+ except Exception:
67
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
68
+
69
+
70
+ def check_ssrf_url(url: str) -> None:
71
+ """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
72
+ try:
73
+ parsed_url = urlparse(url)
74
+ if parsed_url.scheme not in ["http", "https"]:
75
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
76
+
77
+ hostname = parsed_url.hostname
78
+ if not hostname:
79
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
80
+
81
+ ip_info = socket.getaddrinfo(hostname, parsed_url.port)
82
+ ip_address_str = ip_info[0][4][0]
83
+ ip = ipaddress.ip_address(ip_address_str)
84
+
85
+ if not ip.is_global:
86
+ raise HTTPException(
87
+ status_code=status.HTTP_403_FORBIDDEN,
88
+ detail="Access to private or reserved IP addresses is not allowed.",
89
+ )
90
+
91
+ except socket.gaierror:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
94
+ )
95
+ except Exception as e:
96
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
llamafactory/api/protocol.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from enum import Enum, unique
17
+ from typing import Any, Optional, Union
18
+
19
+ from pydantic import BaseModel, Field
20
+ from typing_extensions import Literal
21
+
22
+
23
+ @unique
24
+ class Role(str, Enum):
25
+ USER = "user"
26
+ ASSISTANT = "assistant"
27
+ SYSTEM = "system"
28
+ FUNCTION = "function"
29
+ TOOL = "tool"
30
+
31
+
32
+ @unique
33
+ class Finish(str, Enum):
34
+ STOP = "stop"
35
+ LENGTH = "length"
36
+ TOOL = "tool_calls"
37
+
38
+
39
+ class ModelCard(BaseModel):
40
+ id: str
41
+ object: Literal["model"] = "model"
42
+ created: int = Field(default_factory=lambda: int(time.time()))
43
+ owned_by: Literal["owner"] = "owner"
44
+
45
+
46
+ class ModelList(BaseModel):
47
+ object: Literal["list"] = "list"
48
+ data: list[ModelCard] = []
49
+
50
+
51
+ class Function(BaseModel):
52
+ name: str
53
+ arguments: str
54
+
55
+
56
+ class FunctionDefinition(BaseModel):
57
+ name: str
58
+ description: str
59
+ parameters: dict[str, Any]
60
+
61
+
62
+ class FunctionAvailable(BaseModel):
63
+ type: Literal["function", "code_interpreter"] = "function"
64
+ function: Optional[FunctionDefinition] = None
65
+
66
+
67
+ class FunctionCall(BaseModel):
68
+ id: str
69
+ type: Literal["function"] = "function"
70
+ function: Function
71
+
72
+
73
+ class URL(BaseModel):
74
+ url: str
75
+ detail: Literal["auto", "low", "high"] = "auto"
76
+
77
+
78
+ class MultimodalInputItem(BaseModel):
79
+ type: Literal["text", "image_url", "video_url", "audio_url"]
80
+ text: Optional[str] = None
81
+ image_url: Optional[URL] = None
82
+ video_url: Optional[URL] = None
83
+ audio_url: Optional[URL] = None
84
+
85
+
86
+ class ChatMessage(BaseModel):
87
+ role: Role
88
+ content: Optional[Union[str, list[MultimodalInputItem]]] = None
89
+ tool_calls: Optional[list[FunctionCall]] = None
90
+
91
+
92
+ class ChatCompletionMessage(BaseModel):
93
+ role: Optional[Role] = None
94
+ content: Optional[str] = None
95
+ tool_calls: Optional[list[FunctionCall]] = None
96
+
97
+
98
+ class ChatCompletionRequest(BaseModel):
99
+ model: str
100
+ messages: list[ChatMessage]
101
+ tools: Optional[list[FunctionAvailable]] = None
102
+ do_sample: Optional[bool] = None
103
+ temperature: Optional[float] = None
104
+ top_p: Optional[float] = None
105
+ n: int = 1
106
+ presence_penalty: Optional[float] = None
107
+ max_tokens: Optional[int] = None
108
+ stop: Optional[Union[str, list[str]]] = None
109
+ stream: bool = False
110
+
111
+
112
+ class ChatCompletionResponseChoice(BaseModel):
113
+ index: int
114
+ message: ChatCompletionMessage
115
+ finish_reason: Finish
116
+
117
+
118
+ class ChatCompletionStreamResponseChoice(BaseModel):
119
+ index: int
120
+ delta: ChatCompletionMessage
121
+ finish_reason: Optional[Finish] = None
122
+
123
+
124
+ class ChatCompletionResponseUsage(BaseModel):
125
+ prompt_tokens: int
126
+ completion_tokens: int
127
+ total_tokens: int
128
+
129
+
130
+ class ChatCompletionResponse(BaseModel):
131
+ id: str
132
+ object: Literal["chat.completion"] = "chat.completion"
133
+ created: int = Field(default_factory=lambda: int(time.time()))
134
+ model: str
135
+ choices: list[ChatCompletionResponseChoice]
136
+ usage: ChatCompletionResponseUsage
137
+
138
+
139
+ class ChatCompletionStreamResponse(BaseModel):
140
+ id: str
141
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
142
+ created: int = Field(default_factory=lambda: int(time.time()))
143
+ model: str
144
+ choices: list[ChatCompletionStreamResponseChoice]
145
+
146
+
147
+ class ScoreEvaluationRequest(BaseModel):
148
+ model: str
149
+ messages: list[str]
150
+ max_length: Optional[int] = None
151
+
152
+
153
+ class ScoreEvaluationResponse(BaseModel):
154
+ id: str
155
+ object: Literal["score.evaluation"] = "score.evaluation"
156
+ model: str
157
+ scores: list[float]
llamafactory/chat/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base_engine import BaseEngine
16
+ from .chat_model import ChatModel
17
+
18
+
19
+ __all__ = ["BaseEngine", "ChatModel"]
llamafactory/chat/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
llamafactory/chat/__pycache__/base_engine.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
llamafactory/chat/__pycache__/chat_model.cpython-312.pyc ADDED
Binary file (9.27 kB). View file
 
llamafactory/chat/base_engine.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import AsyncGenerator
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers import PreTrainedModel, PreTrainedTokenizer
23
+ from vllm import AsyncLLMEngine
24
+
25
+ from ..data import Template
26
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
27
+ from ..extras.constants import EngineName
28
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
29
+
30
+
31
+ @dataclass
32
+ class Response:
33
+ response_text: str
34
+ response_length: int
35
+ prompt_length: int
36
+ finish_reason: Literal["stop", "length"]
37
+
38
+
39
+ class BaseEngine(ABC):
40
+ r"""Base class for inference engine of chat models.
41
+
42
+ Must implements async methods: chat(), stream_chat() and get_scores().
43
+ """
44
+
45
+ name: "EngineName"
46
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
47
+ tokenizer: "PreTrainedTokenizer"
48
+ can_generate: bool
49
+ template: "Template"
50
+ generating_args: dict[str, Any]
51
+
52
+ @abstractmethod
53
+ def __init__(
54
+ self,
55
+ model_args: "ModelArguments",
56
+ data_args: "DataArguments",
57
+ finetuning_args: "FinetuningArguments",
58
+ generating_args: "GeneratingArguments",
59
+ ) -> None:
60
+ r"""Initialize an inference engine."""
61
+ ...
62
+
63
+ @abstractmethod
64
+ async def chat(
65
+ self,
66
+ messages: list[dict[str, str]],
67
+ system: Optional[str] = None,
68
+ tools: Optional[str] = None,
69
+ images: Optional[list["ImageInput"]] = None,
70
+ videos: Optional[list["VideoInput"]] = None,
71
+ audios: Optional[list["AudioInput"]] = None,
72
+ **input_kwargs,
73
+ ) -> list["Response"]:
74
+ r"""Get a list of responses of the chat model."""
75
+ ...
76
+
77
+ @abstractmethod
78
+ async def stream_chat(
79
+ self,
80
+ messages: list[dict[str, str]],
81
+ system: Optional[str] = None,
82
+ tools: Optional[str] = None,
83
+ images: Optional[list["ImageInput"]] = None,
84
+ videos: Optional[list["VideoInput"]] = None,
85
+ audios: Optional[list["AudioInput"]] = None,
86
+ **input_kwargs,
87
+ ) -> AsyncGenerator[str, None]:
88
+ r"""Get the response token-by-token of the chat model."""
89
+ ...
90
+
91
+ @abstractmethod
92
+ async def get_scores(
93
+ self,
94
+ batch_input: list[str],
95
+ **input_kwargs,
96
+ ) -> list[float]:
97
+ r"""Get a list of scores of the reward model."""
98
+ ...
llamafactory/chat/chat_model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 THUDM and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the THUDM's ChatGLM implementation.
4
+ # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import asyncio
19
+ import os
20
+ from collections.abc import AsyncGenerator, Generator
21
+ from threading import Thread
22
+ from typing import TYPE_CHECKING, Any, Optional
23
+
24
+ from ..extras.constants import EngineName
25
+ from ..extras.misc import torch_gc
26
+ from ..hparams import get_infer_args
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
31
+ from .base_engine import BaseEngine, Response
32
+
33
+
34
+ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
35
+ asyncio.set_event_loop(loop)
36
+ loop.run_forever()
37
+
38
+
39
+ class ChatModel:
40
+ r"""General class for chat models. Backed by huggingface or vllm engines.
41
+
42
+ Supports both sync and async methods.
43
+ Sync methods: chat(), stream_chat() and get_scores().
44
+ Async methods: achat(), astream_chat() and aget_scores().
45
+ """
46
+
47
+ def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
48
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
49
+
50
+ if model_args.infer_backend == EngineName.HF:
51
+ from .hf_engine import HuggingfaceEngine
52
+
53
+ self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
54
+ elif model_args.infer_backend == EngineName.VLLM:
55
+ try:
56
+ from .vllm_engine import VllmEngine
57
+
58
+ self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
59
+ except ImportError as e:
60
+ raise ImportError(
61
+ "vLLM not install, you may need to run `pip install vllm`\n"
62
+ "or try to use HuggingFace backend: --infer_backend huggingface"
63
+ ) from e
64
+ elif model_args.infer_backend == EngineName.SGLANG:
65
+ try:
66
+ from .sglang_engine import SGLangEngine
67
+
68
+ self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
69
+ except ImportError as e:
70
+ raise ImportError(
71
+ "SGLang not install, you may need to run `pip install sglang[all]`\n"
72
+ "or try to use HuggingFace backend: --infer_backend huggingface"
73
+ ) from e
74
+ elif model_args.infer_backend == EngineName.KT:
75
+ try:
76
+ from .kt_engine import KTransformersEngine
77
+
78
+ self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
79
+ except ImportError as e:
80
+ raise ImportError(
81
+ "KTransformers not install, you may need to run `pip install ktransformers`\n"
82
+ "or try to use HuggingFace backend: --infer_backend huggingface"
83
+ ) from e
84
+ else:
85
+ raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
86
+
87
+ self._loop = asyncio.new_event_loop()
88
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
89
+ self._thread.start()
90
+
91
+ def chat(
92
+ self,
93
+ messages: list[dict[str, str]],
94
+ system: Optional[str] = None,
95
+ tools: Optional[str] = None,
96
+ images: Optional[list["ImageInput"]] = None,
97
+ videos: Optional[list["VideoInput"]] = None,
98
+ audios: Optional[list["AudioInput"]] = None,
99
+ **input_kwargs,
100
+ ) -> list["Response"]:
101
+ r"""Get a list of responses of the chat model."""
102
+ task = asyncio.run_coroutine_threadsafe(
103
+ self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
104
+ )
105
+ return task.result()
106
+
107
+ async def achat(
108
+ self,
109
+ messages: list[dict[str, str]],
110
+ system: Optional[str] = None,
111
+ tools: Optional[str] = None,
112
+ images: Optional[list["ImageInput"]] = None,
113
+ videos: Optional[list["VideoInput"]] = None,
114
+ audios: Optional[list["AudioInput"]] = None,
115
+ **input_kwargs,
116
+ ) -> list["Response"]:
117
+ r"""Asynchronously get a list of responses of the chat model."""
118
+ return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
119
+
120
+ def stream_chat(
121
+ self,
122
+ messages: list[dict[str, str]],
123
+ system: Optional[str] = None,
124
+ tools: Optional[str] = None,
125
+ images: Optional[list["ImageInput"]] = None,
126
+ videos: Optional[list["VideoInput"]] = None,
127
+ audios: Optional[list["AudioInput"]] = None,
128
+ **input_kwargs,
129
+ ) -> Generator[str, None, None]:
130
+ r"""Get the response token-by-token of the chat model."""
131
+ generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
132
+ while True:
133
+ try:
134
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
135
+ yield task.result()
136
+ except StopAsyncIteration:
137
+ break
138
+
139
+ async def astream_chat(
140
+ self,
141
+ messages: list[dict[str, str]],
142
+ system: Optional[str] = None,
143
+ tools: Optional[str] = None,
144
+ images: Optional[list["ImageInput"]] = None,
145
+ videos: Optional[list["VideoInput"]] = None,
146
+ audios: Optional[list["AudioInput"]] = None,
147
+ **input_kwargs,
148
+ ) -> AsyncGenerator[str, None]:
149
+ r"""Asynchronously get the response token-by-token of the chat model."""
150
+ async for new_token in self.engine.stream_chat(
151
+ messages, system, tools, images, videos, audios, **input_kwargs
152
+ ):
153
+ yield new_token
154
+
155
+ def get_scores(
156
+ self,
157
+ batch_input: list[str],
158
+ **input_kwargs,
159
+ ) -> list[float]:
160
+ r"""Get a list of scores of the reward model."""
161
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
162
+ return task.result()
163
+
164
+ async def aget_scores(
165
+ self,
166
+ batch_input: list[str],
167
+ **input_kwargs,
168
+ ) -> list[float]:
169
+ r"""Asynchronously get a list of scores of the reward model."""
170
+ return await self.engine.get_scores(batch_input, **input_kwargs)
171
+
172
+
173
+ def run_chat() -> None:
174
+ if os.name != "nt":
175
+ try:
176
+ import readline # noqa: F401
177
+ except ImportError:
178
+ print("Install `readline` for a better experience.")
179
+
180
+ chat_model = ChatModel()
181
+ messages = []
182
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
183
+
184
+ while True:
185
+ try:
186
+ query = input("\nUser: ")
187
+ except UnicodeDecodeError:
188
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
189
+ continue
190
+ except Exception:
191
+ raise
192
+
193
+ if query.strip() == "exit":
194
+ break
195
+
196
+ if query.strip() == "clear":
197
+ messages = []
198
+ torch_gc()
199
+ print("History has been removed.")
200
+ continue
201
+
202
+ messages.append({"role": "user", "content": query})
203
+ print("Assistant: ", end="", flush=True)
204
+
205
+ response = ""
206
+ for new_text in chat_model.stream_chat(messages):
207
+ print(new_text, end="", flush=True)
208
+ response += new_text
209
+ print()
210
+ messages.append({"role": "assistant", "content": response})
llamafactory/chat/hf_engine.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ from collections.abc import AsyncGenerator
18
+ from threading import Thread
19
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
20
+
21
+ import torch
22
+ from transformers import GenerationConfig, TextIteratorStreamer
23
+ from typing_extensions import override
24
+
25
+ from ..data import get_template_and_fix_tokenizer
26
+ from ..extras import logging
27
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
28
+ from ..model import load_model, load_tokenizer
29
+ from .base_engine import BaseEngine, Response
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
34
+ from trl import PreTrainedModelWrapper
35
+
36
+ from ..data import Template
37
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
38
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class HuggingfaceEngine(BaseEngine):
45
+ def __init__(
46
+ self,
47
+ model_args: "ModelArguments",
48
+ data_args: "DataArguments",
49
+ finetuning_args: "FinetuningArguments",
50
+ generating_args: "GeneratingArguments",
51
+ ) -> None:
52
+ self.name = EngineName.HF
53
+ self.can_generate = finetuning_args.stage == "sft"
54
+ tokenizer_module = load_tokenizer(model_args)
55
+ self.tokenizer = tokenizer_module["tokenizer"]
56
+ self.processor = tokenizer_module["processor"]
57
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
58
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
59
+ self.model = load_model(
60
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
61
+ ) # must after fixing tokenizer to resize vocab
62
+ self.generating_args = generating_args.to_dict()
63
+ try:
64
+ asyncio.get_event_loop()
65
+ except RuntimeError:
66
+ logger.warning_rank0_once("There is no current event loop, creating a new one.")
67
+ loop = asyncio.new_event_loop()
68
+ asyncio.set_event_loop(loop)
69
+
70
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
71
+
72
+ @staticmethod
73
+ def _process_args(
74
+ model: "PreTrainedModel",
75
+ tokenizer: "PreTrainedTokenizer",
76
+ processor: Optional["ProcessorMixin"],
77
+ template: "Template",
78
+ generating_args: dict[str, Any],
79
+ messages: list[dict[str, str]],
80
+ system: Optional[str] = None,
81
+ tools: Optional[str] = None,
82
+ images: Optional[list["ImageInput"]] = None,
83
+ videos: Optional[list["VideoInput"]] = None,
84
+ audios: Optional[list["AudioInput"]] = None,
85
+ input_kwargs: Optional[dict[str, Any]] = {},
86
+ ) -> tuple[dict[str, Any], int]:
87
+ mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
88
+ if images is not None:
89
+ mm_input_dict.update({"images": images, "imglens": [len(images)]})
90
+ if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
91
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
92
+
93
+ if videos is not None:
94
+ mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
95
+ if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
96
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
97
+
98
+ if audios is not None:
99
+ mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
100
+ if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
101
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
102
+
103
+ messages = template.mm_plugin.process_messages(
104
+ messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
105
+ )
106
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
107
+ prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
108
+ prompt_ids, _ = template.mm_plugin.process_token_ids(
109
+ prompt_ids,
110
+ None,
111
+ mm_input_dict["images"],
112
+ mm_input_dict["videos"],
113
+ mm_input_dict["audios"],
114
+ tokenizer,
115
+ processor,
116
+ )
117
+ prompt_length = len(prompt_ids)
118
+ inputs = torch.tensor([prompt_ids], device=model.device)
119
+ attention_mask = torch.ones_like(inputs, dtype=torch.long)
120
+
121
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
122
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
123
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
124
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
125
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
126
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
127
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
128
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
129
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
130
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
131
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
132
+
133
+ if stop is not None:
134
+ logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
135
+
136
+ generating_args = generating_args.copy()
137
+ generating_args.update(
138
+ dict(
139
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
140
+ temperature=temperature if temperature is not None else generating_args["temperature"],
141
+ top_p=top_p if top_p is not None else generating_args["top_p"],
142
+ top_k=top_k if top_k is not None else generating_args["top_k"],
143
+ num_return_sequences=num_return_sequences,
144
+ repetition_penalty=repetition_penalty
145
+ if repetition_penalty is not None
146
+ else generating_args["repetition_penalty"],
147
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
148
+ skip_special_tokens=skip_special_tokens
149
+ if skip_special_tokens is not None
150
+ else generating_args["skip_special_tokens"],
151
+ eos_token_id=template.get_stop_token_ids(tokenizer),
152
+ pad_token_id=tokenizer.pad_token_id,
153
+ )
154
+ )
155
+
156
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
157
+ generating_args["do_sample"] = True
158
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
159
+
160
+ if not generating_args["temperature"]:
161
+ generating_args["do_sample"] = False
162
+
163
+ if not generating_args["do_sample"]:
164
+ generating_args.pop("temperature", None)
165
+ generating_args.pop("top_p", None)
166
+
167
+ if max_length:
168
+ generating_args.pop("max_new_tokens", None)
169
+ generating_args["max_length"] = max_length
170
+
171
+ if max_new_tokens:
172
+ generating_args.pop("max_length", None)
173
+ generating_args["max_new_tokens"] = max_new_tokens
174
+
175
+ gen_kwargs = dict(
176
+ inputs=inputs,
177
+ attention_mask=attention_mask,
178
+ generation_config=GenerationConfig(**generating_args),
179
+ )
180
+
181
+ mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
182
+ for key, value in mm_inputs.items():
183
+ if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
184
+ value = torch.stack(value) # assume they have same sizes
185
+ elif (
186
+ isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
187
+ ): # for minicpmv inputs
188
+ value = torch.stack([torch.stack(v) for v in value])
189
+ elif not isinstance(value, torch.Tensor):
190
+ value = torch.tensor(value)
191
+
192
+ if torch.is_floating_point(value): # cast data dtype for paligemma
193
+ value = value.to(model.dtype)
194
+
195
+ if key == "second_per_grid_ts": # qwen2.5vl special case
196
+ gen_kwargs[key] = value.tolist()
197
+ else:
198
+ gen_kwargs[key] = value.to(model.device)
199
+
200
+ if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
201
+ gen_kwargs["input_ids"] = inputs
202
+ gen_kwargs["tokenizer"] = tokenizer
203
+ if "audio_feature_lens" in mm_inputs:
204
+ gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
205
+
206
+ gen_kwargs.pop("image_sizes", None)
207
+
208
+ return gen_kwargs, prompt_length
209
+
210
+ @staticmethod
211
+ @torch.inference_mode()
212
+ def _chat(
213
+ model: "PreTrainedModel",
214
+ tokenizer: "PreTrainedTokenizer",
215
+ processor: Optional["ProcessorMixin"],
216
+ template: "Template",
217
+ generating_args: dict[str, Any],
218
+ messages: list[dict[str, str]],
219
+ system: Optional[str] = None,
220
+ tools: Optional[str] = None,
221
+ images: Optional[list["ImageInput"]] = None,
222
+ videos: Optional[list["VideoInput"]] = None,
223
+ audios: Optional[list["AudioInput"]] = None,
224
+ input_kwargs: Optional[dict[str, Any]] = {},
225
+ ) -> list["Response"]:
226
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
227
+ model,
228
+ tokenizer,
229
+ processor,
230
+ template,
231
+ generating_args,
232
+ messages,
233
+ system,
234
+ tools,
235
+ images,
236
+ videos,
237
+ audios,
238
+ input_kwargs,
239
+ )
240
+ generate_output = model.generate(**gen_kwargs)
241
+ if isinstance(generate_output, tuple):
242
+ generate_output = generate_output[1][0] # post-process the minicpm_o output
243
+
244
+ response_ids = generate_output[:, prompt_length:]
245
+ response = tokenizer.batch_decode(
246
+ response_ids,
247
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
248
+ clean_up_tokenization_spaces=True,
249
+ )
250
+ results = []
251
+ for i in range(len(response)):
252
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
253
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
254
+ results.append(
255
+ Response(
256
+ response_text=response[i],
257
+ response_length=response_length,
258
+ prompt_length=prompt_length,
259
+ finish_reason="stop" if len(eos_index) else "length",
260
+ )
261
+ )
262
+
263
+ return results
264
+
265
+ @staticmethod
266
+ @torch.inference_mode()
267
+ def _stream_chat(
268
+ model: "PreTrainedModel",
269
+ tokenizer: "PreTrainedTokenizer",
270
+ processor: Optional["ProcessorMixin"],
271
+ template: "Template",
272
+ generating_args: dict[str, Any],
273
+ messages: list[dict[str, str]],
274
+ system: Optional[str] = None,
275
+ tools: Optional[str] = None,
276
+ images: Optional[list["ImageInput"]] = None,
277
+ videos: Optional[list["VideoInput"]] = None,
278
+ audios: Optional[list["AudioInput"]] = None,
279
+ input_kwargs: Optional[dict[str, Any]] = {},
280
+ ) -> Callable[[], str]:
281
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
282
+ model,
283
+ tokenizer,
284
+ processor,
285
+ template,
286
+ generating_args,
287
+ messages,
288
+ system,
289
+ tools,
290
+ images,
291
+ videos,
292
+ audios,
293
+ input_kwargs,
294
+ )
295
+ streamer = TextIteratorStreamer(
296
+ tokenizer,
297
+ skip_prompt=True,
298
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
299
+ )
300
+ gen_kwargs["streamer"] = streamer
301
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
302
+ thread.start()
303
+
304
+ def stream():
305
+ try:
306
+ return streamer.__next__()
307
+ except StopIteration:
308
+ raise StopAsyncIteration()
309
+
310
+ return stream
311
+
312
+ @staticmethod
313
+ @torch.inference_mode()
314
+ def _get_scores(
315
+ model: "PreTrainedModelWrapper",
316
+ tokenizer: "PreTrainedTokenizer",
317
+ batch_input: list[str],
318
+ input_kwargs: Optional[dict[str, Any]] = {},
319
+ ) -> list[float]:
320
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
321
+ device = getattr(model.pretrained_model, "device", "cuda")
322
+ inputs: dict[str, torch.Tensor] = tokenizer(
323
+ batch_input,
324
+ padding=True,
325
+ truncation=True,
326
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
327
+ return_tensors="pt",
328
+ add_special_tokens=False,
329
+ ).to(device)
330
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
331
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
332
+ return scores
333
+
334
+ @override
335
+ async def chat(
336
+ self,
337
+ messages: list[dict[str, str]],
338
+ system: Optional[str] = None,
339
+ tools: Optional[str] = None,
340
+ images: Optional[list["ImageInput"]] = None,
341
+ videos: Optional[list["VideoInput"]] = None,
342
+ audios: Optional[list["AudioInput"]] = None,
343
+ **input_kwargs,
344
+ ) -> list["Response"]:
345
+ if not self.can_generate:
346
+ raise ValueError("The current model does not support `chat`.")
347
+
348
+ input_args = (
349
+ self.model,
350
+ self.tokenizer,
351
+ self.processor,
352
+ self.template,
353
+ self.generating_args,
354
+ messages,
355
+ system,
356
+ tools,
357
+ images,
358
+ videos,
359
+ audios,
360
+ input_kwargs,
361
+ )
362
+ async with self.semaphore:
363
+ return await asyncio.to_thread(self._chat, *input_args)
364
+
365
+ @override
366
+ async def stream_chat(
367
+ self,
368
+ messages: list[dict[str, str]],
369
+ system: Optional[str] = None,
370
+ tools: Optional[str] = None,
371
+ images: Optional[list["ImageInput"]] = None,
372
+ videos: Optional[list["VideoInput"]] = None,
373
+ audios: Optional[list["AudioInput"]] = None,
374
+ **input_kwargs,
375
+ ) -> AsyncGenerator[str, None]:
376
+ if not self.can_generate:
377
+ raise ValueError("The current model does not support `stream_chat`.")
378
+
379
+ input_args = (
380
+ self.model,
381
+ self.tokenizer,
382
+ self.processor,
383
+ self.template,
384
+ self.generating_args,
385
+ messages,
386
+ system,
387
+ tools,
388
+ images,
389
+ videos,
390
+ audios,
391
+ input_kwargs,
392
+ )
393
+ async with self.semaphore:
394
+ stream = self._stream_chat(*input_args)
395
+ while True:
396
+ try:
397
+ yield await asyncio.to_thread(stream)
398
+ except StopAsyncIteration:
399
+ break
400
+
401
+ @override
402
+ async def get_scores(
403
+ self,
404
+ batch_input: list[str],
405
+ **input_kwargs,
406
+ ) -> list[float]:
407
+ if self.can_generate:
408
+ raise ValueError("Cannot get scores using an auto-regressive model.")
409
+
410
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
411
+ async with self.semaphore:
412
+ return await asyncio.to_thread(self._get_scores, *input_args)
llamafactory/chat/kt_engine.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ import platform
18
+ from collections.abc import AsyncGenerator
19
+ from threading import Thread
20
+ from typing import TYPE_CHECKING, Any, Optional
21
+
22
+ import torch
23
+ from typing_extensions import override
24
+
25
+ from ..data import get_template_and_fix_tokenizer
26
+ from ..extras import logging
27
+ from ..extras.constants import EngineName
28
+ from ..model import load_model, load_tokenizer
29
+ from .base_engine import BaseEngine, Response
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import PreTrainedTokenizer
34
+ from trl import PreTrainedModelWrapper
35
+
36
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
37
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
38
+
39
+ from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
40
+ from ktransformers.server.config.config import Config
41
+ from ktransformers.util.utils import (
42
+ get_compute_capability,
43
+ prefill_and_generate_capture,
44
+ )
45
+ from ktransformers.util.vendors import GPUVendor, device_manager
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class KTransformersEngine(BaseEngine):
52
+ def __init__(
53
+ self,
54
+ model_args: "ModelArguments",
55
+ data_args: "DataArguments",
56
+ finetuning_args: "FinetuningArguments",
57
+ generating_args: "GeneratingArguments",
58
+ ) -> None:
59
+ self.name = EngineName.KT
60
+ self.can_generate = finetuning_args.stage == "sft"
61
+
62
+ tok_mod = load_tokenizer(model_args)
63
+ self.tokenizer = tok_mod["tokenizer"]
64
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
65
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
66
+
67
+ self.model = load_model(
68
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
69
+ )
70
+
71
+ self.generating_args = generating_args.to_dict()
72
+ self.max_new_tokens = model_args.kt_maxlen
73
+ self.use_cuda_graph = model_args.kt_use_cuda_graph
74
+ self.mode = model_args.kt_mode
75
+ self.force_think = model_args.kt_force_think
76
+ self.chunk_size = model_args.chunk_size
77
+
78
+ try:
79
+ asyncio.get_event_loop()
80
+ except RuntimeError:
81
+ loop = asyncio.new_event_loop()
82
+ asyncio.set_event_loop(loop)
83
+
84
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
85
+
86
+ @staticmethod
87
+ @torch.inference_mode()
88
+ def _get_scores(
89
+ model: "PreTrainedModelWrapper",
90
+ tokenizer: "PreTrainedTokenizer",
91
+ batch_input: list[str],
92
+ input_kwargs: Optional[dict[str, Any]] = {},
93
+ ) -> list[float]:
94
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
95
+ device = getattr(model.pretrained_model, "device", "cuda")
96
+ inputs = tokenizer(
97
+ batch_input,
98
+ padding=True,
99
+ truncation=True,
100
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
101
+ return_tensors="pt",
102
+ add_special_tokens=False,
103
+ ).to(device)
104
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
105
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
106
+ return scores
107
+
108
+ async def _generate(
109
+ self,
110
+ messages: list[dict[str, str]],
111
+ system: Optional[str] = None,
112
+ tools: Optional[str] = None,
113
+ **input_kwargs,
114
+ ) -> AsyncGenerator[str, None]:
115
+ paired = messages + [{"role": "assistant", "content": ""}]
116
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
117
+ prompt_len = len(prompt_ids)
118
+
119
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
120
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
121
+
122
+ if "max_new_tokens" in self.generating_args:
123
+ max_tokens = int(self.generating_args["max_new_tokens"])
124
+ elif "max_length" in self.generating_args:
125
+ gl = int(self.generating_args["max_length"])
126
+ max_tokens = gl - prompt_len if gl > prompt_len else 1
127
+ else:
128
+ max_tokens = self.max_new_tokens or 256
129
+
130
+ if max_length is not None:
131
+ max_tokens = max(max_length - prompt_len, 1)
132
+ if max_new_tokens is not None:
133
+ max_tokens = int(max_new_tokens)
134
+ max_tokens = max(1, int(max_tokens))
135
+
136
+ if self.mode == "long_context":
137
+ max_len_cfg = Config().long_context_config["max_seq_len"]
138
+ need = prompt_len + max_tokens
139
+ assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
140
+
141
+ device = next(self.model.parameters()).device
142
+ input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
143
+ if self.force_think:
144
+ think = torch.tensor(
145
+ [self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
146
+ )
147
+ input_tensor = torch.cat([input_tensor, think], dim=1)
148
+
149
+ use_flashinfer = (
150
+ platform.system() != "Windows"
151
+ and getattr(self.model.config, "architectures", [""])[0]
152
+ in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
153
+ and flashinfer_enabled
154
+ and get_compute_capability() >= 8
155
+ and device_manager.gpu_vendor == GPUVendor.NVIDIA
156
+ )
157
+
158
+ def make_gen():
159
+ if use_flashinfer:
160
+ return prefill_and_generate_capture(
161
+ self.model,
162
+ self.tokenizer,
163
+ input_tensor,
164
+ max_tokens,
165
+ self.use_cuda_graph,
166
+ mode=self.mode,
167
+ force_think=self.force_think,
168
+ chunk_size=self.chunk_size,
169
+ use_flashinfer_mla=True,
170
+ num_heads=self.model.config.num_attention_heads,
171
+ head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
172
+ head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
173
+ q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
174
+ + getattr(self.model.config, "qk_nope_head_dim", 0),
175
+ echo_stream=False,
176
+ )
177
+ else:
178
+ return prefill_and_generate_capture(
179
+ self.model,
180
+ self.tokenizer,
181
+ input_tensor,
182
+ max_tokens,
183
+ self.use_cuda_graph,
184
+ mode=self.mode,
185
+ force_think=self.force_think,
186
+ chunk_size=self.chunk_size,
187
+ echo_stream=False,
188
+ )
189
+
190
+ loop = asyncio.get_running_loop()
191
+ q: asyncio.Queue[Optional[str]] = asyncio.Queue()
192
+
193
+ def producer():
194
+ try:
195
+ gen = make_gen()
196
+ if hasattr(gen, "__aiter__"):
197
+
198
+ async def drain_async():
199
+ async for t in gen:
200
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
201
+
202
+ asyncio.run(drain_async())
203
+ elif hasattr(gen, "__iter__"):
204
+ for t in gen:
205
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
206
+ else:
207
+ loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
208
+ finally:
209
+ loop.call_soon_threadsafe(q.put_nowait, None)
210
+
211
+ Thread(target=producer, daemon=True).start()
212
+
213
+ while True:
214
+ item = await q.get()
215
+ if item is None:
216
+ break
217
+ yield item
218
+
219
+ @override
220
+ async def chat(
221
+ self,
222
+ messages: list[dict[str, str]],
223
+ system: Optional[str] = None,
224
+ tools: Optional[str] = None,
225
+ images: Optional[list["ImageInput"]] = None,
226
+ videos: Optional[list["VideoInput"]] = None,
227
+ audios: Optional[list["AudioInput"]] = None,
228
+ **input_kwargs,
229
+ ) -> list["Response"]:
230
+ if not self.can_generate:
231
+ raise ValueError("The current model does not support `chat`.")
232
+ async with self.semaphore:
233
+ produced = ""
234
+ final_text = ""
235
+ async for t in self._generate(messages, system, tools, **input_kwargs):
236
+ delta = t
237
+ produced = produced + delta
238
+ if delta:
239
+ final_text += delta
240
+
241
+ prompt_ids, _ = self.template.encode_oneturn(
242
+ self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
243
+ )
244
+ return [
245
+ Response(
246
+ response_text=final_text,
247
+ response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
248
+ prompt_length=len(prompt_ids),
249
+ finish_reason="stop",
250
+ )
251
+ ]
252
+
253
+ @override
254
+ async def stream_chat(
255
+ self,
256
+ messages: list[dict[str, str]],
257
+ system: Optional[str] = None,
258
+ tools: Optional[str] = None,
259
+ images: Optional[list["ImageInput"]] = None,
260
+ videos: Optional[list["VideoInput"]] = None,
261
+ audios: Optional[list["AudioInput"]] = None,
262
+ **input_kwargs,
263
+ ) -> AsyncGenerator[str, None]:
264
+ if not self.can_generate:
265
+ raise ValueError("The current model does not support `stream_chat`.")
266
+ async with self.semaphore:
267
+ produced = ""
268
+ async for t in self._generate(messages, system, tools, **input_kwargs):
269
+ delta = t[len(produced) :] if t.startswith(produced) else t
270
+ produced = t
271
+ if delta:
272
+ yield delta
273
+
274
+ @override
275
+ async def get_scores(
276
+ self,
277
+ batch_input: list[str],
278
+ **input_kwargs,
279
+ ) -> list[float]:
280
+ if self.can_generate:
281
+ raise ValueError("Cannot get scores using an auto-regressive model.")
282
+ args = (self.model, self.tokenizer, batch_input, input_kwargs)
283
+ async with self.semaphore:
284
+ return await asyncio.to_thread(self._get_scores, *args)
llamafactory/chat/sglang_engine.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import atexit
17
+ import json
18
+ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
19
+ from typing import TYPE_CHECKING, Any, Optional, Union
20
+
21
+ import requests
22
+ from typing_extensions import override
23
+
24
+ from ..data import get_template_and_fix_tokenizer
25
+ from ..extras import logging
26
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
27
+ from ..extras.misc import get_device_count, torch_gc
28
+ from ..extras.packages import is_sglang_available
29
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
30
+ from ..model import load_config, load_tokenizer
31
+ from ..model.model_utils.quantization import QuantizationMethod
32
+ from .base_engine import BaseEngine, Response
33
+
34
+
35
+ if is_sglang_available():
36
+ from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class SGLangEngine(BaseEngine):
47
+ """Inference engine for SGLang models.
48
+
49
+ This class wraps the SGLang engine to provide a consistent interface for text generation
50
+ that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
51
+ better interaction and performance. The engine launches a server process and communicates
52
+ with it via HTTP requests.
53
+
54
+ For more details on the SGLang HTTP server approach, see:
55
+ https://docs.sglang.ai/backend/send_request.html
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ model_args: "ModelArguments",
61
+ data_args: "DataArguments",
62
+ finetuning_args: "FinetuningArguments",
63
+ generating_args: "GeneratingArguments",
64
+ ) -> None:
65
+ self.name = EngineName.SGLANG
66
+ self.model_args = model_args
67
+ config = load_config(model_args) # may download model from ms hub
68
+ if getattr(config, "quantization_config", None): # gptq models should use float16
69
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
70
+ quant_method = quantization_config.get("quant_method", "")
71
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
72
+ model_args.infer_dtype = "float16"
73
+
74
+ self.can_generate = finetuning_args.stage == "sft"
75
+ tokenizer_module = load_tokenizer(model_args)
76
+ self.tokenizer = tokenizer_module["tokenizer"]
77
+ self.processor = tokenizer_module["processor"]
78
+ self.tokenizer.padding_side = "left"
79
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
80
+ self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
81
+ self.generating_args = generating_args.to_dict()
82
+ if model_args.adapter_name_or_path is not None:
83
+ self.lora_request = True
84
+ else:
85
+ self.lora_request = False
86
+
87
+ launch_cmd = [
88
+ "python3 -m sglang.launch_server",
89
+ f"--model-path {model_args.model_name_or_path}",
90
+ f"--dtype {model_args.infer_dtype}",
91
+ f"--context-length {model_args.sglang_maxlen}",
92
+ f"--mem-fraction-static {model_args.sglang_mem_fraction}",
93
+ f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
94
+ f"--download-dir {model_args.cache_dir}",
95
+ "--log-level error",
96
+ ]
97
+ if self.lora_request:
98
+ launch_cmd.extend(
99
+ [
100
+ "--max-loras-per-batch 1",
101
+ f"--lora-backend {model_args.sglang_lora_backend}",
102
+ f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
103
+ "--disable-radix-cache",
104
+ ]
105
+ )
106
+ launch_cmd = " ".join(launch_cmd)
107
+ logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
108
+ try:
109
+ torch_gc()
110
+ self.server_process, port = launch_server_cmd(launch_cmd)
111
+ self.base_url = f"http://localhost:{port}"
112
+ atexit.register(self._cleanup_server)
113
+
114
+ logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
115
+ wait_for_server(self.base_url, timeout=300)
116
+ logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
117
+ try:
118
+ response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
119
+ if response.status_code == 200:
120
+ model_info = response.json()
121
+ logger.info(f"SGLang server model info: {model_info}")
122
+ except Exception as e:
123
+ logger.debug(f"Note: could not get model info: {str(e)}")
124
+
125
+ except Exception as e:
126
+ logger.error(f"Failed to start SGLang server: {str(e)}")
127
+ self._cleanup_server() # make sure to clean up any started process
128
+ raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
129
+
130
+ def _cleanup_server(self):
131
+ r"""Clean up the server process when the engine is destroyed."""
132
+ if hasattr(self, "server_process") and self.server_process:
133
+ try:
134
+ logger.info("Terminating SGLang server process")
135
+ terminate_process(self.server_process)
136
+ logger.info("SGLang server process terminated")
137
+ except Exception as e:
138
+ logger.warning(f"Error terminating SGLang server: {str(e)}")
139
+
140
+ async def _generate(
141
+ self,
142
+ messages: list[dict[str, str]],
143
+ system: Optional[str] = None,
144
+ tools: Optional[str] = None,
145
+ images: Optional[list["ImageInput"]] = None,
146
+ videos: Optional[list["VideoInput"]] = None,
147
+ audios: Optional[list["AudioInput"]] = None,
148
+ **input_kwargs,
149
+ ) -> AsyncIterator[dict[str, Any]]:
150
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
151
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
152
+
153
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
154
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
155
+
156
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
157
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
158
+
159
+ messages = self.template.mm_plugin.process_messages(
160
+ messages, images or [], videos or [], audios or [], self.processor
161
+ )
162
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
163
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
164
+ prompt_length = len(prompt_ids)
165
+
166
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
167
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
168
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
169
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
170
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
171
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
172
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
173
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
174
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
175
+
176
+ if num_return_sequences != 1:
177
+ raise NotImplementedError("SGLang only supports n=1.")
178
+
179
+ if "max_new_tokens" in self.generating_args:
180
+ max_tokens = self.generating_args["max_new_tokens"]
181
+ elif "max_length" in self.generating_args:
182
+ if self.generating_args["max_length"] > prompt_length:
183
+ max_tokens = self.generating_args["max_length"] - prompt_length
184
+ else:
185
+ max_tokens = 1
186
+
187
+ if max_length:
188
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
189
+
190
+ if max_new_tokens:
191
+ max_tokens = max_new_tokens
192
+
193
+ sampling_params = {
194
+ "temperature": temperature if temperature is not None else self.generating_args["temperature"],
195
+ "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
196
+ "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
197
+ "stop": stop,
198
+ "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
199
+ "max_new_tokens": max_tokens,
200
+ "repetition_penalty": (
201
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
202
+ )
203
+ or 1.0, # repetition_penalty must > 0
204
+ "skip_special_tokens": skip_special_tokens
205
+ if skip_special_tokens is not None
206
+ else self.generating_args["skip_special_tokens"],
207
+ }
208
+
209
+ def stream_request():
210
+ json_data = {
211
+ "input_ids": prompt_ids,
212
+ "sampling_params": sampling_params,
213
+ "stream": True,
214
+ }
215
+ if self.lora_request:
216
+ json_data["lora_request"] = ["lora0"]
217
+ response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
218
+ if response.status_code != 200:
219
+ raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
220
+
221
+ for chunk in response.iter_lines(decode_unicode=False):
222
+ chunk = str(chunk.decode("utf-8"))
223
+ if chunk == "data: [DONE]":
224
+ break
225
+
226
+ if chunk and chunk.startswith("data:"):
227
+ yield json.loads(chunk[5:].strip("\n"))
228
+
229
+ return await asyncio.to_thread(stream_request)
230
+
231
+ @override
232
+ async def chat(
233
+ self,
234
+ messages: Sequence[dict[str, str]],
235
+ system: Optional[str] = None,
236
+ tools: Optional[str] = None,
237
+ images: Optional[Sequence["ImageInput"]] = None,
238
+ videos: Optional[Sequence["VideoInput"]] = None,
239
+ audios: Optional[Sequence["AudioInput"]] = None,
240
+ **input_kwargs,
241
+ ) -> list["Response"]:
242
+ final_output = None
243
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
244
+ for request_output in generator:
245
+ final_output = request_output
246
+
247
+ results = [
248
+ Response(
249
+ response_text=final_output["text"],
250
+ response_length=final_output["meta_info"]["completion_tokens"],
251
+ prompt_length=final_output["meta_info"]["prompt_tokens"],
252
+ finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
253
+ )
254
+ ]
255
+ return results
256
+
257
+ @override
258
+ async def stream_chat(
259
+ self,
260
+ messages: list[dict[str, str]],
261
+ system: Optional[str] = None,
262
+ tools: Optional[str] = None,
263
+ images: Optional[list["ImageInput"]] = None,
264
+ videos: Optional[list["VideoInput"]] = None,
265
+ audios: Optional[list["AudioInput"]] = None,
266
+ **input_kwargs,
267
+ ) -> AsyncGenerator[str, None]:
268
+ generated_text = ""
269
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
270
+ for result in generator:
271
+ delta_text = result["text"][len(generated_text) :]
272
+ generated_text = result["text"]
273
+ yield delta_text
274
+
275
+ @override
276
+ async def get_scores(
277
+ self,
278
+ batch_input: list[str],
279
+ **input_kwargs,
280
+ ) -> list[float]:
281
+ raise NotImplementedError("SGLang engine does not support `get_scores`.")
282
+
283
+ def __del__(self):
284
+ r"""Ensure server is cleaned up when object is deleted."""
285
+ self._cleanup_server()
286
+ try:
287
+ atexit.unregister(self._cleanup_server)
288
+ except Exception:
289
+ pass
llamafactory/chat/vllm_engine.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import uuid
16
+ from collections.abc import AsyncGenerator, AsyncIterator
17
+ from typing import TYPE_CHECKING, Any, Optional, Union
18
+
19
+ from typing_extensions import override
20
+
21
+ from ..data import get_template_and_fix_tokenizer
22
+ from ..extras import logging
23
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
24
+ from ..extras.misc import get_device_count
25
+ from ..extras.packages import is_vllm_available
26
+ from ..model import load_config, load_tokenizer
27
+ from ..model.model_utils.quantization import QuantizationMethod
28
+ from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
29
+ from .base_engine import BaseEngine, Response
30
+
31
+
32
+ if is_vllm_available():
33
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
34
+ from vllm.lora.request import LoRARequest
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
39
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ class VllmEngine(BaseEngine):
46
+ def __init__(
47
+ self,
48
+ model_args: "ModelArguments",
49
+ data_args: "DataArguments",
50
+ finetuning_args: "FinetuningArguments",
51
+ generating_args: "GeneratingArguments",
52
+ ) -> None:
53
+ self.name = EngineName.VLLM
54
+ self.model_args = model_args
55
+ config = load_config(model_args) # may download model from ms hub
56
+ if getattr(config, "quantization_config", None): # gptq models should use float16
57
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
58
+ quant_method = quantization_config.get("quant_method", "")
59
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
60
+ model_args.infer_dtype = "float16"
61
+
62
+ self.can_generate = finetuning_args.stage == "sft"
63
+ tokenizer_module = load_tokenizer(model_args)
64
+ self.tokenizer = tokenizer_module["tokenizer"]
65
+ self.processor = tokenizer_module["processor"]
66
+ self.tokenizer.padding_side = "left"
67
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
68
+ self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
69
+ self.generating_args = generating_args.to_dict()
70
+
71
+ engine_args = {
72
+ "model": model_args.model_name_or_path,
73
+ "trust_remote_code": model_args.trust_remote_code,
74
+ "download_dir": model_args.cache_dir,
75
+ "dtype": model_args.infer_dtype,
76
+ "max_model_len": model_args.vllm_maxlen,
77
+ "tensor_parallel_size": get_device_count() or 1,
78
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
79
+ "disable_log_stats": True,
80
+ "disable_log_requests": True,
81
+ "enforce_eager": model_args.vllm_enforce_eager,
82
+ "enable_lora": model_args.adapter_name_or_path is not None,
83
+ "max_lora_rank": model_args.vllm_max_lora_rank,
84
+ }
85
+ if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
86
+ engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
87
+
88
+ if isinstance(model_args.vllm_config, dict):
89
+ engine_args.update(model_args.vllm_config)
90
+
91
+ if getattr(config, "is_yi_vl_derived_model", None):
92
+ import vllm.model_executor.models.llava
93
+
94
+ logger.info_rank0("Detected Yi-VL model, applying projector patch.")
95
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
96
+
97
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
98
+ if model_args.adapter_name_or_path is not None:
99
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
100
+ else:
101
+ self.lora_request = None
102
+
103
+ async def _generate(
104
+ self,
105
+ messages: list[dict[str, str]],
106
+ system: Optional[str] = None,
107
+ tools: Optional[str] = None,
108
+ images: Optional[list["ImageInput"]] = None,
109
+ videos: Optional[list["VideoInput"]] = None,
110
+ audios: Optional[list["AudioInput"]] = None,
111
+ **input_kwargs,
112
+ ) -> AsyncIterator["RequestOutput"]:
113
+ request_id = f"chatcmpl-{uuid.uuid4().hex}"
114
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
115
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
116
+
117
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
118
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
119
+
120
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
121
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
122
+
123
+ messages = self.template.mm_plugin.process_messages(
124
+ messages, images or [], videos or [], audios or [], self.processor
125
+ )
126
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
127
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
128
+ prompt_length = len(prompt_ids)
129
+
130
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
131
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
132
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
133
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
134
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
135
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
136
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
137
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
138
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
139
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
140
+
141
+ if length_penalty is not None:
142
+ logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
143
+
144
+ if "max_new_tokens" in self.generating_args:
145
+ max_tokens = self.generating_args["max_new_tokens"]
146
+ elif "max_length" in self.generating_args:
147
+ if self.generating_args["max_length"] > prompt_length:
148
+ max_tokens = self.generating_args["max_length"] - prompt_length
149
+ else:
150
+ max_tokens = 1
151
+
152
+ if max_length:
153
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
154
+
155
+ if max_new_tokens:
156
+ max_tokens = max_new_tokens
157
+
158
+ sampling_params = SamplingParams(
159
+ n=num_return_sequences,
160
+ repetition_penalty=(
161
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
162
+ )
163
+ or 1.0, # repetition_penalty must > 0
164
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
165
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
166
+ top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
167
+ stop=stop,
168
+ stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
169
+ max_tokens=max_tokens,
170
+ skip_special_tokens=skip_special_tokens
171
+ if skip_special_tokens is not None
172
+ else self.generating_args["skip_special_tokens"],
173
+ )
174
+
175
+ if images is not None: # add image features
176
+ multi_modal_data = {
177
+ "image": self.template.mm_plugin._regularize_images(
178
+ images,
179
+ image_max_pixels=self.model_args.image_max_pixels,
180
+ image_min_pixels=self.model_args.image_min_pixels,
181
+ )["images"]
182
+ }
183
+ elif videos is not None:
184
+ multi_modal_data = {
185
+ "video": self.template.mm_plugin._regularize_videos(
186
+ videos,
187
+ image_max_pixels=self.model_args.video_max_pixels,
188
+ image_min_pixels=self.model_args.video_min_pixels,
189
+ video_fps=self.model_args.video_fps,
190
+ video_maxlen=self.model_args.video_maxlen,
191
+ )["videos"]
192
+ }
193
+ elif audios is not None:
194
+ audio_data = self.template.mm_plugin._regularize_audios(
195
+ audios,
196
+ sampling_rate=self.model_args.audio_sampling_rate,
197
+ )
198
+ multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
199
+ else:
200
+ multi_modal_data = None
201
+
202
+ result_generator = self.model.generate(
203
+ {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
204
+ sampling_params=sampling_params,
205
+ request_id=request_id,
206
+ lora_request=self.lora_request,
207
+ )
208
+ return result_generator
209
+
210
+ @override
211
+ async def chat(
212
+ self,
213
+ messages: list[dict[str, str]],
214
+ system: Optional[str] = None,
215
+ tools: Optional[str] = None,
216
+ images: Optional[list["ImageInput"]] = None,
217
+ videos: Optional[list["VideoInput"]] = None,
218
+ audios: Optional[list["AudioInput"]] = None,
219
+ **input_kwargs,
220
+ ) -> list["Response"]:
221
+ final_output = None
222
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
223
+ async for request_output in generator:
224
+ final_output = request_output
225
+
226
+ results = []
227
+ for output in final_output.outputs:
228
+ results.append(
229
+ Response(
230
+ response_text=output.text,
231
+ response_length=len(output.token_ids),
232
+ prompt_length=len(final_output.prompt_token_ids),
233
+ finish_reason=output.finish_reason,
234
+ )
235
+ )
236
+
237
+ return results
238
+
239
+ @override
240
+ async def stream_chat(
241
+ self,
242
+ messages: list[dict[str, str]],
243
+ system: Optional[str] = None,
244
+ tools: Optional[str] = None,
245
+ images: Optional[list["ImageInput"]] = None,
246
+ videos: Optional[list["VideoInput"]] = None,
247
+ audios: Optional[list["AudioInput"]] = None,
248
+ **input_kwargs,
249
+ ) -> AsyncGenerator[str, None]:
250
+ generated_text = ""
251
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
252
+ async for result in generator:
253
+ delta_text = result.outputs[0].text[len(generated_text) :]
254
+ generated_text = result.outputs[0].text
255
+ yield delta_text
256
+
257
+ @override
258
+ async def get_scores(
259
+ self,
260
+ batch_input: list[str],
261
+ **input_kwargs,
262
+ ) -> list[float]:
263
+ raise NotImplementedError("vLLM engine does not support `get_scores`.")
llamafactory/cli.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ def main():
17
+ from .extras.misc import is_env_enabled
18
+
19
+ if is_env_enabled("USE_V1"):
20
+ from .v1 import launcher
21
+ else:
22
+ from . import launcher
23
+
24
+ launcher.launch()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ from multiprocessing import freeze_support
29
+
30
+ freeze_support()
31
+ main()
llamafactory/data/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .collator import (
16
+ KTODataCollatorWithPadding,
17
+ MultiModalDataCollatorForSeq2Seq,
18
+ PairwiseDataCollatorWithPadding,
19
+ SFTDataCollatorWith4DAttentionMask,
20
+ )
21
+ from .data_utils import Role, split_dataset
22
+ from .loader import get_dataset
23
+ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
24
+
25
+
26
+ __all__ = [
27
+ "TEMPLATES",
28
+ "KTODataCollatorWithPadding",
29
+ "MultiModalDataCollatorForSeq2Seq",
30
+ "PairwiseDataCollatorWithPadding",
31
+ "Role",
32
+ "SFTDataCollatorWith4DAttentionMask",
33
+ "Template",
34
+ "get_dataset",
35
+ "get_template_and_fix_tokenizer",
36
+ "split_dataset",
37
+ ]
llamafactory/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (643 Bytes). View file
 
llamafactory/data/__pycache__/collator.cpython-312.pyc ADDED
Binary file (15.2 kB). View file
 
llamafactory/data/__pycache__/converter.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
llamafactory/data/__pycache__/data_utils.cpython-312.pyc ADDED
Binary file (8.54 kB). View file
 
llamafactory/data/__pycache__/formatter.cpython-312.pyc ADDED
Binary file (7.89 kB). View file
 
llamafactory/data/__pycache__/loader.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc ADDED
Binary file (91.5 kB). View file
 
llamafactory/data/__pycache__/parser.cpython-312.pyc ADDED
Binary file (6.32 kB). View file
 
llamafactory/data/__pycache__/template.cpython-312.pyc ADDED
Binary file (69.3 kB). View file
 
llamafactory/data/__pycache__/tool_utils.cpython-312.pyc ADDED
Binary file (23.6 kB). View file
 
llamafactory/data/collator.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the OpenAccess AI Collective's axolotl library.
4
+ # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass
19
+ from typing import TYPE_CHECKING, Any, Literal, Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from peft import PeftModel
25
+ from transformers import DataCollatorForSeq2Seq
26
+
27
+ from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
28
+ from ..extras.packages import is_pillow_available
29
+
30
+
31
+ if is_pillow_available():
32
+ from PIL import Image
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from transformers import ProcessorMixin
37
+
38
+ from .template import Template
39
+
40
+
41
+ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
42
+ r"""Expand 2d attention mask to 4d attention mask.
43
+
44
+ Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
45
+ handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
46
+
47
+ e.g.
48
+ ```python
49
+ # input
50
+ [[1, 1, 2, 2, 2, 0]]
51
+ # output
52
+ [
53
+ [
54
+ [
55
+ [o, x, x, x, x, x],
56
+ [o, o, x, x, x, x],
57
+ [x, x, o, x, x, x],
58
+ [x, x, o, o, x, x],
59
+ [x, x, o, o, o, x],
60
+ [x, x, x, x, x, x],
61
+ ]
62
+ ]
63
+ ]
64
+ ```
65
+ where `o` equals to `0.0`, `x` equals to `min_dtype`.
66
+ """
67
+ _, seq_len = attention_mask_with_indices.size()
68
+ min_dtype = torch.finfo(dtype).min
69
+ zero_tensor = torch.tensor(0, dtype=dtype)
70
+
71
+ # Create a non-padding mask.
72
+ non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
73
+ # Create indices for comparison.
74
+ indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
75
+ indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
76
+ # Create a lower triangular mask.
77
+ tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
78
+ attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
79
+ # Invert the attention mask.
80
+ attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
81
+ return attention_mask_4d
82
+
83
+
84
+ @dataclass
85
+ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
86
+ r"""Data collator that supports VLMs.
87
+
88
+ Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
89
+ """
90
+
91
+ template: Optional["Template"] = None
92
+ processor: Optional["ProcessorMixin"] = None
93
+
94
+ def __post_init__(self):
95
+ if self.template is None:
96
+ raise ValueError("Template is required for MultiModalDataCollator.")
97
+
98
+ if isinstance(self.model, PeftModel):
99
+ self.model = self.model.base_model.model
100
+
101
+ if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
102
+ self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
103
+ elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
104
+ self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
105
+ else:
106
+ self.get_rope_func = None
107
+
108
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
109
+ batch_images, batch_videos, batch_audios = [], [], []
110
+ batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
111
+ for feature in features:
112
+ images = feature.pop("images", None) or []
113
+ videos = feature.pop("videos", None) or []
114
+ audios = feature.pop("audios", None) or []
115
+ batch_images.extend(images)
116
+ batch_videos.extend(videos)
117
+ batch_audios.extend(audios)
118
+ batch_imglens.append(len(images))
119
+ batch_vidlens.append(len(videos))
120
+ batch_audlens.append(len(audios))
121
+ batch_input_ids.append(feature["input_ids"])
122
+
123
+ fake_input_ids = []
124
+ if (
125
+ self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
126
+ ): # avoid process hanging in zero3/fsdp case
127
+ fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
128
+ fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
129
+ fake_messages = self.template.mm_plugin.process_messages(
130
+ fake_messages, fake_images, [], [], self.processor
131
+ )
132
+ _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
133
+ _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
134
+ _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
135
+ )
136
+ fake_input_ids.extend(_fake_input_ids)
137
+ batch_images = fake_images
138
+ batch_imglens[0] = 1
139
+
140
+ if (
141
+ self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
142
+ ): # avoid process hanging in zero3/fsdp case
143
+ fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
144
+ fake_audios = [np.zeros(1600)]
145
+ fake_messages = self.template.mm_plugin.process_messages(
146
+ fake_messages, [], [], fake_audios, self.processor
147
+ )
148
+ _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
149
+ _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
150
+ _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
151
+ )
152
+ fake_input_ids.extend(_fake_input_ids)
153
+ batch_audios = fake_audios
154
+ batch_audlens[0] = 1
155
+
156
+ if len(fake_input_ids) != 0:
157
+ if self.tokenizer.padding_side == "right":
158
+ features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
159
+ features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
160
+ features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
161
+ else:
162
+ features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
163
+ features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
164
+ features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
165
+
166
+ batch_input_ids[0] = features[0]["input_ids"]
167
+
168
+ mm_inputs = self.template.mm_plugin.get_mm_inputs(
169
+ batch_images,
170
+ batch_videos,
171
+ batch_audios,
172
+ batch_imglens,
173
+ batch_vidlens,
174
+ batch_audlens,
175
+ batch_input_ids,
176
+ self.processor,
177
+ )
178
+ if "token_type_ids" in mm_inputs:
179
+ token_type_ids = mm_inputs.pop("token_type_ids")
180
+ for i, feature in enumerate(features):
181
+ feature["token_type_ids"] = token_type_ids[i]
182
+
183
+ features: dict[str, torch.Tensor] = super().__call__(features)
184
+
185
+ if self.get_rope_func is not None:
186
+ rope_index_kwargs = {
187
+ "input_ids": features["input_ids"],
188
+ "image_grid_thw": mm_inputs.get("image_grid_thw"),
189
+ "video_grid_thw": mm_inputs.get("video_grid_thw"),
190
+ "attention_mask": (features["attention_mask"] >= 1).float(),
191
+ }
192
+ if "second_per_grid_ts" in mm_inputs: # for qwen2vl
193
+ rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
194
+ elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
195
+ rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
196
+
197
+ if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
198
+ rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
199
+ feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
200
+ if feature_attention_mask is not None: # FIXME: need to get video image lengths
201
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
202
+ rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
203
+
204
+ features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
205
+ features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
206
+ dim=-1
207
+ ).unsqueeze(-1)
208
+ else: # for qwen vl
209
+ features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
210
+
211
+ if (
212
+ self.model is not None
213
+ and getattr(self.model.config, "model_type", None)
214
+ in [
215
+ "glm4v",
216
+ "Keye",
217
+ "qwen2_vl",
218
+ "qwen2_5_vl",
219
+ "qwen2_5_omni_thinker",
220
+ "qwen3_omni_moe_thinker",
221
+ "qwen3_vl",
222
+ "qwen3_vl_moe",
223
+ ]
224
+ and ("position_ids" not in features or features["position_ids"].dim() != 3)
225
+ ):
226
+ raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
227
+
228
+ if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
229
+ cross_attention_mask = mm_inputs.pop("cross_attention_mask")
230
+ seq_len = features["input_ids"].size(1)
231
+ orig_len = cross_attention_mask.size(1)
232
+ mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
233
+
234
+ features.update(mm_inputs)
235
+
236
+ if "image_bound" in features: # for minicpmv inputs
237
+ bsz, seq_length = features["input_ids"].shape
238
+ features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
239
+ return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
240
+
241
+ return features
242
+
243
+
244
+ @dataclass
245
+ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
246
+ r"""Data collator for 4d attention mask."""
247
+
248
+ block_diag_attn: bool = False
249
+ attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
250
+ compute_dtype: "torch.dtype" = torch.float32
251
+
252
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
253
+ features = super().__call__(features)
254
+ if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
255
+ features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
256
+
257
+ for key, value in features.items(): # cast data dtype for paligemma
258
+ if torch.is_tensor(value) and torch.is_floating_point(value):
259
+ features[key] = value.to(self.compute_dtype)
260
+
261
+ return features
262
+
263
+
264
+ @dataclass
265
+ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
266
+ r"""Data collator for pairwise data."""
267
+
268
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
269
+ r"""Pad batched data to the longest sequence in the batch.
270
+
271
+ We generate 2 * n examples where the first n examples represent chosen examples and
272
+ the last n examples represent rejected examples.
273
+ """
274
+ concatenated_features = []
275
+ for key in ("chosen", "rejected"):
276
+ for feature in features:
277
+ target_feature = {
278
+ "input_ids": feature[f"{key}_input_ids"],
279
+ "attention_mask": feature[f"{key}_attention_mask"],
280
+ "labels": feature[f"{key}_labels"],
281
+ "images": feature["images"],
282
+ "videos": feature["videos"],
283
+ "audios": feature["audios"],
284
+ }
285
+ concatenated_features.append(target_feature)
286
+
287
+ return super().__call__(concatenated_features)
288
+
289
+
290
+ @dataclass
291
+ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
292
+ r"""Data collator for KTO data."""
293
+
294
+ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
295
+ target_features = []
296
+ kl_features = []
297
+ kto_tags = []
298
+ for feature in features:
299
+ target_feature = {
300
+ "input_ids": feature["input_ids"],
301
+ "attention_mask": feature["attention_mask"],
302
+ "labels": feature["labels"],
303
+ "images": feature["images"],
304
+ "videos": feature["videos"],
305
+ "audios": feature["audios"],
306
+ }
307
+ kl_feature = {
308
+ "input_ids": feature["kl_input_ids"],
309
+ "attention_mask": feature["kl_attention_mask"],
310
+ "labels": feature["kl_labels"],
311
+ "images": feature["images"],
312
+ "videos": feature["videos"],
313
+ "audios": feature["audios"],
314
+ }
315
+ target_features.append(target_feature)
316
+ kl_features.append(kl_feature)
317
+ kto_tags.append(feature["kto_tags"])
318
+
319
+ batch = super().__call__(target_features)
320
+ kl_batch = super().__call__(kl_features)
321
+ batch["kl_input_ids"] = kl_batch["input_ids"]
322
+ batch["kl_attention_mask"] = kl_batch["attention_mask"]
323
+ batch["kl_labels"] = kl_batch["labels"]
324
+ if "cross_attention_mask" in kl_batch: # for mllama inputs
325
+ batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
326
+
327
+ if "token_type_ids" in kl_batch:
328
+ batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
329
+
330
+ batch["kto_tags"] = torch.tensor(kto_tags)
331
+ return batch
llamafactory/data/converter.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ from abc import abstractmethod
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Any, Optional, Union
19
+
20
+ from ..extras import logging
21
+ from .data_utils import Role
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from datasets import Dataset, IterableDataset
26
+ from transformers import Seq2SeqTrainingArguments
27
+
28
+ from ..hparams import DataArguments
29
+ from .mm_plugin import AudioInput, ImageInput, VideoInput
30
+ from .parser import DatasetAttr
31
+
32
+ MediaType = Union[ImageInput, VideoInput, AudioInput]
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class DatasetConverter:
40
+ dataset_attr: "DatasetAttr"
41
+ data_args: "DataArguments"
42
+
43
+ def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
44
+ r"""Optionally concatenate media path to media dir when loading from local disk."""
45
+ if medias is None:
46
+ return None
47
+ elif not isinstance(medias, list):
48
+ medias = [medias]
49
+ elif len(medias) == 0:
50
+ return None
51
+ else:
52
+ medias = medias[:]
53
+
54
+ if self.dataset_attr.load_from in ["script", "file"]:
55
+ if isinstance(medias[0], str):
56
+ for i in range(len(medias)):
57
+ media_path = os.path.join(self.data_args.media_dir, medias[i])
58
+ if os.path.isfile(media_path):
59
+ medias[i] = media_path
60
+ else:
61
+ logger.warning_rank0_once(
62
+ f"Media {medias[i]} does not exist in `media_dir`. Use original path."
63
+ )
64
+ elif isinstance(medias[0], list): # for processed video frames
65
+ # medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
66
+ for i in range(len(medias)):
67
+ for j in range(len(medias[i])):
68
+ media_path = os.path.join(self.data_args.media_dir, medias[i][j])
69
+ if os.path.isfile(media_path):
70
+ medias[i][j] = media_path
71
+ else:
72
+ logger.warning_rank0_once(
73
+ f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
74
+ )
75
+
76
+ return medias
77
+
78
+ @abstractmethod
79
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
80
+ r"""Convert a single example in the dataset to the standard format."""
81
+ ...
82
+
83
+
84
+ @dataclass
85
+ class AlpacaDatasetConverter(DatasetConverter):
86
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
87
+ prompt = []
88
+ if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
89
+ for old_prompt, old_response in example[self.dataset_attr.history]:
90
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
91
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
92
+
93
+ query = []
94
+ if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
95
+ query.append(example[self.dataset_attr.prompt])
96
+
97
+ if self.dataset_attr.query and example[self.dataset_attr.query]:
98
+ query.append(example[self.dataset_attr.query])
99
+
100
+ prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
101
+
102
+ if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
103
+ response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
104
+ if example[self.dataset_attr.kto_tag]:
105
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
106
+ else:
107
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
108
+ elif (
109
+ self.dataset_attr.ranking
110
+ and isinstance(example[self.dataset_attr.chosen], str)
111
+ and isinstance(example[self.dataset_attr.rejected], str)
112
+ ): # pairwise example
113
+ response = [
114
+ {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
115
+ {"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
116
+ ]
117
+ elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
118
+ response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
119
+ else: # unsupervised
120
+ response = []
121
+
122
+ output = {
123
+ "_prompt": prompt,
124
+ "_response": response,
125
+ "_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
126
+ "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
127
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
128
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
129
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
130
+ }
131
+ return output
132
+
133
+
134
+ @dataclass
135
+ class SharegptDatasetConverter(DatasetConverter):
136
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
137
+ tag_mapping = {
138
+ self.dataset_attr.user_tag: Role.USER.value,
139
+ self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
140
+ self.dataset_attr.observation_tag: Role.OBSERVATION.value,
141
+ self.dataset_attr.function_tag: Role.FUNCTION.value,
142
+ self.dataset_attr.system_tag: Role.SYSTEM.value,
143
+ }
144
+ odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
145
+ even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
146
+ accept_tags = (odd_tags, even_tags)
147
+ messages = example[self.dataset_attr.messages]
148
+ if (
149
+ self.dataset_attr.system_tag
150
+ and len(messages) != 0
151
+ and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
152
+ ):
153
+ system = messages[0][self.dataset_attr.content_tag]
154
+ messages = messages[1:]
155
+ else:
156
+ system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
157
+
158
+ aligned_messages = []
159
+ broken_data = False
160
+ for turn_idx, message in enumerate(messages):
161
+ if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
162
+ logger.warning_rank0(f"Invalid role tag in {messages}.")
163
+ broken_data = True
164
+ break
165
+
166
+ aligned_messages.append(
167
+ {
168
+ "role": tag_mapping[message[self.dataset_attr.role_tag]],
169
+ "content": message[self.dataset_attr.content_tag],
170
+ }
171
+ )
172
+
173
+ if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
174
+ self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
175
+ ):
176
+ logger.warning_rank0(f"Invalid message count in {messages}.")
177
+ broken_data = True
178
+
179
+ if broken_data:
180
+ logger.warning_rank0("Skipping this abnormal example.")
181
+ prompt, response = [], []
182
+ elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
183
+ prompt = aligned_messages[:-1]
184
+ response = aligned_messages[-1:]
185
+ if example[self.dataset_attr.kto_tag]:
186
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
187
+ else:
188
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
189
+ elif (
190
+ self.dataset_attr.ranking
191
+ and isinstance(example[self.dataset_attr.chosen], dict)
192
+ and isinstance(example[self.dataset_attr.rejected], dict)
193
+ ): # pairwise example
194
+ chosen = example[self.dataset_attr.chosen]
195
+ rejected = example[self.dataset_attr.rejected]
196
+ if (
197
+ chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
198
+ or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
199
+ ):
200
+ logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
201
+ broken_data = True
202
+
203
+ prompt = aligned_messages
204
+ response = [
205
+ {
206
+ "role": tag_mapping[chosen[self.dataset_attr.role_tag]],
207
+ "content": chosen[self.dataset_attr.content_tag],
208
+ },
209
+ {
210
+ "role": tag_mapping[rejected[self.dataset_attr.role_tag]],
211
+ "content": rejected[self.dataset_attr.content_tag],
212
+ },
213
+ ]
214
+ else: # normal example
215
+ prompt = aligned_messages[:-1]
216
+ response = aligned_messages[-1:]
217
+
218
+ output = {
219
+ "_prompt": prompt,
220
+ "_response": response,
221
+ "_system": system,
222
+ "_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
223
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
224
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
225
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
226
+ }
227
+ return output
228
+
229
+
230
+ @dataclass
231
+ class OpenAIDatasetConverter(DatasetConverter):
232
+ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
233
+ tag_mapping = {
234
+ self.dataset_attr.user_tag: Role.USER.value,
235
+ self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
236
+ self.dataset_attr.observation_tag: Role.OBSERVATION.value,
237
+ self.dataset_attr.function_tag: Role.FUNCTION.value,
238
+ self.dataset_attr.system_tag: Role.SYSTEM.value,
239
+ }
240
+
241
+ messages = example[self.dataset_attr.messages]
242
+ if (
243
+ self.dataset_attr.system_tag
244
+ and len(messages) != 0
245
+ and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
246
+ ):
247
+ system = messages[0][self.dataset_attr.content_tag]
248
+ messages = messages[1:]
249
+ else:
250
+ system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
251
+
252
+ aligned_messages = []
253
+ tool_responses = []
254
+ broken_data = False
255
+ for turn_idx, message in enumerate(messages):
256
+ role = message[self.dataset_attr.role_tag]
257
+ content = message[self.dataset_attr.content_tag]
258
+
259
+ if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
260
+ if "tool_calls" in message and len(message["tool_calls"]) > 0:
261
+ tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
262
+ content = json.dumps(tool_calls_list, ensure_ascii=False)
263
+ role = self.dataset_attr.function_tag
264
+
265
+ if role == self.dataset_attr.observation_tag:
266
+ tool_responses.append(content)
267
+ continue
268
+ elif len(tool_responses) > 0:
269
+ _content = "\n</tool_response>\n<tool_response>\n".join(tool_responses)
270
+ aligned_messages.append(
271
+ {
272
+ "role": Role.OBSERVATION.value,
273
+ "content": _content,
274
+ }
275
+ )
276
+ tool_responses = []
277
+
278
+ aligned_messages.append(
279
+ {
280
+ "role": tag_mapping[role],
281
+ "content": content,
282
+ }
283
+ )
284
+
285
+ odd_tags = (Role.USER.value, Role.OBSERVATION.value)
286
+ even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
287
+ accept_tags = (odd_tags, even_tags)
288
+ for turn_idx, message in enumerate(aligned_messages):
289
+ if message["role"] not in accept_tags[turn_idx % 2]:
290
+ logger.warning_rank0(f"Invalid role tag in {messages}.")
291
+ broken_data = True
292
+ break
293
+
294
+ if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
295
+ self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
296
+ ):
297
+ logger.warning_rank0(f"Invalid message count in {messages}.")
298
+ broken_data = True
299
+
300
+ if broken_data:
301
+ logger.warning_rank0("Skipping this abnormal example.")
302
+ prompt, response = [], []
303
+ elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
304
+ prompt = aligned_messages[:-1]
305
+ response = aligned_messages[-1:]
306
+ if example[self.dataset_attr.kto_tag]:
307
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
308
+ else:
309
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
310
+ elif (
311
+ self.dataset_attr.ranking
312
+ and isinstance(example[self.dataset_attr.chosen], dict)
313
+ and isinstance(example[self.dataset_attr.rejected], dict)
314
+ ): # pairwise example
315
+ chosen = example[self.dataset_attr.chosen]
316
+ rejected = example[self.dataset_attr.rejected]
317
+ if (
318
+ chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
319
+ or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
320
+ ):
321
+ logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
322
+ broken_data = True
323
+
324
+ prompt = aligned_messages
325
+ response = [
326
+ {
327
+ "role": tag_mapping[chosen[self.dataset_attr.role_tag]],
328
+ "content": chosen[self.dataset_attr.content_tag],
329
+ },
330
+ {
331
+ "role": tag_mapping[rejected[self.dataset_attr.role_tag]],
332
+ "content": rejected[self.dataset_attr.content_tag],
333
+ },
334
+ ]
335
+ else: # normal example
336
+ prompt = aligned_messages[:-1]
337
+ response = aligned_messages[-1:]
338
+
339
+ tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
340
+ if isinstance(tools, dict) or isinstance(tools, list):
341
+ tools = json.dumps(tools, ensure_ascii=False)
342
+
343
+ short_system_prompt = "detailed thinking off"
344
+ if not system:
345
+ if not tools:
346
+ system = short_system_prompt
347
+ else:
348
+ pass
349
+ else:
350
+ if not tools:
351
+ if "detailed thinking on" in system or "detailed thinking off" in system:
352
+ pass
353
+ else:
354
+ system += "\n" + short_system_prompt
355
+ else:
356
+ system += "\n"
357
+
358
+ output = {
359
+ "_prompt": prompt,
360
+ "_response": response,
361
+ "_system": system,
362
+ "_tools": tools,
363
+ "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
364
+ "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
365
+ "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
366
+ }
367
+ return output
368
+
369
+
370
+ DATASET_CONVERTERS = {
371
+ "alpaca": AlpacaDatasetConverter,
372
+ "sharegpt": SharegptDatasetConverter,
373
+ "openai": OpenAIDatasetConverter,
374
+ }
375
+
376
+
377
+ def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
378
+ r"""Register a new dataset converter."""
379
+ if name in DATASET_CONVERTERS:
380
+ raise ValueError(f"Dataset converter {name} already exists.")
381
+
382
+ DATASET_CONVERTERS[name] = dataset_converter
383
+
384
+
385
+ def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
386
+ r"""Get a dataset converter."""
387
+ if name not in DATASET_CONVERTERS:
388
+ raise ValueError(f"Dataset converter {name} not found.")
389
+
390
+ return DATASET_CONVERTERS[name](dataset_attr, data_args)
391
+
392
+
393
+ def align_dataset(
394
+ dataset: Union["Dataset", "IterableDataset"],
395
+ dataset_attr: "DatasetAttr",
396
+ data_args: "DataArguments",
397
+ training_args: "Seq2SeqTrainingArguments",
398
+ ) -> Union["Dataset", "IterableDataset"]:
399
+ r"""Align the dataset to a specific format.
400
+
401
+ Aligned dataset:
402
+ _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
403
+ _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
404
+ _system: "..."
405
+ _tools: "..."
406
+ _images: []
407
+ _videos: []
408
+ _audios: []
409
+ """
410
+ column_names = list(next(iter(dataset)).keys())
411
+ kwargs = {}
412
+ if not data_args.streaming:
413
+ kwargs = dict(
414
+ num_proc=data_args.preprocessing_num_workers,
415
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
416
+ desc="Converting format of dataset",
417
+ )
418
+
419
+ dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
420
+ return dataset.map(
421
+ dataset_converter,
422
+ batched=False,
423
+ remove_columns=column_names,
424
+ **kwargs,
425
+ )
llamafactory/data/data_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from enum import Enum, unique
17
+ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
18
+
19
+ import fsspec
20
+ from datasets import DatasetDict, concatenate_datasets, interleave_datasets
21
+
22
+ from ..extras import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from datasets import Dataset, IterableDataset
27
+
28
+ from ..hparams import DataArguments
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ SLOTS = list[Union[str, set[str], dict[str, str]]]
35
+
36
+
37
+ @unique
38
+ class Role(str, Enum):
39
+ USER = "user"
40
+ ASSISTANT = "assistant"
41
+ SYSTEM = "system"
42
+ FUNCTION = "function"
43
+ OBSERVATION = "observation"
44
+
45
+
46
+ class DatasetModule(TypedDict):
47
+ train_dataset: Optional[Union["Dataset", "IterableDataset"]]
48
+ eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
49
+
50
+
51
+ def merge_dataset(
52
+ all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
53
+ ) -> Union["Dataset", "IterableDataset"]:
54
+ r"""Merge multiple datasets to a unified dataset."""
55
+ if len(all_datasets) == 1:
56
+ return all_datasets[0]
57
+
58
+ elif data_args.mix_strategy == "concat":
59
+ if data_args.streaming:
60
+ logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
61
+
62
+ return concatenate_datasets(all_datasets)
63
+
64
+ elif data_args.mix_strategy.startswith("interleave"):
65
+ if not data_args.streaming:
66
+ logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
67
+
68
+ return interleave_datasets(
69
+ datasets=all_datasets,
70
+ probabilities=data_args.interleave_probs,
71
+ seed=seed,
72
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
73
+ )
74
+
75
+ else:
76
+ raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
77
+
78
+
79
+ def split_dataset(
80
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
81
+ eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
82
+ data_args: "DataArguments",
83
+ seed: int,
84
+ ) -> "DatasetDict":
85
+ r"""Split the dataset and returns a dataset dict containing train set and validation set.
86
+
87
+ Support both map dataset and iterable dataset.
88
+ """
89
+ if eval_dataset is not None and data_args.val_size > 1e-6:
90
+ raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
91
+
92
+ dataset_dict = {}
93
+ if dataset is not None:
94
+ if data_args.streaming:
95
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
96
+
97
+ if data_args.val_size > 1e-6:
98
+ if data_args.streaming:
99
+ dataset_dict["validation"] = dataset.take(int(data_args.val_size))
100
+ dataset_dict["train"] = dataset.skip(int(data_args.val_size))
101
+ else:
102
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
103
+ dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
104
+ dataset = dataset.train_test_split(test_size=val_size, seed=seed)
105
+ dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
106
+ else:
107
+ dataset_dict["train"] = dataset
108
+
109
+ if eval_dataset is not None:
110
+ if isinstance(eval_dataset, dict):
111
+ dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
112
+ else:
113
+ if data_args.streaming:
114
+ eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
115
+
116
+ dataset_dict["validation"] = eval_dataset
117
+
118
+ return DatasetDict(dataset_dict)
119
+
120
+
121
+ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
122
+ r"""Convert dataset or dataset dict to dataset module."""
123
+ dataset_module: DatasetModule = {}
124
+ if isinstance(dataset, DatasetDict): # dataset dict
125
+ if "train" in dataset:
126
+ dataset_module["train_dataset"] = dataset["train"]
127
+
128
+ if "validation" in dataset:
129
+ dataset_module["eval_dataset"] = dataset["validation"]
130
+ else:
131
+ eval_dataset = {}
132
+ for key in dataset.keys():
133
+ if key.startswith("validation_"):
134
+ eval_dataset[key[len("validation_") :]] = dataset[key]
135
+
136
+ if len(eval_dataset):
137
+ dataset_module["eval_dataset"] = eval_dataset
138
+
139
+ else: # single dataset
140
+ dataset_module["train_dataset"] = dataset
141
+
142
+ return dataset_module
143
+
144
+
145
+ def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
146
+ r"""Set up a filesystem object based on the path protocol."""
147
+ storage_options = {"anon": anon} if anon else {}
148
+ if path.startswith("s3://"):
149
+ fs = fsspec.filesystem("s3", **storage_options)
150
+ elif path.startswith(("gs://", "gcs://")):
151
+ fs = fsspec.filesystem("gcs", **storage_options)
152
+ else:
153
+ raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
154
+
155
+ if not fs.exists(path):
156
+ raise ValueError(f"Path does not exist: {path}.")
157
+
158
+ return fs
159
+
160
+
161
+ def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
162
+ r"""Helper function to read JSON/JSONL files using fsspec."""
163
+ with fs.open(path, "r") as f:
164
+ if path.endswith(".jsonl"):
165
+ return [json.loads(line) for line in f if line.strip()]
166
+ else:
167
+ return json.load(f)
168
+
169
+
170
+ def read_cloud_json(cloud_path: str) -> list[Any]:
171
+ r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
172
+
173
+ Args:
174
+ cloud_path: str
175
+ Cloud path in the format:
176
+ - 's3://bucket-name/file.json' for AWS S3
177
+ - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
178
+ """
179
+ try:
180
+ fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
181
+ except Exception:
182
+ fs = setup_fs(cloud_path) # try again with credentials
183
+
184
+ # filter out non-JSON files
185
+ files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
186
+ files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
187
+ if not files:
188
+ raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
189
+
190
+ return sum([_read_json_with_fs(fs, file) for file in files], [])
llamafactory/data/formatter.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional, Union
20
+
21
+ from typing_extensions import override
22
+
23
+ from .data_utils import SLOTS
24
+ from .tool_utils import FunctionCall, get_tool_utils
25
+
26
+
27
+ @dataclass
28
+ class Formatter(ABC):
29
+ slots: SLOTS = field(default_factory=list)
30
+ tool_format: Optional[str] = None
31
+
32
+ @abstractmethod
33
+ def apply(self, **kwargs) -> SLOTS:
34
+ r"""Forms a list of slots according to the inputs to encode."""
35
+ ...
36
+
37
+ def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
38
+ r"""Extract a list of tuples from the response message if using tools.
39
+
40
+ Each tuple consists of function name and function arguments.
41
+ """
42
+ raise NotImplementedError
43
+
44
+
45
+ @dataclass
46
+ class EmptyFormatter(Formatter):
47
+ def __post_init__(self):
48
+ has_placeholder = False
49
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
50
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
51
+ has_placeholder = True
52
+
53
+ if has_placeholder:
54
+ raise ValueError("Empty formatter should not contain any placeholder.")
55
+
56
+ @override
57
+ def apply(self, **kwargs) -> SLOTS:
58
+ return self.slots
59
+
60
+
61
+ @dataclass
62
+ class StringFormatter(Formatter):
63
+ def __post_init__(self):
64
+ has_placeholder = False
65
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
66
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
67
+ has_placeholder = True
68
+
69
+ if not has_placeholder:
70
+ raise ValueError("A placeholder is required in the string formatter.")
71
+
72
+ @override
73
+ def apply(self, **kwargs) -> SLOTS:
74
+ elements = []
75
+ for slot in self.slots:
76
+ if isinstance(slot, str):
77
+ for name, value in kwargs.items():
78
+ if not isinstance(value, str):
79
+ raise RuntimeError(f"Expected a string, got {value}")
80
+
81
+ slot = slot.replace("{{" + name + "}}", value, 1)
82
+ elements.append(slot)
83
+ elif isinstance(slot, (dict, set)):
84
+ elements.append(slot)
85
+ else:
86
+ raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
87
+
88
+ return elements
89
+
90
+
91
+ @dataclass
92
+ class FunctionFormatter(StringFormatter):
93
+ def __post_init__(self):
94
+ super().__post_init__()
95
+ self.tool_utils = get_tool_utils(self.tool_format)
96
+
97
+ @override
98
+ def apply(self, **kwargs) -> SLOTS:
99
+ content: str = kwargs.pop("content")
100
+ thought_words, thought = kwargs.pop("thought_words", None), None
101
+ if thought_words and len(thought_words) == 2:
102
+ regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
103
+ thought = re.search(regex, content)
104
+
105
+ if thought:
106
+ content = content.replace(thought.group(0), "")
107
+
108
+ functions: list[FunctionCall] = []
109
+ try:
110
+ tool_calls = json.loads(content)
111
+ if not isinstance(tool_calls, list): # parallel function call
112
+ tool_calls = [tool_calls]
113
+
114
+ for tool_call in tool_calls:
115
+ functions.append(
116
+ FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
117
+ )
118
+
119
+ except json.JSONDecodeError:
120
+ raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
121
+
122
+ function_str = self.tool_utils.function_formatter(functions)
123
+ if thought:
124
+ function_str = thought.group(0) + function_str
125
+
126
+ return super().apply(content=function_str)
127
+
128
+
129
+ @dataclass
130
+ class ToolFormatter(Formatter):
131
+ def __post_init__(self):
132
+ self.tool_utils = get_tool_utils(self.tool_format)
133
+
134
+ @override
135
+ def apply(self, **kwargs) -> SLOTS:
136
+ content = kwargs.pop("content")
137
+ try:
138
+ tools = json.loads(content)
139
+ return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
140
+ except json.JSONDecodeError:
141
+ raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
142
+
143
+ @override
144
+ def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
145
+ return self.tool_utils.tool_extractor(content)
llamafactory/data/loader.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import TYPE_CHECKING, Literal, Optional, Union
17
+
18
+ import numpy as np
19
+ from datasets import Dataset, load_dataset, load_from_disk
20
+
21
+ from ..extras import logging
22
+ from ..extras.constants import FILEEXT2TYPE
23
+ from ..extras.misc import check_version, has_tokenized_data
24
+ from .converter import align_dataset
25
+ from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
26
+ from .parser import get_dataset_list
27
+ from .processor import (
28
+ FeedbackDatasetProcessor,
29
+ PackedSupervisedDatasetProcessor,
30
+ PairwiseDatasetProcessor,
31
+ PretrainDatasetProcessor,
32
+ SupervisedDatasetProcessor,
33
+ UnsupervisedDatasetProcessor,
34
+ )
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from datasets import Dataset, IterableDataset
39
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
40
+
41
+ from ..hparams import DataArguments, ModelArguments
42
+ from .data_utils import DatasetModule
43
+ from .parser import DatasetAttr
44
+ from .processor import DatasetProcessor
45
+ from .template import Template
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def _load_single_dataset(
52
+ dataset_attr: "DatasetAttr",
53
+ model_args: "ModelArguments",
54
+ data_args: "DataArguments",
55
+ training_args: "Seq2SeqTrainingArguments",
56
+ ) -> Union["Dataset", "IterableDataset"]:
57
+ r"""Load a single dataset and aligns it to the standard format."""
58
+ logger.info_rank0(f"Loading dataset {dataset_attr}...")
59
+ data_path, data_name, data_dir, data_files = None, None, None, None
60
+ if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
61
+ data_path = dataset_attr.dataset_name
62
+ data_name = dataset_attr.subset
63
+ data_dir = dataset_attr.folder
64
+
65
+ elif dataset_attr.load_from == "script":
66
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
67
+ data_name = dataset_attr.subset
68
+ data_dir = dataset_attr.folder
69
+
70
+ elif dataset_attr.load_from == "cloud_file":
71
+ data_path = dataset_attr.dataset_name
72
+
73
+ elif dataset_attr.load_from == "file":
74
+ data_files = []
75
+ local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
76
+ if os.path.isdir(local_path): # is directory
77
+ for file_name in os.listdir(local_path):
78
+ data_files.append(os.path.join(local_path, file_name))
79
+ elif os.path.isfile(local_path): # is file
80
+ data_files.append(local_path)
81
+ else:
82
+ raise ValueError(f"File {local_path} not found.")
83
+
84
+ data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
85
+ if data_path is None:
86
+ raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
87
+
88
+ if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
89
+ raise ValueError("File types should be identical.")
90
+ else:
91
+ raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
92
+
93
+ if dataset_attr.load_from == "ms_hub":
94
+ check_version("modelscope>=1.14.0", mandatory=True)
95
+ from modelscope import MsDataset # type: ignore
96
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
97
+
98
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
99
+ dataset = MsDataset.load(
100
+ dataset_name=data_path,
101
+ subset_name=data_name,
102
+ data_dir=data_dir,
103
+ data_files=data_files,
104
+ split=dataset_attr.split,
105
+ cache_dir=cache_dir,
106
+ token=model_args.ms_hub_token,
107
+ use_streaming=data_args.streaming,
108
+ )
109
+ if isinstance(dataset, MsDataset):
110
+ dataset = dataset.to_hf_dataset()
111
+
112
+ elif dataset_attr.load_from == "om_hub":
113
+ check_version("openmind>=0.8.0", mandatory=True)
114
+ from openmind import OmDataset # type: ignore
115
+ from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
116
+
117
+ cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
118
+ dataset = OmDataset.load_dataset(
119
+ path=data_path,
120
+ name=data_name,
121
+ data_dir=data_dir,
122
+ data_files=data_files,
123
+ split=dataset_attr.split,
124
+ cache_dir=cache_dir,
125
+ token=model_args.om_hub_token,
126
+ streaming=data_args.streaming,
127
+ )
128
+ elif dataset_attr.load_from == "cloud_file":
129
+ dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
130
+ else:
131
+ dataset = load_dataset(
132
+ path=data_path,
133
+ name=data_name,
134
+ data_dir=data_dir,
135
+ data_files=data_files,
136
+ split=dataset_attr.split,
137
+ cache_dir=model_args.cache_dir,
138
+ token=model_args.hf_hub_token,
139
+ num_proc=data_args.preprocessing_num_workers,
140
+ streaming=data_args.streaming and dataset_attr.load_from != "file",
141
+ )
142
+ if data_args.streaming and dataset_attr.load_from == "file":
143
+ dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
144
+
145
+ if dataset_attr.num_samples is not None and not data_args.streaming:
146
+ target_num = dataset_attr.num_samples
147
+ indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
148
+ target_num -= len(indexes)
149
+ if target_num > 0:
150
+ expand_indexes = np.random.choice(len(dataset), target_num)
151
+ indexes = np.concatenate((indexes, expand_indexes), axis=0)
152
+
153
+ assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
154
+ dataset = dataset.select(indexes)
155
+ logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
156
+
157
+ if data_args.max_samples is not None: # truncate dataset
158
+ max_samples = min(data_args.max_samples, len(dataset))
159
+ dataset = dataset.select(range(max_samples))
160
+
161
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
162
+
163
+
164
+ def _get_merged_dataset(
165
+ dataset_names: Optional[list[str]],
166
+ model_args: "ModelArguments",
167
+ data_args: "DataArguments",
168
+ training_args: "Seq2SeqTrainingArguments",
169
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
170
+ return_dict: bool = False,
171
+ ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
172
+ r"""Return the merged datasets in the standard format."""
173
+ if dataset_names is None:
174
+ return None
175
+
176
+ datasets = {}
177
+ for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
178
+ if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
179
+ raise ValueError("The dataset is not applicable in the current training stage.")
180
+
181
+ datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
182
+
183
+ if return_dict:
184
+ return datasets
185
+ else:
186
+ return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
187
+
188
+
189
+ def _get_dataset_processor(
190
+ data_args: "DataArguments",
191
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
192
+ template: "Template",
193
+ tokenizer: "PreTrainedTokenizer",
194
+ processor: Optional["ProcessorMixin"],
195
+ do_generate: bool = False,
196
+ ) -> "DatasetProcessor":
197
+ r"""Return the corresponding dataset processor."""
198
+ if stage == "pt":
199
+ dataset_processor_class = PretrainDatasetProcessor
200
+ elif stage == "sft" and not do_generate:
201
+ if data_args.packing:
202
+ if data_args.neat_packing: # hack datasets to have int32 attention mask
203
+ from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
204
+
205
+ def __init__(self, data, **kwargs):
206
+ return TypedSequence.__init__(
207
+ self,
208
+ data,
209
+ type=kwargs.pop("type", None),
210
+ try_type=kwargs.pop("try_type", None),
211
+ optimized_int_type=kwargs.pop("optimized_int_type", None),
212
+ )
213
+
214
+ OptimizedTypedSequence.__init__ = __init__
215
+ dataset_processor_class = PackedSupervisedDatasetProcessor
216
+ else:
217
+ dataset_processor_class = SupervisedDatasetProcessor
218
+
219
+ elif stage == "rm":
220
+ dataset_processor_class = PairwiseDatasetProcessor
221
+ elif stage == "kto":
222
+ dataset_processor_class = FeedbackDatasetProcessor
223
+ else:
224
+ dataset_processor_class = UnsupervisedDatasetProcessor
225
+
226
+ return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
227
+
228
+
229
+ def _get_preprocessed_dataset(
230
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
231
+ data_args: "DataArguments",
232
+ training_args: "Seq2SeqTrainingArguments",
233
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
234
+ template: "Template",
235
+ tokenizer: "PreTrainedTokenizer",
236
+ processor: Optional["ProcessorMixin"] = None,
237
+ is_eval: bool = False,
238
+ ) -> Optional[Union["Dataset", "IterableDataset"]]:
239
+ r"""Preprocesses the dataset, including format checking and tokenization."""
240
+ if dataset is None:
241
+ return None
242
+
243
+ dataset_processor = _get_dataset_processor(
244
+ data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
245
+ )
246
+ column_names = list(next(iter(dataset)).keys())
247
+ kwargs = {}
248
+ if not data_args.streaming:
249
+ kwargs = dict(
250
+ num_proc=data_args.preprocessing_num_workers,
251
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
252
+ desc="Running tokenizer on dataset",
253
+ )
254
+
255
+ dataset = dataset.map(
256
+ dataset_processor.preprocess_dataset,
257
+ batched=True,
258
+ batch_size=data_args.preprocessing_batch_size,
259
+ remove_columns=column_names,
260
+ **kwargs,
261
+ )
262
+
263
+ if training_args.should_log:
264
+ try:
265
+ print("eval example:" if is_eval else "training example:")
266
+ dataset_processor.print_data_example(next(iter(dataset)))
267
+ except StopIteration:
268
+ if stage == "pt":
269
+ raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
270
+ else:
271
+ raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
272
+
273
+ return dataset
274
+
275
+
276
+ def get_dataset(
277
+ template: "Template",
278
+ model_args: "ModelArguments",
279
+ data_args: "DataArguments",
280
+ training_args: "Seq2SeqTrainingArguments",
281
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
282
+ tokenizer: "PreTrainedTokenizer",
283
+ processor: Optional["ProcessorMixin"] = None,
284
+ ) -> "DatasetModule":
285
+ r"""Get the train dataset and optionally gets the evaluation dataset."""
286
+ # Load tokenized dataset if path exists
287
+ if data_args.tokenized_path is not None:
288
+ if has_tokenized_data(data_args.tokenized_path):
289
+ logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
290
+ tokenized_data = load_from_disk(data_args.tokenized_path)
291
+ dataset_module = get_dataset_module(tokenized_data)
292
+ if data_args.streaming:
293
+ dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset()
294
+
295
+ logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
296
+ return dataset_module
297
+
298
+ if data_args.streaming:
299
+ raise ValueError("Turn off `streaming` when saving dataset to disk.")
300
+
301
+ # Load and preprocess dataset
302
+ with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
303
+ dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
304
+ eval_dataset = _get_merged_dataset(
305
+ data_args.eval_dataset,
306
+ model_args,
307
+ data_args,
308
+ training_args,
309
+ stage,
310
+ return_dict=data_args.eval_on_each_dataset,
311
+ )
312
+
313
+ with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
314
+ dataset = _get_preprocessed_dataset(
315
+ dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
316
+ )
317
+ if isinstance(eval_dataset, dict):
318
+ for eval_name, eval_data in eval_dataset.items():
319
+ eval_dataset[eval_name] = _get_preprocessed_dataset(
320
+ eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
321
+ )
322
+ else:
323
+ eval_dataset = _get_preprocessed_dataset(
324
+ eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
325
+ )
326
+
327
+ dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
328
+ if data_args.tokenized_path is not None: # save tokenized dataset to disk
329
+ if training_args.should_save:
330
+ dataset_dict.save_to_disk(data_args.tokenized_path)
331
+ logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
332
+ logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")
333
+
334
+ return get_dataset_module(dataset_dict)
llamafactory/data/mm_plugin.py ADDED
@@ -0,0 +1,2082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's Transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import inspect
19
+ import math
20
+ import os
21
+ import re
22
+ from copy import deepcopy
23
+ from dataclasses import dataclass
24
+ from io import BytesIO
25
+ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
26
+
27
+ import numpy as np
28
+ import torch
29
+ from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
30
+ from transformers.models.mllama.processing_mllama import (
31
+ convert_sparse_cross_attention_mask_to_dense,
32
+ get_cross_attention_token_mask,
33
+ )
34
+ from typing_extensions import NotRequired, override
35
+
36
+ from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
37
+ from ..extras.packages import (
38
+ is_librosa_available,
39
+ is_pillow_available,
40
+ is_pyav_available,
41
+ is_transformers_version_greater_than,
42
+ )
43
+
44
+
45
+ if is_librosa_available():
46
+ import librosa
47
+
48
+
49
+ if is_pillow_available():
50
+ from PIL import Image
51
+ from PIL.Image import Image as ImageObject
52
+
53
+
54
+ if is_pyav_available():
55
+ import av
56
+
57
+
58
+ if is_transformers_version_greater_than("4.52.0"):
59
+ from transformers.image_utils import make_flat_list_of_images
60
+ from transformers.video_utils import make_batched_videos
61
+ else:
62
+ from transformers.image_utils import make_batched_videos, make_flat_list_of_images
63
+
64
+
65
+ if TYPE_CHECKING:
66
+ from av.stream import Stream
67
+ from numpy.typing import NDArray
68
+ from transformers import PreTrainedTokenizer, ProcessorMixin
69
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
70
+ from transformers.image_processing_utils import BaseImageProcessor
71
+ from transformers.video_processing_utils import BaseVideoProcessor
72
+
73
+ class EncodedImage(TypedDict):
74
+ path: Optional[str]
75
+ bytes: Optional[bytes]
76
+
77
+ ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
78
+ VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
79
+ AudioInput = Union[str, BinaryIO, NDArray]
80
+
81
+ class RegularizedImageOutput(TypedDict):
82
+ images: list[ImageObject]
83
+
84
+ class RegularizedVideoOutput(TypedDict):
85
+ videos: list[list[ImageObject]]
86
+ durations: list[float]
87
+ fps_per_video: NotRequired[list[float]]
88
+
89
+ class RegularizedAudioOutput(TypedDict):
90
+ audios: list[NDArray]
91
+ sampling_rates: list[float]
92
+
93
+ class MMProcessor(ProcessorMixin):
94
+ patch_size: int
95
+ image_seq_length: int
96
+ num_additional_image_tokens: int
97
+ vision_feature_select_strategy: Literal["default", "full"]
98
+
99
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
100
+ pass
101
+
102
+
103
+ def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
104
+ r"""Get paligemma token type ids for computing loss.
105
+
106
+ It is slightly different with the original token type ids where the prompt part is 0.
107
+
108
+ Returns:
109
+ batch_token_type_ids: shape (batch_size, seq_length)
110
+
111
+ """
112
+ batch_token_type_ids = []
113
+ for imglen, seqlen in zip(imglens, seqlens):
114
+ image_seqlen = imglen * processor.image_seq_length
115
+ batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
116
+
117
+ return batch_token_type_ids
118
+
119
+
120
+ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
121
+ r"""Get gemma3 token type ids for computing loss.
122
+
123
+ Returns:
124
+ batch_token_type_ids: shape (batch_size, seq_length)
125
+
126
+ """
127
+ image_token_id: int = getattr(processor, "image_token_id")
128
+ batch_token_type_ids = []
129
+ for token_ids in batch_ids:
130
+ token_ids = np.array(token_ids)
131
+ token_type_ids = np.zeros_like(token_ids)
132
+ token_type_ids[token_ids == image_token_id] = 1
133
+ batch_token_type_ids.append(token_type_ids.tolist())
134
+
135
+ return batch_token_type_ids
136
+
137
+
138
+ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
139
+ r"""Make nested list of images."""
140
+ batch_images = []
141
+ for imglen in imglens:
142
+ batch_images.append(images[:imglen])
143
+ images = images[imglen:]
144
+
145
+ return batch_images
146
+
147
+
148
+ def _check_video_is_nested_images(video: "VideoInput") -> bool:
149
+ r"""Check if the video is nested images."""
150
+ return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
151
+
152
+
153
+ @dataclass
154
+ class MMPluginMixin:
155
+ image_token: Optional[str]
156
+ video_token: Optional[str]
157
+ audio_token: Optional[str]
158
+ expand_mm_tokens: bool = True
159
+
160
+ def _validate_input(
161
+ self,
162
+ processor: Optional["MMProcessor"],
163
+ images: list["ImageInput"],
164
+ videos: list["VideoInput"],
165
+ audios: list["AudioInput"],
166
+ ) -> None:
167
+ r"""Validate if this model accepts the input modalities."""
168
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
169
+ video_processor: BaseImageProcessor = getattr(
170
+ processor, "video_processor", getattr(processor, "image_processor", None)
171
+ )
172
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
173
+ if len(images) != 0 and self.image_token is None:
174
+ raise ValueError(
175
+ "This model does not support image input. Please check whether the correct `template` is used."
176
+ )
177
+
178
+ if len(videos) != 0 and self.video_token is None:
179
+ raise ValueError(
180
+ "This model does not support video input. Please check whether the correct `template` is used."
181
+ )
182
+
183
+ if len(audios) != 0 and self.audio_token is None:
184
+ raise ValueError(
185
+ "This model does not support audio input. Please check whether the correct `template` is used."
186
+ )
187
+
188
+ if self.image_token is not None and processor is None:
189
+ raise ValueError("Processor was not found, please check and update your model file.")
190
+
191
+ if self.image_token is not None and image_processor is None:
192
+ raise ValueError("Image processor was not found, please check and update your model file.")
193
+
194
+ if self.video_token is not None and video_processor is None:
195
+ raise ValueError("Video processor was not found, please check and update your model file.")
196
+
197
+ if self.audio_token is not None and feature_extractor is None:
198
+ raise ValueError("Audio feature extractor was not found, please check and update your model file.")
199
+
200
+ def _validate_messages(
201
+ self,
202
+ messages: list[dict[str, str]],
203
+ images: list["ImageInput"],
204
+ videos: list["VideoInput"],
205
+ audios: list["AudioInput"],
206
+ ):
207
+ r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
208
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
209
+ for message in messages:
210
+ num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
211
+ num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
212
+ num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
213
+
214
+ if len(images) != num_image_tokens:
215
+ raise ValueError(
216
+ f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
217
+ )
218
+
219
+ if len(videos) != num_video_tokens:
220
+ raise ValueError(
221
+ f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
222
+ )
223
+
224
+ if len(audios) != num_audio_tokens:
225
+ raise ValueError(
226
+ f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
227
+ )
228
+
229
+ def _preprocess_image(
230
+ self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
231
+ ) -> "ImageObject":
232
+ r"""Pre-process a single image."""
233
+ if (image.width * image.height) > image_max_pixels:
234
+ resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
235
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
236
+ image = image.resize((width, height))
237
+
238
+ if (image.width * image.height) < image_min_pixels:
239
+ resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
240
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
241
+ image = image.resize((width, height))
242
+
243
+ if image.mode != "RGB":
244
+ image = image.convert("RGB")
245
+
246
+ return image
247
+
248
+ def _get_video_sample_indices(
249
+ self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
250
+ ) -> list[int]:
251
+ r"""Compute video sample indices according to fps."""
252
+ total_frames = video_stream.frames
253
+ if total_frames == 0: # infinite video
254
+ return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
255
+
256
+ sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps))
257
+ sample_frames = min(total_frames, video_maxlen, sample_frames)
258
+ return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
259
+
260
+ def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
261
+ r"""Regularize images to avoid error. Including reading and pre-processing."""
262
+ results = []
263
+ for image in images:
264
+ if isinstance(image, (str, BinaryIO)):
265
+ image = Image.open(image)
266
+ elif isinstance(image, bytes):
267
+ image = Image.open(BytesIO(image))
268
+ elif isinstance(image, dict):
269
+ if image["bytes"] is not None:
270
+ image = Image.open(BytesIO(image["bytes"]))
271
+ else:
272
+ image = Image.open(image["path"])
273
+
274
+ if not isinstance(image, ImageObject):
275
+ raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
276
+
277
+ results.append(self._preprocess_image(image, **kwargs))
278
+
279
+ return {"images": results}
280
+
281
+ def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
282
+ r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
283
+ results = []
284
+ durations = []
285
+ for video in videos:
286
+ frames: list[ImageObject] = []
287
+ if _check_video_is_nested_images(video):
288
+ for frame in video:
289
+ if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
290
+ raise ValueError("Invalid image found in video frames.")
291
+ frames = video
292
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
293
+ else:
294
+ container = av.open(video, "r")
295
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
296
+ sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
297
+ container.seek(0)
298
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
299
+ if frame_idx in sample_indices:
300
+ frames.append(frame.to_image())
301
+
302
+ if video_stream.duration is None:
303
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
304
+ else:
305
+ durations.append(float(video_stream.duration * video_stream.time_base))
306
+
307
+ frames = self._regularize_images(frames, **kwargs)["images"]
308
+ results.append(frames)
309
+
310
+ return {"videos": results, "durations": durations}
311
+
312
+ def _regularize_audios(
313
+ self, audios: list["AudioInput"], sampling_rate: float, **kwargs
314
+ ) -> "RegularizedAudioOutput":
315
+ r"""Regularizes audios to avoid error. Including reading and resampling."""
316
+ results, sampling_rates = [], []
317
+ for audio in audios:
318
+ if not isinstance(audio, np.ndarray):
319
+ audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
320
+
321
+ results.append(audio)
322
+ sampling_rates.append(sampling_rate)
323
+
324
+ return {"audios": results, "sampling_rates": sampling_rates}
325
+
326
+ def _get_mm_inputs(
327
+ self,
328
+ images: list["ImageInput"],
329
+ videos: list["VideoInput"],
330
+ audios: list["AudioInput"],
331
+ processor: "MMProcessor",
332
+ imglens: Optional[list[int]] = None,
333
+ ) -> dict[str, "torch.Tensor"]:
334
+ r"""Process visual inputs.
335
+
336
+ Returns: (llava and paligemma)
337
+ pixel_values: tensor with shape (B, C, H, W)
338
+
339
+ Returns: (qwen2-vl)
340
+ pixel_values: tensor with shape (num_patches, patch_dim)
341
+ image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
342
+ where num_patches == torch.prod(image_grid_thw)
343
+
344
+ Returns: (mllama)
345
+ pixel_values: tensor with shape
346
+ (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
347
+ For example, (2, 1, 4, 3, 560, 560).
348
+ aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
349
+ aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
350
+ num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
351
+
352
+ """
353
+ mm_inputs = {}
354
+ if len(images) != 0:
355
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
356
+ images = self._regularize_images(
357
+ images,
358
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
359
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
360
+ )["images"]
361
+ if imglens is not None: # if imglens are provided, make batched images
362
+ images = _make_batched_images(images, imglens)
363
+
364
+ image_processor_kwargs = {}
365
+ if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
366
+ image_processor_kwargs.update(
367
+ {
368
+ "do_pan_and_scan": True,
369
+ "pan_and_scan_min_crop_size": 256,
370
+ "pan_and_scan_max_num_crops": 4,
371
+ "pan_and_scan_min_ratio_to_activate": 1.2,
372
+ }
373
+ )
374
+
375
+ mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
376
+
377
+ if len(videos) != 0:
378
+ video_processor: BaseImageProcessor = getattr(
379
+ processor, "video_processor", getattr(processor, "image_processor", None)
380
+ )
381
+ videos = self._regularize_videos(
382
+ videos,
383
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
384
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
385
+ video_fps=getattr(processor, "video_fps", 2.0),
386
+ video_maxlen=getattr(processor, "video_maxlen", 128),
387
+ )["videos"]
388
+ if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
389
+ mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
390
+ else: # for llava_next_video
391
+ mm_inputs.update(video_processor(videos, return_tensors="pt"))
392
+
393
+ if len(audios) != 0:
394
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
395
+ audios = self._regularize_audios(
396
+ audios,
397
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
398
+ )["audios"]
399
+ mm_inputs.update(
400
+ feature_extractor(
401
+ audios,
402
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
403
+ return_attention_mask=True,
404
+ padding="max_length",
405
+ return_tensors="pt",
406
+ )
407
+ )
408
+ mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
409
+
410
+ return mm_inputs
411
+
412
+
413
+ @dataclass
414
+ class BasePlugin(MMPluginMixin):
415
+ def process_messages(
416
+ self,
417
+ messages: list[dict[str, str]],
418
+ images: list["ImageInput"],
419
+ videos: list["VideoInput"],
420
+ audios: list["AudioInput"],
421
+ processor: Optional["MMProcessor"],
422
+ ) -> list[dict[str, str]]:
423
+ r"""Pre-process input messages before tokenization for VLMs."""
424
+ self._validate_input(processor, images, videos, audios)
425
+ return messages
426
+
427
+ def process_token_ids(
428
+ self,
429
+ input_ids: list[int],
430
+ labels: Optional[list[int]],
431
+ images: list["ImageInput"],
432
+ videos: list["VideoInput"],
433
+ audios: list["AudioInput"],
434
+ tokenizer: "PreTrainedTokenizer",
435
+ processor: Optional["MMProcessor"],
436
+ ) -> tuple[list[int], Optional[list[int]]]:
437
+ r"""Pre-process token ids after tokenization for VLMs."""
438
+ self._validate_input(processor, images, videos, audios)
439
+ return input_ids, labels
440
+
441
+ def get_mm_inputs(
442
+ self,
443
+ images: list["ImageInput"],
444
+ videos: list["VideoInput"],
445
+ audios: list["AudioInput"],
446
+ imglens: list[int],
447
+ vidlens: list[int],
448
+ audlens: list[int],
449
+ batch_ids: list[list[int]],
450
+ processor: Optional["MMProcessor"],
451
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
452
+ r"""Build batched multimodal inputs for VLMs.
453
+
454
+ Arguments:
455
+ images: a list of image inputs, shape (num_images,)
456
+ videos: a list of video inputs, shape (num_videos,)
457
+ audios: a list of audio inputs, shape (num_audios,)
458
+ imglens: number of images in each sample, shape (batch_size,)
459
+ vidlens: number of videos in each sample, shape (batch_size,)
460
+ audlens: number of audios in each sample, shape (batch_size,)
461
+ batch_ids: token ids of input samples, shape (batch_size, seq_len)
462
+ processor: a processor for pre-processing images and videos
463
+
464
+ """
465
+ self._validate_input(processor, images, videos, audios)
466
+ return self._get_mm_inputs(images, videos, audios, processor)
467
+
468
+
469
+ @dataclass
470
+ class Gemma3Plugin(BasePlugin):
471
+ @override
472
+ def process_messages(
473
+ self,
474
+ messages: list[dict[str, str]],
475
+ images: list["ImageInput"],
476
+ videos: list["VideoInput"],
477
+ audios: list["AudioInput"],
478
+ processor: Optional["MMProcessor"],
479
+ ) -> list[dict[str, str]]:
480
+ self._validate_input(processor, images, videos, audios)
481
+ self._validate_messages(messages, images, videos, audios)
482
+ num_image_tokens = 0
483
+ messages = deepcopy(messages)
484
+ boi_token: str = getattr(processor, "boi_token")
485
+ full_image_sequence: str = getattr(processor, "full_image_sequence")
486
+ image_str = full_image_sequence if self.expand_mm_tokens else boi_token
487
+
488
+ do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
489
+ if do_pan_and_scan:
490
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
491
+
492
+ for message in messages:
493
+ content = message["content"]
494
+ while IMAGE_PLACEHOLDER in content:
495
+ if do_pan_and_scan:
496
+ image_placeholder_str = (
497
+ "Here is the original image {{image}} and here are some crops to help you see better "
498
+ + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
499
+ )
500
+ else:
501
+ image_placeholder_str = "{{image}}"
502
+
503
+ content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
504
+ num_image_tokens += 1
505
+
506
+ message["content"] = content.replace("{{image}}", image_str)
507
+
508
+ return messages
509
+
510
+ @override
511
+ def get_mm_inputs(
512
+ self,
513
+ images: list["ImageInput"],
514
+ videos: list["VideoInput"],
515
+ audios: list["AudioInput"],
516
+ imglens: list[int],
517
+ vidlens: list[int],
518
+ audlens: list[int],
519
+ batch_ids: list[list[int]],
520
+ processor: Optional["MMProcessor"],
521
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
522
+ self._validate_input(processor, images, videos, audios)
523
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
524
+ mm_inputs.pop("num_crops", None)
525
+ mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
526
+ return mm_inputs
527
+
528
+
529
+ class Gemma3nPlugin(Gemma3Plugin):
530
+ @override
531
+ def process_messages(
532
+ self,
533
+ messages: list[dict[str, str]],
534
+ images: list["ImageInput"],
535
+ videos: list["VideoInput"],
536
+ audios: list["AudioInput"],
537
+ processor: Optional["MMProcessor"],
538
+ ) -> list[dict[str, str]]:
539
+ self._validate_input(processor, images, videos, audios)
540
+ self._validate_messages(messages, images, videos, audios)
541
+ messages = deepcopy(messages)
542
+ boi_token: str = getattr(processor, "boi_token")
543
+ boa_token: str = getattr(processor, "boa_token")
544
+ full_image_sequence: str = getattr(processor, "full_image_sequence")
545
+ full_audio_sequence: str = getattr(processor, "full_audio_sequence")
546
+ image_str = full_image_sequence if self.expand_mm_tokens else boi_token
547
+ audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token
548
+
549
+ for message in messages:
550
+ content = message["content"]
551
+ while IMAGE_PLACEHOLDER in content:
552
+ content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
553
+
554
+ while AUDIO_PLACEHOLDER in content:
555
+ content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
556
+
557
+ message["content"] = content
558
+
559
+ return messages
560
+
561
+
562
+ @dataclass
563
+ class InternVLPlugin(BasePlugin):
564
+ @override
565
+ def _get_mm_inputs(
566
+ self,
567
+ images: list["ImageInput"],
568
+ videos: list["VideoInput"],
569
+ audios: list["AudioInput"],
570
+ processor: "ProcessorMixin",
571
+ **kwargs,
572
+ ) -> dict[str, "torch.Tensor"]:
573
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
574
+ image_processor_kwargs = {}
575
+ if getattr(processor, "crop_to_patches", False):
576
+ image_processor_kwargs.update(
577
+ {
578
+ "crop_to_patches": True,
579
+ "max_patches": 12,
580
+ "min_patches": 1,
581
+ }
582
+ )
583
+
584
+ mm_inputs = {}
585
+ image_video_patches = []
586
+
587
+ if len(images) != 0:
588
+ images = self._regularize_images(
589
+ images,
590
+ image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
591
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
592
+ )["images"]
593
+
594
+ if len(videos) != 0:
595
+ videos = self._regularize_videos(
596
+ videos,
597
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
598
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
599
+ video_fps=getattr(processor, "video_fps", 2.0),
600
+ video_maxlen=getattr(processor, "video_maxlen", 128),
601
+ )["videos"]
602
+
603
+ if len(images) != 0:
604
+ images = make_flat_list_of_images(images)
605
+ image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
606
+ image_num_patches = image_inputs.pop("num_patches")
607
+ image_pixel_values = image_inputs.pop("pixel_values")
608
+ image_num_patches_indices = np.cumsum(image_num_patches)
609
+
610
+ if len(videos) != 0:
611
+ videos = make_batched_videos(videos)
612
+ num_frames_per_video = [len(video) for video in videos]
613
+ patch_indices = np.cumsum(num_frames_per_video)
614
+ image_processor_kwargs["crop_to_patches"] = False
615
+ video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
616
+ video_num_patches = video_inputs.pop("num_patches")
617
+ video_pixel_values = video_inputs.pop("pixel_values")
618
+ video_num_patches_indices = np.cumsum(video_num_patches)
619
+
620
+ # NOT SUPPORT IMAGE VIDEO INTERLEAVED
621
+ if len(images) != 0 and image_pixel_values is not None:
622
+ for i in range(len(images)):
623
+ start_index = image_num_patches_indices[i - 1] if i > 0 else 0
624
+ end_index = image_num_patches_indices[i]
625
+ image_video_patches.append(image_pixel_values[start_index:end_index])
626
+
627
+ if len(videos) != 0 and video_pixel_values is not None:
628
+ patch_indices_with_prefix = [0] + list(patch_indices)
629
+ for i in range(len(videos)):
630
+ current_patch_index = patch_indices_with_prefix[i]
631
+ end_patch_index = patch_indices_with_prefix[i + 1]
632
+ start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
633
+ end_index = video_num_patches_indices[end_patch_index - 1]
634
+ image_video_patches.append(video_pixel_values[start_index:end_index])
635
+
636
+ if len(images) != 0 or len(videos) != 0:
637
+ mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
638
+
639
+ if len(images) != 0:
640
+ mm_inputs.update({"image_num_patches": image_num_patches})
641
+
642
+ if len(videos) != 0:
643
+ mm_inputs.update({"video_patch_indices": patch_indices})
644
+ mm_inputs.update({"video_num_patches": video_num_patches})
645
+
646
+ return mm_inputs
647
+
648
+ @override
649
+ def process_messages(
650
+ self,
651
+ messages: list[dict[str, str]],
652
+ images: list["ImageInput"],
653
+ videos: list["VideoInput"],
654
+ audios: list["AudioInput"],
655
+ processor: Optional["ProcessorMixin"],
656
+ ) -> list[dict[str, str]]:
657
+ self._validate_input(processor, images, videos, audios)
658
+ self._validate_messages(messages, images, videos, audios)
659
+ num_image_tokens, num_video_tokens = 0, 0
660
+ image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
661
+ messages = deepcopy(messages)
662
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
663
+
664
+ image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
665
+ video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
666
+ video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
667
+
668
+ for message in messages:
669
+ content = message["content"]
670
+ while IMAGE_PLACEHOLDER in content:
671
+ content = content.replace(
672
+ IMAGE_PLACEHOLDER,
673
+ f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
674
+ 1,
675
+ )
676
+ num_image_tokens += 1
677
+
678
+ while VIDEO_PLACEHOLDER in content:
679
+ current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
680
+ end_patch_index = video_patch_indices[num_video_tokens]
681
+ num_patches = list(video_num_patches[current_patch_index:end_patch_index])
682
+ video_replaced_prompt = "\n".join(
683
+ f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
684
+ for i in range(len(num_patches))
685
+ )
686
+ content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
687
+ num_video_tokens += 1
688
+
689
+ message["content"] = content
690
+
691
+ return messages
692
+
693
+ @override
694
+ def get_mm_inputs(
695
+ self,
696
+ images: list["ImageInput"],
697
+ videos: list["VideoInput"],
698
+ audios: list["AudioInput"],
699
+ imglens: list[int],
700
+ vidlens: list[int],
701
+ audlens: list[int],
702
+ batch_ids: list[list[int]],
703
+ processor: Optional["ProcessorMixin"],
704
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
705
+ self._validate_input(processor, images, videos, audios)
706
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
707
+ mm_inputs.pop("image_num_patches", None)
708
+ mm_inputs.pop("video_patch_indices", None)
709
+ mm_inputs.pop("video_num_patches", None)
710
+ return mm_inputs
711
+
712
+
713
+ class KimiVLPlugin(BasePlugin):
714
+ @override
715
+ def process_messages(self, messages, images, videos, audios, processor):
716
+ self._validate_input(processor, images, videos, audios)
717
+ self._validate_messages(messages, images, videos, audios)
718
+ if self.expand_mm_tokens:
719
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
720
+ image_grid_hws = mm_inputs.get("image_grid_hws", [])
721
+ else:
722
+ image_grid_hws = [None] * len(images)
723
+
724
+ num_image_tokens = 0
725
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
726
+ merge_length = math.prod(image_processor.merge_kernel_size)
727
+ messages = deepcopy(messages)
728
+ for message in messages:
729
+ content = message["content"]
730
+ while IMAGE_PLACEHOLDER in content:
731
+ image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
732
+ content = content.replace(
733
+ IMAGE_PLACEHOLDER,
734
+ f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
735
+ 1,
736
+ )
737
+ num_image_tokens += 1
738
+
739
+ message["content"] = content
740
+
741
+ return messages
742
+
743
+
744
+ @dataclass
745
+ class Llama4Plugin(BasePlugin):
746
+ @override
747
+ def process_messages(
748
+ self,
749
+ messages: list[dict[str, str]],
750
+ images: list["ImageInput"],
751
+ videos: list["VideoInput"],
752
+ audios: list["AudioInput"],
753
+ processor: Optional["MMProcessor"],
754
+ ) -> list[dict[str, str]]:
755
+ self._validate_input(processor, images, videos, audios)
756
+ self._validate_messages(messages, images, videos, audios)
757
+ if self.expand_mm_tokens:
758
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
759
+ if "pixel_values" in mm_inputs:
760
+ image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
761
+ num_patches_per_chunk = int(
762
+ (image_height // processor.patch_size)
763
+ * (image_width // processor.patch_size)
764
+ // processor.downsample_ratio
765
+ )
766
+ aspect_ratios = mm_inputs.pop("aspect_ratios")
767
+
768
+ num_image_tokens = 0
769
+ messages = deepcopy(messages)
770
+ for message in messages:
771
+ content = message["content"]
772
+ if self.expand_mm_tokens:
773
+ placeholder_count = content.count(IMAGE_PLACEHOLDER)
774
+ prompt_splits = content.split(IMAGE_PLACEHOLDER)
775
+ new_content = []
776
+ for local_image_index, split_part in enumerate(prompt_splits):
777
+ new_content.append(split_part)
778
+ if local_image_index < placeholder_count:
779
+ tokens_for_this_image = processor._prompt_split_image(
780
+ aspect_ratios[num_image_tokens], num_patches_per_chunk
781
+ )
782
+ num_image_tokens += 1
783
+ new_content.append(tokens_for_this_image)
784
+
785
+ content = "".join(new_content)
786
+ else:
787
+ content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
788
+
789
+ message["content"] = content
790
+
791
+ return messages
792
+
793
+ @override
794
+ def get_mm_inputs(
795
+ self,
796
+ images: list["ImageInput"],
797
+ videos: list["VideoInput"],
798
+ audios: list["AudioInput"],
799
+ imglens: list[int],
800
+ vidlens: list[int],
801
+ audlens: list[int],
802
+ batch_ids: list[list[int]],
803
+ processor: Optional["MMProcessor"],
804
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
805
+ self._validate_input(processor, images, videos, audios)
806
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
807
+ mm_inputs.pop("aspect_ratios", None)
808
+ return mm_inputs
809
+
810
+
811
+ @dataclass
812
+ class LlavaPlugin(BasePlugin):
813
+ @override
814
+ def process_messages(
815
+ self,
816
+ messages: list[dict[str, str]],
817
+ images: list["ImageInput"],
818
+ videos: list["VideoInput"],
819
+ audios: list["AudioInput"],
820
+ processor: Optional["MMProcessor"],
821
+ ) -> list[dict[str, str]]:
822
+ self._validate_input(processor, images, videos, audios)
823
+ self._validate_messages(messages, images, videos, audios)
824
+ messages = deepcopy(messages)
825
+ if self.expand_mm_tokens:
826
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
827
+ if "pixel_values" in mm_inputs:
828
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
829
+ image_seqlen = (height // processor.patch_size) * (
830
+ width // processor.patch_size
831
+ ) + processor.num_additional_image_tokens
832
+ if processor.vision_feature_select_strategy == "default":
833
+ image_seqlen -= 1
834
+ else:
835
+ image_seqlen = 1
836
+
837
+ for message in messages:
838
+ content = message["content"]
839
+ while IMAGE_PLACEHOLDER in content:
840
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
841
+
842
+ message["content"] = content.replace("{{image}}", self.image_token)
843
+
844
+ return messages
845
+
846
+
847
+ @dataclass
848
+ class LlavaNextPlugin(BasePlugin):
849
+ @override
850
+ def process_messages(
851
+ self,
852
+ messages: list[dict[str, str]],
853
+ images: list["ImageInput"],
854
+ videos: list["VideoInput"],
855
+ audios: list["AudioInput"],
856
+ processor: Optional["MMProcessor"],
857
+ ) -> list[dict[str, str]]:
858
+ self._validate_input(processor, images, videos, audios)
859
+ self._validate_messages(messages, images, videos, audios)
860
+ num_image_tokens = 0
861
+ messages = deepcopy(messages)
862
+ if self.expand_mm_tokens:
863
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
864
+ if "pixel_values" in mm_inputs:
865
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
866
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
867
+
868
+ for message in messages:
869
+ content = message["content"]
870
+ while IMAGE_PLACEHOLDER in content:
871
+ if self.expand_mm_tokens:
872
+ orig_height, orig_width = next(image_sizes)
873
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
874
+ if processor.vision_feature_select_strategy == "default":
875
+ image_seqlen -= 1
876
+ else:
877
+ image_seqlen = 1
878
+
879
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
880
+ num_image_tokens += 1
881
+
882
+ message["content"] = content.replace("{{image}}", self.image_token)
883
+
884
+ return messages
885
+
886
+
887
+ @dataclass
888
+ class LlavaNextVideoPlugin(BasePlugin):
889
+ @override
890
+ def process_messages(
891
+ self,
892
+ messages: list[dict[str, str]],
893
+ images: list["ImageInput"],
894
+ videos: list["VideoInput"],
895
+ audios: list["AudioInput"],
896
+ processor: Optional["MMProcessor"],
897
+ ) -> list[dict[str, str]]:
898
+ self._validate_input(processor, images, videos, audios)
899
+ self._validate_messages(messages, images, videos, audios)
900
+ messages = deepcopy(messages)
901
+ if self.expand_mm_tokens:
902
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
903
+ if "pixel_values" in mm_inputs:
904
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
905
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
906
+
907
+ for message in messages:
908
+ content = message["content"]
909
+ while IMAGE_PLACEHOLDER in content:
910
+ if self.expand_mm_tokens:
911
+ orig_height, orig_width = next(image_sizes)
912
+ image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
913
+ if processor.vision_feature_select_strategy == "default":
914
+ image_seqlen -= 1
915
+ else:
916
+ image_seqlen = 1
917
+
918
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
919
+
920
+ message["content"] = content.replace("{{image}}", self.image_token)
921
+
922
+ if self.expand_mm_tokens:
923
+ if "pixel_values_videos" in mm_inputs:
924
+ one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
925
+ height, width = get_image_size(one_video[0])
926
+ num_frames = one_video.shape[0] # frame dim is always after batch dim
927
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
928
+ video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
929
+ else:
930
+ video_seqlen = 1
931
+
932
+ for message in messages:
933
+ content = message["content"]
934
+ while VIDEO_PLACEHOLDER in content:
935
+ content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
936
+
937
+ message["content"] = content.replace("{{video}}", self.video_token)
938
+
939
+ return messages
940
+
941
+
942
+ @dataclass
943
+ class MiniCPMVPlugin(BasePlugin):
944
+ @override
945
+ def _get_mm_inputs(
946
+ self,
947
+ images: list["ImageInput"],
948
+ videos: list["VideoInput"],
949
+ audios: list["AudioInput"],
950
+ processor: "MMProcessor",
951
+ **kwargs,
952
+ ) -> dict[str, "torch.Tensor"]:
953
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
954
+ mm_inputs = {}
955
+ if len(images) != 0:
956
+ images = self._regularize_images(
957
+ images,
958
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
959
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
960
+ )["images"]
961
+ if "valid_image_nums_ls" in kwargs:
962
+ valid_image_nums_ls = kwargs["valid_image_nums_ls"]
963
+ new_images = []
964
+ idx = 0
965
+ for valid_image_nums in valid_image_nums_ls:
966
+ new_images.append(images[idx : idx + valid_image_nums])
967
+ idx += valid_image_nums
968
+
969
+ images = new_images
970
+
971
+ image_inputs = image_processor(
972
+ images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
973
+ )
974
+ mm_inputs.update(image_inputs)
975
+
976
+ if len(videos) != 0:
977
+ videos = self._regularize_videos(
978
+ videos,
979
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
980
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
981
+ video_fps=getattr(processor, "video_fps", 2.0),
982
+ video_maxlen=getattr(processor, "video_maxlen", 128),
983
+ )["videos"]
984
+ video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
985
+ mm_inputs.update(video_inputs)
986
+
987
+ if len(audios) != 0:
988
+ audios = self._regularize_audios(
989
+ audios,
990
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
991
+ )["audios"]
992
+ if "valid_audio_nums_ls" in kwargs:
993
+ valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
994
+ audios_ls = []
995
+ idx = 0
996
+ for valid_audio_nums in valid_audio_nums_ls:
997
+ audios_ls.append(audios[idx : idx + valid_audio_nums])
998
+ idx += valid_audio_nums
999
+ else:
1000
+ audios_ls = [audios]
1001
+
1002
+ audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
1003
+ audios_ls,
1004
+ chunk_input=True,
1005
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
1006
+ )
1007
+ audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
1008
+ mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
1009
+ if kwargs.get("ret_phs", False):
1010
+ mm_inputs.update({"audio_phs": audio_phs})
1011
+
1012
+ return mm_inputs
1013
+
1014
+ @override
1015
+ def process_messages(
1016
+ self,
1017
+ messages: list[dict[str, str]],
1018
+ images: list["ImageInput"],
1019
+ videos: list["VideoInput"],
1020
+ audios: list["AudioInput"],
1021
+ processor: Optional["MMProcessor"],
1022
+ ) -> list[dict[str, str]]:
1023
+ self._validate_input(processor, images, videos, audios)
1024
+ self._validate_messages(messages, images, videos, audios)
1025
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
1026
+ messages = deepcopy(messages)
1027
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
1028
+ mm_inputs, audio_inputs = {}, {}
1029
+ if len(images) != 0 and len(videos) != 0:
1030
+ raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
1031
+
1032
+ if len(videos) != 0:
1033
+ max_slice_nums = 2
1034
+ use_image_id = False
1035
+ mm_inputs = self._get_mm_inputs([], videos, [], processor)
1036
+ else:
1037
+ max_slice_nums = image_processor.max_slice_nums
1038
+ use_image_id = image_processor.use_image_id
1039
+
1040
+ for i, message in enumerate(messages):
1041
+ content = message["content"]
1042
+ while IMAGE_PLACEHOLDER in content:
1043
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
1044
+ num_image_tokens += 1
1045
+
1046
+ while VIDEO_PLACEHOLDER in content:
1047
+ video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
1048
+ content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
1049
+ num_video_tokens += 1
1050
+
1051
+ while AUDIO_PLACEHOLDER in content:
1052
+ content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
1053
+ num_audio_tokens += 1
1054
+
1055
+ message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
1056
+ "{{audio}}", "(<audio>./</audio>)"
1057
+ )
1058
+
1059
+ if len(images):
1060
+ mm_inputs = self._get_mm_inputs(images, [], [], processor)
1061
+
1062
+ if len(audios):
1063
+ audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
1064
+
1065
+ if self.expand_mm_tokens and mm_inputs:
1066
+ pattern = "(<image>./</image>)"
1067
+ image_sizes = mm_inputs["image_sizes"]
1068
+ idx = 0
1069
+ for index, message in enumerate(messages):
1070
+ text = message["content"]
1071
+ image_tags = re.findall(pattern, text)
1072
+ text_chunks = text.split(pattern)
1073
+ final_text = ""
1074
+ for i in range(len(image_tags)):
1075
+ final_text = (
1076
+ final_text
1077
+ + text_chunks[i]
1078
+ + image_processor.get_slice_image_placeholder(
1079
+ image_sizes[0][idx], idx, max_slice_nums, use_image_id
1080
+ )
1081
+ )
1082
+ idx += 1
1083
+
1084
+ final_text += text_chunks[-1]
1085
+ messages[index]["content"] = final_text
1086
+
1087
+ if self.expand_mm_tokens and audio_inputs:
1088
+ pattern = "(<audio>./</audio>)"
1089
+ idx = 0
1090
+ for index, message in enumerate(messages):
1091
+ text = message["content"]
1092
+ audio_tags = re.findall(pattern, text)
1093
+ text_chunks = text.split(pattern)
1094
+ final_text = ""
1095
+ for i in range(len(audio_tags)):
1096
+ audio_placeholder = audio_inputs["audio_phs"][0][idx]
1097
+ final_text = final_text + text_chunks[i] + audio_placeholder
1098
+ idx += 1
1099
+
1100
+ final_text += text_chunks[-1]
1101
+ messages[index]["content"] = final_text
1102
+
1103
+ return messages
1104
+
1105
+ @override
1106
+ def get_mm_inputs(
1107
+ self,
1108
+ images: list["ImageInput"],
1109
+ videos: list["VideoInput"],
1110
+ audios: list["AudioInput"],
1111
+ imglens: list[int],
1112
+ vidlens: list[int],
1113
+ audlens: list[int],
1114
+ batch_ids: list[list[int]],
1115
+ processor: Optional["MMProcessor"],
1116
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1117
+ self._validate_input(processor, images, videos, audios)
1118
+ # image bound
1119
+ image_bounds_list = []
1120
+ valid_image_nums_ls = []
1121
+ for i, input_ids in enumerate(batch_ids):
1122
+ input_ids_ = torch.tensor(input_ids)
1123
+ start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
1124
+ input_ids_ == processor.tokenizer.slice_start_id
1125
+ )
1126
+ end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
1127
+ image_start_tokens = torch.where(start_cond)[0]
1128
+ image_start_tokens += 1
1129
+ image_end_tokens = torch.where(end_cond)[0]
1130
+ valid_image_nums_ls.append(imglens[i])
1131
+ image_bounds = torch.hstack(
1132
+ [
1133
+ image_start_tokens.unsqueeze(-1),
1134
+ image_end_tokens.unsqueeze(-1),
1135
+ ]
1136
+ )
1137
+ image_bounds_list.append(image_bounds)
1138
+
1139
+ mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
1140
+ if "tgt_sizes" not in mm_inputs:
1141
+ dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
1142
+ mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
1143
+
1144
+ mm_inputs.update({"image_bound": image_bounds_list})
1145
+
1146
+ if len(audios) > 0:
1147
+ # audio bound
1148
+ audio_bounds_ls = []
1149
+ spk_bounds_ls = []
1150
+ valid_audio_nums_ls = []
1151
+
1152
+ for input_ids, audiolen in zip(batch_ids, audlens):
1153
+ input_ids_ = torch.tensor(input_ids)
1154
+ audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
1155
+ audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
1156
+ assert len(audio_start_idx) == len(audio_end_idx)
1157
+ audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
1158
+ audio_bounds_ls.append(audio_bounds)
1159
+ valid_audio_nums_ls.append(audiolen)
1160
+
1161
+ spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
1162
+ spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
1163
+ assert len(spk_start_idx) == len(spk_end_idx)
1164
+ spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
1165
+ spk_bounds_ls.append(spk_bounds)
1166
+
1167
+ audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
1168
+ mm_inputs.update(audio_inputs)
1169
+ mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
1170
+
1171
+ return mm_inputs
1172
+
1173
+
1174
+ @dataclass
1175
+ class MllamaPlugin(BasePlugin):
1176
+ @override
1177
+ def process_messages(
1178
+ self,
1179
+ messages: list[dict[str, str]],
1180
+ images: list["ImageInput"],
1181
+ videos: list["VideoInput"],
1182
+ audios: list["AudioInput"],
1183
+ processor: Optional["MMProcessor"],
1184
+ ) -> list[dict[str, str]]:
1185
+ self._validate_input(processor, images, videos, audios)
1186
+ self._validate_messages(messages, images, videos, audios)
1187
+ num_image_tokens = 0
1188
+ messages = deepcopy(messages)
1189
+ for message in messages:
1190
+ content = message["content"]
1191
+ num_image_tokens += content.count(IMAGE_PLACEHOLDER)
1192
+ message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
1193
+
1194
+ return messages
1195
+
1196
+ @override
1197
+ def get_mm_inputs(
1198
+ self,
1199
+ images: list["ImageInput"],
1200
+ videos: list["VideoInput"],
1201
+ audios: list["AudioInput"],
1202
+ imglens: list[int],
1203
+ vidlens: list[int],
1204
+ audlens: list[int],
1205
+ batch_ids: list[list[int]],
1206
+ processor: Optional["MMProcessor"],
1207
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1208
+ self._validate_input(processor, images, videos, audios)
1209
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
1210
+ if mm_inputs:
1211
+ num_tiles = mm_inputs.pop("num_tiles")
1212
+ image_token_id: int = getattr(processor, "image_token_id")
1213
+ max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
1214
+ cross_attention_token_mask = [
1215
+ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
1216
+ ]
1217
+ mm_inputs["cross_attention_mask"] = torch.from_numpy(
1218
+ convert_sparse_cross_attention_mask_to_dense(
1219
+ cross_attention_token_mask,
1220
+ num_tiles=num_tiles,
1221
+ max_num_tiles=max_image_tiles,
1222
+ length=max(len(input_ids) for input_ids in batch_ids),
1223
+ )
1224
+ ) # shape: (batch_size, length, max_num_images, max_num_tiles)
1225
+
1226
+ return mm_inputs
1227
+
1228
+
1229
+ @dataclass
1230
+ class PaliGemmaPlugin(BasePlugin):
1231
+ @override
1232
+ def process_messages(
1233
+ self,
1234
+ messages: list[dict[str, str]],
1235
+ images: list["ImageInput"],
1236
+ videos: list["VideoInput"],
1237
+ audios: list["AudioInput"],
1238
+ processor: Optional["MMProcessor"],
1239
+ ) -> list[dict[str, str]]:
1240
+ self._validate_input(processor, images, videos, audios)
1241
+ self._validate_messages(messages, images, videos, audios)
1242
+ num_image_tokens = 0
1243
+ messages = deepcopy(messages)
1244
+ for message in messages:
1245
+ content = message["content"]
1246
+ while IMAGE_PLACEHOLDER in content:
1247
+ content = content.replace(IMAGE_PLACEHOLDER, "", 1)
1248
+ num_image_tokens += 1
1249
+
1250
+ message["content"] = content
1251
+
1252
+ return messages
1253
+
1254
+ @override
1255
+ def process_token_ids(
1256
+ self,
1257
+ input_ids: list[int],
1258
+ labels: Optional[list[int]],
1259
+ images: list["ImageInput"],
1260
+ videos: list["VideoInput"],
1261
+ audios: list["AudioInput"],
1262
+ tokenizer: "PreTrainedTokenizer",
1263
+ processor: Optional["MMProcessor"],
1264
+ ) -> tuple[list[int], Optional[list[int]]]:
1265
+ self._validate_input(processor, images, videos, audios)
1266
+ num_images = len(images)
1267
+ image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
1268
+ image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
1269
+ input_ids = [image_token_id] * num_images * image_seqlen + input_ids
1270
+ if labels is not None:
1271
+ labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
1272
+
1273
+ return input_ids, labels
1274
+
1275
+ @override
1276
+ def get_mm_inputs(
1277
+ self,
1278
+ images: list["ImageInput"],
1279
+ videos: list["VideoInput"],
1280
+ audios: list["AudioInput"],
1281
+ imglens: list[int],
1282
+ vidlens: list[int],
1283
+ audlens: list[int],
1284
+ batch_ids: list[list[int]],
1285
+ processor: Optional["MMProcessor"],
1286
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1287
+ self._validate_input(processor, images, videos, audios)
1288
+ seqlens = [len(input_ids) for input_ids in batch_ids]
1289
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1290
+ mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
1291
+ return mm_inputs
1292
+
1293
+
1294
+ @dataclass
1295
+ class PixtralPlugin(BasePlugin):
1296
+ @override
1297
+ def process_messages(
1298
+ self,
1299
+ messages: list[dict[str, str]],
1300
+ images: list["ImageInput"],
1301
+ videos: list["VideoInput"],
1302
+ audios: list["AudioInput"],
1303
+ processor: Optional["MMProcessor"],
1304
+ ) -> list[dict[str, str]]:
1305
+ self._validate_input(processor, images, videos, audios)
1306
+ self._validate_messages(messages, images, videos, audios)
1307
+ messages = deepcopy(messages)
1308
+ if self.expand_mm_tokens:
1309
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1310
+ if "pixel_values" in mm_inputs:
1311
+ # BC for transformers < 4.49.0
1312
+ if isinstance(mm_inputs["image_sizes"], list):
1313
+ image_sizes = iter(mm_inputs["image_sizes"][0])
1314
+ else:
1315
+ image_sizes = iter(mm_inputs["image_sizes"].tolist())
1316
+
1317
+ image_break_token: str = getattr(processor, "image_break_token")
1318
+ image_end_token: str = getattr(processor, "image_end_token")
1319
+
1320
+ for message in messages:
1321
+ content = message["content"]
1322
+ while IMAGE_PLACEHOLDER in content:
1323
+ if self.expand_mm_tokens:
1324
+ patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1)
1325
+ height, width = next(image_sizes)
1326
+ num_height_tokens = height // patch_size
1327
+ num_width_tokens = width // patch_size
1328
+ replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
1329
+ replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
1330
+ replace_tokens[-1] = image_end_token
1331
+ replace_str = "".join(replace_tokens)
1332
+ else:
1333
+ replace_str = self.image_token
1334
+
1335
+ content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
1336
+
1337
+ message["content"] = content
1338
+
1339
+ return messages
1340
+
1341
+ @override
1342
+ def get_mm_inputs(
1343
+ self,
1344
+ images: list["ImageInput"],
1345
+ videos: list["VideoInput"],
1346
+ audios: list["AudioInput"],
1347
+ imglens: list[int],
1348
+ vidlens: list[int],
1349
+ audlens: list[int],
1350
+ batch_ids: list[list[int]],
1351
+ processor: Optional["MMProcessor"],
1352
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1353
+ self._validate_input(processor, images, videos, audios)
1354
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1355
+ # ref to this commit https://github.com/huggingface/transformers/pull/35122
1356
+ # after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
1357
+ # it can be passed into `LlavaConditionalGeneration` as a parameter.
1358
+ if not is_transformers_version_greater_than("4.49.0"):
1359
+ mm_inputs.pop("image_sizes", None)
1360
+ return mm_inputs
1361
+
1362
+
1363
+ @dataclass
1364
+ class Qwen2AudioPlugin(BasePlugin):
1365
+ @override
1366
+ def process_messages(
1367
+ self,
1368
+ messages: list[dict[str, str]],
1369
+ images: list["ImageInput"],
1370
+ videos: list["VideoInput"],
1371
+ audios: list["AudioInput"],
1372
+ processor: Optional["MMProcessor"],
1373
+ ) -> list[dict[str, str]]:
1374
+ self._validate_input(processor, images, videos, audios)
1375
+ self._validate_messages(messages, images, videos, audios)
1376
+ bos_token: str = getattr(processor, "audio_bos_token")
1377
+ eos_token: str = getattr(processor, "audio_eos_token")
1378
+ messages = deepcopy(messages)
1379
+ if self.expand_mm_tokens:
1380
+ mm_inputs = self._get_mm_inputs([], [], audios, processor)
1381
+ if "feature_attention_mask" in mm_inputs:
1382
+ audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
1383
+
1384
+ for message in messages:
1385
+ content = message["content"]
1386
+ while AUDIO_PLACEHOLDER in content:
1387
+ if self.expand_mm_tokens:
1388
+ audio_length = audio_lengths.pop(0)
1389
+ input_length = (audio_length - 1) // 2 + 1
1390
+ audio_seqlen = (input_length - 2) // 2 + 1
1391
+ else:
1392
+ audio_seqlen = 1
1393
+
1394
+ content = content.replace(
1395
+ AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
1396
+ )
1397
+
1398
+ message["content"] = content
1399
+
1400
+ return messages
1401
+
1402
+ @override
1403
+ def get_mm_inputs(
1404
+ self,
1405
+ images: list["ImageInput"],
1406
+ videos: list["VideoInput"],
1407
+ audios: list["AudioInput"],
1408
+ imglens: list[int],
1409
+ vidlens: list[int],
1410
+ audlens: list[int],
1411
+ batch_ids: list[list[int]],
1412
+ processor: Optional["MMProcessor"],
1413
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1414
+ self._validate_input(processor, images, videos, audios)
1415
+ return self._get_mm_inputs(images, videos, audios, processor)
1416
+
1417
+
1418
+ @dataclass
1419
+ class Qwen2VLPlugin(BasePlugin):
1420
+ vision_bos_token: str = "<|vision_start|>"
1421
+ vision_eos_token: str = "<|vision_end|>"
1422
+
1423
+ @override
1424
+ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
1425
+ image = super()._preprocess_image(image, **kwargs)
1426
+ if min(image.width, image.height) < 28:
1427
+ width, height = max(image.width, 28), max(image.height, 28)
1428
+ image = image.resize((width, height))
1429
+
1430
+ if image.width / image.height > 200:
1431
+ width, height = image.height * 180, image.height
1432
+ image = image.resize((width, height))
1433
+
1434
+ if image.height / image.width > 200:
1435
+ width, height = image.width, image.width * 180
1436
+ image = image.resize((width, height))
1437
+
1438
+ return image
1439
+
1440
+ @override
1441
+ def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
1442
+ results, fps_per_video, durations = [], [], []
1443
+ for video in videos:
1444
+ frames: list[ImageObject] = []
1445
+ if _check_video_is_nested_images(video):
1446
+ for frame in video:
1447
+ if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
1448
+ raise ValueError("Invalid image found in video frames.")
1449
+
1450
+ frames = video
1451
+ fps_per_video.append(kwargs.get("video_fps", 2.0))
1452
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
1453
+ else:
1454
+ container = av.open(video, "r")
1455
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
1456
+ sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
1457
+ container.seek(0)
1458
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
1459
+ if frame_idx in sample_indices:
1460
+ frames.append(frame.to_image())
1461
+
1462
+ if video_stream.duration is None:
1463
+ fps_per_video.append(kwargs.get("video_fps", 2.0))
1464
+ durations.append(len(frames) / kwargs.get("video_fps", 2.0))
1465
+ else:
1466
+ fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
1467
+ durations.append(float(video_stream.duration * video_stream.time_base))
1468
+
1469
+ if len(frames) % 2 != 0:
1470
+ frames.append(frames[-1])
1471
+
1472
+ frames = self._regularize_images(frames, **kwargs)["images"]
1473
+ results.append(frames)
1474
+
1475
+ return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
1476
+
1477
+ @override
1478
+ def _get_mm_inputs(
1479
+ self,
1480
+ images: list["ImageInput"],
1481
+ videos: list["VideoInput"],
1482
+ audios: list["AudioInput"],
1483
+ processor: "MMProcessor",
1484
+ ) -> dict[str, "torch.Tensor"]:
1485
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1486
+ video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
1487
+ mm_inputs = {}
1488
+ if len(images) != 0:
1489
+ images = self._regularize_images(
1490
+ images,
1491
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
1492
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
1493
+ )["images"]
1494
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
1495
+
1496
+ if len(videos) != 0:
1497
+ video_data = self._regularize_videos(
1498
+ videos,
1499
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
1500
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
1501
+ video_fps=getattr(processor, "video_fps", 2.0),
1502
+ video_maxlen=getattr(processor, "video_maxlen", 128),
1503
+ )
1504
+ mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
1505
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1506
+ if "second_per_grid_ts" in processor.model_input_names:
1507
+ mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
1508
+
1509
+ return mm_inputs
1510
+
1511
+ @override
1512
+ def process_messages(
1513
+ self,
1514
+ messages: list[dict[str, str]],
1515
+ images: list["ImageInput"],
1516
+ videos: list["VideoInput"],
1517
+ audios: list["AudioInput"],
1518
+ processor: Optional["MMProcessor"],
1519
+ ) -> list[dict[str, str]]:
1520
+ self._validate_input(processor, images, videos, audios)
1521
+ self._validate_messages(messages, images, videos, audios)
1522
+ num_image_tokens, num_video_tokens = 0, 0
1523
+ messages = deepcopy(messages)
1524
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
1525
+
1526
+ merge_length: int = getattr(image_processor, "merge_size") ** 2
1527
+ if self.expand_mm_tokens:
1528
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1529
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
1530
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
1531
+ else:
1532
+ image_grid_thw = [None] * len(images)
1533
+ video_grid_thw = [None] * len(videos)
1534
+
1535
+ for message in messages:
1536
+ content = message["content"]
1537
+ while IMAGE_PLACEHOLDER in content:
1538
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
1539
+ content = content.replace(
1540
+ IMAGE_PLACEHOLDER,
1541
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1542
+ 1,
1543
+ )
1544
+ num_image_tokens += 1
1545
+
1546
+ while VIDEO_PLACEHOLDER in content:
1547
+ video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
1548
+ content = content.replace(
1549
+ VIDEO_PLACEHOLDER,
1550
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1551
+ 1,
1552
+ )
1553
+ num_video_tokens += 1
1554
+
1555
+ message["content"] = content
1556
+
1557
+ return messages
1558
+
1559
+
1560
+ @dataclass
1561
+ class Qwen3VLPlugin(Qwen2VLPlugin):
1562
+ @override
1563
+ def _get_mm_inputs(
1564
+ self,
1565
+ images: list["ImageInput"],
1566
+ videos: list["VideoInput"],
1567
+ audios: list["AudioInput"],
1568
+ processor: "MMProcessor",
1569
+ ) -> dict[str, "torch.Tensor"]:
1570
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1571
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
1572
+ mm_inputs = {}
1573
+ if len(images) != 0:
1574
+ images = self._regularize_images(
1575
+ images,
1576
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
1577
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
1578
+ )["images"]
1579
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
1580
+
1581
+ if len(videos) != 0:
1582
+ videos = self._regularize_videos(
1583
+ videos,
1584
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
1585
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
1586
+ video_fps=getattr(processor, "video_fps", 2.0),
1587
+ video_maxlen=getattr(processor, "video_maxlen", 128),
1588
+ )
1589
+ video_metadata = [
1590
+ {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
1591
+ for video, duration in zip(videos["videos"], videos["durations"])
1592
+ ]
1593
+ mm_inputs.update(
1594
+ video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
1595
+ )
1596
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1597
+ if "second_per_grid_ts" in processor.model_input_names:
1598
+ mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
1599
+
1600
+ return mm_inputs
1601
+
1602
+ @override
1603
+ def process_messages(
1604
+ self,
1605
+ messages: list[dict[str, str]],
1606
+ images: list["ImageInput"],
1607
+ videos: list["VideoInput"],
1608
+ audios: list["AudioInput"],
1609
+ processor: Optional["MMProcessor"],
1610
+ ) -> list[dict[str, str]]:
1611
+ self._validate_input(processor, images, videos, audios)
1612
+ self._validate_messages(messages, images, videos, audios)
1613
+ num_image_tokens, num_video_tokens = 0, 0
1614
+ messages = deepcopy(messages)
1615
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
1616
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor")
1617
+
1618
+ image_merge_length: int = getattr(image_processor, "merge_size") ** 2
1619
+ video_merge_length: int = getattr(video_processor, "merge_size") ** 2
1620
+ if self.expand_mm_tokens:
1621
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1622
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
1623
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
1624
+ num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
1625
+ video_metadata = mm_inputs.get("video_metadata", {})
1626
+
1627
+ else:
1628
+ image_grid_thw = [None] * len(images)
1629
+ video_grid_thw = [None] * len(videos)
1630
+ num_frames = 0
1631
+ timestamps = [0]
1632
+
1633
+ for idx, message in enumerate(messages):
1634
+ content = message["content"]
1635
+ while IMAGE_PLACEHOLDER in content:
1636
+ image_seqlen = (
1637
+ image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
1638
+ )
1639
+ content = content.replace(
1640
+ IMAGE_PLACEHOLDER,
1641
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1642
+ 1,
1643
+ )
1644
+ num_image_tokens += 1
1645
+
1646
+ while VIDEO_PLACEHOLDER in content:
1647
+ if self.expand_mm_tokens:
1648
+ metadata = video_metadata[idx]
1649
+ timestamps = processor._calculate_timestamps(
1650
+ metadata.frames_indices,
1651
+ metadata.fps,
1652
+ video_processor.merge_size,
1653
+ )
1654
+ video_structure = ""
1655
+ for frame_index in range(num_frames):
1656
+ video_seqlen = (
1657
+ video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
1658
+ if self.expand_mm_tokens
1659
+ else 1
1660
+ )
1661
+ timestamp_sec = timestamps[frame_index]
1662
+ frame_structure = (
1663
+ f"<{timestamp_sec:.1f} seconds>"
1664
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
1665
+ )
1666
+ video_structure += frame_structure
1667
+ else:
1668
+ video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
1669
+
1670
+ content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
1671
+ num_video_tokens += 1
1672
+
1673
+ message["content"] = content
1674
+
1675
+ return messages
1676
+
1677
+
1678
+ @dataclass
1679
+ class GLM4VPlugin(Qwen2VLPlugin):
1680
+ @override
1681
+ def _get_mm_inputs(
1682
+ self,
1683
+ images: list["ImageInput"],
1684
+ videos: list["VideoInput"],
1685
+ audios: list["AudioInput"],
1686
+ processor: "MMProcessor",
1687
+ ) -> dict[str, "torch.Tensor"]:
1688
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1689
+ video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
1690
+ mm_inputs = {}
1691
+ if len(images) != 0:
1692
+ images = self._regularize_images(
1693
+ images,
1694
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
1695
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
1696
+ )["images"]
1697
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
1698
+
1699
+ if len(videos) != 0:
1700
+ video_data = self._regularize_videos(
1701
+ videos,
1702
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
1703
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
1704
+ video_fps=getattr(processor, "video_fps", 2.0),
1705
+ video_maxlen=getattr(processor, "video_maxlen", 128),
1706
+ )
1707
+ # prepare video metadata
1708
+ video_metadata = [
1709
+ {"fps": 2, "duration": duration, "total_frames": len(video)}
1710
+ for video, duration in zip(video_data["videos"], video_data["durations"])
1711
+ ]
1712
+ mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
1713
+
1714
+ return mm_inputs
1715
+
1716
+ @override
1717
+ def process_messages(
1718
+ self,
1719
+ messages: list[dict[str, str]],
1720
+ images: list["ImageInput"],
1721
+ videos: list["VideoInput"],
1722
+ audios: list["AudioInput"],
1723
+ processor: Optional["MMProcessor"],
1724
+ ) -> list[dict[str, str]]:
1725
+ self._validate_input(processor, images, videos, audios)
1726
+ self._validate_messages(messages, images, videos, audios)
1727
+ num_image_tokens, num_video_tokens = 0, 0
1728
+ messages = deepcopy(messages)
1729
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
1730
+
1731
+ merge_length: int = getattr(image_processor, "merge_size") ** 2
1732
+ if self.expand_mm_tokens:
1733
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1734
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
1735
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
1736
+ num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
1737
+ timestamps = mm_inputs.get("timestamps", [])
1738
+
1739
+ if hasattr(timestamps, "tolist"):
1740
+ timestamps = timestamps.tolist()
1741
+
1742
+ if not timestamps:
1743
+ timestamps_list = []
1744
+ elif isinstance(timestamps[0], list):
1745
+ timestamps_list = timestamps[0]
1746
+ else:
1747
+ timestamps_list = timestamps
1748
+
1749
+ unique_timestamps = timestamps_list.copy()
1750
+ selected_timestamps = unique_timestamps[:num_frames]
1751
+ while len(selected_timestamps) < num_frames:
1752
+ selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
1753
+
1754
+ else:
1755
+ image_grid_thw = [None] * len(images)
1756
+ video_grid_thw = [None] * len(videos)
1757
+ num_frames = 0
1758
+ selected_timestamps = [0]
1759
+
1760
+ for message in messages:
1761
+ content = message["content"]
1762
+ while IMAGE_PLACEHOLDER in content:
1763
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
1764
+ content = content.replace(
1765
+ IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
1766
+ )
1767
+ num_image_tokens += 1
1768
+
1769
+ while VIDEO_PLACEHOLDER in content:
1770
+ video_structure = ""
1771
+ for frame_index in range(num_frames):
1772
+ video_seqlen = (
1773
+ video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
1774
+ )
1775
+ timestamp_sec = selected_timestamps[frame_index]
1776
+ frame_structure = (
1777
+ f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
1778
+ )
1779
+ video_structure += frame_structure
1780
+
1781
+ if not self.expand_mm_tokens:
1782
+ video_structure = self.video_token
1783
+
1784
+ content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
1785
+ num_video_tokens += 1
1786
+
1787
+ message["content"] = content
1788
+
1789
+ return messages
1790
+
1791
+ @override
1792
+ def get_mm_inputs(
1793
+ self,
1794
+ images: list["ImageInput"],
1795
+ videos: list["VideoInput"],
1796
+ audios: list["AudioInput"],
1797
+ imglens: list[int],
1798
+ vidlens: list[int],
1799
+ audlens: list[int],
1800
+ batch_ids: list[list[int]],
1801
+ processor: Optional["ProcessorMixin"],
1802
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
1803
+ self._validate_input(processor, images, videos, audios)
1804
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1805
+ mm_inputs.pop("timestamps", None)
1806
+ return mm_inputs
1807
+
1808
+
1809
+ @dataclass
1810
+ class Qwen2OmniPlugin(Qwen2VLPlugin):
1811
+ audio_bos_token: str = "<|audio_start|>"
1812
+ audio_eos_token: str = "<|audio_end|>"
1813
+
1814
+ @override
1815
+ def _get_mm_inputs(
1816
+ self,
1817
+ images: list["ImageInput"],
1818
+ videos: list["VideoInput"],
1819
+ audios: list["AudioInput"],
1820
+ processor: "MMProcessor",
1821
+ ) -> dict[str, "torch.Tensor"]:
1822
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1823
+ video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
1824
+ feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
1825
+ mm_inputs = {}
1826
+ if len(images) != 0:
1827
+ images = self._regularize_images(
1828
+ images,
1829
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
1830
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
1831
+ )["images"]
1832
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
1833
+
1834
+ if len(videos) != 0:
1835
+ video_dict = self._regularize_videos(
1836
+ videos,
1837
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
1838
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
1839
+ video_fps=getattr(processor, "video_fps", 2.0),
1840
+ video_maxlen=getattr(processor, "video_maxlen", 128),
1841
+ )
1842
+ mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
1843
+ temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1844
+ mm_inputs["video_second_per_grid"] = torch.tensor(
1845
+ [temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
1846
+ )
1847
+
1848
+ if len(audios) != 0:
1849
+ audios = self._regularize_audios(
1850
+ audios,
1851
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
1852
+ )["audios"]
1853
+ mm_inputs.update(
1854
+ feature_extractor(
1855
+ audios,
1856
+ sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
1857
+ return_attention_mask=True,
1858
+ padding="max_length",
1859
+ return_tensors="pt",
1860
+ )
1861
+ )
1862
+ mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
1863
+
1864
+ return mm_inputs
1865
+
1866
+ @override
1867
+ def process_messages(
1868
+ self,
1869
+ messages: list[dict[str, str]],
1870
+ images: list["ImageInput"],
1871
+ videos: list["VideoInput"],
1872
+ audios: list["AudioInput"],
1873
+ processor: Optional["MMProcessor"],
1874
+ ) -> list[dict[str, str]]:
1875
+ self._validate_input(processor, images, videos, audios)
1876
+ self._validate_messages(messages, images, videos, audios)
1877
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
1878
+ messages = deepcopy(messages)
1879
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
1880
+
1881
+ merge_length = processor.image_processor.merge_size**2
1882
+ use_audio_in_video = getattr(processor, "use_audio_in_video", False)
1883
+ if self.expand_mm_tokens:
1884
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1885
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
1886
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
1887
+ if "feature_attention_mask" in mm_inputs:
1888
+ if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
1889
+ input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
1890
+ input_lengths_leave = input_lengths % 100
1891
+ feature_lengths = (input_lengths_leave - 1) // 2 + 1
1892
+ audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
1893
+ else:
1894
+ input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
1895
+ audio_lengths = (input_lengths - 2) // 2 + 1
1896
+ else:
1897
+ mm_inputs = {}
1898
+ image_grid_thw = [None] * len(images)
1899
+ video_grid_thw = [None] * len(videos)
1900
+ audio_lengths = [None] * len(audios)
1901
+
1902
+ for message in messages:
1903
+ content = message["content"]
1904
+ while IMAGE_PLACEHOLDER in content:
1905
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
1906
+ content = content.replace(
1907
+ IMAGE_PLACEHOLDER,
1908
+ f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
1909
+ 1,
1910
+ )
1911
+ num_image_tokens += 1
1912
+
1913
+ if (
1914
+ use_audio_in_video and len(audios) and len(videos)
1915
+ ): # if use the audio of video # deal video token and audio token togather
1916
+ if len(videos) != len(audios):
1917
+ raise ValueError(
1918
+ f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video."
1919
+ )
1920
+
1921
+ while VIDEO_PLACEHOLDER in content:
1922
+ video_pos = content.find(VIDEO_PLACEHOLDER)
1923
+ audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
1924
+ if audio_pos == -1 or audio_pos < video_pos:
1925
+ raise ValueError(
1926
+ f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
1927
+ )
1928
+
1929
+ audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
1930
+ video_t_index = (
1931
+ torch.arange(video_grid_thw[num_video_tokens][0])
1932
+ .view(-1, 1, 1)
1933
+ .expand(
1934
+ -1,
1935
+ video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
1936
+ video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
1937
+ )
1938
+ .flatten()
1939
+ * mm_inputs["video_second_per_grid"][num_video_tokens]
1940
+ * 25 # FIXME hardcode of position_id_per_seconds=25
1941
+ ).long()
1942
+ t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
1943
+ video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
1944
+ audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
1945
+ placeholder_string = ""
1946
+ placeholder_string += self.vision_bos_token + self.audio_bos_token
1947
+ for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
1948
+ video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
1949
+ audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
1950
+ if video_chunk_index is not None:
1951
+ placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
1952
+
1953
+ if audio_chunk_index is not None:
1954
+ placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
1955
+
1956
+ placeholder_string += self.audio_eos_token + self.vision_eos_token
1957
+ content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
1958
+ content = content.replace(AUDIO_PLACEHOLDER, "", 1)
1959
+ num_audio_tokens += 1
1960
+ num_video_tokens += 1
1961
+ else:
1962
+ while AUDIO_PLACEHOLDER in content:
1963
+ audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
1964
+ content = content.replace(
1965
+ AUDIO_PLACEHOLDER,
1966
+ f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
1967
+ 1,
1968
+ )
1969
+ num_audio_tokens += 1
1970
+
1971
+ while VIDEO_PLACEHOLDER in content:
1972
+ video_seqlen = (
1973
+ video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
1974
+ )
1975
+ content = content.replace(
1976
+ VIDEO_PLACEHOLDER,
1977
+ f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
1978
+ 1,
1979
+ )
1980
+ num_video_tokens += 1
1981
+
1982
+ message["content"] = content
1983
+
1984
+ return messages
1985
+
1986
+
1987
+ @dataclass
1988
+ class VideoLlavaPlugin(BasePlugin):
1989
+ @override
1990
+ def process_messages(
1991
+ self,
1992
+ messages: list[dict[str, str]],
1993
+ images: list["ImageInput"],
1994
+ videos: list["VideoInput"],
1995
+ audios: list["AudioInput"],
1996
+ processor: Optional["MMProcessor"],
1997
+ ) -> list[dict[str, str]]:
1998
+ self._validate_input(processor, images, videos, audios)
1999
+ self._validate_messages(messages, images, videos, audios)
2000
+ num_image_tokens, num_video_tokens = 0, 0
2001
+ messages = deepcopy(messages)
2002
+ num_frames = 0
2003
+ if self.expand_mm_tokens:
2004
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
2005
+ if "pixel_values_images" in mm_inputs:
2006
+ height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
2007
+ num_frames = 1
2008
+
2009
+ if "pixel_values_videos" in mm_inputs:
2010
+ one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
2011
+ height, width = get_image_size(one_video[0])
2012
+ num_frames = one_video.shape[0] # frame dim is always after batch dim
2013
+
2014
+ if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
2015
+ image_seqlen = (height // processor.patch_size) * (
2016
+ width // processor.patch_size
2017
+ ) + processor.num_additional_image_tokens
2018
+ video_seqlen = image_seqlen * num_frames
2019
+ if processor.vision_feature_select_strategy == "default":
2020
+ image_seqlen -= 1
2021
+ else:
2022
+ image_seqlen, video_seqlen = 1, 1
2023
+
2024
+ for message in messages:
2025
+ content = message["content"]
2026
+ while IMAGE_PLACEHOLDER in content:
2027
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
2028
+ num_image_tokens += 1
2029
+
2030
+ while VIDEO_PLACEHOLDER in content:
2031
+ content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
2032
+ num_video_tokens += 1
2033
+
2034
+ content = content.replace("{{image}}", self.image_token)
2035
+ message["content"] = content.replace("{{video}}", self.video_token)
2036
+
2037
+ return messages
2038
+
2039
+
2040
+ PLUGINS = {
2041
+ "base": BasePlugin,
2042
+ "gemma3": Gemma3Plugin,
2043
+ "glm4v": GLM4VPlugin,
2044
+ "gemma3n": Gemma3nPlugin,
2045
+ "intern_vl": InternVLPlugin,
2046
+ "kimi_vl": KimiVLPlugin,
2047
+ "llama4": Llama4Plugin,
2048
+ "llava": LlavaPlugin,
2049
+ "llava_next": LlavaNextPlugin,
2050
+ "llava_next_video": LlavaNextVideoPlugin,
2051
+ "minicpm_v": MiniCPMVPlugin,
2052
+ "mllama": MllamaPlugin,
2053
+ "paligemma": PaliGemmaPlugin,
2054
+ "pixtral": PixtralPlugin,
2055
+ "qwen2_audio": Qwen2AudioPlugin,
2056
+ "qwen2_omni": Qwen2OmniPlugin,
2057
+ "qwen2_vl": Qwen2VLPlugin,
2058
+ "qwen3_vl": Qwen3VLPlugin,
2059
+ "video_llava": VideoLlavaPlugin,
2060
+ }
2061
+
2062
+
2063
+ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
2064
+ r"""Register a multimodal plugin."""
2065
+ if name in PLUGINS:
2066
+ raise ValueError(f"Multimodal plugin {name} already exists.")
2067
+
2068
+ PLUGINS[name] = plugin_class
2069
+
2070
+
2071
+ def get_mm_plugin(
2072
+ name: str,
2073
+ image_token: Optional[str] = None,
2074
+ video_token: Optional[str] = None,
2075
+ audio_token: Optional[str] = None,
2076
+ **kwargs,
2077
+ ) -> "BasePlugin":
2078
+ r"""Get plugin for multimodal inputs."""
2079
+ if name not in PLUGINS:
2080
+ raise ValueError(f"Multimodal plugin `{name}` not found.")
2081
+
2082
+ return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
llamafactory/data/parser.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from dataclasses import dataclass
18
+ from typing import Any, Literal, Optional, Union
19
+
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ from ..extras.constants import DATA_CONFIG
23
+ from ..extras.misc import use_modelscope, use_openmind
24
+
25
+
26
+ @dataclass
27
+ class DatasetAttr:
28
+ r"""Dataset attributes."""
29
+
30
+ # basic configs
31
+ load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
32
+ dataset_name: str
33
+ formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
34
+ ranking: bool = False
35
+ # extra configs
36
+ subset: Optional[str] = None
37
+ split: str = "train"
38
+ folder: Optional[str] = None
39
+ num_samples: Optional[int] = None
40
+ # common columns
41
+ system: Optional[str] = None
42
+ tools: Optional[str] = None
43
+ images: Optional[str] = None
44
+ videos: Optional[str] = None
45
+ audios: Optional[str] = None
46
+ # dpo columns
47
+ chosen: Optional[str] = None
48
+ rejected: Optional[str] = None
49
+ kto_tag: Optional[str] = None
50
+ # alpaca columns
51
+ prompt: Optional[str] = "instruction"
52
+ query: Optional[str] = "input"
53
+ response: Optional[str] = "output"
54
+ history: Optional[str] = None
55
+ # sharegpt columns
56
+ messages: Optional[str] = "conversations"
57
+ # sharegpt tags
58
+ role_tag: Optional[str] = "from"
59
+ content_tag: Optional[str] = "value"
60
+ user_tag: Optional[str] = "human"
61
+ assistant_tag: Optional[str] = "gpt"
62
+ observation_tag: Optional[str] = "observation"
63
+ function_tag: Optional[str] = "function_call"
64
+ system_tag: Optional[str] = "system"
65
+
66
+ def __repr__(self) -> str:
67
+ return self.dataset_name
68
+
69
+ def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
70
+ setattr(self, key, obj.get(key, default))
71
+
72
+ def join(self, attr: dict[str, Any]) -> None:
73
+ self.set_attr("formatting", attr, default="alpaca")
74
+ self.set_attr("ranking", attr, default=False)
75
+ self.set_attr("subset", attr)
76
+ self.set_attr("split", attr, default="train")
77
+ self.set_attr("folder", attr)
78
+ self.set_attr("num_samples", attr)
79
+
80
+ if "columns" in attr:
81
+ column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
82
+ column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
83
+ for column_name in column_names:
84
+ self.set_attr(column_name, attr["columns"])
85
+
86
+ if "tags" in attr:
87
+ tag_names = ["role_tag", "content_tag"]
88
+ tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
89
+ for tag in tag_names:
90
+ self.set_attr(tag, attr["tags"])
91
+
92
+
93
+ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
94
+ r"""Get the attributes of the datasets."""
95
+ if dataset_names is None:
96
+ dataset_names = []
97
+
98
+ if isinstance(dataset_dir, dict):
99
+ dataset_info = dataset_dir
100
+ elif dataset_dir == "ONLINE":
101
+ dataset_info = None
102
+ else:
103
+ if dataset_dir.startswith("REMOTE:"):
104
+ config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
105
+ else:
106
+ config_path = os.path.join(dataset_dir, DATA_CONFIG)
107
+
108
+ try:
109
+ with open(config_path) as f:
110
+ dataset_info = json.load(f)
111
+ except Exception as err:
112
+ if len(dataset_names) != 0:
113
+ raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
114
+
115
+ dataset_info = None
116
+
117
+ dataset_list: list[DatasetAttr] = []
118
+ for name in dataset_names:
119
+ if dataset_info is None: # dataset_dir is ONLINE
120
+ load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
121
+ dataset_attr = DatasetAttr(load_from, dataset_name=name)
122
+ dataset_list.append(dataset_attr)
123
+ continue
124
+
125
+ if name not in dataset_info:
126
+ raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
127
+
128
+ has_hf_url = "hf_hub_url" in dataset_info[name]
129
+ has_ms_url = "ms_hub_url" in dataset_info[name]
130
+ has_om_url = "om_hub_url" in dataset_info[name]
131
+
132
+ if has_hf_url or has_ms_url or has_om_url:
133
+ if has_ms_url and (use_modelscope() or not has_hf_url):
134
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
135
+ elif has_om_url and (use_openmind() or not has_hf_url):
136
+ dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
137
+ else:
138
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
139
+ elif "script_url" in dataset_info[name]:
140
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
141
+ elif "cloud_file_name" in dataset_info[name]:
142
+ dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
143
+ else:
144
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
145
+
146
+ dataset_attr.join(dataset_info[name])
147
+ dataset_list.append(dataset_attr)
148
+
149
+ return dataset_list
llamafactory/data/processor/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .feedback import FeedbackDatasetProcessor
16
+ from .pairwise import PairwiseDatasetProcessor
17
+ from .pretrain import PretrainDatasetProcessor
18
+ from .processor_utils import DatasetProcessor
19
+ from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
20
+ from .unsupervised import UnsupervisedDatasetProcessor
21
+
22
+
23
+ __all__ = [
24
+ "DatasetProcessor",
25
+ "FeedbackDatasetProcessor",
26
+ "PackedSupervisedDatasetProcessor",
27
+ "PairwiseDatasetProcessor",
28
+ "PretrainDatasetProcessor",
29
+ "SupervisedDatasetProcessor",
30
+ "UnsupervisedDatasetProcessor",
31
+ ]
llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (630 Bytes). View file
 
llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc ADDED
Binary file (7.39 kB). View file