Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +3 -9
- api.py +33 -0
- llamafactory.egg-info/PKG-INFO +1124 -0
- llamafactory.egg-info/SOURCES.txt +178 -0
- llamafactory.egg-info/dependency_links.txt +1 -0
- llamafactory.egg-info/entry_points.txt +3 -0
- llamafactory.egg-info/requires.txt +125 -0
- llamafactory.egg-info/top_level.txt +1 -0
- llamafactory/__init__.py +31 -0
- llamafactory/__pycache__/__init__.cpython-312.pyc +0 -0
- llamafactory/__pycache__/cli.cpython-312.pyc +0 -0
- llamafactory/__pycache__/launcher.cpython-312.pyc +0 -0
- llamafactory/api/__init__.py +0 -0
- llamafactory/api/app.py +133 -0
- llamafactory/api/chat.py +291 -0
- llamafactory/api/common.py +96 -0
- llamafactory/api/protocol.py +157 -0
- llamafactory/chat/__init__.py +19 -0
- llamafactory/chat/__pycache__/__init__.cpython-312.pyc +0 -0
- llamafactory/chat/__pycache__/base_engine.cpython-312.pyc +0 -0
- llamafactory/chat/__pycache__/chat_model.cpython-312.pyc +0 -0
- llamafactory/chat/base_engine.py +98 -0
- llamafactory/chat/chat_model.py +210 -0
- llamafactory/chat/hf_engine.py +412 -0
- llamafactory/chat/kt_engine.py +284 -0
- llamafactory/chat/sglang_engine.py +289 -0
- llamafactory/chat/vllm_engine.py +263 -0
- llamafactory/cli.py +31 -0
- llamafactory/data/__init__.py +37 -0
- llamafactory/data/__pycache__/__init__.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/collator.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/converter.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/data_utils.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/formatter.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/loader.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/parser.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/template.cpython-312.pyc +0 -0
- llamafactory/data/__pycache__/tool_utils.cpython-312.pyc +0 -0
- llamafactory/data/collator.py +331 -0
- llamafactory/data/converter.py +425 -0
- llamafactory/data/data_utils.py +190 -0
- llamafactory/data/formatter.py +145 -0
- llamafactory/data/loader.py +334 -0
- llamafactory/data/mm_plugin.py +2082 -0
- llamafactory/data/parser.py +149 -0
- llamafactory/data/processor/__init__.py +31 -0
- llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc +0 -0
- llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
llamafactory/extras/__pycache__/constants.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,6 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: app
|
| 3 |
+
app_file: webui.py
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
+
sdk_version: 5.45.0
|
|
|
|
|
|
|
| 6 |
---
|
|
|
|
|
|
api.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import uvicorn
|
| 18 |
+
|
| 19 |
+
from llamafactory.api.app import create_app
|
| 20 |
+
from llamafactory.chat import ChatModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
chat_model = ChatModel()
|
| 25 |
+
app = create_app(chat_model)
|
| 26 |
+
api_host = os.getenv("API_HOST", "0.0.0.0")
|
| 27 |
+
api_port = int(os.getenv("API_PORT", "8000"))
|
| 28 |
+
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
| 29 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
llamafactory.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,1124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: llamafactory
|
| 3 |
+
Version: 0.9.4.dev0
|
| 4 |
+
Summary: Unified Efficient Fine-Tuning of 100+ LLMs
|
| 5 |
+
Home-page: https://github.com/hiyouga/LLaMA-Factory
|
| 6 |
+
Author: hiyouga
|
| 7 |
+
Author-email: [email protected]
|
| 8 |
+
License: Apache 2.0 License
|
| 9 |
+
Keywords: AI,LLM,GPT,ChatGPT,Llama,Transformer,DeepSeek,Pytorch
|
| 10 |
+
Classifier: Development Status :: 4 - Beta
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 15 |
+
Classifier: Operating System :: OS Independent
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 21 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 22 |
+
Requires-Python: >=3.9.0
|
| 23 |
+
Description-Content-Type: text/markdown
|
| 24 |
+
License-File: LICENSE
|
| 25 |
+
Requires-Dist: transformers!=4.52.0,<=4.56.2,>=4.49.0; python_version < "3.10"
|
| 26 |
+
Requires-Dist: transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0; python_version >= "3.10"
|
| 27 |
+
Requires-Dist: datasets<=4.0.0,>=2.16.0
|
| 28 |
+
Requires-Dist: accelerate<=1.11.0,>=1.3.0
|
| 29 |
+
Requires-Dist: peft<=0.17.1,>=0.14.0
|
| 30 |
+
Requires-Dist: trl<=0.9.6,>=0.8.6
|
| 31 |
+
Requires-Dist: gradio<=5.45.0,>=4.38.0
|
| 32 |
+
Requires-Dist: matplotlib>=3.7.0
|
| 33 |
+
Requires-Dist: tyro<0.9.0
|
| 34 |
+
Requires-Dist: einops
|
| 35 |
+
Requires-Dist: numpy<2.0.0
|
| 36 |
+
Requires-Dist: pandas>=2.0.0
|
| 37 |
+
Requires-Dist: scipy
|
| 38 |
+
Requires-Dist: sentencepiece
|
| 39 |
+
Requires-Dist: tiktoken
|
| 40 |
+
Requires-Dist: modelscope>=1.14.0
|
| 41 |
+
Requires-Dist: hf-transfer
|
| 42 |
+
Requires-Dist: safetensors<=0.5.3
|
| 43 |
+
Requires-Dist: fire
|
| 44 |
+
Requires-Dist: omegaconf
|
| 45 |
+
Requires-Dist: packaging
|
| 46 |
+
Requires-Dist: protobuf
|
| 47 |
+
Requires-Dist: pyyaml
|
| 48 |
+
Requires-Dist: pydantic<=2.10.6
|
| 49 |
+
Requires-Dist: uvicorn
|
| 50 |
+
Requires-Dist: fastapi
|
| 51 |
+
Requires-Dist: sse-starlette
|
| 52 |
+
Requires-Dist: av
|
| 53 |
+
Requires-Dist: librosa
|
| 54 |
+
Requires-Dist: propcache!=0.4.0
|
| 55 |
+
Provides-Extra: torch
|
| 56 |
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
| 57 |
+
Requires-Dist: torchvision>=0.15.0; extra == "torch"
|
| 58 |
+
Provides-Extra: torch-npu
|
| 59 |
+
Requires-Dist: torch==2.7.1; extra == "torch-npu"
|
| 60 |
+
Requires-Dist: torch-npu==2.7.1; extra == "torch-npu"
|
| 61 |
+
Requires-Dist: torchvision==0.22.1; extra == "torch-npu"
|
| 62 |
+
Requires-Dist: decorator; extra == "torch-npu"
|
| 63 |
+
Provides-Extra: metrics
|
| 64 |
+
Requires-Dist: nltk; extra == "metrics"
|
| 65 |
+
Requires-Dist: jieba; extra == "metrics"
|
| 66 |
+
Requires-Dist: rouge-chinese; extra == "metrics"
|
| 67 |
+
Provides-Extra: deepspeed
|
| 68 |
+
Requires-Dist: deepspeed<=0.16.9,>=0.10.0; extra == "deepspeed"
|
| 69 |
+
Provides-Extra: liger-kernel
|
| 70 |
+
Requires-Dist: liger-kernel>=0.5.5; extra == "liger-kernel"
|
| 71 |
+
Provides-Extra: bitsandbytes
|
| 72 |
+
Requires-Dist: bitsandbytes>=0.39.0; extra == "bitsandbytes"
|
| 73 |
+
Provides-Extra: hqq
|
| 74 |
+
Requires-Dist: hqq; extra == "hqq"
|
| 75 |
+
Provides-Extra: eetq
|
| 76 |
+
Requires-Dist: eetq; extra == "eetq"
|
| 77 |
+
Provides-Extra: gptq
|
| 78 |
+
Requires-Dist: optimum>=1.24.0; extra == "gptq"
|
| 79 |
+
Requires-Dist: gptqmodel>=2.0.0; extra == "gptq"
|
| 80 |
+
Provides-Extra: aqlm
|
| 81 |
+
Requires-Dist: aqlm[gpu]>=1.1.0; extra == "aqlm"
|
| 82 |
+
Provides-Extra: vllm
|
| 83 |
+
Requires-Dist: vllm<=0.11.0,>=0.4.3; extra == "vllm"
|
| 84 |
+
Provides-Extra: sglang
|
| 85 |
+
Requires-Dist: sglang[srt]>=0.4.5; extra == "sglang"
|
| 86 |
+
Requires-Dist: transformers==4.51.1; extra == "sglang"
|
| 87 |
+
Provides-Extra: galore
|
| 88 |
+
Requires-Dist: galore-torch; extra == "galore"
|
| 89 |
+
Provides-Extra: apollo
|
| 90 |
+
Requires-Dist: apollo-torch; extra == "apollo"
|
| 91 |
+
Provides-Extra: badam
|
| 92 |
+
Requires-Dist: badam>=1.2.1; extra == "badam"
|
| 93 |
+
Provides-Extra: adam-mini
|
| 94 |
+
Requires-Dist: adam-mini; extra == "adam-mini"
|
| 95 |
+
Provides-Extra: minicpm-v
|
| 96 |
+
Requires-Dist: soundfile; extra == "minicpm-v"
|
| 97 |
+
Requires-Dist: torchvision; extra == "minicpm-v"
|
| 98 |
+
Requires-Dist: torchaudio; extra == "minicpm-v"
|
| 99 |
+
Requires-Dist: vector_quantize_pytorch; extra == "minicpm-v"
|
| 100 |
+
Requires-Dist: vocos; extra == "minicpm-v"
|
| 101 |
+
Requires-Dist: msgpack; extra == "minicpm-v"
|
| 102 |
+
Requires-Dist: referencing; extra == "minicpm-v"
|
| 103 |
+
Requires-Dist: jsonschema_specifications; extra == "minicpm-v"
|
| 104 |
+
Provides-Extra: openmind
|
| 105 |
+
Requires-Dist: openmind; extra == "openmind"
|
| 106 |
+
Provides-Extra: swanlab
|
| 107 |
+
Requires-Dist: swanlab; extra == "swanlab"
|
| 108 |
+
Provides-Extra: fp8
|
| 109 |
+
Requires-Dist: torchao>=0.8.0; extra == "fp8"
|
| 110 |
+
Requires-Dist: accelerate>=1.10.0; extra == "fp8"
|
| 111 |
+
Provides-Extra: fp8-te
|
| 112 |
+
Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-te"
|
| 113 |
+
Requires-Dist: accelerate>=1.10.0; extra == "fp8-te"
|
| 114 |
+
Provides-Extra: fp8-all
|
| 115 |
+
Requires-Dist: torchao>=0.8.0; extra == "fp8-all"
|
| 116 |
+
Requires-Dist: transformer_engine[pytorch]>=2.0.0; extra == "fp8-all"
|
| 117 |
+
Requires-Dist: accelerate>=1.10.0; extra == "fp8-all"
|
| 118 |
+
Provides-Extra: dev
|
| 119 |
+
Requires-Dist: pre-commit; extra == "dev"
|
| 120 |
+
Requires-Dist: ruff; extra == "dev"
|
| 121 |
+
Requires-Dist: pytest; extra == "dev"
|
| 122 |
+
Requires-Dist: build; extra == "dev"
|
| 123 |
+
Dynamic: author
|
| 124 |
+
Dynamic: author-email
|
| 125 |
+
Dynamic: classifier
|
| 126 |
+
Dynamic: description
|
| 127 |
+
Dynamic: description-content-type
|
| 128 |
+
Dynamic: home-page
|
| 129 |
+
Dynamic: keywords
|
| 130 |
+
Dynamic: license
|
| 131 |
+
Dynamic: license-file
|
| 132 |
+
Dynamic: provides-extra
|
| 133 |
+
Dynamic: requires-dist
|
| 134 |
+
Dynamic: requires-python
|
| 135 |
+
Dynamic: summary
|
| 136 |
+
|
| 137 |
+

|
| 138 |
+
|
| 139 |
+
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
| 140 |
+
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
| 141 |
+
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
|
| 142 |
+
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
|
| 143 |
+
[](https://pypi.org/project/llamafactory/)
|
| 144 |
+
[](https://scholar.google.com/scholar?cites=12620864006390196564)
|
| 145 |
+
[](https://hub.docker.com/r/hiyouga/llamafactory/tags)
|
| 146 |
+
|
| 147 |
+
[](https://twitter.com/llamafactory_ai)
|
| 148 |
+
[](https://discord.gg/rKfvV9r9FK)
|
| 149 |
+
[](https://github.com/hiyouga/llamafactory-community)
|
| 150 |
+
[](https://blog.llamafactory.net/en/)
|
| 151 |
+
|
| 152 |
+
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
| 153 |
+
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
| 154 |
+
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
| 155 |
+
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
| 156 |
+
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
| 157 |
+
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
| 158 |
+
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
| 159 |
+
|
| 160 |
+
### Used by [Amazon](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/), [NVIDIA](https://developer.nvidia.com/rtx/ai-toolkit), [Aliyun](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory), etc.
|
| 161 |
+
|
| 162 |
+
<div align="center" markdown="1">
|
| 163 |
+
|
| 164 |
+
### Supporters ❤️
|
| 165 |
+
|
| 166 |
+
| <div style="text-align: center;"><a href="https://warp.dev/llama-factory"><img alt="Warp sponsorship" width="400" src="assets/sponsors/warp.jpg"></a><br><a href="https://warp.dev/llama-factory" style="font-size:larger;">Warp, the agentic terminal for developers</a><br><a href="https://warp.dev/llama-factory">Available for MacOS, Linux, & Windows</a> | <a href="https://serpapi.com"><img alt="SerpAPI sponsorship" width="250" src="assets/sponsors/serpapi.svg"> </a> |
|
| 167 |
+
| ---- | ---- |
|
| 168 |
+
|
| 169 |
+
----
|
| 170 |
+
|
| 171 |
+
### Easily fine-tune 100+ large language models with zero-code [CLI](#quickstart) and [Web UI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
| 172 |
+
|
| 173 |
+

|
| 174 |
+
|
| 175 |
+
</div>
|
| 176 |
+
|
| 177 |
+
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
|
| 178 |
+
|
| 179 |
+
\[ English | [中文](README_zh.md) \]
|
| 180 |
+
|
| 181 |
+
**Fine-tuning a large language model can be easy as...**
|
| 182 |
+
|
| 183 |
+
https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
|
| 184 |
+
|
| 185 |
+
Start local training:
|
| 186 |
+
- Please refer to [usage](#getting-started)
|
| 187 |
+
|
| 188 |
+
Start cloud training:
|
| 189 |
+
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
| 190 |
+
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
| 191 |
+
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
| 192 |
+
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
| 193 |
+
|
| 194 |
+
Read technical notes:
|
| 195 |
+
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
| 196 |
+
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
| 197 |
+
- **Official Blog**: https://blog.llamafactory.net/en/
|
| 198 |
+
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
| 199 |
+
|
| 200 |
+
> [!NOTE]
|
| 201 |
+
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
| 202 |
+
|
| 203 |
+
## Table of Contents
|
| 204 |
+
|
| 205 |
+
- [Features](#features)
|
| 206 |
+
- [Blogs](#blogs)
|
| 207 |
+
- [Changelog](#changelog)
|
| 208 |
+
- [Supported Models](#supported-models)
|
| 209 |
+
- [Supported Training Approaches](#supported-training-approaches)
|
| 210 |
+
- [Provided Datasets](#provided-datasets)
|
| 211 |
+
- [Requirement](#requirement)
|
| 212 |
+
- [Getting Started](#getting-started)
|
| 213 |
+
- [Installation](#installation)
|
| 214 |
+
- [Data Preparation](#data-preparation)
|
| 215 |
+
- [Quickstart](#quickstart)
|
| 216 |
+
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
| 217 |
+
- [LLaMA Factory Online](#llama-factory-online)
|
| 218 |
+
- [Build Docker](#build-docker)
|
| 219 |
+
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
| 220 |
+
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
| 221 |
+
- [Download from Modelers Hub](#download-from-modelers-hub)
|
| 222 |
+
- [Use W&B Logger](#use-wb-logger)
|
| 223 |
+
- [Use SwanLab Logger](#use-swanlab-logger)
|
| 224 |
+
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
| 225 |
+
- [License](#license)
|
| 226 |
+
- [Citation](#citation)
|
| 227 |
+
- [Acknowledgement](#acknowledgement)
|
| 228 |
+
|
| 229 |
+
## Features
|
| 230 |
+
|
| 231 |
+
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
| 232 |
+
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
| 233 |
+
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
| 234 |
+
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
| 235 |
+
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
| 236 |
+
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
|
| 237 |
+
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
|
| 238 |
+
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
|
| 239 |
+
|
| 240 |
+
### Day-N Support for Fine-Tuning Cutting-Edge Models
|
| 241 |
+
|
| 242 |
+
| Support Date | Model Name |
|
| 243 |
+
| ------------ | -------------------------------------------------------------------- |
|
| 244 |
+
| Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / GLM-4.1V / InternLM 3 / MiniCPM-o-2.6 |
|
| 245 |
+
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
|
| 246 |
+
|
| 247 |
+
## Blogs
|
| 248 |
+
|
| 249 |
+
> [!TIP]
|
| 250 |
+
> Now we have a dedicated blog for LLaMA Factory!
|
| 251 |
+
>
|
| 252 |
+
> Website: https://blog.llamafactory.net/en/
|
| 253 |
+
|
| 254 |
+
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
| 255 |
+
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
| 256 |
+
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
| 257 |
+
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
| 258 |
+
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
| 259 |
+
|
| 260 |
+
<details><summary>All Blogs</summary>
|
| 261 |
+
|
| 262 |
+
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
|
| 263 |
+
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
| 264 |
+
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
| 265 |
+
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
| 266 |
+
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
| 267 |
+
- [LLaMA Factory: Fine-tuning Llama3 for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
|
| 268 |
+
|
| 269 |
+
</details>
|
| 270 |
+
|
| 271 |
+
## Changelog
|
| 272 |
+
|
| 273 |
+
[25/10/26] We support Megatron-core training backend with [**mcore_adapter**](https://github.com/alibaba/ROLL/tree/main/mcore_adapter). See [PR #9237](https://github.com/hiyouga/LLaMA-Factory/pull/9237) to get started.
|
| 274 |
+
|
| 275 |
+
[25/08/22] We supported **[OFT](https://arxiv.org/abs/2306.07280)** and **[OFTv2](https://arxiv.org/abs/2506.19847)**. See [examples](examples/README.md) for usage.
|
| 276 |
+
|
| 277 |
+
[25/08/20] We supported fine-tuning the **[Intern-S1-mini](https://huggingface.co/internlm/Intern-S1-mini)** models. See [PR #8976](https://github.com/hiyouga/LLaMA-Factory/pull/8976) to get started.
|
| 278 |
+
|
| 279 |
+
[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started.
|
| 280 |
+
|
| 281 |
+
<details><summary>Full Changelog</summary>
|
| 282 |
+
|
| 283 |
+
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model.
|
| 284 |
+
|
| 285 |
+
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
|
| 286 |
+
|
| 287 |
+
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
|
| 288 |
+
|
| 289 |
+
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
|
| 290 |
+
|
| 291 |
+
[25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models.
|
| 292 |
+
|
| 293 |
+
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
|
| 294 |
+
|
| 295 |
+
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
|
| 296 |
+
|
| 297 |
+
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
|
| 298 |
+
|
| 299 |
+
[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
|
| 300 |
+
|
| 301 |
+
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
|
| 302 |
+
|
| 303 |
+
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
|
| 304 |
+
|
| 305 |
+
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
|
| 306 |
+
|
| 307 |
+
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** models.
|
| 308 |
+
|
| 309 |
+
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
|
| 310 |
+
|
| 311 |
+
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
|
| 312 |
+
|
| 313 |
+
[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
| 314 |
+
|
| 315 |
+
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
| 316 |
+
|
| 317 |
+
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
|
| 318 |
+
|
| 319 |
+
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
|
| 320 |
+
|
| 321 |
+
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
| 322 |
+
|
| 323 |
+
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
| 324 |
+
|
| 325 |
+
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
| 326 |
+
|
| 327 |
+
[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
| 328 |
+
|
| 329 |
+
[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
| 330 |
+
|
| 331 |
+
[24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
| 332 |
+
|
| 333 |
+
[24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
| 334 |
+
|
| 335 |
+
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
| 336 |
+
|
| 337 |
+
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
| 338 |
+
|
| 339 |
+
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
|
| 340 |
+
|
| 341 |
+
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
| 342 |
+
|
| 343 |
+
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
| 344 |
+
|
| 345 |
+
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
| 346 |
+
|
| 347 |
+
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
| 348 |
+
|
| 349 |
+
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
| 350 |
+
|
| 351 |
+
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
|
| 352 |
+
|
| 353 |
+
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
| 354 |
+
|
| 355 |
+
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
|
| 356 |
+
|
| 357 |
+
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
| 358 |
+
|
| 359 |
+
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
|
| 360 |
+
|
| 361 |
+
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
|
| 362 |
+
|
| 363 |
+
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
|
| 364 |
+
|
| 365 |
+
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
|
| 366 |
+
|
| 367 |
+
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
|
| 368 |
+
|
| 369 |
+
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
|
| 370 |
+
|
| 371 |
+
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
| 372 |
+
|
| 373 |
+
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
|
| 374 |
+
|
| 375 |
+
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
| 376 |
+
|
| 377 |
+
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
| 378 |
+
|
| 379 |
+
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
| 380 |
+
|
| 381 |
+
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
| 382 |
+
|
| 383 |
+
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
|
| 384 |
+
|
| 385 |
+
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
|
| 386 |
+
|
| 387 |
+
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
| 388 |
+
|
| 389 |
+
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
|
| 390 |
+
|
| 391 |
+
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
|
| 392 |
+
|
| 393 |
+
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
|
| 394 |
+
|
| 395 |
+
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
| 396 |
+
|
| 397 |
+
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
| 398 |
+
|
| 399 |
+
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
| 400 |
+
|
| 401 |
+
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
| 402 |
+
|
| 403 |
+
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
| 404 |
+
|
| 405 |
+
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
|
| 406 |
+
|
| 407 |
+
</details>
|
| 408 |
+
|
| 409 |
+
> [!TIP]
|
| 410 |
+
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
|
| 411 |
+
|
| 412 |
+
## Supported Models
|
| 413 |
+
|
| 414 |
+
| Model | Model size | Template |
|
| 415 |
+
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
| 416 |
+
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| 417 |
+
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| 418 |
+
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| 419 |
+
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| 420 |
+
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| 421 |
+
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
| 422 |
+
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
| 423 |
+
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
| 424 |
+
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| 425 |
+
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
| 426 |
+
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
| 427 |
+
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
| 428 |
+
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/zai-org) | 9B/32B | glm4/glmz1 |
|
| 429 |
+
| [GLM-4.1V](https://huggingface.co/zai-org) | 9B | glm4v |
|
| 430 |
+
| [GLM-4.5/GLM-4.5V](https://huggingface.co/zai-org) | 106B/355B | glm4_moe/glm4v_moe |
|
| 431 |
+
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| 432 |
+
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
| 433 |
+
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| 434 |
+
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
| 435 |
+
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
| 436 |
+
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
| 437 |
+
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| 438 |
+
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
| 439 |
+
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| 440 |
+
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
| 441 |
+
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
| 442 |
+
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| 443 |
+
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| 444 |
+
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| 445 |
+
| [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 |
|
| 446 |
+
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
| 447 |
+
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| 448 |
+
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| 449 |
+
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| 450 |
+
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
|
| 451 |
+
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
| 452 |
+
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
| 453 |
+
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
|
| 454 |
+
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| 455 |
+
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
|
| 456 |
+
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| 457 |
+
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
| 458 |
+
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| 459 |
+
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
| 460 |
+
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
| 461 |
+
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
| 462 |
+
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
| 463 |
+
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| 464 |
+
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
| 465 |
+
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
| 466 |
+
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
| 467 |
+
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
| 468 |
+
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
| 469 |
+
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
| 470 |
+
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
| 471 |
+
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
| 472 |
+
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| 473 |
+
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
|
| 474 |
+
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| 475 |
+
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| 476 |
+
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| 477 |
+
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| 478 |
+
|
| 479 |
+
> [!NOTE]
|
| 480 |
+
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
| 481 |
+
>
|
| 482 |
+
> If the model has both reasoning and non-reasoning versions, please use the `_nothink` suffix to distinguish between them. For example, `qwen3` and `qwen3_nothink`.
|
| 483 |
+
>
|
| 484 |
+
> Remember to use the **SAME** template in training and inference.
|
| 485 |
+
>
|
| 486 |
+
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
|
| 487 |
+
>
|
| 488 |
+
> \*\*: You need to install a specific version of `transformers` to use the corresponding model.
|
| 489 |
+
|
| 490 |
+
Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
|
| 491 |
+
|
| 492 |
+
You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
|
| 493 |
+
|
| 494 |
+
## Supported Training Approaches
|
| 495 |
+
|
| 496 |
+
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT |
|
| 497 |
+
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
|
| 498 |
+
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 499 |
+
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 500 |
+
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 501 |
+
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 502 |
+
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 503 |
+
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 504 |
+
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 505 |
+
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 506 |
+
|
| 507 |
+
> [!TIP]
|
| 508 |
+
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
|
| 509 |
+
|
| 510 |
+
## Provided Datasets
|
| 511 |
+
|
| 512 |
+
<details><summary>Pre-training datasets</summary>
|
| 513 |
+
|
| 514 |
+
- [Wiki Demo (en)](data/wiki_demo.txt)
|
| 515 |
+
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
| 516 |
+
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
| 517 |
+
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
| 518 |
+
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
| 519 |
+
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
| 520 |
+
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
| 521 |
+
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
| 522 |
+
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
| 523 |
+
- [CCI3-HQ (zh)](https://huggingface.co/datasets/BAAI/CCI3-HQ)
|
| 524 |
+
- [CCI3-Data (zh)](https://huggingface.co/datasets/BAAI/CCI3-Data)
|
| 525 |
+
- [CCI4.0-M2-Base-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Base-v1)
|
| 526 |
+
- [CCI4.0-M2-CoT-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-CoT-v1)
|
| 527 |
+
- [CCI4.0-M2-Extra-v1 (en&zh)](https://huggingface.co/datasets/BAAI/CCI4.0-M2-Extra-v1)
|
| 528 |
+
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
| 529 |
+
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
| 530 |
+
|
| 531 |
+
</details>
|
| 532 |
+
|
| 533 |
+
<details><summary>Supervised fine-tuning datasets</summary>
|
| 534 |
+
|
| 535 |
+
- [Identity (en&zh)](data/identity.json)
|
| 536 |
+
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
| 537 |
+
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
| 538 |
+
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
| 539 |
+
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
| 540 |
+
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
| 541 |
+
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
| 542 |
+
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
| 543 |
+
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
| 544 |
+
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
| 545 |
+
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
| 546 |
+
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
| 547 |
+
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
| 548 |
+
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
| 549 |
+
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
| 550 |
+
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
| 551 |
+
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
| 552 |
+
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
| 553 |
+
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
| 554 |
+
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
| 555 |
+
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
| 556 |
+
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
| 557 |
+
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
| 558 |
+
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
| 559 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
| 560 |
+
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
| 561 |
+
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
| 562 |
+
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
| 563 |
+
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
| 564 |
+
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
| 565 |
+
- [Infinity Instruct (zh)](https://huggingface.co/datasets/BAAI/Infinity-Instruct)
|
| 566 |
+
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
| 567 |
+
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
| 568 |
+
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
| 569 |
+
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
| 570 |
+
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
| 571 |
+
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
| 572 |
+
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
| 573 |
+
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
| 574 |
+
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
|
| 575 |
+
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
| 576 |
+
- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
|
| 577 |
+
- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
|
| 578 |
+
- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
|
| 579 |
+
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
|
| 580 |
+
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
| 581 |
+
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
|
| 582 |
+
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
| 583 |
+
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
| 584 |
+
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
| 585 |
+
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
| 586 |
+
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
| 587 |
+
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
| 588 |
+
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
| 589 |
+
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
| 590 |
+
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
| 591 |
+
|
| 592 |
+
</details>
|
| 593 |
+
|
| 594 |
+
<details><summary>Preference datasets</summary>
|
| 595 |
+
|
| 596 |
+
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
| 597 |
+
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
| 598 |
+
- [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
|
| 599 |
+
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
| 600 |
+
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
|
| 601 |
+
- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
|
| 602 |
+
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
| 603 |
+
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
| 604 |
+
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
| 605 |
+
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
| 606 |
+
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
| 607 |
+
|
| 608 |
+
</details>
|
| 609 |
+
|
| 610 |
+
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
| 611 |
+
|
| 612 |
+
```bash
|
| 613 |
+
pip install "huggingface_hub<1.0.0"
|
| 614 |
+
huggingface-cli login
|
| 615 |
+
```
|
| 616 |
+
|
| 617 |
+
## Requirement
|
| 618 |
+
|
| 619 |
+
| Mandatory | Minimum | Recommend |
|
| 620 |
+
| ------------ | ------- | --------- |
|
| 621 |
+
| python | 3.9 | 3.10 |
|
| 622 |
+
| torch | 2.0.0 | 2.6.0 |
|
| 623 |
+
| torchvision | 0.15.0 | 0.21.0 |
|
| 624 |
+
| transformers | 4.49.0 | 4.50.0 |
|
| 625 |
+
| datasets | 2.16.0 | 3.2.0 |
|
| 626 |
+
| accelerate | 0.34.0 | 1.2.1 |
|
| 627 |
+
| peft | 0.14.0 | 0.15.1 |
|
| 628 |
+
| trl | 0.8.6 | 0.9.6 |
|
| 629 |
+
|
| 630 |
+
| Optional | Minimum | Recommend |
|
| 631 |
+
| ------------ | ------- | --------- |
|
| 632 |
+
| CUDA | 11.6 | 12.2 |
|
| 633 |
+
| deepspeed | 0.10.0 | 0.16.4 |
|
| 634 |
+
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| 635 |
+
| vllm | 0.4.3 | 0.8.2 |
|
| 636 |
+
| flash-attn | 2.5.6 | 2.7.2 |
|
| 637 |
+
|
| 638 |
+
### Hardware Requirement
|
| 639 |
+
|
| 640 |
+
\* *estimated*
|
| 641 |
+
|
| 642 |
+
| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
|
| 643 |
+
| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
| 644 |
+
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
| 645 |
+
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
| 646 |
+
| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
| 647 |
+
| QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
| 648 |
+
| QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
| 649 |
+
| QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
| 650 |
+
|
| 651 |
+
## Getting Started
|
| 652 |
+
|
| 653 |
+
### Installation
|
| 654 |
+
|
| 655 |
+
> [!IMPORTANT]
|
| 656 |
+
> Installation is mandatory.
|
| 657 |
+
|
| 658 |
+
#### Install from Source
|
| 659 |
+
|
| 660 |
+
```bash
|
| 661 |
+
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
| 662 |
+
cd LLaMA-Factory
|
| 663 |
+
pip install -e ".[torch,metrics]" --no-build-isolation
|
| 664 |
+
```
|
| 665 |
+
|
| 666 |
+
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
|
| 667 |
+
|
| 668 |
+
#### Install from Docker Image
|
| 669 |
+
|
| 670 |
+
```bash
|
| 671 |
+
docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
|
| 672 |
+
```
|
| 673 |
+
|
| 674 |
+
This image is built on Ubuntu 22.04 (x86\_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4.
|
| 675 |
+
|
| 676 |
+
Find the pre-built images: https://hub.docker.com/r/hiyouga/llamafactory/tags
|
| 677 |
+
|
| 678 |
+
Please refer to [build docker](#build-docker) to build the image yourself.
|
| 679 |
+
|
| 680 |
+
<details><summary>Setting up a virtual environment with <b>uv</b></summary>
|
| 681 |
+
|
| 682 |
+
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
| 683 |
+
|
| 684 |
+
```bash
|
| 685 |
+
uv sync --extra torch --extra metrics --prerelease=allow
|
| 686 |
+
```
|
| 687 |
+
|
| 688 |
+
Run LLaMA-Factory in the isolated environment:
|
| 689 |
+
|
| 690 |
+
```bash
|
| 691 |
+
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
| 692 |
+
```
|
| 693 |
+
|
| 694 |
+
</details>
|
| 695 |
+
|
| 696 |
+
<details><summary>For Windows users</summary>
|
| 697 |
+
|
| 698 |
+
#### Install PyTorch
|
| 699 |
+
|
| 700 |
+
You need to manually install the GPU version of PyTorch on the Windows platform. Please refer to the [official website](https://pytorch.org/get-started/locally/) and the following command to install PyTorch with CUDA support:
|
| 701 |
+
|
| 702 |
+
```bash
|
| 703 |
+
pip uninstall torch torchvision torchaudio
|
| 704 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
|
| 705 |
+
python -c "import torch; print(torch.cuda.is_available())"
|
| 706 |
+
```
|
| 707 |
+
|
| 708 |
+
If you see `True` then you have successfully installed PyTorch with CUDA support.
|
| 709 |
+
|
| 710 |
+
Try `dataloader_num_workers: 0` if you encounter `Can't pickle local object` error.
|
| 711 |
+
|
| 712 |
+
#### Install BitsAndBytes
|
| 713 |
+
|
| 714 |
+
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
| 715 |
+
|
| 716 |
+
```bash
|
| 717 |
+
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
| 718 |
+
```
|
| 719 |
+
|
| 720 |
+
#### Install Flash Attention-2
|
| 721 |
+
|
| 722 |
+
To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself.
|
| 723 |
+
|
| 724 |
+
</details>
|
| 725 |
+
|
| 726 |
+
<details><summary>For Ascend NPU users</summary>
|
| 727 |
+
|
| 728 |
+
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
| 729 |
+
|
| 730 |
+
```bash
|
| 731 |
+
# replace the url according to your CANN version and devices
|
| 732 |
+
# install CANN Toolkit
|
| 733 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
| 734 |
+
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
| 735 |
+
|
| 736 |
+
# install CANN Kernels
|
| 737 |
+
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
| 738 |
+
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
| 739 |
+
|
| 740 |
+
# set env variables
|
| 741 |
+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
| 742 |
+
```
|
| 743 |
+
|
| 744 |
+
| Requirement | Minimum | Recommend |
|
| 745 |
+
| ------------ | ------- | -------------- |
|
| 746 |
+
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
| 747 |
+
| torch | 2.1.0 | 2.4.0 |
|
| 748 |
+
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
| 749 |
+
| deepspeed | 0.13.2 | 0.13.2 |
|
| 750 |
+
| vllm-ascend | - | 0.7.3 |
|
| 751 |
+
|
| 752 |
+
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
| 753 |
+
|
| 754 |
+
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
| 755 |
+
|
| 756 |
+
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
| 757 |
+
|
| 758 |
+
#### Install BitsAndBytes
|
| 759 |
+
|
| 760 |
+
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
| 761 |
+
|
| 762 |
+
1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
|
| 763 |
+
|
| 764 |
+
```bash
|
| 765 |
+
# Install bitsandbytes from source
|
| 766 |
+
# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
|
| 767 |
+
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
|
| 768 |
+
cd bitsandbytes/
|
| 769 |
+
|
| 770 |
+
# Install dependencies
|
| 771 |
+
pip install -r requirements-dev.txt
|
| 772 |
+
|
| 773 |
+
# Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference
|
| 774 |
+
apt-get install -y build-essential cmake
|
| 775 |
+
|
| 776 |
+
# Compile & install
|
| 777 |
+
cmake -DCOMPUTE_BACKEND=npu -S .
|
| 778 |
+
make
|
| 779 |
+
pip install .
|
| 780 |
+
```
|
| 781 |
+
|
| 782 |
+
2. Install transformers from the main branch.
|
| 783 |
+
|
| 784 |
+
```bash
|
| 785 |
+
git clone -b main https://github.com/huggingface/transformers.git
|
| 786 |
+
cd transformers
|
| 787 |
+
pip install .
|
| 788 |
+
```
|
| 789 |
+
|
| 790 |
+
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
|
| 791 |
+
|
| 792 |
+
</details>
|
| 793 |
+
|
| 794 |
+
### Data Preparation
|
| 795 |
+
|
| 796 |
+
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace / ModelScope / Modelers hub, load the dataset in local disk, or specify a path to s3/gcs cloud storage.
|
| 797 |
+
|
| 798 |
+
> [!NOTE]
|
| 799 |
+
> Please update `data/dataset_info.json` to use your custom dataset.
|
| 800 |
+
|
| 801 |
+
You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**, **[DataFlow](https://github.com/OpenDCAI/DataFlow)** and **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
|
| 802 |
+
|
| 803 |
+
### Quickstart
|
| 804 |
+
|
| 805 |
+
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
| 806 |
+
|
| 807 |
+
```bash
|
| 808 |
+
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
| 809 |
+
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
| 810 |
+
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
| 811 |
+
```
|
| 812 |
+
|
| 813 |
+
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
|
| 814 |
+
|
| 815 |
+
> [!TIP]
|
| 816 |
+
> Use `llamafactory-cli help` to show help information.
|
| 817 |
+
>
|
| 818 |
+
> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
|
| 819 |
+
|
| 820 |
+
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
| 821 |
+
|
| 822 |
+
```bash
|
| 823 |
+
llamafactory-cli webui
|
| 824 |
+
```
|
| 825 |
+
|
| 826 |
+
### LLaMA Factory Online
|
| 827 |
+
|
| 828 |
+
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
|
| 829 |
+
|
| 830 |
+
### Build Docker
|
| 831 |
+
|
| 832 |
+
For CUDA users:
|
| 833 |
+
|
| 834 |
+
```bash
|
| 835 |
+
cd docker/docker-cuda/
|
| 836 |
+
docker compose up -d
|
| 837 |
+
docker compose exec llamafactory bash
|
| 838 |
+
```
|
| 839 |
+
|
| 840 |
+
For Ascend NPU users:
|
| 841 |
+
|
| 842 |
+
```bash
|
| 843 |
+
cd docker/docker-npu/
|
| 844 |
+
docker compose up -d
|
| 845 |
+
docker compose exec llamafactory bash
|
| 846 |
+
```
|
| 847 |
+
|
| 848 |
+
For AMD ROCm users:
|
| 849 |
+
|
| 850 |
+
```bash
|
| 851 |
+
cd docker/docker-rocm/
|
| 852 |
+
docker compose up -d
|
| 853 |
+
docker compose exec llamafactory bash
|
| 854 |
+
```
|
| 855 |
+
|
| 856 |
+
<details><summary>Build without Docker Compose</summary>
|
| 857 |
+
|
| 858 |
+
For CUDA users:
|
| 859 |
+
|
| 860 |
+
```bash
|
| 861 |
+
docker build -f ./docker/docker-cuda/Dockerfile \
|
| 862 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
| 863 |
+
--build-arg EXTRAS=metrics \
|
| 864 |
+
-t llamafactory:latest .
|
| 865 |
+
|
| 866 |
+
docker run -dit --ipc=host --gpus=all \
|
| 867 |
+
-p 7860:7860 \
|
| 868 |
+
-p 8000:8000 \
|
| 869 |
+
--name llamafactory \
|
| 870 |
+
llamafactory:latest
|
| 871 |
+
|
| 872 |
+
docker exec -it llamafactory bash
|
| 873 |
+
```
|
| 874 |
+
|
| 875 |
+
For Ascend NPU users:
|
| 876 |
+
|
| 877 |
+
```bash
|
| 878 |
+
docker build -f ./docker/docker-npu/Dockerfile \
|
| 879 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
| 880 |
+
--build-arg EXTRAS=torch-npu,metrics \
|
| 881 |
+
-t llamafactory:latest .
|
| 882 |
+
|
| 883 |
+
docker run -dit --ipc=host \
|
| 884 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 885 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 886 |
+
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
| 887 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 888 |
+
-p 7860:7860 \
|
| 889 |
+
-p 8000:8000 \
|
| 890 |
+
--device /dev/davinci0 \
|
| 891 |
+
--device /dev/davinci_manager \
|
| 892 |
+
--device /dev/devmm_svm \
|
| 893 |
+
--device /dev/hisi_hdc \
|
| 894 |
+
--name llamafactory \
|
| 895 |
+
llamafactory:latest
|
| 896 |
+
|
| 897 |
+
docker exec -it llamafactory bash
|
| 898 |
+
```
|
| 899 |
+
|
| 900 |
+
For AMD ROCm users:
|
| 901 |
+
|
| 902 |
+
```bash
|
| 903 |
+
docker build -f ./docker/docker-rocm/Dockerfile \
|
| 904 |
+
--build-arg PIP_INDEX=https://pypi.org/simple \
|
| 905 |
+
--build-arg EXTRAS=metrics \
|
| 906 |
+
-t llamafactory:latest .
|
| 907 |
+
|
| 908 |
+
docker run -dit --ipc=host \
|
| 909 |
+
-p 7860:7860 \
|
| 910 |
+
-p 8000:8000 \
|
| 911 |
+
--device /dev/kfd \
|
| 912 |
+
--device /dev/dri \
|
| 913 |
+
--name llamafactory \
|
| 914 |
+
llamafactory:latest
|
| 915 |
+
|
| 916 |
+
docker exec -it llamafactory bash
|
| 917 |
+
```
|
| 918 |
+
|
| 919 |
+
</details>
|
| 920 |
+
|
| 921 |
+
<details><summary>Use Docker volumes</summary>
|
| 922 |
+
|
| 923 |
+
You can uncomment `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` in the Dockerfile to use data volumes.
|
| 924 |
+
|
| 925 |
+
When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` argument to mount the local directory to the container. The following data volumes are available.
|
| 926 |
+
|
| 927 |
+
- `hf_cache`: Utilize Hugging Face cache on the host machine.
|
| 928 |
+
- `shared_data`: The directionary to store datasets on the host machine.
|
| 929 |
+
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
| 930 |
+
|
| 931 |
+
</details>
|
| 932 |
+
|
| 933 |
+
### Deploy with OpenAI-style API and vLLM
|
| 934 |
+
|
| 935 |
+
```bash
|
| 936 |
+
API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
|
| 937 |
+
```
|
| 938 |
+
|
| 939 |
+
> [!TIP]
|
| 940 |
+
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
| 941 |
+
>
|
| 942 |
+
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
|
| 943 |
+
|
| 944 |
+
### Download from ModelScope Hub
|
| 945 |
+
|
| 946 |
+
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
| 947 |
+
|
| 948 |
+
```bash
|
| 949 |
+
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
| 950 |
+
```
|
| 951 |
+
|
| 952 |
+
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
| 953 |
+
|
| 954 |
+
### Download from Modelers Hub
|
| 955 |
+
|
| 956 |
+
You can also use Modelers Hub to download models and datasets.
|
| 957 |
+
|
| 958 |
+
```bash
|
| 959 |
+
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
| 960 |
+
```
|
| 961 |
+
|
| 962 |
+
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
| 963 |
+
|
| 964 |
+
### Use W&B Logger
|
| 965 |
+
|
| 966 |
+
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
| 967 |
+
|
| 968 |
+
```yaml
|
| 969 |
+
report_to: wandb
|
| 970 |
+
run_name: test_run # optional
|
| 971 |
+
```
|
| 972 |
+
|
| 973 |
+
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
| 974 |
+
|
| 975 |
+
### Use SwanLab Logger
|
| 976 |
+
|
| 977 |
+
To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files.
|
| 978 |
+
|
| 979 |
+
```yaml
|
| 980 |
+
use_swanlab: true
|
| 981 |
+
swanlab_run_name: test_run # optional
|
| 982 |
+
```
|
| 983 |
+
|
| 984 |
+
When launching training tasks, you can log in to SwanLab in three ways:
|
| 985 |
+
|
| 986 |
+
1. Add `swanlab_api_key=<your_api_key>` to the yaml file, and set it to your [API key](https://swanlab.cn/settings).
|
| 987 |
+
2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings).
|
| 988 |
+
3. Use the `swanlab login` command to complete the login.
|
| 989 |
+
|
| 990 |
+
## Projects using LLaMA Factory
|
| 991 |
+
|
| 992 |
+
If you have a project that should be incorporated, please contact via email or create a pull request.
|
| 993 |
+
|
| 994 |
+
<details><summary>Click to show</summary>
|
| 995 |
+
|
| 996 |
+
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
| 997 |
+
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
| 998 |
+
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
| 999 |
+
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
| 1000 |
+
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
| 1001 |
+
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
| 1002 |
+
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
| 1003 |
+
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
| 1004 |
+
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
| 1005 |
+
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
| 1006 |
+
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
| 1007 |
+
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
| 1008 |
+
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
| 1009 |
+
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
| 1010 |
+
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
| 1011 |
+
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
| 1012 |
+
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
| 1013 |
+
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
| 1014 |
+
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
| 1015 |
+
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
| 1016 |
+
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
| 1017 |
+
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
| 1018 |
+
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
| 1019 |
+
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
| 1020 |
+
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
| 1021 |
+
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
| 1022 |
+
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
| 1023 |
+
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
| 1024 |
+
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
| 1025 |
+
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
| 1026 |
+
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
| 1027 |
+
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
| 1028 |
+
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
| 1029 |
+
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
| 1030 |
+
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
| 1031 |
+
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
| 1032 |
+
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
| 1033 |
+
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
| 1034 |
+
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
| 1035 |
+
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
| 1036 |
+
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
| 1037 |
+
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
| 1038 |
+
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
| 1039 |
+
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
| 1040 |
+
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
| 1041 |
+
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
| 1042 |
+
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
| 1043 |
+
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
| 1044 |
+
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
| 1045 |
+
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
| 1046 |
+
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
| 1047 |
+
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
| 1048 |
+
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
| 1049 |
+
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
| 1050 |
+
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
| 1051 |
+
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
| 1052 |
+
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
| 1053 |
+
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
| 1054 |
+
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
| 1055 |
+
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
| 1056 |
+
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
| 1057 |
+
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
| 1058 |
+
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
|
| 1059 |
+
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
|
| 1060 |
+
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
|
| 1061 |
+
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
|
| 1062 |
+
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
|
| 1063 |
+
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
|
| 1064 |
+
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
|
| 1065 |
+
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
|
| 1066 |
+
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
|
| 1067 |
+
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
|
| 1068 |
+
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
|
| 1069 |
+
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
|
| 1070 |
+
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
|
| 1071 |
+
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
|
| 1072 |
+
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
|
| 1073 |
+
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
|
| 1074 |
+
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
| 1075 |
+
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
| 1076 |
+
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
| 1077 |
+
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
| 1078 |
+
1. Zhang et al. CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling. ACL 2024. [[paper]](https://aclanthology.org/2024.findings-acl.830.pdf)
|
| 1079 |
+
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
| 1080 |
+
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
| 1081 |
+
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
| 1082 |
+
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
| 1083 |
+
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
| 1084 |
+
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
| 1085 |
+
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
| 1086 |
+
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
| 1087 |
+
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
| 1088 |
+
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
| 1089 |
+
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
| 1090 |
+
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
|
| 1091 |
+
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
|
| 1092 |
+
1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
|
| 1093 |
+
1. **[EmoLLM](https://github.com/SmartFlowAI/EmoLLM)**: A project about large language models (LLMs) and mental health.
|
| 1094 |
+
</details>
|
| 1095 |
+
|
| 1096 |
+
## License
|
| 1097 |
+
|
| 1098 |
+
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
| 1099 |
+
|
| 1100 |
+
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
| 1101 |
+
|
| 1102 |
+
## Citation
|
| 1103 |
+
|
| 1104 |
+
If this work is helpful, please kindly cite as:
|
| 1105 |
+
|
| 1106 |
+
```bibtex
|
| 1107 |
+
@inproceedings{zheng2024llamafactory,
|
| 1108 |
+
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
| 1109 |
+
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
| 1110 |
+
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
| 1111 |
+
address={Bangkok, Thailand},
|
| 1112 |
+
publisher={Association for Computational Linguistics},
|
| 1113 |
+
year={2024},
|
| 1114 |
+
url={http://arxiv.org/abs/2403.13372}
|
| 1115 |
+
}
|
| 1116 |
+
```
|
| 1117 |
+
|
| 1118 |
+
## Acknowledgement
|
| 1119 |
+
|
| 1120 |
+
This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
| 1121 |
+
|
| 1122 |
+
## Star History
|
| 1123 |
+
|
| 1124 |
+

|
llamafactory.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
MANIFEST.in
|
| 3 |
+
README.md
|
| 4 |
+
pyproject.toml
|
| 5 |
+
requirements.txt
|
| 6 |
+
setup.py
|
| 7 |
+
src/llamafactory/__init__.py
|
| 8 |
+
src/llamafactory/cli.py
|
| 9 |
+
src/llamafactory/launcher.py
|
| 10 |
+
src/llamafactory.egg-info/PKG-INFO
|
| 11 |
+
src/llamafactory.egg-info/SOURCES.txt
|
| 12 |
+
src/llamafactory.egg-info/dependency_links.txt
|
| 13 |
+
src/llamafactory.egg-info/entry_points.txt
|
| 14 |
+
src/llamafactory.egg-info/requires.txt
|
| 15 |
+
src/llamafactory.egg-info/top_level.txt
|
| 16 |
+
src/llamafactory/api/__init__.py
|
| 17 |
+
src/llamafactory/api/app.py
|
| 18 |
+
src/llamafactory/api/chat.py
|
| 19 |
+
src/llamafactory/api/common.py
|
| 20 |
+
src/llamafactory/api/protocol.py
|
| 21 |
+
src/llamafactory/chat/__init__.py
|
| 22 |
+
src/llamafactory/chat/base_engine.py
|
| 23 |
+
src/llamafactory/chat/chat_model.py
|
| 24 |
+
src/llamafactory/chat/hf_engine.py
|
| 25 |
+
src/llamafactory/chat/kt_engine.py
|
| 26 |
+
src/llamafactory/chat/sglang_engine.py
|
| 27 |
+
src/llamafactory/chat/vllm_engine.py
|
| 28 |
+
src/llamafactory/data/__init__.py
|
| 29 |
+
src/llamafactory/data/collator.py
|
| 30 |
+
src/llamafactory/data/converter.py
|
| 31 |
+
src/llamafactory/data/data_utils.py
|
| 32 |
+
src/llamafactory/data/formatter.py
|
| 33 |
+
src/llamafactory/data/loader.py
|
| 34 |
+
src/llamafactory/data/mm_plugin.py
|
| 35 |
+
src/llamafactory/data/parser.py
|
| 36 |
+
src/llamafactory/data/template.py
|
| 37 |
+
src/llamafactory/data/tool_utils.py
|
| 38 |
+
src/llamafactory/data/processor/__init__.py
|
| 39 |
+
src/llamafactory/data/processor/feedback.py
|
| 40 |
+
src/llamafactory/data/processor/pairwise.py
|
| 41 |
+
src/llamafactory/data/processor/pretrain.py
|
| 42 |
+
src/llamafactory/data/processor/processor_utils.py
|
| 43 |
+
src/llamafactory/data/processor/supervised.py
|
| 44 |
+
src/llamafactory/data/processor/unsupervised.py
|
| 45 |
+
src/llamafactory/eval/__init__.py
|
| 46 |
+
src/llamafactory/eval/evaluator.py
|
| 47 |
+
src/llamafactory/eval/template.py
|
| 48 |
+
src/llamafactory/extras/__init__.py
|
| 49 |
+
src/llamafactory/extras/constants.py
|
| 50 |
+
src/llamafactory/extras/env.py
|
| 51 |
+
src/llamafactory/extras/logging.py
|
| 52 |
+
src/llamafactory/extras/misc.py
|
| 53 |
+
src/llamafactory/extras/packages.py
|
| 54 |
+
src/llamafactory/extras/ploting.py
|
| 55 |
+
src/llamafactory/hparams/__init__.py
|
| 56 |
+
src/llamafactory/hparams/data_args.py
|
| 57 |
+
src/llamafactory/hparams/evaluation_args.py
|
| 58 |
+
src/llamafactory/hparams/finetuning_args.py
|
| 59 |
+
src/llamafactory/hparams/generating_args.py
|
| 60 |
+
src/llamafactory/hparams/model_args.py
|
| 61 |
+
src/llamafactory/hparams/parser.py
|
| 62 |
+
src/llamafactory/hparams/training_args.py
|
| 63 |
+
src/llamafactory/model/__init__.py
|
| 64 |
+
src/llamafactory/model/adapter.py
|
| 65 |
+
src/llamafactory/model/loader.py
|
| 66 |
+
src/llamafactory/model/patcher.py
|
| 67 |
+
src/llamafactory/model/model_utils/__init__.py
|
| 68 |
+
src/llamafactory/model/model_utils/attention.py
|
| 69 |
+
src/llamafactory/model/model_utils/checkpointing.py
|
| 70 |
+
src/llamafactory/model/model_utils/embedding.py
|
| 71 |
+
src/llamafactory/model/model_utils/ktransformers.py
|
| 72 |
+
src/llamafactory/model/model_utils/kv_cache.py
|
| 73 |
+
src/llamafactory/model/model_utils/liger_kernel.py
|
| 74 |
+
src/llamafactory/model/model_utils/longlora.py
|
| 75 |
+
src/llamafactory/model/model_utils/misc.py
|
| 76 |
+
src/llamafactory/model/model_utils/mod.py
|
| 77 |
+
src/llamafactory/model/model_utils/moe.py
|
| 78 |
+
src/llamafactory/model/model_utils/packing.py
|
| 79 |
+
src/llamafactory/model/model_utils/quantization.py
|
| 80 |
+
src/llamafactory/model/model_utils/rope.py
|
| 81 |
+
src/llamafactory/model/model_utils/unsloth.py
|
| 82 |
+
src/llamafactory/model/model_utils/valuehead.py
|
| 83 |
+
src/llamafactory/model/model_utils/visual.py
|
| 84 |
+
src/llamafactory/third_party/__init__.py
|
| 85 |
+
src/llamafactory/third_party/muon/__init__.py
|
| 86 |
+
src/llamafactory/third_party/muon/muon.py
|
| 87 |
+
src/llamafactory/train/__init__.py
|
| 88 |
+
src/llamafactory/train/callbacks.py
|
| 89 |
+
src/llamafactory/train/fp8_utils.py
|
| 90 |
+
src/llamafactory/train/test_utils.py
|
| 91 |
+
src/llamafactory/train/trainer_utils.py
|
| 92 |
+
src/llamafactory/train/tuner.py
|
| 93 |
+
src/llamafactory/train/dpo/__init__.py
|
| 94 |
+
src/llamafactory/train/dpo/trainer.py
|
| 95 |
+
src/llamafactory/train/dpo/workflow.py
|
| 96 |
+
src/llamafactory/train/ksft/__init__.py
|
| 97 |
+
src/llamafactory/train/ksft/workflow.py
|
| 98 |
+
src/llamafactory/train/kto/__init__.py
|
| 99 |
+
src/llamafactory/train/kto/trainer.py
|
| 100 |
+
src/llamafactory/train/kto/workflow.py
|
| 101 |
+
src/llamafactory/train/mca/__init__.py
|
| 102 |
+
src/llamafactory/train/mca/trainer.py
|
| 103 |
+
src/llamafactory/train/mca/workflow.py
|
| 104 |
+
src/llamafactory/train/ppo/__init__.py
|
| 105 |
+
src/llamafactory/train/ppo/ppo_utils.py
|
| 106 |
+
src/llamafactory/train/ppo/trainer.py
|
| 107 |
+
src/llamafactory/train/ppo/workflow.py
|
| 108 |
+
src/llamafactory/train/pt/__init__.py
|
| 109 |
+
src/llamafactory/train/pt/trainer.py
|
| 110 |
+
src/llamafactory/train/pt/workflow.py
|
| 111 |
+
src/llamafactory/train/rm/__init__.py
|
| 112 |
+
src/llamafactory/train/rm/metric.py
|
| 113 |
+
src/llamafactory/train/rm/trainer.py
|
| 114 |
+
src/llamafactory/train/rm/workflow.py
|
| 115 |
+
src/llamafactory/train/sft/__init__.py
|
| 116 |
+
src/llamafactory/train/sft/metric.py
|
| 117 |
+
src/llamafactory/train/sft/trainer.py
|
| 118 |
+
src/llamafactory/train/sft/workflow.py
|
| 119 |
+
src/llamafactory/v1/__init__.py
|
| 120 |
+
src/llamafactory/v1/launcher.py
|
| 121 |
+
src/llamafactory/v1/config/__init__.py
|
| 122 |
+
src/llamafactory/v1/config/data_args.py
|
| 123 |
+
src/llamafactory/v1/config/model_args.py
|
| 124 |
+
src/llamafactory/v1/config/parser.py
|
| 125 |
+
src/llamafactory/v1/config/sample_args.py
|
| 126 |
+
src/llamafactory/v1/config/training_args.py
|
| 127 |
+
src/llamafactory/v1/core/__init__.py
|
| 128 |
+
src/llamafactory/v1/core/base_trainer.py
|
| 129 |
+
src/llamafactory/v1/core/chat_sampler.py
|
| 130 |
+
src/llamafactory/v1/core/data_engine.py
|
| 131 |
+
src/llamafactory/v1/core/model_engine.py
|
| 132 |
+
src/llamafactory/v1/plugins/__init__.py
|
| 133 |
+
src/llamafactory/v1/plugins/data_plugins/__init__.py
|
| 134 |
+
src/llamafactory/v1/plugins/data_plugins/converter.py
|
| 135 |
+
src/llamafactory/v1/plugins/data_plugins/loader.py
|
| 136 |
+
src/llamafactory/v1/plugins/data_plugins/template.py
|
| 137 |
+
src/llamafactory/v1/plugins/model_plugins/__init__.py
|
| 138 |
+
src/llamafactory/v1/plugins/model_plugins/added_token.py
|
| 139 |
+
src/llamafactory/v1/plugins/model_plugins/peft.py
|
| 140 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py
|
| 141 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/constants.py
|
| 142 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
|
| 143 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py
|
| 144 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py
|
| 145 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py
|
| 146 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py
|
| 147 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py
|
| 148 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py
|
| 149 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py
|
| 150 |
+
src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py
|
| 151 |
+
src/llamafactory/v1/plugins/sampler_plugins/__init__.py
|
| 152 |
+
src/llamafactory/v1/plugins/sampler_plugins/vllm.py
|
| 153 |
+
src/llamafactory/v1/plugins/trainer_plugins/__init__.py
|
| 154 |
+
src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py
|
| 155 |
+
src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py
|
| 156 |
+
src/llamafactory/v1/trainers/__init__.py
|
| 157 |
+
src/llamafactory/v1/trainers/dpo_trainer.py
|
| 158 |
+
src/llamafactory/v1/trainers/rm_trainer.py
|
| 159 |
+
src/llamafactory/v1/trainers/sft_trainer.py
|
| 160 |
+
src/llamafactory/webui/__init__.py
|
| 161 |
+
src/llamafactory/webui/chatter.py
|
| 162 |
+
src/llamafactory/webui/common.py
|
| 163 |
+
src/llamafactory/webui/control.py
|
| 164 |
+
src/llamafactory/webui/css.py
|
| 165 |
+
src/llamafactory/webui/engine.py
|
| 166 |
+
src/llamafactory/webui/interface.py
|
| 167 |
+
src/llamafactory/webui/locales.py
|
| 168 |
+
src/llamafactory/webui/manager.py
|
| 169 |
+
src/llamafactory/webui/runner.py
|
| 170 |
+
src/llamafactory/webui/components/__init__.py
|
| 171 |
+
src/llamafactory/webui/components/chatbot.py
|
| 172 |
+
src/llamafactory/webui/components/data.py
|
| 173 |
+
src/llamafactory/webui/components/eval.py
|
| 174 |
+
src/llamafactory/webui/components/export.py
|
| 175 |
+
src/llamafactory/webui/components/footer.py
|
| 176 |
+
src/llamafactory/webui/components/infer.py
|
| 177 |
+
src/llamafactory/webui/components/top.py
|
| 178 |
+
src/llamafactory/webui/components/train.py
|
llamafactory.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
llamafactory.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
llamafactory-cli = llamafactory.cli:main
|
| 3 |
+
lmf = llamafactory.cli:main
|
llamafactory.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets<=4.0.0,>=2.16.0
|
| 2 |
+
accelerate<=1.11.0,>=1.3.0
|
| 3 |
+
peft<=0.17.1,>=0.14.0
|
| 4 |
+
trl<=0.9.6,>=0.8.6
|
| 5 |
+
gradio<=5.45.0,>=4.38.0
|
| 6 |
+
matplotlib>=3.7.0
|
| 7 |
+
tyro<0.9.0
|
| 8 |
+
einops
|
| 9 |
+
numpy<2.0.0
|
| 10 |
+
pandas>=2.0.0
|
| 11 |
+
scipy
|
| 12 |
+
sentencepiece
|
| 13 |
+
tiktoken
|
| 14 |
+
modelscope>=1.14.0
|
| 15 |
+
hf-transfer
|
| 16 |
+
safetensors<=0.5.3
|
| 17 |
+
fire
|
| 18 |
+
omegaconf
|
| 19 |
+
packaging
|
| 20 |
+
protobuf
|
| 21 |
+
pyyaml
|
| 22 |
+
pydantic<=2.10.6
|
| 23 |
+
uvicorn
|
| 24 |
+
fastapi
|
| 25 |
+
sse-starlette
|
| 26 |
+
av
|
| 27 |
+
librosa
|
| 28 |
+
propcache!=0.4.0
|
| 29 |
+
|
| 30 |
+
[:python_version < "3.10"]
|
| 31 |
+
transformers!=4.52.0,<=4.56.2,>=4.49.0
|
| 32 |
+
|
| 33 |
+
[:python_version >= "3.10"]
|
| 34 |
+
transformers!=4.52.0,!=4.57.0,<=4.57.1,>=4.49.0
|
| 35 |
+
|
| 36 |
+
[adam-mini]
|
| 37 |
+
adam-mini
|
| 38 |
+
|
| 39 |
+
[apollo]
|
| 40 |
+
apollo-torch
|
| 41 |
+
|
| 42 |
+
[aqlm]
|
| 43 |
+
aqlm[gpu]>=1.1.0
|
| 44 |
+
|
| 45 |
+
[badam]
|
| 46 |
+
badam>=1.2.1
|
| 47 |
+
|
| 48 |
+
[bitsandbytes]
|
| 49 |
+
bitsandbytes>=0.39.0
|
| 50 |
+
|
| 51 |
+
[deepspeed]
|
| 52 |
+
deepspeed<=0.16.9,>=0.10.0
|
| 53 |
+
|
| 54 |
+
[dev]
|
| 55 |
+
pre-commit
|
| 56 |
+
ruff
|
| 57 |
+
pytest
|
| 58 |
+
build
|
| 59 |
+
|
| 60 |
+
[eetq]
|
| 61 |
+
eetq
|
| 62 |
+
|
| 63 |
+
[fp8]
|
| 64 |
+
torchao>=0.8.0
|
| 65 |
+
accelerate>=1.10.0
|
| 66 |
+
|
| 67 |
+
[fp8-all]
|
| 68 |
+
torchao>=0.8.0
|
| 69 |
+
transformer_engine[pytorch]>=2.0.0
|
| 70 |
+
accelerate>=1.10.0
|
| 71 |
+
|
| 72 |
+
[fp8-te]
|
| 73 |
+
transformer_engine[pytorch]>=2.0.0
|
| 74 |
+
accelerate>=1.10.0
|
| 75 |
+
|
| 76 |
+
[galore]
|
| 77 |
+
galore-torch
|
| 78 |
+
|
| 79 |
+
[gptq]
|
| 80 |
+
optimum>=1.24.0
|
| 81 |
+
gptqmodel>=2.0.0
|
| 82 |
+
|
| 83 |
+
[hqq]
|
| 84 |
+
hqq
|
| 85 |
+
|
| 86 |
+
[liger-kernel]
|
| 87 |
+
liger-kernel>=0.5.5
|
| 88 |
+
|
| 89 |
+
[metrics]
|
| 90 |
+
nltk
|
| 91 |
+
jieba
|
| 92 |
+
rouge-chinese
|
| 93 |
+
|
| 94 |
+
[minicpm_v]
|
| 95 |
+
soundfile
|
| 96 |
+
torchvision
|
| 97 |
+
torchaudio
|
| 98 |
+
vector_quantize_pytorch
|
| 99 |
+
vocos
|
| 100 |
+
msgpack
|
| 101 |
+
referencing
|
| 102 |
+
jsonschema_specifications
|
| 103 |
+
|
| 104 |
+
[openmind]
|
| 105 |
+
openmind
|
| 106 |
+
|
| 107 |
+
[sglang]
|
| 108 |
+
sglang[srt]>=0.4.5
|
| 109 |
+
transformers==4.51.1
|
| 110 |
+
|
| 111 |
+
[swanlab]
|
| 112 |
+
swanlab
|
| 113 |
+
|
| 114 |
+
[torch]
|
| 115 |
+
torch>=2.0.0
|
| 116 |
+
torchvision>=0.15.0
|
| 117 |
+
|
| 118 |
+
[torch-npu]
|
| 119 |
+
torch==2.7.1
|
| 120 |
+
torch-npu==2.7.1
|
| 121 |
+
torchvision==0.22.1
|
| 122 |
+
decorator
|
| 123 |
+
|
| 124 |
+
[vllm]
|
| 125 |
+
vllm<=0.11.0,>=0.4.3
|
llamafactory.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
llamafactory
|
llamafactory/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
r"""Efficient fine-tuning of large language models.
|
| 16 |
+
|
| 17 |
+
Level:
|
| 18 |
+
api, webui > chat, eval, train > data, model > hparams > extras
|
| 19 |
+
|
| 20 |
+
Disable version checking: DISABLE_VERSION_CHECK=1
|
| 21 |
+
Enable VRAM recording: RECORD_VRAM=1
|
| 22 |
+
Force using torchrun: FORCE_TORCHRUN=1
|
| 23 |
+
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
| 24 |
+
Use modelscope: USE_MODELSCOPE_HUB=1
|
| 25 |
+
Use openmind: USE_OPENMIND_HUB=1
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from .extras.env import VERSION
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__version__ = VERSION
|
llamafactory/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (612 Bytes). View file
|
|
|
llamafactory/__pycache__/cli.cpython-312.pyc
ADDED
|
Binary file (581 Bytes). View file
|
|
|
llamafactory/__pycache__/launcher.cpython-312.pyc
ADDED
|
Binary file (6.29 kB). View file
|
|
|
llamafactory/api/__init__.py
ADDED
|
File without changes
|
llamafactory/api/app.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
from functools import partial
|
| 19 |
+
from typing import Annotated, Optional
|
| 20 |
+
|
| 21 |
+
from ..chat import ChatModel
|
| 22 |
+
from ..extras.constants import EngineName
|
| 23 |
+
from ..extras.misc import torch_gc
|
| 24 |
+
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
| 25 |
+
from .chat import (
|
| 26 |
+
create_chat_completion_response,
|
| 27 |
+
create_score_evaluation_response,
|
| 28 |
+
create_stream_chat_completion_response,
|
| 29 |
+
)
|
| 30 |
+
from .protocol import (
|
| 31 |
+
ChatCompletionRequest,
|
| 32 |
+
ChatCompletionResponse,
|
| 33 |
+
ModelCard,
|
| 34 |
+
ModelList,
|
| 35 |
+
ScoreEvaluationRequest,
|
| 36 |
+
ScoreEvaluationResponse,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_fastapi_available():
|
| 41 |
+
from fastapi import Depends, FastAPI, HTTPException, status
|
| 42 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 43 |
+
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if is_starlette_available():
|
| 47 |
+
from sse_starlette import EventSourceResponse
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_uvicorn_available():
|
| 51 |
+
import uvicorn
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
async def sweeper() -> None:
|
| 55 |
+
while True:
|
| 56 |
+
torch_gc()
|
| 57 |
+
await asyncio.sleep(300)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@asynccontextmanager
|
| 61 |
+
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
| 62 |
+
if chat_model.engine.name == EngineName.HF:
|
| 63 |
+
asyncio.create_task(sweeper())
|
| 64 |
+
|
| 65 |
+
yield
|
| 66 |
+
torch_gc()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
| 70 |
+
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
| 71 |
+
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
| 72 |
+
app.add_middleware(
|
| 73 |
+
CORSMiddleware,
|
| 74 |
+
allow_origins=["*"],
|
| 75 |
+
allow_credentials=True,
|
| 76 |
+
allow_methods=["*"],
|
| 77 |
+
allow_headers=["*"],
|
| 78 |
+
)
|
| 79 |
+
api_key = os.getenv("API_KEY")
|
| 80 |
+
security = HTTPBearer(auto_error=False)
|
| 81 |
+
|
| 82 |
+
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
| 83 |
+
if api_key and (auth is None or auth.credentials != api_key):
|
| 84 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
| 85 |
+
|
| 86 |
+
@app.get(
|
| 87 |
+
"/v1/models",
|
| 88 |
+
response_model=ModelList,
|
| 89 |
+
status_code=status.HTTP_200_OK,
|
| 90 |
+
dependencies=[Depends(verify_api_key)],
|
| 91 |
+
)
|
| 92 |
+
async def list_models():
|
| 93 |
+
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
| 94 |
+
return ModelList(data=[model_card])
|
| 95 |
+
|
| 96 |
+
@app.post(
|
| 97 |
+
"/v1/chat/completions",
|
| 98 |
+
response_model=ChatCompletionResponse,
|
| 99 |
+
status_code=status.HTTP_200_OK,
|
| 100 |
+
dependencies=[Depends(verify_api_key)],
|
| 101 |
+
)
|
| 102 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
| 103 |
+
if not chat_model.engine.can_generate:
|
| 104 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
| 105 |
+
|
| 106 |
+
if request.stream:
|
| 107 |
+
generate = create_stream_chat_completion_response(request, chat_model)
|
| 108 |
+
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
| 109 |
+
else:
|
| 110 |
+
return await create_chat_completion_response(request, chat_model)
|
| 111 |
+
|
| 112 |
+
@app.post(
|
| 113 |
+
"/v1/score/evaluation",
|
| 114 |
+
response_model=ScoreEvaluationResponse,
|
| 115 |
+
status_code=status.HTTP_200_OK,
|
| 116 |
+
dependencies=[Depends(verify_api_key)],
|
| 117 |
+
)
|
| 118 |
+
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
| 119 |
+
if chat_model.engine.can_generate:
|
| 120 |
+
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
| 121 |
+
|
| 122 |
+
return await create_score_evaluation_response(request, chat_model)
|
| 123 |
+
|
| 124 |
+
return app
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def run_api() -> None:
|
| 128 |
+
chat_model = ChatModel()
|
| 129 |
+
app = create_app(chat_model)
|
| 130 |
+
api_host = os.getenv("API_HOST", "0.0.0.0")
|
| 131 |
+
api_port = int(os.getenv("API_PORT", "8000"))
|
| 132 |
+
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
| 133 |
+
uvicorn.run(app, host=api_host, port=api_port)
|
llamafactory/api/chat.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import io
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import uuid
|
| 21 |
+
from collections.abc import AsyncGenerator
|
| 22 |
+
from typing import TYPE_CHECKING, Optional
|
| 23 |
+
|
| 24 |
+
from ..data import Role as DataRole
|
| 25 |
+
from ..extras import logging
|
| 26 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
| 27 |
+
from ..extras.misc import is_env_enabled
|
| 28 |
+
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
| 29 |
+
from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
|
| 30 |
+
from .protocol import (
|
| 31 |
+
ChatCompletionMessage,
|
| 32 |
+
ChatCompletionResponse,
|
| 33 |
+
ChatCompletionResponseChoice,
|
| 34 |
+
ChatCompletionResponseUsage,
|
| 35 |
+
ChatCompletionStreamResponse,
|
| 36 |
+
ChatCompletionStreamResponseChoice,
|
| 37 |
+
Finish,
|
| 38 |
+
Function,
|
| 39 |
+
FunctionCall,
|
| 40 |
+
Role,
|
| 41 |
+
ScoreEvaluationResponse,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_fastapi_available():
|
| 46 |
+
from fastapi import HTTPException, status
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_pillow_available():
|
| 50 |
+
from PIL import Image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if is_requests_available():
|
| 54 |
+
import requests
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if TYPE_CHECKING:
|
| 58 |
+
from ..chat import ChatModel
|
| 59 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 60 |
+
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
logger = logging.get_logger(__name__)
|
| 64 |
+
ROLE_MAPPING = {
|
| 65 |
+
Role.USER: DataRole.USER.value,
|
| 66 |
+
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
| 67 |
+
Role.SYSTEM: DataRole.SYSTEM.value,
|
| 68 |
+
Role.FUNCTION: DataRole.FUNCTION.value,
|
| 69 |
+
Role.TOOL: DataRole.OBSERVATION.value,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _process_request(
|
| 74 |
+
request: "ChatCompletionRequest",
|
| 75 |
+
) -> tuple[
|
| 76 |
+
list[dict[str, str]],
|
| 77 |
+
Optional[str],
|
| 78 |
+
Optional[str],
|
| 79 |
+
Optional[list["ImageInput"]],
|
| 80 |
+
Optional[list["VideoInput"]],
|
| 81 |
+
Optional[list["AudioInput"]],
|
| 82 |
+
]:
|
| 83 |
+
if is_env_enabled("API_VERBOSE", "1"):
|
| 84 |
+
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
| 85 |
+
|
| 86 |
+
if len(request.messages) == 0:
|
| 87 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
| 88 |
+
|
| 89 |
+
if request.messages[0].role == Role.SYSTEM:
|
| 90 |
+
content = request.messages.pop(0).content
|
| 91 |
+
system = content[0].text if isinstance(content, list) else content
|
| 92 |
+
else:
|
| 93 |
+
system = None
|
| 94 |
+
|
| 95 |
+
if len(request.messages) % 2 == 0:
|
| 96 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
| 97 |
+
|
| 98 |
+
input_messages = []
|
| 99 |
+
images, videos, audios = [], [], []
|
| 100 |
+
for i, message in enumerate(request.messages):
|
| 101 |
+
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
| 102 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
| 103 |
+
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
| 104 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
| 105 |
+
|
| 106 |
+
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
| 107 |
+
tool_calls = [
|
| 108 |
+
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
| 109 |
+
for tool_call in message.tool_calls
|
| 110 |
+
]
|
| 111 |
+
content = json.dumps(tool_calls, ensure_ascii=False)
|
| 112 |
+
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
| 113 |
+
elif isinstance(message.content, list):
|
| 114 |
+
text_content = ""
|
| 115 |
+
for input_item in message.content:
|
| 116 |
+
if input_item.type == "text":
|
| 117 |
+
text_content += input_item.text
|
| 118 |
+
elif input_item.type == "image_url":
|
| 119 |
+
text_content += IMAGE_PLACEHOLDER
|
| 120 |
+
image_url = input_item.image_url.url
|
| 121 |
+
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
| 122 |
+
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
| 123 |
+
elif os.path.isfile(image_url): # local file
|
| 124 |
+
check_lfi_path(image_url)
|
| 125 |
+
image_stream = open(image_url, "rb")
|
| 126 |
+
else: # web uri
|
| 127 |
+
check_ssrf_url(image_url)
|
| 128 |
+
image_stream = requests.get(image_url, stream=True).raw
|
| 129 |
+
|
| 130 |
+
images.append(Image.open(image_stream).convert("RGB"))
|
| 131 |
+
elif input_item.type == "video_url":
|
| 132 |
+
text_content += VIDEO_PLACEHOLDER
|
| 133 |
+
video_url = input_item.video_url.url
|
| 134 |
+
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
|
| 135 |
+
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
|
| 136 |
+
elif os.path.isfile(video_url): # local file
|
| 137 |
+
check_lfi_path(video_url)
|
| 138 |
+
video_stream = video_url
|
| 139 |
+
else: # web uri
|
| 140 |
+
check_ssrf_url(video_url)
|
| 141 |
+
video_stream = requests.get(video_url, stream=True).raw
|
| 142 |
+
|
| 143 |
+
videos.append(video_stream)
|
| 144 |
+
elif input_item.type == "audio_url":
|
| 145 |
+
text_content += AUDIO_PLACEHOLDER
|
| 146 |
+
audio_url = input_item.audio_url.url
|
| 147 |
+
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
|
| 148 |
+
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
|
| 149 |
+
elif os.path.isfile(audio_url): # local file
|
| 150 |
+
check_lfi_path(audio_url)
|
| 151 |
+
audio_stream = audio_url
|
| 152 |
+
else: # web uri
|
| 153 |
+
check_ssrf_url(audio_url)
|
| 154 |
+
audio_stream = requests.get(audio_url, stream=True).raw
|
| 155 |
+
|
| 156 |
+
audios.append(audio_stream)
|
| 157 |
+
else:
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
|
| 163 |
+
else:
|
| 164 |
+
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
| 165 |
+
|
| 166 |
+
tool_list = request.tools
|
| 167 |
+
if isinstance(tool_list, list) and len(tool_list):
|
| 168 |
+
try:
|
| 169 |
+
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
| 170 |
+
except json.JSONDecodeError:
|
| 171 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
| 172 |
+
else:
|
| 173 |
+
tools = None
|
| 174 |
+
|
| 175 |
+
return input_messages, system, tools, images or None, videos or None, audios or None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _create_stream_chat_completion_chunk(
|
| 179 |
+
completion_id: str,
|
| 180 |
+
model: str,
|
| 181 |
+
delta: "ChatCompletionMessage",
|
| 182 |
+
index: Optional[int] = 0,
|
| 183 |
+
finish_reason: Optional["Finish"] = None,
|
| 184 |
+
) -> str:
|
| 185 |
+
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
| 186 |
+
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
| 187 |
+
return jsonify(chunk)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
async def create_chat_completion_response(
|
| 191 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
| 192 |
+
) -> "ChatCompletionResponse":
|
| 193 |
+
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
| 194 |
+
input_messages, system, tools, images, videos, audios = _process_request(request)
|
| 195 |
+
responses = await chat_model.achat(
|
| 196 |
+
input_messages,
|
| 197 |
+
system,
|
| 198 |
+
tools,
|
| 199 |
+
images,
|
| 200 |
+
videos,
|
| 201 |
+
audios,
|
| 202 |
+
do_sample=request.do_sample,
|
| 203 |
+
temperature=request.temperature,
|
| 204 |
+
top_p=request.top_p,
|
| 205 |
+
max_new_tokens=request.max_tokens,
|
| 206 |
+
num_return_sequences=request.n,
|
| 207 |
+
repetition_penalty=request.presence_penalty,
|
| 208 |
+
stop=request.stop,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
prompt_length, response_length = 0, 0
|
| 212 |
+
choices = []
|
| 213 |
+
for i, response in enumerate(responses):
|
| 214 |
+
if tools:
|
| 215 |
+
result = chat_model.engine.template.extract_tool(response.response_text)
|
| 216 |
+
else:
|
| 217 |
+
result = response.response_text
|
| 218 |
+
|
| 219 |
+
if isinstance(result, list):
|
| 220 |
+
tool_calls = []
|
| 221 |
+
for tool in result:
|
| 222 |
+
function = Function(name=tool.name, arguments=tool.arguments)
|
| 223 |
+
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
| 224 |
+
|
| 225 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
| 226 |
+
finish_reason = Finish.TOOL
|
| 227 |
+
else:
|
| 228 |
+
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
| 229 |
+
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
| 230 |
+
|
| 231 |
+
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
| 232 |
+
prompt_length = response.prompt_length
|
| 233 |
+
response_length += response.response_length
|
| 234 |
+
|
| 235 |
+
usage = ChatCompletionResponseUsage(
|
| 236 |
+
prompt_tokens=prompt_length,
|
| 237 |
+
completion_tokens=response_length,
|
| 238 |
+
total_tokens=prompt_length + response_length,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
async def create_stream_chat_completion_response(
|
| 245 |
+
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
| 246 |
+
) -> AsyncGenerator[str, None]:
|
| 247 |
+
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
| 248 |
+
input_messages, system, tools, images, videos, audios = _process_request(request)
|
| 249 |
+
if tools:
|
| 250 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
| 251 |
+
|
| 252 |
+
if request.n > 1:
|
| 253 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
|
| 254 |
+
|
| 255 |
+
yield _create_stream_chat_completion_chunk(
|
| 256 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
| 257 |
+
)
|
| 258 |
+
async for new_token in chat_model.astream_chat(
|
| 259 |
+
input_messages,
|
| 260 |
+
system,
|
| 261 |
+
tools,
|
| 262 |
+
images,
|
| 263 |
+
videos,
|
| 264 |
+
audios,
|
| 265 |
+
do_sample=request.do_sample,
|
| 266 |
+
temperature=request.temperature,
|
| 267 |
+
top_p=request.top_p,
|
| 268 |
+
max_new_tokens=request.max_tokens,
|
| 269 |
+
repetition_penalty=request.presence_penalty,
|
| 270 |
+
stop=request.stop,
|
| 271 |
+
):
|
| 272 |
+
if len(new_token) != 0:
|
| 273 |
+
yield _create_stream_chat_completion_chunk(
|
| 274 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
yield _create_stream_chat_completion_chunk(
|
| 278 |
+
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
| 279 |
+
)
|
| 280 |
+
yield "[DONE]"
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
async def create_score_evaluation_response(
|
| 284 |
+
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
| 285 |
+
) -> "ScoreEvaluationResponse":
|
| 286 |
+
score_id = f"scoreval-{uuid.uuid4().hex}"
|
| 287 |
+
if len(request.messages) == 0:
|
| 288 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
| 289 |
+
|
| 290 |
+
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
| 291 |
+
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
llamafactory/api/common.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ipaddress
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import socket
|
| 19 |
+
from typing import TYPE_CHECKING, Any
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
|
| 22 |
+
from ..extras.misc import is_env_enabled
|
| 23 |
+
from ..extras.packages import is_fastapi_available
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_fastapi_available():
|
| 27 |
+
from fastapi import HTTPException, status
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from pydantic import BaseModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
|
| 35 |
+
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def dictify(data: "BaseModel") -> dict[str, Any]:
|
| 39 |
+
try: # pydantic v2
|
| 40 |
+
return data.model_dump(exclude_unset=True)
|
| 41 |
+
except AttributeError: # pydantic v1
|
| 42 |
+
return data.dict(exclude_unset=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def jsonify(data: "BaseModel") -> str:
|
| 46 |
+
try: # pydantic v2
|
| 47 |
+
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
| 48 |
+
except AttributeError: # pydantic v1
|
| 49 |
+
return data.json(exclude_unset=True, ensure_ascii=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def check_lfi_path(path: str) -> None:
|
| 53 |
+
"""Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
|
| 54 |
+
if not ALLOW_LOCAL_FILES:
|
| 55 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
|
| 59 |
+
real_path = os.path.realpath(path)
|
| 60 |
+
safe_path = os.path.realpath(SAFE_MEDIA_PATH)
|
| 61 |
+
|
| 62 |
+
if not real_path.startswith(safe_path):
|
| 63 |
+
raise HTTPException(
|
| 64 |
+
status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
|
| 65 |
+
)
|
| 66 |
+
except Exception:
|
| 67 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def check_ssrf_url(url: str) -> None:
|
| 71 |
+
"""Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
|
| 72 |
+
try:
|
| 73 |
+
parsed_url = urlparse(url)
|
| 74 |
+
if parsed_url.scheme not in ["http", "https"]:
|
| 75 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
|
| 76 |
+
|
| 77 |
+
hostname = parsed_url.hostname
|
| 78 |
+
if not hostname:
|
| 79 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
|
| 80 |
+
|
| 81 |
+
ip_info = socket.getaddrinfo(hostname, parsed_url.port)
|
| 82 |
+
ip_address_str = ip_info[0][4][0]
|
| 83 |
+
ip = ipaddress.ip_address(ip_address_str)
|
| 84 |
+
|
| 85 |
+
if not ip.is_global:
|
| 86 |
+
raise HTTPException(
|
| 87 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 88 |
+
detail="Access to private or reserved IP addresses is not allowed.",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
except socket.gaierror:
|
| 92 |
+
raise HTTPException(
|
| 93 |
+
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
|
| 94 |
+
)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
|
llamafactory/api/protocol.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
from enum import Enum, unique
|
| 17 |
+
from typing import Any, Optional, Union
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel, Field
|
| 20 |
+
from typing_extensions import Literal
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@unique
|
| 24 |
+
class Role(str, Enum):
|
| 25 |
+
USER = "user"
|
| 26 |
+
ASSISTANT = "assistant"
|
| 27 |
+
SYSTEM = "system"
|
| 28 |
+
FUNCTION = "function"
|
| 29 |
+
TOOL = "tool"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@unique
|
| 33 |
+
class Finish(str, Enum):
|
| 34 |
+
STOP = "stop"
|
| 35 |
+
LENGTH = "length"
|
| 36 |
+
TOOL = "tool_calls"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ModelCard(BaseModel):
|
| 40 |
+
id: str
|
| 41 |
+
object: Literal["model"] = "model"
|
| 42 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 43 |
+
owned_by: Literal["owner"] = "owner"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ModelList(BaseModel):
|
| 47 |
+
object: Literal["list"] = "list"
|
| 48 |
+
data: list[ModelCard] = []
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Function(BaseModel):
|
| 52 |
+
name: str
|
| 53 |
+
arguments: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FunctionDefinition(BaseModel):
|
| 57 |
+
name: str
|
| 58 |
+
description: str
|
| 59 |
+
parameters: dict[str, Any]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FunctionAvailable(BaseModel):
|
| 63 |
+
type: Literal["function", "code_interpreter"] = "function"
|
| 64 |
+
function: Optional[FunctionDefinition] = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FunctionCall(BaseModel):
|
| 68 |
+
id: str
|
| 69 |
+
type: Literal["function"] = "function"
|
| 70 |
+
function: Function
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class URL(BaseModel):
|
| 74 |
+
url: str
|
| 75 |
+
detail: Literal["auto", "low", "high"] = "auto"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MultimodalInputItem(BaseModel):
|
| 79 |
+
type: Literal["text", "image_url", "video_url", "audio_url"]
|
| 80 |
+
text: Optional[str] = None
|
| 81 |
+
image_url: Optional[URL] = None
|
| 82 |
+
video_url: Optional[URL] = None
|
| 83 |
+
audio_url: Optional[URL] = None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ChatMessage(BaseModel):
|
| 87 |
+
role: Role
|
| 88 |
+
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
| 89 |
+
tool_calls: Optional[list[FunctionCall]] = None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ChatCompletionMessage(BaseModel):
|
| 93 |
+
role: Optional[Role] = None
|
| 94 |
+
content: Optional[str] = None
|
| 95 |
+
tool_calls: Optional[list[FunctionCall]] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ChatCompletionRequest(BaseModel):
|
| 99 |
+
model: str
|
| 100 |
+
messages: list[ChatMessage]
|
| 101 |
+
tools: Optional[list[FunctionAvailable]] = None
|
| 102 |
+
do_sample: Optional[bool] = None
|
| 103 |
+
temperature: Optional[float] = None
|
| 104 |
+
top_p: Optional[float] = None
|
| 105 |
+
n: int = 1
|
| 106 |
+
presence_penalty: Optional[float] = None
|
| 107 |
+
max_tokens: Optional[int] = None
|
| 108 |
+
stop: Optional[Union[str, list[str]]] = None
|
| 109 |
+
stream: bool = False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class ChatCompletionResponseChoice(BaseModel):
|
| 113 |
+
index: int
|
| 114 |
+
message: ChatCompletionMessage
|
| 115 |
+
finish_reason: Finish
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ChatCompletionStreamResponseChoice(BaseModel):
|
| 119 |
+
index: int
|
| 120 |
+
delta: ChatCompletionMessage
|
| 121 |
+
finish_reason: Optional[Finish] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ChatCompletionResponseUsage(BaseModel):
|
| 125 |
+
prompt_tokens: int
|
| 126 |
+
completion_tokens: int
|
| 127 |
+
total_tokens: int
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ChatCompletionResponse(BaseModel):
|
| 131 |
+
id: str
|
| 132 |
+
object: Literal["chat.completion"] = "chat.completion"
|
| 133 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 134 |
+
model: str
|
| 135 |
+
choices: list[ChatCompletionResponseChoice]
|
| 136 |
+
usage: ChatCompletionResponseUsage
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ChatCompletionStreamResponse(BaseModel):
|
| 140 |
+
id: str
|
| 141 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
| 142 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
| 143 |
+
model: str
|
| 144 |
+
choices: list[ChatCompletionStreamResponseChoice]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ScoreEvaluationRequest(BaseModel):
|
| 148 |
+
model: str
|
| 149 |
+
messages: list[str]
|
| 150 |
+
max_length: Optional[int] = None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ScoreEvaluationResponse(BaseModel):
|
| 154 |
+
id: str
|
| 155 |
+
object: Literal["score.evaluation"] = "score.evaluation"
|
| 156 |
+
model: str
|
| 157 |
+
scores: list[float]
|
llamafactory/chat/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .base_engine import BaseEngine
|
| 16 |
+
from .chat_model import ChatModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["BaseEngine", "ChatModel"]
|
llamafactory/chat/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
llamafactory/chat/__pycache__/base_engine.cpython-312.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
llamafactory/chat/__pycache__/chat_model.cpython-312.pyc
ADDED
|
Binary file (9.27 kB). View file
|
|
|
llamafactory/chat/base_engine.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from collections.abc import AsyncGenerator
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 23 |
+
from vllm import AsyncLLMEngine
|
| 24 |
+
|
| 25 |
+
from ..data import Template
|
| 26 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 27 |
+
from ..extras.constants import EngineName
|
| 28 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class Response:
|
| 33 |
+
response_text: str
|
| 34 |
+
response_length: int
|
| 35 |
+
prompt_length: int
|
| 36 |
+
finish_reason: Literal["stop", "length"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BaseEngine(ABC):
|
| 40 |
+
r"""Base class for inference engine of chat models.
|
| 41 |
+
|
| 42 |
+
Must implements async methods: chat(), stream_chat() and get_scores().
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
name: "EngineName"
|
| 46 |
+
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
| 47 |
+
tokenizer: "PreTrainedTokenizer"
|
| 48 |
+
can_generate: bool
|
| 49 |
+
template: "Template"
|
| 50 |
+
generating_args: dict[str, Any]
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
model_args: "ModelArguments",
|
| 56 |
+
data_args: "DataArguments",
|
| 57 |
+
finetuning_args: "FinetuningArguments",
|
| 58 |
+
generating_args: "GeneratingArguments",
|
| 59 |
+
) -> None:
|
| 60 |
+
r"""Initialize an inference engine."""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
async def chat(
|
| 65 |
+
self,
|
| 66 |
+
messages: list[dict[str, str]],
|
| 67 |
+
system: Optional[str] = None,
|
| 68 |
+
tools: Optional[str] = None,
|
| 69 |
+
images: Optional[list["ImageInput"]] = None,
|
| 70 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 71 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 72 |
+
**input_kwargs,
|
| 73 |
+
) -> list["Response"]:
|
| 74 |
+
r"""Get a list of responses of the chat model."""
|
| 75 |
+
...
|
| 76 |
+
|
| 77 |
+
@abstractmethod
|
| 78 |
+
async def stream_chat(
|
| 79 |
+
self,
|
| 80 |
+
messages: list[dict[str, str]],
|
| 81 |
+
system: Optional[str] = None,
|
| 82 |
+
tools: Optional[str] = None,
|
| 83 |
+
images: Optional[list["ImageInput"]] = None,
|
| 84 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 85 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 86 |
+
**input_kwargs,
|
| 87 |
+
) -> AsyncGenerator[str, None]:
|
| 88 |
+
r"""Get the response token-by-token of the chat model."""
|
| 89 |
+
...
|
| 90 |
+
|
| 91 |
+
@abstractmethod
|
| 92 |
+
async def get_scores(
|
| 93 |
+
self,
|
| 94 |
+
batch_input: list[str],
|
| 95 |
+
**input_kwargs,
|
| 96 |
+
) -> list[float]:
|
| 97 |
+
r"""Get a list of scores of the reward model."""
|
| 98 |
+
...
|
llamafactory/chat/chat_model.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 THUDM and the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# This code is inspired by the THUDM's ChatGLM implementation.
|
| 4 |
+
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import os
|
| 20 |
+
from collections.abc import AsyncGenerator, Generator
|
| 21 |
+
from threading import Thread
|
| 22 |
+
from typing import TYPE_CHECKING, Any, Optional
|
| 23 |
+
|
| 24 |
+
from ..extras.constants import EngineName
|
| 25 |
+
from ..extras.misc import torch_gc
|
| 26 |
+
from ..hparams import get_infer_args
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING:
|
| 30 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 31 |
+
from .base_engine import BaseEngine, Response
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
| 35 |
+
asyncio.set_event_loop(loop)
|
| 36 |
+
loop.run_forever()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ChatModel:
|
| 40 |
+
r"""General class for chat models. Backed by huggingface or vllm engines.
|
| 41 |
+
|
| 42 |
+
Supports both sync and async methods.
|
| 43 |
+
Sync methods: chat(), stream_chat() and get_scores().
|
| 44 |
+
Async methods: achat(), astream_chat() and aget_scores().
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
| 48 |
+
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
| 49 |
+
|
| 50 |
+
if model_args.infer_backend == EngineName.HF:
|
| 51 |
+
from .hf_engine import HuggingfaceEngine
|
| 52 |
+
|
| 53 |
+
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
| 54 |
+
elif model_args.infer_backend == EngineName.VLLM:
|
| 55 |
+
try:
|
| 56 |
+
from .vllm_engine import VllmEngine
|
| 57 |
+
|
| 58 |
+
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
| 59 |
+
except ImportError as e:
|
| 60 |
+
raise ImportError(
|
| 61 |
+
"vLLM not install, you may need to run `pip install vllm`\n"
|
| 62 |
+
"or try to use HuggingFace backend: --infer_backend huggingface"
|
| 63 |
+
) from e
|
| 64 |
+
elif model_args.infer_backend == EngineName.SGLANG:
|
| 65 |
+
try:
|
| 66 |
+
from .sglang_engine import SGLangEngine
|
| 67 |
+
|
| 68 |
+
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
|
| 69 |
+
except ImportError as e:
|
| 70 |
+
raise ImportError(
|
| 71 |
+
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
| 72 |
+
"or try to use HuggingFace backend: --infer_backend huggingface"
|
| 73 |
+
) from e
|
| 74 |
+
elif model_args.infer_backend == EngineName.KT:
|
| 75 |
+
try:
|
| 76 |
+
from .kt_engine import KTransformersEngine
|
| 77 |
+
|
| 78 |
+
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
|
| 79 |
+
except ImportError as e:
|
| 80 |
+
raise ImportError(
|
| 81 |
+
"KTransformers not install, you may need to run `pip install ktransformers`\n"
|
| 82 |
+
"or try to use HuggingFace backend: --infer_backend huggingface"
|
| 83 |
+
) from e
|
| 84 |
+
else:
|
| 85 |
+
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
| 86 |
+
|
| 87 |
+
self._loop = asyncio.new_event_loop()
|
| 88 |
+
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
| 89 |
+
self._thread.start()
|
| 90 |
+
|
| 91 |
+
def chat(
|
| 92 |
+
self,
|
| 93 |
+
messages: list[dict[str, str]],
|
| 94 |
+
system: Optional[str] = None,
|
| 95 |
+
tools: Optional[str] = None,
|
| 96 |
+
images: Optional[list["ImageInput"]] = None,
|
| 97 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 98 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 99 |
+
**input_kwargs,
|
| 100 |
+
) -> list["Response"]:
|
| 101 |
+
r"""Get a list of responses of the chat model."""
|
| 102 |
+
task = asyncio.run_coroutine_threadsafe(
|
| 103 |
+
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
|
| 104 |
+
)
|
| 105 |
+
return task.result()
|
| 106 |
+
|
| 107 |
+
async def achat(
|
| 108 |
+
self,
|
| 109 |
+
messages: list[dict[str, str]],
|
| 110 |
+
system: Optional[str] = None,
|
| 111 |
+
tools: Optional[str] = None,
|
| 112 |
+
images: Optional[list["ImageInput"]] = None,
|
| 113 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 114 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 115 |
+
**input_kwargs,
|
| 116 |
+
) -> list["Response"]:
|
| 117 |
+
r"""Asynchronously get a list of responses of the chat model."""
|
| 118 |
+
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 119 |
+
|
| 120 |
+
def stream_chat(
|
| 121 |
+
self,
|
| 122 |
+
messages: list[dict[str, str]],
|
| 123 |
+
system: Optional[str] = None,
|
| 124 |
+
tools: Optional[str] = None,
|
| 125 |
+
images: Optional[list["ImageInput"]] = None,
|
| 126 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 127 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 128 |
+
**input_kwargs,
|
| 129 |
+
) -> Generator[str, None, None]:
|
| 130 |
+
r"""Get the response token-by-token of the chat model."""
|
| 131 |
+
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 132 |
+
while True:
|
| 133 |
+
try:
|
| 134 |
+
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
| 135 |
+
yield task.result()
|
| 136 |
+
except StopAsyncIteration:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
async def astream_chat(
|
| 140 |
+
self,
|
| 141 |
+
messages: list[dict[str, str]],
|
| 142 |
+
system: Optional[str] = None,
|
| 143 |
+
tools: Optional[str] = None,
|
| 144 |
+
images: Optional[list["ImageInput"]] = None,
|
| 145 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 146 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 147 |
+
**input_kwargs,
|
| 148 |
+
) -> AsyncGenerator[str, None]:
|
| 149 |
+
r"""Asynchronously get the response token-by-token of the chat model."""
|
| 150 |
+
async for new_token in self.engine.stream_chat(
|
| 151 |
+
messages, system, tools, images, videos, audios, **input_kwargs
|
| 152 |
+
):
|
| 153 |
+
yield new_token
|
| 154 |
+
|
| 155 |
+
def get_scores(
|
| 156 |
+
self,
|
| 157 |
+
batch_input: list[str],
|
| 158 |
+
**input_kwargs,
|
| 159 |
+
) -> list[float]:
|
| 160 |
+
r"""Get a list of scores of the reward model."""
|
| 161 |
+
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
| 162 |
+
return task.result()
|
| 163 |
+
|
| 164 |
+
async def aget_scores(
|
| 165 |
+
self,
|
| 166 |
+
batch_input: list[str],
|
| 167 |
+
**input_kwargs,
|
| 168 |
+
) -> list[float]:
|
| 169 |
+
r"""Asynchronously get a list of scores of the reward model."""
|
| 170 |
+
return await self.engine.get_scores(batch_input, **input_kwargs)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def run_chat() -> None:
|
| 174 |
+
if os.name != "nt":
|
| 175 |
+
try:
|
| 176 |
+
import readline # noqa: F401
|
| 177 |
+
except ImportError:
|
| 178 |
+
print("Install `readline` for a better experience.")
|
| 179 |
+
|
| 180 |
+
chat_model = ChatModel()
|
| 181 |
+
messages = []
|
| 182 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
| 183 |
+
|
| 184 |
+
while True:
|
| 185 |
+
try:
|
| 186 |
+
query = input("\nUser: ")
|
| 187 |
+
except UnicodeDecodeError:
|
| 188 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
| 189 |
+
continue
|
| 190 |
+
except Exception:
|
| 191 |
+
raise
|
| 192 |
+
|
| 193 |
+
if query.strip() == "exit":
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
if query.strip() == "clear":
|
| 197 |
+
messages = []
|
| 198 |
+
torch_gc()
|
| 199 |
+
print("History has been removed.")
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
messages.append({"role": "user", "content": query})
|
| 203 |
+
print("Assistant: ", end="", flush=True)
|
| 204 |
+
|
| 205 |
+
response = ""
|
| 206 |
+
for new_text in chat_model.stream_chat(messages):
|
| 207 |
+
print(new_text, end="", flush=True)
|
| 208 |
+
response += new_text
|
| 209 |
+
print()
|
| 210 |
+
messages.append({"role": "assistant", "content": response})
|
llamafactory/chat/hf_engine.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import os
|
| 17 |
+
from collections.abc import AsyncGenerator
|
| 18 |
+
from threading import Thread
|
| 19 |
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import GenerationConfig, TextIteratorStreamer
|
| 23 |
+
from typing_extensions import override
|
| 24 |
+
|
| 25 |
+
from ..data import get_template_and_fix_tokenizer
|
| 26 |
+
from ..extras import logging
|
| 27 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
| 28 |
+
from ..model import load_model, load_tokenizer
|
| 29 |
+
from .base_engine import BaseEngine, Response
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
| 34 |
+
from trl import PreTrainedModelWrapper
|
| 35 |
+
|
| 36 |
+
from ..data import Template
|
| 37 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 38 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class HuggingfaceEngine(BaseEngine):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
model_args: "ModelArguments",
|
| 48 |
+
data_args: "DataArguments",
|
| 49 |
+
finetuning_args: "FinetuningArguments",
|
| 50 |
+
generating_args: "GeneratingArguments",
|
| 51 |
+
) -> None:
|
| 52 |
+
self.name = EngineName.HF
|
| 53 |
+
self.can_generate = finetuning_args.stage == "sft"
|
| 54 |
+
tokenizer_module = load_tokenizer(model_args)
|
| 55 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
| 56 |
+
self.processor = tokenizer_module["processor"]
|
| 57 |
+
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
| 58 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
| 59 |
+
self.model = load_model(
|
| 60 |
+
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
| 61 |
+
) # must after fixing tokenizer to resize vocab
|
| 62 |
+
self.generating_args = generating_args.to_dict()
|
| 63 |
+
try:
|
| 64 |
+
asyncio.get_event_loop()
|
| 65 |
+
except RuntimeError:
|
| 66 |
+
logger.warning_rank0_once("There is no current event loop, creating a new one.")
|
| 67 |
+
loop = asyncio.new_event_loop()
|
| 68 |
+
asyncio.set_event_loop(loop)
|
| 69 |
+
|
| 70 |
+
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _process_args(
|
| 74 |
+
model: "PreTrainedModel",
|
| 75 |
+
tokenizer: "PreTrainedTokenizer",
|
| 76 |
+
processor: Optional["ProcessorMixin"],
|
| 77 |
+
template: "Template",
|
| 78 |
+
generating_args: dict[str, Any],
|
| 79 |
+
messages: list[dict[str, str]],
|
| 80 |
+
system: Optional[str] = None,
|
| 81 |
+
tools: Optional[str] = None,
|
| 82 |
+
images: Optional[list["ImageInput"]] = None,
|
| 83 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 84 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 85 |
+
input_kwargs: Optional[dict[str, Any]] = {},
|
| 86 |
+
) -> tuple[dict[str, Any], int]:
|
| 87 |
+
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
| 88 |
+
if images is not None:
|
| 89 |
+
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
| 90 |
+
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
| 91 |
+
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
| 92 |
+
|
| 93 |
+
if videos is not None:
|
| 94 |
+
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
| 95 |
+
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
| 96 |
+
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
| 97 |
+
|
| 98 |
+
if audios is not None:
|
| 99 |
+
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
|
| 100 |
+
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
|
| 101 |
+
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
|
| 102 |
+
|
| 103 |
+
messages = template.mm_plugin.process_messages(
|
| 104 |
+
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
| 105 |
+
)
|
| 106 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
| 107 |
+
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
| 108 |
+
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
| 109 |
+
prompt_ids,
|
| 110 |
+
None,
|
| 111 |
+
mm_input_dict["images"],
|
| 112 |
+
mm_input_dict["videos"],
|
| 113 |
+
mm_input_dict["audios"],
|
| 114 |
+
tokenizer,
|
| 115 |
+
processor,
|
| 116 |
+
)
|
| 117 |
+
prompt_length = len(prompt_ids)
|
| 118 |
+
inputs = torch.tensor([prompt_ids], device=model.device)
|
| 119 |
+
attention_mask = torch.ones_like(inputs, dtype=torch.long)
|
| 120 |
+
|
| 121 |
+
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
| 122 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
| 123 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
| 124 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
| 125 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
| 126 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
| 127 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
| 128 |
+
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
| 129 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 130 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
| 131 |
+
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
| 132 |
+
|
| 133 |
+
if stop is not None:
|
| 134 |
+
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
| 135 |
+
|
| 136 |
+
generating_args = generating_args.copy()
|
| 137 |
+
generating_args.update(
|
| 138 |
+
dict(
|
| 139 |
+
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
| 140 |
+
temperature=temperature if temperature is not None else generating_args["temperature"],
|
| 141 |
+
top_p=top_p if top_p is not None else generating_args["top_p"],
|
| 142 |
+
top_k=top_k if top_k is not None else generating_args["top_k"],
|
| 143 |
+
num_return_sequences=num_return_sequences,
|
| 144 |
+
repetition_penalty=repetition_penalty
|
| 145 |
+
if repetition_penalty is not None
|
| 146 |
+
else generating_args["repetition_penalty"],
|
| 147 |
+
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
| 148 |
+
skip_special_tokens=skip_special_tokens
|
| 149 |
+
if skip_special_tokens is not None
|
| 150 |
+
else generating_args["skip_special_tokens"],
|
| 151 |
+
eos_token_id=template.get_stop_token_ids(tokenizer),
|
| 152 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
| 157 |
+
generating_args["do_sample"] = True
|
| 158 |
+
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
| 159 |
+
|
| 160 |
+
if not generating_args["temperature"]:
|
| 161 |
+
generating_args["do_sample"] = False
|
| 162 |
+
|
| 163 |
+
if not generating_args["do_sample"]:
|
| 164 |
+
generating_args.pop("temperature", None)
|
| 165 |
+
generating_args.pop("top_p", None)
|
| 166 |
+
|
| 167 |
+
if max_length:
|
| 168 |
+
generating_args.pop("max_new_tokens", None)
|
| 169 |
+
generating_args["max_length"] = max_length
|
| 170 |
+
|
| 171 |
+
if max_new_tokens:
|
| 172 |
+
generating_args.pop("max_length", None)
|
| 173 |
+
generating_args["max_new_tokens"] = max_new_tokens
|
| 174 |
+
|
| 175 |
+
gen_kwargs = dict(
|
| 176 |
+
inputs=inputs,
|
| 177 |
+
attention_mask=attention_mask,
|
| 178 |
+
generation_config=GenerationConfig(**generating_args),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
| 182 |
+
for key, value in mm_inputs.items():
|
| 183 |
+
if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
|
| 184 |
+
value = torch.stack(value) # assume they have same sizes
|
| 185 |
+
elif (
|
| 186 |
+
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
|
| 187 |
+
): # for minicpmv inputs
|
| 188 |
+
value = torch.stack([torch.stack(v) for v in value])
|
| 189 |
+
elif not isinstance(value, torch.Tensor):
|
| 190 |
+
value = torch.tensor(value)
|
| 191 |
+
|
| 192 |
+
if torch.is_floating_point(value): # cast data dtype for paligemma
|
| 193 |
+
value = value.to(model.dtype)
|
| 194 |
+
|
| 195 |
+
if key == "second_per_grid_ts": # qwen2.5vl special case
|
| 196 |
+
gen_kwargs[key] = value.tolist()
|
| 197 |
+
else:
|
| 198 |
+
gen_kwargs[key] = value.to(model.device)
|
| 199 |
+
|
| 200 |
+
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
| 201 |
+
gen_kwargs["input_ids"] = inputs
|
| 202 |
+
gen_kwargs["tokenizer"] = tokenizer
|
| 203 |
+
if "audio_feature_lens" in mm_inputs:
|
| 204 |
+
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
|
| 205 |
+
|
| 206 |
+
gen_kwargs.pop("image_sizes", None)
|
| 207 |
+
|
| 208 |
+
return gen_kwargs, prompt_length
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
@torch.inference_mode()
|
| 212 |
+
def _chat(
|
| 213 |
+
model: "PreTrainedModel",
|
| 214 |
+
tokenizer: "PreTrainedTokenizer",
|
| 215 |
+
processor: Optional["ProcessorMixin"],
|
| 216 |
+
template: "Template",
|
| 217 |
+
generating_args: dict[str, Any],
|
| 218 |
+
messages: list[dict[str, str]],
|
| 219 |
+
system: Optional[str] = None,
|
| 220 |
+
tools: Optional[str] = None,
|
| 221 |
+
images: Optional[list["ImageInput"]] = None,
|
| 222 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 223 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 224 |
+
input_kwargs: Optional[dict[str, Any]] = {},
|
| 225 |
+
) -> list["Response"]:
|
| 226 |
+
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
| 227 |
+
model,
|
| 228 |
+
tokenizer,
|
| 229 |
+
processor,
|
| 230 |
+
template,
|
| 231 |
+
generating_args,
|
| 232 |
+
messages,
|
| 233 |
+
system,
|
| 234 |
+
tools,
|
| 235 |
+
images,
|
| 236 |
+
videos,
|
| 237 |
+
audios,
|
| 238 |
+
input_kwargs,
|
| 239 |
+
)
|
| 240 |
+
generate_output = model.generate(**gen_kwargs)
|
| 241 |
+
if isinstance(generate_output, tuple):
|
| 242 |
+
generate_output = generate_output[1][0] # post-process the minicpm_o output
|
| 243 |
+
|
| 244 |
+
response_ids = generate_output[:, prompt_length:]
|
| 245 |
+
response = tokenizer.batch_decode(
|
| 246 |
+
response_ids,
|
| 247 |
+
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
|
| 248 |
+
clean_up_tokenization_spaces=True,
|
| 249 |
+
)
|
| 250 |
+
results = []
|
| 251 |
+
for i in range(len(response)):
|
| 252 |
+
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
| 253 |
+
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
| 254 |
+
results.append(
|
| 255 |
+
Response(
|
| 256 |
+
response_text=response[i],
|
| 257 |
+
response_length=response_length,
|
| 258 |
+
prompt_length=prompt_length,
|
| 259 |
+
finish_reason="stop" if len(eos_index) else "length",
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return results
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
@torch.inference_mode()
|
| 267 |
+
def _stream_chat(
|
| 268 |
+
model: "PreTrainedModel",
|
| 269 |
+
tokenizer: "PreTrainedTokenizer",
|
| 270 |
+
processor: Optional["ProcessorMixin"],
|
| 271 |
+
template: "Template",
|
| 272 |
+
generating_args: dict[str, Any],
|
| 273 |
+
messages: list[dict[str, str]],
|
| 274 |
+
system: Optional[str] = None,
|
| 275 |
+
tools: Optional[str] = None,
|
| 276 |
+
images: Optional[list["ImageInput"]] = None,
|
| 277 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 278 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 279 |
+
input_kwargs: Optional[dict[str, Any]] = {},
|
| 280 |
+
) -> Callable[[], str]:
|
| 281 |
+
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
| 282 |
+
model,
|
| 283 |
+
tokenizer,
|
| 284 |
+
processor,
|
| 285 |
+
template,
|
| 286 |
+
generating_args,
|
| 287 |
+
messages,
|
| 288 |
+
system,
|
| 289 |
+
tools,
|
| 290 |
+
images,
|
| 291 |
+
videos,
|
| 292 |
+
audios,
|
| 293 |
+
input_kwargs,
|
| 294 |
+
)
|
| 295 |
+
streamer = TextIteratorStreamer(
|
| 296 |
+
tokenizer,
|
| 297 |
+
skip_prompt=True,
|
| 298 |
+
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
|
| 299 |
+
)
|
| 300 |
+
gen_kwargs["streamer"] = streamer
|
| 301 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
| 302 |
+
thread.start()
|
| 303 |
+
|
| 304 |
+
def stream():
|
| 305 |
+
try:
|
| 306 |
+
return streamer.__next__()
|
| 307 |
+
except StopIteration:
|
| 308 |
+
raise StopAsyncIteration()
|
| 309 |
+
|
| 310 |
+
return stream
|
| 311 |
+
|
| 312 |
+
@staticmethod
|
| 313 |
+
@torch.inference_mode()
|
| 314 |
+
def _get_scores(
|
| 315 |
+
model: "PreTrainedModelWrapper",
|
| 316 |
+
tokenizer: "PreTrainedTokenizer",
|
| 317 |
+
batch_input: list[str],
|
| 318 |
+
input_kwargs: Optional[dict[str, Any]] = {},
|
| 319 |
+
) -> list[float]:
|
| 320 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 321 |
+
device = getattr(model.pretrained_model, "device", "cuda")
|
| 322 |
+
inputs: dict[str, torch.Tensor] = tokenizer(
|
| 323 |
+
batch_input,
|
| 324 |
+
padding=True,
|
| 325 |
+
truncation=True,
|
| 326 |
+
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
| 327 |
+
return_tensors="pt",
|
| 328 |
+
add_special_tokens=False,
|
| 329 |
+
).to(device)
|
| 330 |
+
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
| 331 |
+
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
| 332 |
+
return scores
|
| 333 |
+
|
| 334 |
+
@override
|
| 335 |
+
async def chat(
|
| 336 |
+
self,
|
| 337 |
+
messages: list[dict[str, str]],
|
| 338 |
+
system: Optional[str] = None,
|
| 339 |
+
tools: Optional[str] = None,
|
| 340 |
+
images: Optional[list["ImageInput"]] = None,
|
| 341 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 342 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 343 |
+
**input_kwargs,
|
| 344 |
+
) -> list["Response"]:
|
| 345 |
+
if not self.can_generate:
|
| 346 |
+
raise ValueError("The current model does not support `chat`.")
|
| 347 |
+
|
| 348 |
+
input_args = (
|
| 349 |
+
self.model,
|
| 350 |
+
self.tokenizer,
|
| 351 |
+
self.processor,
|
| 352 |
+
self.template,
|
| 353 |
+
self.generating_args,
|
| 354 |
+
messages,
|
| 355 |
+
system,
|
| 356 |
+
tools,
|
| 357 |
+
images,
|
| 358 |
+
videos,
|
| 359 |
+
audios,
|
| 360 |
+
input_kwargs,
|
| 361 |
+
)
|
| 362 |
+
async with self.semaphore:
|
| 363 |
+
return await asyncio.to_thread(self._chat, *input_args)
|
| 364 |
+
|
| 365 |
+
@override
|
| 366 |
+
async def stream_chat(
|
| 367 |
+
self,
|
| 368 |
+
messages: list[dict[str, str]],
|
| 369 |
+
system: Optional[str] = None,
|
| 370 |
+
tools: Optional[str] = None,
|
| 371 |
+
images: Optional[list["ImageInput"]] = None,
|
| 372 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 373 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 374 |
+
**input_kwargs,
|
| 375 |
+
) -> AsyncGenerator[str, None]:
|
| 376 |
+
if not self.can_generate:
|
| 377 |
+
raise ValueError("The current model does not support `stream_chat`.")
|
| 378 |
+
|
| 379 |
+
input_args = (
|
| 380 |
+
self.model,
|
| 381 |
+
self.tokenizer,
|
| 382 |
+
self.processor,
|
| 383 |
+
self.template,
|
| 384 |
+
self.generating_args,
|
| 385 |
+
messages,
|
| 386 |
+
system,
|
| 387 |
+
tools,
|
| 388 |
+
images,
|
| 389 |
+
videos,
|
| 390 |
+
audios,
|
| 391 |
+
input_kwargs,
|
| 392 |
+
)
|
| 393 |
+
async with self.semaphore:
|
| 394 |
+
stream = self._stream_chat(*input_args)
|
| 395 |
+
while True:
|
| 396 |
+
try:
|
| 397 |
+
yield await asyncio.to_thread(stream)
|
| 398 |
+
except StopAsyncIteration:
|
| 399 |
+
break
|
| 400 |
+
|
| 401 |
+
@override
|
| 402 |
+
async def get_scores(
|
| 403 |
+
self,
|
| 404 |
+
batch_input: list[str],
|
| 405 |
+
**input_kwargs,
|
| 406 |
+
) -> list[float]:
|
| 407 |
+
if self.can_generate:
|
| 408 |
+
raise ValueError("Cannot get scores using an auto-regressive model.")
|
| 409 |
+
|
| 410 |
+
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
| 411 |
+
async with self.semaphore:
|
| 412 |
+
return await asyncio.to_thread(self._get_scores, *input_args)
|
llamafactory/chat/kt_engine.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import os
|
| 17 |
+
import platform
|
| 18 |
+
from collections.abc import AsyncGenerator
|
| 19 |
+
from threading import Thread
|
| 20 |
+
from typing import TYPE_CHECKING, Any, Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from typing_extensions import override
|
| 24 |
+
|
| 25 |
+
from ..data import get_template_and_fix_tokenizer
|
| 26 |
+
from ..extras import logging
|
| 27 |
+
from ..extras.constants import EngineName
|
| 28 |
+
from ..model import load_model, load_tokenizer
|
| 29 |
+
from .base_engine import BaseEngine, Response
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from transformers import PreTrainedTokenizer
|
| 34 |
+
from trl import PreTrainedModelWrapper
|
| 35 |
+
|
| 36 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 37 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
| 38 |
+
|
| 39 |
+
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
| 40 |
+
from ktransformers.server.config.config import Config
|
| 41 |
+
from ktransformers.util.utils import (
|
| 42 |
+
get_compute_capability,
|
| 43 |
+
prefill_and_generate_capture,
|
| 44 |
+
)
|
| 45 |
+
from ktransformers.util.vendors import GPUVendor, device_manager
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class KTransformersEngine(BaseEngine):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
model_args: "ModelArguments",
|
| 55 |
+
data_args: "DataArguments",
|
| 56 |
+
finetuning_args: "FinetuningArguments",
|
| 57 |
+
generating_args: "GeneratingArguments",
|
| 58 |
+
) -> None:
|
| 59 |
+
self.name = EngineName.KT
|
| 60 |
+
self.can_generate = finetuning_args.stage == "sft"
|
| 61 |
+
|
| 62 |
+
tok_mod = load_tokenizer(model_args)
|
| 63 |
+
self.tokenizer = tok_mod["tokenizer"]
|
| 64 |
+
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
| 65 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
| 66 |
+
|
| 67 |
+
self.model = load_model(
|
| 68 |
+
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.generating_args = generating_args.to_dict()
|
| 72 |
+
self.max_new_tokens = model_args.kt_maxlen
|
| 73 |
+
self.use_cuda_graph = model_args.kt_use_cuda_graph
|
| 74 |
+
self.mode = model_args.kt_mode
|
| 75 |
+
self.force_think = model_args.kt_force_think
|
| 76 |
+
self.chunk_size = model_args.chunk_size
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
asyncio.get_event_loop()
|
| 80 |
+
except RuntimeError:
|
| 81 |
+
loop = asyncio.new_event_loop()
|
| 82 |
+
asyncio.set_event_loop(loop)
|
| 83 |
+
|
| 84 |
+
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
@torch.inference_mode()
|
| 88 |
+
def _get_scores(
|
| 89 |
+
model: "PreTrainedModelWrapper",
|
| 90 |
+
tokenizer: "PreTrainedTokenizer",
|
| 91 |
+
batch_input: list[str],
|
| 92 |
+
input_kwargs: Optional[dict[str, Any]] = {},
|
| 93 |
+
) -> list[float]:
|
| 94 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 95 |
+
device = getattr(model.pretrained_model, "device", "cuda")
|
| 96 |
+
inputs = tokenizer(
|
| 97 |
+
batch_input,
|
| 98 |
+
padding=True,
|
| 99 |
+
truncation=True,
|
| 100 |
+
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
| 101 |
+
return_tensors="pt",
|
| 102 |
+
add_special_tokens=False,
|
| 103 |
+
).to(device)
|
| 104 |
+
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
| 105 |
+
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
| 106 |
+
return scores
|
| 107 |
+
|
| 108 |
+
async def _generate(
|
| 109 |
+
self,
|
| 110 |
+
messages: list[dict[str, str]],
|
| 111 |
+
system: Optional[str] = None,
|
| 112 |
+
tools: Optional[str] = None,
|
| 113 |
+
**input_kwargs,
|
| 114 |
+
) -> AsyncGenerator[str, None]:
|
| 115 |
+
paired = messages + [{"role": "assistant", "content": ""}]
|
| 116 |
+
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
|
| 117 |
+
prompt_len = len(prompt_ids)
|
| 118 |
+
|
| 119 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 120 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
| 121 |
+
|
| 122 |
+
if "max_new_tokens" in self.generating_args:
|
| 123 |
+
max_tokens = int(self.generating_args["max_new_tokens"])
|
| 124 |
+
elif "max_length" in self.generating_args:
|
| 125 |
+
gl = int(self.generating_args["max_length"])
|
| 126 |
+
max_tokens = gl - prompt_len if gl > prompt_len else 1
|
| 127 |
+
else:
|
| 128 |
+
max_tokens = self.max_new_tokens or 256
|
| 129 |
+
|
| 130 |
+
if max_length is not None:
|
| 131 |
+
max_tokens = max(max_length - prompt_len, 1)
|
| 132 |
+
if max_new_tokens is not None:
|
| 133 |
+
max_tokens = int(max_new_tokens)
|
| 134 |
+
max_tokens = max(1, int(max_tokens))
|
| 135 |
+
|
| 136 |
+
if self.mode == "long_context":
|
| 137 |
+
max_len_cfg = Config().long_context_config["max_seq_len"]
|
| 138 |
+
need = prompt_len + max_tokens
|
| 139 |
+
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
|
| 140 |
+
|
| 141 |
+
device = next(self.model.parameters()).device
|
| 142 |
+
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 143 |
+
if self.force_think:
|
| 144 |
+
think = torch.tensor(
|
| 145 |
+
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
| 146 |
+
)
|
| 147 |
+
input_tensor = torch.cat([input_tensor, think], dim=1)
|
| 148 |
+
|
| 149 |
+
use_flashinfer = (
|
| 150 |
+
platform.system() != "Windows"
|
| 151 |
+
and getattr(self.model.config, "architectures", [""])[0]
|
| 152 |
+
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
| 153 |
+
and flashinfer_enabled
|
| 154 |
+
and get_compute_capability() >= 8
|
| 155 |
+
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def make_gen():
|
| 159 |
+
if use_flashinfer:
|
| 160 |
+
return prefill_and_generate_capture(
|
| 161 |
+
self.model,
|
| 162 |
+
self.tokenizer,
|
| 163 |
+
input_tensor,
|
| 164 |
+
max_tokens,
|
| 165 |
+
self.use_cuda_graph,
|
| 166 |
+
mode=self.mode,
|
| 167 |
+
force_think=self.force_think,
|
| 168 |
+
chunk_size=self.chunk_size,
|
| 169 |
+
use_flashinfer_mla=True,
|
| 170 |
+
num_heads=self.model.config.num_attention_heads,
|
| 171 |
+
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
| 172 |
+
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
| 173 |
+
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
| 174 |
+
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
| 175 |
+
echo_stream=False,
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
return prefill_and_generate_capture(
|
| 179 |
+
self.model,
|
| 180 |
+
self.tokenizer,
|
| 181 |
+
input_tensor,
|
| 182 |
+
max_tokens,
|
| 183 |
+
self.use_cuda_graph,
|
| 184 |
+
mode=self.mode,
|
| 185 |
+
force_think=self.force_think,
|
| 186 |
+
chunk_size=self.chunk_size,
|
| 187 |
+
echo_stream=False,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
loop = asyncio.get_running_loop()
|
| 191 |
+
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
|
| 192 |
+
|
| 193 |
+
def producer():
|
| 194 |
+
try:
|
| 195 |
+
gen = make_gen()
|
| 196 |
+
if hasattr(gen, "__aiter__"):
|
| 197 |
+
|
| 198 |
+
async def drain_async():
|
| 199 |
+
async for t in gen:
|
| 200 |
+
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
| 201 |
+
|
| 202 |
+
asyncio.run(drain_async())
|
| 203 |
+
elif hasattr(gen, "__iter__"):
|
| 204 |
+
for t in gen:
|
| 205 |
+
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
| 206 |
+
else:
|
| 207 |
+
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
|
| 208 |
+
finally:
|
| 209 |
+
loop.call_soon_threadsafe(q.put_nowait, None)
|
| 210 |
+
|
| 211 |
+
Thread(target=producer, daemon=True).start()
|
| 212 |
+
|
| 213 |
+
while True:
|
| 214 |
+
item = await q.get()
|
| 215 |
+
if item is None:
|
| 216 |
+
break
|
| 217 |
+
yield item
|
| 218 |
+
|
| 219 |
+
@override
|
| 220 |
+
async def chat(
|
| 221 |
+
self,
|
| 222 |
+
messages: list[dict[str, str]],
|
| 223 |
+
system: Optional[str] = None,
|
| 224 |
+
tools: Optional[str] = None,
|
| 225 |
+
images: Optional[list["ImageInput"]] = None,
|
| 226 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 227 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 228 |
+
**input_kwargs,
|
| 229 |
+
) -> list["Response"]:
|
| 230 |
+
if not self.can_generate:
|
| 231 |
+
raise ValueError("The current model does not support `chat`.")
|
| 232 |
+
async with self.semaphore:
|
| 233 |
+
produced = ""
|
| 234 |
+
final_text = ""
|
| 235 |
+
async for t in self._generate(messages, system, tools, **input_kwargs):
|
| 236 |
+
delta = t
|
| 237 |
+
produced = produced + delta
|
| 238 |
+
if delta:
|
| 239 |
+
final_text += delta
|
| 240 |
+
|
| 241 |
+
prompt_ids, _ = self.template.encode_oneturn(
|
| 242 |
+
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
|
| 243 |
+
)
|
| 244 |
+
return [
|
| 245 |
+
Response(
|
| 246 |
+
response_text=final_text,
|
| 247 |
+
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
|
| 248 |
+
prompt_length=len(prompt_ids),
|
| 249 |
+
finish_reason="stop",
|
| 250 |
+
)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
@override
|
| 254 |
+
async def stream_chat(
|
| 255 |
+
self,
|
| 256 |
+
messages: list[dict[str, str]],
|
| 257 |
+
system: Optional[str] = None,
|
| 258 |
+
tools: Optional[str] = None,
|
| 259 |
+
images: Optional[list["ImageInput"]] = None,
|
| 260 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 261 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 262 |
+
**input_kwargs,
|
| 263 |
+
) -> AsyncGenerator[str, None]:
|
| 264 |
+
if not self.can_generate:
|
| 265 |
+
raise ValueError("The current model does not support `stream_chat`.")
|
| 266 |
+
async with self.semaphore:
|
| 267 |
+
produced = ""
|
| 268 |
+
async for t in self._generate(messages, system, tools, **input_kwargs):
|
| 269 |
+
delta = t[len(produced) :] if t.startswith(produced) else t
|
| 270 |
+
produced = t
|
| 271 |
+
if delta:
|
| 272 |
+
yield delta
|
| 273 |
+
|
| 274 |
+
@override
|
| 275 |
+
async def get_scores(
|
| 276 |
+
self,
|
| 277 |
+
batch_input: list[str],
|
| 278 |
+
**input_kwargs,
|
| 279 |
+
) -> list[float]:
|
| 280 |
+
if self.can_generate:
|
| 281 |
+
raise ValueError("Cannot get scores using an auto-regressive model.")
|
| 282 |
+
args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
| 283 |
+
async with self.semaphore:
|
| 284 |
+
return await asyncio.to_thread(self._get_scores, *args)
|
llamafactory/chat/sglang_engine.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import atexit
|
| 17 |
+
import json
|
| 18 |
+
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
| 19 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
from typing_extensions import override
|
| 23 |
+
|
| 24 |
+
from ..data import get_template_and_fix_tokenizer
|
| 25 |
+
from ..extras import logging
|
| 26 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
| 27 |
+
from ..extras.misc import get_device_count, torch_gc
|
| 28 |
+
from ..extras.packages import is_sglang_available
|
| 29 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
| 30 |
+
from ..model import load_config, load_tokenizer
|
| 31 |
+
from ..model.model_utils.quantization import QuantizationMethod
|
| 32 |
+
from .base_engine import BaseEngine, Response
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_sglang_available():
|
| 36 |
+
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SGLangEngine(BaseEngine):
|
| 47 |
+
"""Inference engine for SGLang models.
|
| 48 |
+
|
| 49 |
+
This class wraps the SGLang engine to provide a consistent interface for text generation
|
| 50 |
+
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
|
| 51 |
+
better interaction and performance. The engine launches a server process and communicates
|
| 52 |
+
with it via HTTP requests.
|
| 53 |
+
|
| 54 |
+
For more details on the SGLang HTTP server approach, see:
|
| 55 |
+
https://docs.sglang.ai/backend/send_request.html
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
model_args: "ModelArguments",
|
| 61 |
+
data_args: "DataArguments",
|
| 62 |
+
finetuning_args: "FinetuningArguments",
|
| 63 |
+
generating_args: "GeneratingArguments",
|
| 64 |
+
) -> None:
|
| 65 |
+
self.name = EngineName.SGLANG
|
| 66 |
+
self.model_args = model_args
|
| 67 |
+
config = load_config(model_args) # may download model from ms hub
|
| 68 |
+
if getattr(config, "quantization_config", None): # gptq models should use float16
|
| 69 |
+
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
| 70 |
+
quant_method = quantization_config.get("quant_method", "")
|
| 71 |
+
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
| 72 |
+
model_args.infer_dtype = "float16"
|
| 73 |
+
|
| 74 |
+
self.can_generate = finetuning_args.stage == "sft"
|
| 75 |
+
tokenizer_module = load_tokenizer(model_args)
|
| 76 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
| 77 |
+
self.processor = tokenizer_module["processor"]
|
| 78 |
+
self.tokenizer.padding_side = "left"
|
| 79 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
| 80 |
+
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
|
| 81 |
+
self.generating_args = generating_args.to_dict()
|
| 82 |
+
if model_args.adapter_name_or_path is not None:
|
| 83 |
+
self.lora_request = True
|
| 84 |
+
else:
|
| 85 |
+
self.lora_request = False
|
| 86 |
+
|
| 87 |
+
launch_cmd = [
|
| 88 |
+
"python3 -m sglang.launch_server",
|
| 89 |
+
f"--model-path {model_args.model_name_or_path}",
|
| 90 |
+
f"--dtype {model_args.infer_dtype}",
|
| 91 |
+
f"--context-length {model_args.sglang_maxlen}",
|
| 92 |
+
f"--mem-fraction-static {model_args.sglang_mem_fraction}",
|
| 93 |
+
f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
|
| 94 |
+
f"--download-dir {model_args.cache_dir}",
|
| 95 |
+
"--log-level error",
|
| 96 |
+
]
|
| 97 |
+
if self.lora_request:
|
| 98 |
+
launch_cmd.extend(
|
| 99 |
+
[
|
| 100 |
+
"--max-loras-per-batch 1",
|
| 101 |
+
f"--lora-backend {model_args.sglang_lora_backend}",
|
| 102 |
+
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
|
| 103 |
+
"--disable-radix-cache",
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
launch_cmd = " ".join(launch_cmd)
|
| 107 |
+
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
|
| 108 |
+
try:
|
| 109 |
+
torch_gc()
|
| 110 |
+
self.server_process, port = launch_server_cmd(launch_cmd)
|
| 111 |
+
self.base_url = f"http://localhost:{port}"
|
| 112 |
+
atexit.register(self._cleanup_server)
|
| 113 |
+
|
| 114 |
+
logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
|
| 115 |
+
wait_for_server(self.base_url, timeout=300)
|
| 116 |
+
logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
|
| 117 |
+
try:
|
| 118 |
+
response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
|
| 119 |
+
if response.status_code == 200:
|
| 120 |
+
model_info = response.json()
|
| 121 |
+
logger.info(f"SGLang server model info: {model_info}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.debug(f"Note: could not get model info: {str(e)}")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Failed to start SGLang server: {str(e)}")
|
| 127 |
+
self._cleanup_server() # make sure to clean up any started process
|
| 128 |
+
raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
|
| 129 |
+
|
| 130 |
+
def _cleanup_server(self):
|
| 131 |
+
r"""Clean up the server process when the engine is destroyed."""
|
| 132 |
+
if hasattr(self, "server_process") and self.server_process:
|
| 133 |
+
try:
|
| 134 |
+
logger.info("Terminating SGLang server process")
|
| 135 |
+
terminate_process(self.server_process)
|
| 136 |
+
logger.info("SGLang server process terminated")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning(f"Error terminating SGLang server: {str(e)}")
|
| 139 |
+
|
| 140 |
+
async def _generate(
|
| 141 |
+
self,
|
| 142 |
+
messages: list[dict[str, str]],
|
| 143 |
+
system: Optional[str] = None,
|
| 144 |
+
tools: Optional[str] = None,
|
| 145 |
+
images: Optional[list["ImageInput"]] = None,
|
| 146 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 147 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 148 |
+
**input_kwargs,
|
| 149 |
+
) -> AsyncIterator[dict[str, Any]]:
|
| 150 |
+
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
| 151 |
+
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
| 152 |
+
|
| 153 |
+
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
| 154 |
+
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
| 155 |
+
|
| 156 |
+
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
|
| 157 |
+
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
|
| 158 |
+
|
| 159 |
+
messages = self.template.mm_plugin.process_messages(
|
| 160 |
+
messages, images or [], videos or [], audios or [], self.processor
|
| 161 |
+
)
|
| 162 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
| 163 |
+
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
| 164 |
+
prompt_length = len(prompt_ids)
|
| 165 |
+
|
| 166 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
| 167 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
| 168 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
| 169 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
| 170 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
| 171 |
+
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
| 172 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 173 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
| 174 |
+
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
| 175 |
+
|
| 176 |
+
if num_return_sequences != 1:
|
| 177 |
+
raise NotImplementedError("SGLang only supports n=1.")
|
| 178 |
+
|
| 179 |
+
if "max_new_tokens" in self.generating_args:
|
| 180 |
+
max_tokens = self.generating_args["max_new_tokens"]
|
| 181 |
+
elif "max_length" in self.generating_args:
|
| 182 |
+
if self.generating_args["max_length"] > prompt_length:
|
| 183 |
+
max_tokens = self.generating_args["max_length"] - prompt_length
|
| 184 |
+
else:
|
| 185 |
+
max_tokens = 1
|
| 186 |
+
|
| 187 |
+
if max_length:
|
| 188 |
+
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
| 189 |
+
|
| 190 |
+
if max_new_tokens:
|
| 191 |
+
max_tokens = max_new_tokens
|
| 192 |
+
|
| 193 |
+
sampling_params = {
|
| 194 |
+
"temperature": temperature if temperature is not None else self.generating_args["temperature"],
|
| 195 |
+
"top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
| 196 |
+
"top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
|
| 197 |
+
"stop": stop,
|
| 198 |
+
"stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
|
| 199 |
+
"max_new_tokens": max_tokens,
|
| 200 |
+
"repetition_penalty": (
|
| 201 |
+
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
| 202 |
+
)
|
| 203 |
+
or 1.0, # repetition_penalty must > 0
|
| 204 |
+
"skip_special_tokens": skip_special_tokens
|
| 205 |
+
if skip_special_tokens is not None
|
| 206 |
+
else self.generating_args["skip_special_tokens"],
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
def stream_request():
|
| 210 |
+
json_data = {
|
| 211 |
+
"input_ids": prompt_ids,
|
| 212 |
+
"sampling_params": sampling_params,
|
| 213 |
+
"stream": True,
|
| 214 |
+
}
|
| 215 |
+
if self.lora_request:
|
| 216 |
+
json_data["lora_request"] = ["lora0"]
|
| 217 |
+
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
|
| 218 |
+
if response.status_code != 200:
|
| 219 |
+
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
|
| 220 |
+
|
| 221 |
+
for chunk in response.iter_lines(decode_unicode=False):
|
| 222 |
+
chunk = str(chunk.decode("utf-8"))
|
| 223 |
+
if chunk == "data: [DONE]":
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
if chunk and chunk.startswith("data:"):
|
| 227 |
+
yield json.loads(chunk[5:].strip("\n"))
|
| 228 |
+
|
| 229 |
+
return await asyncio.to_thread(stream_request)
|
| 230 |
+
|
| 231 |
+
@override
|
| 232 |
+
async def chat(
|
| 233 |
+
self,
|
| 234 |
+
messages: Sequence[dict[str, str]],
|
| 235 |
+
system: Optional[str] = None,
|
| 236 |
+
tools: Optional[str] = None,
|
| 237 |
+
images: Optional[Sequence["ImageInput"]] = None,
|
| 238 |
+
videos: Optional[Sequence["VideoInput"]] = None,
|
| 239 |
+
audios: Optional[Sequence["AudioInput"]] = None,
|
| 240 |
+
**input_kwargs,
|
| 241 |
+
) -> list["Response"]:
|
| 242 |
+
final_output = None
|
| 243 |
+
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 244 |
+
for request_output in generator:
|
| 245 |
+
final_output = request_output
|
| 246 |
+
|
| 247 |
+
results = [
|
| 248 |
+
Response(
|
| 249 |
+
response_text=final_output["text"],
|
| 250 |
+
response_length=final_output["meta_info"]["completion_tokens"],
|
| 251 |
+
prompt_length=final_output["meta_info"]["prompt_tokens"],
|
| 252 |
+
finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
|
| 253 |
+
)
|
| 254 |
+
]
|
| 255 |
+
return results
|
| 256 |
+
|
| 257 |
+
@override
|
| 258 |
+
async def stream_chat(
|
| 259 |
+
self,
|
| 260 |
+
messages: list[dict[str, str]],
|
| 261 |
+
system: Optional[str] = None,
|
| 262 |
+
tools: Optional[str] = None,
|
| 263 |
+
images: Optional[list["ImageInput"]] = None,
|
| 264 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 265 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 266 |
+
**input_kwargs,
|
| 267 |
+
) -> AsyncGenerator[str, None]:
|
| 268 |
+
generated_text = ""
|
| 269 |
+
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 270 |
+
for result in generator:
|
| 271 |
+
delta_text = result["text"][len(generated_text) :]
|
| 272 |
+
generated_text = result["text"]
|
| 273 |
+
yield delta_text
|
| 274 |
+
|
| 275 |
+
@override
|
| 276 |
+
async def get_scores(
|
| 277 |
+
self,
|
| 278 |
+
batch_input: list[str],
|
| 279 |
+
**input_kwargs,
|
| 280 |
+
) -> list[float]:
|
| 281 |
+
raise NotImplementedError("SGLang engine does not support `get_scores`.")
|
| 282 |
+
|
| 283 |
+
def __del__(self):
|
| 284 |
+
r"""Ensure server is cleaned up when object is deleted."""
|
| 285 |
+
self._cleanup_server()
|
| 286 |
+
try:
|
| 287 |
+
atexit.unregister(self._cleanup_server)
|
| 288 |
+
except Exception:
|
| 289 |
+
pass
|
llamafactory/chat/vllm_engine.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import uuid
|
| 16 |
+
from collections.abc import AsyncGenerator, AsyncIterator
|
| 17 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
| 18 |
+
|
| 19 |
+
from typing_extensions import override
|
| 20 |
+
|
| 21 |
+
from ..data import get_template_and_fix_tokenizer
|
| 22 |
+
from ..extras import logging
|
| 23 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
| 24 |
+
from ..extras.misc import get_device_count
|
| 25 |
+
from ..extras.packages import is_vllm_available
|
| 26 |
+
from ..model import load_config, load_tokenizer
|
| 27 |
+
from ..model.model_utils.quantization import QuantizationMethod
|
| 28 |
+
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
| 29 |
+
from .base_engine import BaseEngine, Response
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_vllm_available():
|
| 33 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
| 34 |
+
from vllm.lora.request import LoRARequest
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
| 39 |
+
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class VllmEngine(BaseEngine):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
model_args: "ModelArguments",
|
| 49 |
+
data_args: "DataArguments",
|
| 50 |
+
finetuning_args: "FinetuningArguments",
|
| 51 |
+
generating_args: "GeneratingArguments",
|
| 52 |
+
) -> None:
|
| 53 |
+
self.name = EngineName.VLLM
|
| 54 |
+
self.model_args = model_args
|
| 55 |
+
config = load_config(model_args) # may download model from ms hub
|
| 56 |
+
if getattr(config, "quantization_config", None): # gptq models should use float16
|
| 57 |
+
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
| 58 |
+
quant_method = quantization_config.get("quant_method", "")
|
| 59 |
+
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
| 60 |
+
model_args.infer_dtype = "float16"
|
| 61 |
+
|
| 62 |
+
self.can_generate = finetuning_args.stage == "sft"
|
| 63 |
+
tokenizer_module = load_tokenizer(model_args)
|
| 64 |
+
self.tokenizer = tokenizer_module["tokenizer"]
|
| 65 |
+
self.processor = tokenizer_module["processor"]
|
| 66 |
+
self.tokenizer.padding_side = "left"
|
| 67 |
+
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
| 68 |
+
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
|
| 69 |
+
self.generating_args = generating_args.to_dict()
|
| 70 |
+
|
| 71 |
+
engine_args = {
|
| 72 |
+
"model": model_args.model_name_or_path,
|
| 73 |
+
"trust_remote_code": model_args.trust_remote_code,
|
| 74 |
+
"download_dir": model_args.cache_dir,
|
| 75 |
+
"dtype": model_args.infer_dtype,
|
| 76 |
+
"max_model_len": model_args.vllm_maxlen,
|
| 77 |
+
"tensor_parallel_size": get_device_count() or 1,
|
| 78 |
+
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
| 79 |
+
"disable_log_stats": True,
|
| 80 |
+
"disable_log_requests": True,
|
| 81 |
+
"enforce_eager": model_args.vllm_enforce_eager,
|
| 82 |
+
"enable_lora": model_args.adapter_name_or_path is not None,
|
| 83 |
+
"max_lora_rank": model_args.vllm_max_lora_rank,
|
| 84 |
+
}
|
| 85 |
+
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
| 86 |
+
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
|
| 87 |
+
|
| 88 |
+
if isinstance(model_args.vllm_config, dict):
|
| 89 |
+
engine_args.update(model_args.vllm_config)
|
| 90 |
+
|
| 91 |
+
if getattr(config, "is_yi_vl_derived_model", None):
|
| 92 |
+
import vllm.model_executor.models.llava
|
| 93 |
+
|
| 94 |
+
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
| 95 |
+
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
| 96 |
+
|
| 97 |
+
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
| 98 |
+
if model_args.adapter_name_or_path is not None:
|
| 99 |
+
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
| 100 |
+
else:
|
| 101 |
+
self.lora_request = None
|
| 102 |
+
|
| 103 |
+
async def _generate(
|
| 104 |
+
self,
|
| 105 |
+
messages: list[dict[str, str]],
|
| 106 |
+
system: Optional[str] = None,
|
| 107 |
+
tools: Optional[str] = None,
|
| 108 |
+
images: Optional[list["ImageInput"]] = None,
|
| 109 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 110 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 111 |
+
**input_kwargs,
|
| 112 |
+
) -> AsyncIterator["RequestOutput"]:
|
| 113 |
+
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
| 114 |
+
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
| 115 |
+
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
| 116 |
+
|
| 117 |
+
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
| 118 |
+
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
| 119 |
+
|
| 120 |
+
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
|
| 121 |
+
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
|
| 122 |
+
|
| 123 |
+
messages = self.template.mm_plugin.process_messages(
|
| 124 |
+
messages, images or [], videos or [], audios or [], self.processor
|
| 125 |
+
)
|
| 126 |
+
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
| 127 |
+
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
| 128 |
+
prompt_length = len(prompt_ids)
|
| 129 |
+
|
| 130 |
+
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
| 131 |
+
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
| 132 |
+
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
| 133 |
+
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
| 134 |
+
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
| 135 |
+
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
| 136 |
+
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
| 137 |
+
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
| 138 |
+
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
| 139 |
+
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
| 140 |
+
|
| 141 |
+
if length_penalty is not None:
|
| 142 |
+
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
| 143 |
+
|
| 144 |
+
if "max_new_tokens" in self.generating_args:
|
| 145 |
+
max_tokens = self.generating_args["max_new_tokens"]
|
| 146 |
+
elif "max_length" in self.generating_args:
|
| 147 |
+
if self.generating_args["max_length"] > prompt_length:
|
| 148 |
+
max_tokens = self.generating_args["max_length"] - prompt_length
|
| 149 |
+
else:
|
| 150 |
+
max_tokens = 1
|
| 151 |
+
|
| 152 |
+
if max_length:
|
| 153 |
+
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
| 154 |
+
|
| 155 |
+
if max_new_tokens:
|
| 156 |
+
max_tokens = max_new_tokens
|
| 157 |
+
|
| 158 |
+
sampling_params = SamplingParams(
|
| 159 |
+
n=num_return_sequences,
|
| 160 |
+
repetition_penalty=(
|
| 161 |
+
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
| 162 |
+
)
|
| 163 |
+
or 1.0, # repetition_penalty must > 0
|
| 164 |
+
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
| 165 |
+
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
| 166 |
+
top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
|
| 167 |
+
stop=stop,
|
| 168 |
+
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
|
| 169 |
+
max_tokens=max_tokens,
|
| 170 |
+
skip_special_tokens=skip_special_tokens
|
| 171 |
+
if skip_special_tokens is not None
|
| 172 |
+
else self.generating_args["skip_special_tokens"],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if images is not None: # add image features
|
| 176 |
+
multi_modal_data = {
|
| 177 |
+
"image": self.template.mm_plugin._regularize_images(
|
| 178 |
+
images,
|
| 179 |
+
image_max_pixels=self.model_args.image_max_pixels,
|
| 180 |
+
image_min_pixels=self.model_args.image_min_pixels,
|
| 181 |
+
)["images"]
|
| 182 |
+
}
|
| 183 |
+
elif videos is not None:
|
| 184 |
+
multi_modal_data = {
|
| 185 |
+
"video": self.template.mm_plugin._regularize_videos(
|
| 186 |
+
videos,
|
| 187 |
+
image_max_pixels=self.model_args.video_max_pixels,
|
| 188 |
+
image_min_pixels=self.model_args.video_min_pixels,
|
| 189 |
+
video_fps=self.model_args.video_fps,
|
| 190 |
+
video_maxlen=self.model_args.video_maxlen,
|
| 191 |
+
)["videos"]
|
| 192 |
+
}
|
| 193 |
+
elif audios is not None:
|
| 194 |
+
audio_data = self.template.mm_plugin._regularize_audios(
|
| 195 |
+
audios,
|
| 196 |
+
sampling_rate=self.model_args.audio_sampling_rate,
|
| 197 |
+
)
|
| 198 |
+
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
| 199 |
+
else:
|
| 200 |
+
multi_modal_data = None
|
| 201 |
+
|
| 202 |
+
result_generator = self.model.generate(
|
| 203 |
+
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
| 204 |
+
sampling_params=sampling_params,
|
| 205 |
+
request_id=request_id,
|
| 206 |
+
lora_request=self.lora_request,
|
| 207 |
+
)
|
| 208 |
+
return result_generator
|
| 209 |
+
|
| 210 |
+
@override
|
| 211 |
+
async def chat(
|
| 212 |
+
self,
|
| 213 |
+
messages: list[dict[str, str]],
|
| 214 |
+
system: Optional[str] = None,
|
| 215 |
+
tools: Optional[str] = None,
|
| 216 |
+
images: Optional[list["ImageInput"]] = None,
|
| 217 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 218 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 219 |
+
**input_kwargs,
|
| 220 |
+
) -> list["Response"]:
|
| 221 |
+
final_output = None
|
| 222 |
+
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 223 |
+
async for request_output in generator:
|
| 224 |
+
final_output = request_output
|
| 225 |
+
|
| 226 |
+
results = []
|
| 227 |
+
for output in final_output.outputs:
|
| 228 |
+
results.append(
|
| 229 |
+
Response(
|
| 230 |
+
response_text=output.text,
|
| 231 |
+
response_length=len(output.token_ids),
|
| 232 |
+
prompt_length=len(final_output.prompt_token_ids),
|
| 233 |
+
finish_reason=output.finish_reason,
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return results
|
| 238 |
+
|
| 239 |
+
@override
|
| 240 |
+
async def stream_chat(
|
| 241 |
+
self,
|
| 242 |
+
messages: list[dict[str, str]],
|
| 243 |
+
system: Optional[str] = None,
|
| 244 |
+
tools: Optional[str] = None,
|
| 245 |
+
images: Optional[list["ImageInput"]] = None,
|
| 246 |
+
videos: Optional[list["VideoInput"]] = None,
|
| 247 |
+
audios: Optional[list["AudioInput"]] = None,
|
| 248 |
+
**input_kwargs,
|
| 249 |
+
) -> AsyncGenerator[str, None]:
|
| 250 |
+
generated_text = ""
|
| 251 |
+
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
| 252 |
+
async for result in generator:
|
| 253 |
+
delta_text = result.outputs[0].text[len(generated_text) :]
|
| 254 |
+
generated_text = result.outputs[0].text
|
| 255 |
+
yield delta_text
|
| 256 |
+
|
| 257 |
+
@override
|
| 258 |
+
async def get_scores(
|
| 259 |
+
self,
|
| 260 |
+
batch_input: list[str],
|
| 261 |
+
**input_kwargs,
|
| 262 |
+
) -> list[float]:
|
| 263 |
+
raise NotImplementedError("vLLM engine does not support `get_scores`.")
|
llamafactory/cli.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
from .extras.misc import is_env_enabled
|
| 18 |
+
|
| 19 |
+
if is_env_enabled("USE_V1"):
|
| 20 |
+
from .v1 import launcher
|
| 21 |
+
else:
|
| 22 |
+
from . import launcher
|
| 23 |
+
|
| 24 |
+
launcher.launch()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
from multiprocessing import freeze_support
|
| 29 |
+
|
| 30 |
+
freeze_support()
|
| 31 |
+
main()
|
llamafactory/data/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .collator import (
|
| 16 |
+
KTODataCollatorWithPadding,
|
| 17 |
+
MultiModalDataCollatorForSeq2Seq,
|
| 18 |
+
PairwiseDataCollatorWithPadding,
|
| 19 |
+
SFTDataCollatorWith4DAttentionMask,
|
| 20 |
+
)
|
| 21 |
+
from .data_utils import Role, split_dataset
|
| 22 |
+
from .loader import get_dataset
|
| 23 |
+
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"TEMPLATES",
|
| 28 |
+
"KTODataCollatorWithPadding",
|
| 29 |
+
"MultiModalDataCollatorForSeq2Seq",
|
| 30 |
+
"PairwiseDataCollatorWithPadding",
|
| 31 |
+
"Role",
|
| 32 |
+
"SFTDataCollatorWith4DAttentionMask",
|
| 33 |
+
"Template",
|
| 34 |
+
"get_dataset",
|
| 35 |
+
"get_template_and_fix_tokenizer",
|
| 36 |
+
"split_dataset",
|
| 37 |
+
]
|
llamafactory/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (643 Bytes). View file
|
|
|
llamafactory/data/__pycache__/collator.cpython-312.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
llamafactory/data/__pycache__/converter.cpython-312.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
llamafactory/data/__pycache__/data_utils.cpython-312.pyc
ADDED
|
Binary file (8.54 kB). View file
|
|
|
llamafactory/data/__pycache__/formatter.cpython-312.pyc
ADDED
|
Binary file (7.89 kB). View file
|
|
|
llamafactory/data/__pycache__/loader.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
llamafactory/data/__pycache__/mm_plugin.cpython-312.pyc
ADDED
|
Binary file (91.5 kB). View file
|
|
|
llamafactory/data/__pycache__/parser.cpython-312.pyc
ADDED
|
Binary file (6.32 kB). View file
|
|
|
llamafactory/data/__pycache__/template.cpython-312.pyc
ADDED
|
Binary file (69.3 kB). View file
|
|
|
llamafactory/data/__pycache__/tool_utils.cpython-312.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
llamafactory/data/collator.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
| 4 |
+
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import TYPE_CHECKING, Any, Literal, Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from peft import PeftModel
|
| 25 |
+
from transformers import DataCollatorForSeq2Seq
|
| 26 |
+
|
| 27 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
| 28 |
+
from ..extras.packages import is_pillow_available
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_pillow_available():
|
| 32 |
+
from PIL import Image
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
from transformers import ProcessorMixin
|
| 37 |
+
|
| 38 |
+
from .template import Template
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
| 42 |
+
r"""Expand 2d attention mask to 4d attention mask.
|
| 43 |
+
|
| 44 |
+
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
| 45 |
+
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
| 46 |
+
|
| 47 |
+
e.g.
|
| 48 |
+
```python
|
| 49 |
+
# input
|
| 50 |
+
[[1, 1, 2, 2, 2, 0]]
|
| 51 |
+
# output
|
| 52 |
+
[
|
| 53 |
+
[
|
| 54 |
+
[
|
| 55 |
+
[o, x, x, x, x, x],
|
| 56 |
+
[o, o, x, x, x, x],
|
| 57 |
+
[x, x, o, x, x, x],
|
| 58 |
+
[x, x, o, o, x, x],
|
| 59 |
+
[x, x, o, o, o, x],
|
| 60 |
+
[x, x, x, x, x, x],
|
| 61 |
+
]
|
| 62 |
+
]
|
| 63 |
+
]
|
| 64 |
+
```
|
| 65 |
+
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
| 66 |
+
"""
|
| 67 |
+
_, seq_len = attention_mask_with_indices.size()
|
| 68 |
+
min_dtype = torch.finfo(dtype).min
|
| 69 |
+
zero_tensor = torch.tensor(0, dtype=dtype)
|
| 70 |
+
|
| 71 |
+
# Create a non-padding mask.
|
| 72 |
+
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
| 73 |
+
# Create indices for comparison.
|
| 74 |
+
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
|
| 75 |
+
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
|
| 76 |
+
# Create a lower triangular mask.
|
| 77 |
+
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
|
| 78 |
+
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
|
| 79 |
+
# Invert the attention mask.
|
| 80 |
+
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
|
| 81 |
+
return attention_mask_4d
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
| 86 |
+
r"""Data collator that supports VLMs.
|
| 87 |
+
|
| 88 |
+
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
template: Optional["Template"] = None
|
| 92 |
+
processor: Optional["ProcessorMixin"] = None
|
| 93 |
+
|
| 94 |
+
def __post_init__(self):
|
| 95 |
+
if self.template is None:
|
| 96 |
+
raise ValueError("Template is required for MultiModalDataCollator.")
|
| 97 |
+
|
| 98 |
+
if isinstance(self.model, PeftModel):
|
| 99 |
+
self.model = self.model.base_model.model
|
| 100 |
+
|
| 101 |
+
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
| 102 |
+
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
|
| 103 |
+
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
|
| 104 |
+
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
|
| 105 |
+
else:
|
| 106 |
+
self.get_rope_func = None
|
| 107 |
+
|
| 108 |
+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
| 109 |
+
batch_images, batch_videos, batch_audios = [], [], []
|
| 110 |
+
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
| 111 |
+
for feature in features:
|
| 112 |
+
images = feature.pop("images", None) or []
|
| 113 |
+
videos = feature.pop("videos", None) or []
|
| 114 |
+
audios = feature.pop("audios", None) or []
|
| 115 |
+
batch_images.extend(images)
|
| 116 |
+
batch_videos.extend(videos)
|
| 117 |
+
batch_audios.extend(audios)
|
| 118 |
+
batch_imglens.append(len(images))
|
| 119 |
+
batch_vidlens.append(len(videos))
|
| 120 |
+
batch_audlens.append(len(audios))
|
| 121 |
+
batch_input_ids.append(feature["input_ids"])
|
| 122 |
+
|
| 123 |
+
fake_input_ids = []
|
| 124 |
+
if (
|
| 125 |
+
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
| 126 |
+
): # avoid process hanging in zero3/fsdp case
|
| 127 |
+
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
| 128 |
+
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
| 129 |
+
fake_messages = self.template.mm_plugin.process_messages(
|
| 130 |
+
fake_messages, fake_images, [], [], self.processor
|
| 131 |
+
)
|
| 132 |
+
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
| 133 |
+
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
| 134 |
+
_fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
|
| 135 |
+
)
|
| 136 |
+
fake_input_ids.extend(_fake_input_ids)
|
| 137 |
+
batch_images = fake_images
|
| 138 |
+
batch_imglens[0] = 1
|
| 139 |
+
|
| 140 |
+
if (
|
| 141 |
+
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
| 142 |
+
): # avoid process hanging in zero3/fsdp case
|
| 143 |
+
fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
|
| 144 |
+
fake_audios = [np.zeros(1600)]
|
| 145 |
+
fake_messages = self.template.mm_plugin.process_messages(
|
| 146 |
+
fake_messages, [], [], fake_audios, self.processor
|
| 147 |
+
)
|
| 148 |
+
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
| 149 |
+
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
|
| 150 |
+
_fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
|
| 151 |
+
)
|
| 152 |
+
fake_input_ids.extend(_fake_input_ids)
|
| 153 |
+
batch_audios = fake_audios
|
| 154 |
+
batch_audlens[0] = 1
|
| 155 |
+
|
| 156 |
+
if len(fake_input_ids) != 0:
|
| 157 |
+
if self.tokenizer.padding_side == "right":
|
| 158 |
+
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
| 159 |
+
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
| 160 |
+
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
| 161 |
+
else:
|
| 162 |
+
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
|
| 163 |
+
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
|
| 164 |
+
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
| 165 |
+
|
| 166 |
+
batch_input_ids[0] = features[0]["input_ids"]
|
| 167 |
+
|
| 168 |
+
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
| 169 |
+
batch_images,
|
| 170 |
+
batch_videos,
|
| 171 |
+
batch_audios,
|
| 172 |
+
batch_imglens,
|
| 173 |
+
batch_vidlens,
|
| 174 |
+
batch_audlens,
|
| 175 |
+
batch_input_ids,
|
| 176 |
+
self.processor,
|
| 177 |
+
)
|
| 178 |
+
if "token_type_ids" in mm_inputs:
|
| 179 |
+
token_type_ids = mm_inputs.pop("token_type_ids")
|
| 180 |
+
for i, feature in enumerate(features):
|
| 181 |
+
feature["token_type_ids"] = token_type_ids[i]
|
| 182 |
+
|
| 183 |
+
features: dict[str, torch.Tensor] = super().__call__(features)
|
| 184 |
+
|
| 185 |
+
if self.get_rope_func is not None:
|
| 186 |
+
rope_index_kwargs = {
|
| 187 |
+
"input_ids": features["input_ids"],
|
| 188 |
+
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
| 189 |
+
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
| 190 |
+
"attention_mask": (features["attention_mask"] >= 1).float(),
|
| 191 |
+
}
|
| 192 |
+
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
| 193 |
+
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
| 194 |
+
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
| 195 |
+
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
| 196 |
+
|
| 197 |
+
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
| 198 |
+
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
| 199 |
+
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
| 200 |
+
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
| 201 |
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 202 |
+
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
| 203 |
+
|
| 204 |
+
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
| 205 |
+
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
| 206 |
+
dim=-1
|
| 207 |
+
).unsqueeze(-1)
|
| 208 |
+
else: # for qwen vl
|
| 209 |
+
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
| 210 |
+
|
| 211 |
+
if (
|
| 212 |
+
self.model is not None
|
| 213 |
+
and getattr(self.model.config, "model_type", None)
|
| 214 |
+
in [
|
| 215 |
+
"glm4v",
|
| 216 |
+
"Keye",
|
| 217 |
+
"qwen2_vl",
|
| 218 |
+
"qwen2_5_vl",
|
| 219 |
+
"qwen2_5_omni_thinker",
|
| 220 |
+
"qwen3_omni_moe_thinker",
|
| 221 |
+
"qwen3_vl",
|
| 222 |
+
"qwen3_vl_moe",
|
| 223 |
+
]
|
| 224 |
+
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
| 225 |
+
):
|
| 226 |
+
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
| 227 |
+
|
| 228 |
+
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
| 229 |
+
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
| 230 |
+
seq_len = features["input_ids"].size(1)
|
| 231 |
+
orig_len = cross_attention_mask.size(1)
|
| 232 |
+
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
|
| 233 |
+
|
| 234 |
+
features.update(mm_inputs)
|
| 235 |
+
|
| 236 |
+
if "image_bound" in features: # for minicpmv inputs
|
| 237 |
+
bsz, seq_length = features["input_ids"].shape
|
| 238 |
+
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
|
| 239 |
+
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
|
| 240 |
+
|
| 241 |
+
return features
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@dataclass
|
| 245 |
+
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
| 246 |
+
r"""Data collator for 4d attention mask."""
|
| 247 |
+
|
| 248 |
+
block_diag_attn: bool = False
|
| 249 |
+
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
| 250 |
+
compute_dtype: "torch.dtype" = torch.float32
|
| 251 |
+
|
| 252 |
+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
| 253 |
+
features = super().__call__(features)
|
| 254 |
+
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
| 255 |
+
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
| 256 |
+
|
| 257 |
+
for key, value in features.items(): # cast data dtype for paligemma
|
| 258 |
+
if torch.is_tensor(value) and torch.is_floating_point(value):
|
| 259 |
+
features[key] = value.to(self.compute_dtype)
|
| 260 |
+
|
| 261 |
+
return features
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@dataclass
|
| 265 |
+
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
| 266 |
+
r"""Data collator for pairwise data."""
|
| 267 |
+
|
| 268 |
+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
| 269 |
+
r"""Pad batched data to the longest sequence in the batch.
|
| 270 |
+
|
| 271 |
+
We generate 2 * n examples where the first n examples represent chosen examples and
|
| 272 |
+
the last n examples represent rejected examples.
|
| 273 |
+
"""
|
| 274 |
+
concatenated_features = []
|
| 275 |
+
for key in ("chosen", "rejected"):
|
| 276 |
+
for feature in features:
|
| 277 |
+
target_feature = {
|
| 278 |
+
"input_ids": feature[f"{key}_input_ids"],
|
| 279 |
+
"attention_mask": feature[f"{key}_attention_mask"],
|
| 280 |
+
"labels": feature[f"{key}_labels"],
|
| 281 |
+
"images": feature["images"],
|
| 282 |
+
"videos": feature["videos"],
|
| 283 |
+
"audios": feature["audios"],
|
| 284 |
+
}
|
| 285 |
+
concatenated_features.append(target_feature)
|
| 286 |
+
|
| 287 |
+
return super().__call__(concatenated_features)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@dataclass
|
| 291 |
+
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
| 292 |
+
r"""Data collator for KTO data."""
|
| 293 |
+
|
| 294 |
+
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
| 295 |
+
target_features = []
|
| 296 |
+
kl_features = []
|
| 297 |
+
kto_tags = []
|
| 298 |
+
for feature in features:
|
| 299 |
+
target_feature = {
|
| 300 |
+
"input_ids": feature["input_ids"],
|
| 301 |
+
"attention_mask": feature["attention_mask"],
|
| 302 |
+
"labels": feature["labels"],
|
| 303 |
+
"images": feature["images"],
|
| 304 |
+
"videos": feature["videos"],
|
| 305 |
+
"audios": feature["audios"],
|
| 306 |
+
}
|
| 307 |
+
kl_feature = {
|
| 308 |
+
"input_ids": feature["kl_input_ids"],
|
| 309 |
+
"attention_mask": feature["kl_attention_mask"],
|
| 310 |
+
"labels": feature["kl_labels"],
|
| 311 |
+
"images": feature["images"],
|
| 312 |
+
"videos": feature["videos"],
|
| 313 |
+
"audios": feature["audios"],
|
| 314 |
+
}
|
| 315 |
+
target_features.append(target_feature)
|
| 316 |
+
kl_features.append(kl_feature)
|
| 317 |
+
kto_tags.append(feature["kto_tags"])
|
| 318 |
+
|
| 319 |
+
batch = super().__call__(target_features)
|
| 320 |
+
kl_batch = super().__call__(kl_features)
|
| 321 |
+
batch["kl_input_ids"] = kl_batch["input_ids"]
|
| 322 |
+
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
| 323 |
+
batch["kl_labels"] = kl_batch["labels"]
|
| 324 |
+
if "cross_attention_mask" in kl_batch: # for mllama inputs
|
| 325 |
+
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
|
| 326 |
+
|
| 327 |
+
if "token_type_ids" in kl_batch:
|
| 328 |
+
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
| 329 |
+
|
| 330 |
+
batch["kto_tags"] = torch.tensor(kto_tags)
|
| 331 |
+
return batch
|
llamafactory/data/converter.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from abc import abstractmethod
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
| 19 |
+
|
| 20 |
+
from ..extras import logging
|
| 21 |
+
from .data_utils import Role
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from datasets import Dataset, IterableDataset
|
| 26 |
+
from transformers import Seq2SeqTrainingArguments
|
| 27 |
+
|
| 28 |
+
from ..hparams import DataArguments
|
| 29 |
+
from .mm_plugin import AudioInput, ImageInput, VideoInput
|
| 30 |
+
from .parser import DatasetAttr
|
| 31 |
+
|
| 32 |
+
MediaType = Union[ImageInput, VideoInput, AudioInput]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class DatasetConverter:
|
| 40 |
+
dataset_attr: "DatasetAttr"
|
| 41 |
+
data_args: "DataArguments"
|
| 42 |
+
|
| 43 |
+
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
| 44 |
+
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
| 45 |
+
if medias is None:
|
| 46 |
+
return None
|
| 47 |
+
elif not isinstance(medias, list):
|
| 48 |
+
medias = [medias]
|
| 49 |
+
elif len(medias) == 0:
|
| 50 |
+
return None
|
| 51 |
+
else:
|
| 52 |
+
medias = medias[:]
|
| 53 |
+
|
| 54 |
+
if self.dataset_attr.load_from in ["script", "file"]:
|
| 55 |
+
if isinstance(medias[0], str):
|
| 56 |
+
for i in range(len(medias)):
|
| 57 |
+
media_path = os.path.join(self.data_args.media_dir, medias[i])
|
| 58 |
+
if os.path.isfile(media_path):
|
| 59 |
+
medias[i] = media_path
|
| 60 |
+
else:
|
| 61 |
+
logger.warning_rank0_once(
|
| 62 |
+
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
|
| 63 |
+
)
|
| 64 |
+
elif isinstance(medias[0], list): # for processed video frames
|
| 65 |
+
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
|
| 66 |
+
for i in range(len(medias)):
|
| 67 |
+
for j in range(len(medias[i])):
|
| 68 |
+
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
|
| 69 |
+
if os.path.isfile(media_path):
|
| 70 |
+
medias[i][j] = media_path
|
| 71 |
+
else:
|
| 72 |
+
logger.warning_rank0_once(
|
| 73 |
+
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return medias
|
| 77 |
+
|
| 78 |
+
@abstractmethod
|
| 79 |
+
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
| 80 |
+
r"""Convert a single example in the dataset to the standard format."""
|
| 81 |
+
...
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class AlpacaDatasetConverter(DatasetConverter):
|
| 86 |
+
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
| 87 |
+
prompt = []
|
| 88 |
+
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
|
| 89 |
+
for old_prompt, old_response in example[self.dataset_attr.history]:
|
| 90 |
+
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
| 91 |
+
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
| 92 |
+
|
| 93 |
+
query = []
|
| 94 |
+
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
|
| 95 |
+
query.append(example[self.dataset_attr.prompt])
|
| 96 |
+
|
| 97 |
+
if self.dataset_attr.query and example[self.dataset_attr.query]:
|
| 98 |
+
query.append(example[self.dataset_attr.query])
|
| 99 |
+
|
| 100 |
+
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
| 101 |
+
|
| 102 |
+
if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
| 103 |
+
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
| 104 |
+
if example[self.dataset_attr.kto_tag]:
|
| 105 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
| 106 |
+
else:
|
| 107 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
| 108 |
+
elif (
|
| 109 |
+
self.dataset_attr.ranking
|
| 110 |
+
and isinstance(example[self.dataset_attr.chosen], str)
|
| 111 |
+
and isinstance(example[self.dataset_attr.rejected], str)
|
| 112 |
+
): # pairwise example
|
| 113 |
+
response = [
|
| 114 |
+
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
|
| 115 |
+
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
|
| 116 |
+
]
|
| 117 |
+
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
|
| 118 |
+
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
|
| 119 |
+
else: # unsupervised
|
| 120 |
+
response = []
|
| 121 |
+
|
| 122 |
+
output = {
|
| 123 |
+
"_prompt": prompt,
|
| 124 |
+
"_response": response,
|
| 125 |
+
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
|
| 126 |
+
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
| 127 |
+
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
| 128 |
+
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
| 129 |
+
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
| 130 |
+
}
|
| 131 |
+
return output
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class SharegptDatasetConverter(DatasetConverter):
|
| 136 |
+
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
| 137 |
+
tag_mapping = {
|
| 138 |
+
self.dataset_attr.user_tag: Role.USER.value,
|
| 139 |
+
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
| 140 |
+
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
|
| 141 |
+
self.dataset_attr.function_tag: Role.FUNCTION.value,
|
| 142 |
+
self.dataset_attr.system_tag: Role.SYSTEM.value,
|
| 143 |
+
}
|
| 144 |
+
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
|
| 145 |
+
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
|
| 146 |
+
accept_tags = (odd_tags, even_tags)
|
| 147 |
+
messages = example[self.dataset_attr.messages]
|
| 148 |
+
if (
|
| 149 |
+
self.dataset_attr.system_tag
|
| 150 |
+
and len(messages) != 0
|
| 151 |
+
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
|
| 152 |
+
):
|
| 153 |
+
system = messages[0][self.dataset_attr.content_tag]
|
| 154 |
+
messages = messages[1:]
|
| 155 |
+
else:
|
| 156 |
+
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
|
| 157 |
+
|
| 158 |
+
aligned_messages = []
|
| 159 |
+
broken_data = False
|
| 160 |
+
for turn_idx, message in enumerate(messages):
|
| 161 |
+
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
| 162 |
+
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
| 163 |
+
broken_data = True
|
| 164 |
+
break
|
| 165 |
+
|
| 166 |
+
aligned_messages.append(
|
| 167 |
+
{
|
| 168 |
+
"role": tag_mapping[message[self.dataset_attr.role_tag]],
|
| 169 |
+
"content": message[self.dataset_attr.content_tag],
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
| 174 |
+
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
| 175 |
+
):
|
| 176 |
+
logger.warning_rank0(f"Invalid message count in {messages}.")
|
| 177 |
+
broken_data = True
|
| 178 |
+
|
| 179 |
+
if broken_data:
|
| 180 |
+
logger.warning_rank0("Skipping this abnormal example.")
|
| 181 |
+
prompt, response = [], []
|
| 182 |
+
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
| 183 |
+
prompt = aligned_messages[:-1]
|
| 184 |
+
response = aligned_messages[-1:]
|
| 185 |
+
if example[self.dataset_attr.kto_tag]:
|
| 186 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
| 187 |
+
else:
|
| 188 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
| 189 |
+
elif (
|
| 190 |
+
self.dataset_attr.ranking
|
| 191 |
+
and isinstance(example[self.dataset_attr.chosen], dict)
|
| 192 |
+
and isinstance(example[self.dataset_attr.rejected], dict)
|
| 193 |
+
): # pairwise example
|
| 194 |
+
chosen = example[self.dataset_attr.chosen]
|
| 195 |
+
rejected = example[self.dataset_attr.rejected]
|
| 196 |
+
if (
|
| 197 |
+
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
|
| 198 |
+
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
|
| 199 |
+
):
|
| 200 |
+
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
| 201 |
+
broken_data = True
|
| 202 |
+
|
| 203 |
+
prompt = aligned_messages
|
| 204 |
+
response = [
|
| 205 |
+
{
|
| 206 |
+
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
|
| 207 |
+
"content": chosen[self.dataset_attr.content_tag],
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
|
| 211 |
+
"content": rejected[self.dataset_attr.content_tag],
|
| 212 |
+
},
|
| 213 |
+
]
|
| 214 |
+
else: # normal example
|
| 215 |
+
prompt = aligned_messages[:-1]
|
| 216 |
+
response = aligned_messages[-1:]
|
| 217 |
+
|
| 218 |
+
output = {
|
| 219 |
+
"_prompt": prompt,
|
| 220 |
+
"_response": response,
|
| 221 |
+
"_system": system,
|
| 222 |
+
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
|
| 223 |
+
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
| 224 |
+
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
| 225 |
+
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
| 226 |
+
}
|
| 227 |
+
return output
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclass
|
| 231 |
+
class OpenAIDatasetConverter(DatasetConverter):
|
| 232 |
+
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
| 233 |
+
tag_mapping = {
|
| 234 |
+
self.dataset_attr.user_tag: Role.USER.value,
|
| 235 |
+
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
| 236 |
+
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
|
| 237 |
+
self.dataset_attr.function_tag: Role.FUNCTION.value,
|
| 238 |
+
self.dataset_attr.system_tag: Role.SYSTEM.value,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
messages = example[self.dataset_attr.messages]
|
| 242 |
+
if (
|
| 243 |
+
self.dataset_attr.system_tag
|
| 244 |
+
and len(messages) != 0
|
| 245 |
+
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
|
| 246 |
+
):
|
| 247 |
+
system = messages[0][self.dataset_attr.content_tag]
|
| 248 |
+
messages = messages[1:]
|
| 249 |
+
else:
|
| 250 |
+
system = example.get(self.dataset_attr.system, "") if self.dataset_attr.system else ""
|
| 251 |
+
|
| 252 |
+
aligned_messages = []
|
| 253 |
+
tool_responses = []
|
| 254 |
+
broken_data = False
|
| 255 |
+
for turn_idx, message in enumerate(messages):
|
| 256 |
+
role = message[self.dataset_attr.role_tag]
|
| 257 |
+
content = message[self.dataset_attr.content_tag]
|
| 258 |
+
|
| 259 |
+
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
|
| 260 |
+
if "tool_calls" in message and len(message["tool_calls"]) > 0:
|
| 261 |
+
tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
|
| 262 |
+
content = json.dumps(tool_calls_list, ensure_ascii=False)
|
| 263 |
+
role = self.dataset_attr.function_tag
|
| 264 |
+
|
| 265 |
+
if role == self.dataset_attr.observation_tag:
|
| 266 |
+
tool_responses.append(content)
|
| 267 |
+
continue
|
| 268 |
+
elif len(tool_responses) > 0:
|
| 269 |
+
_content = "\n</tool_response>\n<tool_response>\n".join(tool_responses)
|
| 270 |
+
aligned_messages.append(
|
| 271 |
+
{
|
| 272 |
+
"role": Role.OBSERVATION.value,
|
| 273 |
+
"content": _content,
|
| 274 |
+
}
|
| 275 |
+
)
|
| 276 |
+
tool_responses = []
|
| 277 |
+
|
| 278 |
+
aligned_messages.append(
|
| 279 |
+
{
|
| 280 |
+
"role": tag_mapping[role],
|
| 281 |
+
"content": content,
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
odd_tags = (Role.USER.value, Role.OBSERVATION.value)
|
| 286 |
+
even_tags = (Role.ASSISTANT.value, Role.FUNCTION.value)
|
| 287 |
+
accept_tags = (odd_tags, even_tags)
|
| 288 |
+
for turn_idx, message in enumerate(aligned_messages):
|
| 289 |
+
if message["role"] not in accept_tags[turn_idx % 2]:
|
| 290 |
+
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
| 291 |
+
broken_data = True
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
| 295 |
+
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
| 296 |
+
):
|
| 297 |
+
logger.warning_rank0(f"Invalid message count in {messages}.")
|
| 298 |
+
broken_data = True
|
| 299 |
+
|
| 300 |
+
if broken_data:
|
| 301 |
+
logger.warning_rank0("Skipping this abnormal example.")
|
| 302 |
+
prompt, response = [], []
|
| 303 |
+
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
|
| 304 |
+
prompt = aligned_messages[:-1]
|
| 305 |
+
response = aligned_messages[-1:]
|
| 306 |
+
if example[self.dataset_attr.kto_tag]:
|
| 307 |
+
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
| 308 |
+
else:
|
| 309 |
+
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
| 310 |
+
elif (
|
| 311 |
+
self.dataset_attr.ranking
|
| 312 |
+
and isinstance(example[self.dataset_attr.chosen], dict)
|
| 313 |
+
and isinstance(example[self.dataset_attr.rejected], dict)
|
| 314 |
+
): # pairwise example
|
| 315 |
+
chosen = example[self.dataset_attr.chosen]
|
| 316 |
+
rejected = example[self.dataset_attr.rejected]
|
| 317 |
+
if (
|
| 318 |
+
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
|
| 319 |
+
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
|
| 320 |
+
):
|
| 321 |
+
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
| 322 |
+
broken_data = True
|
| 323 |
+
|
| 324 |
+
prompt = aligned_messages
|
| 325 |
+
response = [
|
| 326 |
+
{
|
| 327 |
+
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
|
| 328 |
+
"content": chosen[self.dataset_attr.content_tag],
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
|
| 332 |
+
"content": rejected[self.dataset_attr.content_tag],
|
| 333 |
+
},
|
| 334 |
+
]
|
| 335 |
+
else: # normal example
|
| 336 |
+
prompt = aligned_messages[:-1]
|
| 337 |
+
response = aligned_messages[-1:]
|
| 338 |
+
|
| 339 |
+
tools = example.get(self.dataset_attr.tools, "") if self.dataset_attr.tools else ""
|
| 340 |
+
if isinstance(tools, dict) or isinstance(tools, list):
|
| 341 |
+
tools = json.dumps(tools, ensure_ascii=False)
|
| 342 |
+
|
| 343 |
+
short_system_prompt = "detailed thinking off"
|
| 344 |
+
if not system:
|
| 345 |
+
if not tools:
|
| 346 |
+
system = short_system_prompt
|
| 347 |
+
else:
|
| 348 |
+
pass
|
| 349 |
+
else:
|
| 350 |
+
if not tools:
|
| 351 |
+
if "detailed thinking on" in system or "detailed thinking off" in system:
|
| 352 |
+
pass
|
| 353 |
+
else:
|
| 354 |
+
system += "\n" + short_system_prompt
|
| 355 |
+
else:
|
| 356 |
+
system += "\n"
|
| 357 |
+
|
| 358 |
+
output = {
|
| 359 |
+
"_prompt": prompt,
|
| 360 |
+
"_response": response,
|
| 361 |
+
"_system": system,
|
| 362 |
+
"_tools": tools,
|
| 363 |
+
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
|
| 364 |
+
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
|
| 365 |
+
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
|
| 366 |
+
}
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
DATASET_CONVERTERS = {
|
| 371 |
+
"alpaca": AlpacaDatasetConverter,
|
| 372 |
+
"sharegpt": SharegptDatasetConverter,
|
| 373 |
+
"openai": OpenAIDatasetConverter,
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
|
| 378 |
+
r"""Register a new dataset converter."""
|
| 379 |
+
if name in DATASET_CONVERTERS:
|
| 380 |
+
raise ValueError(f"Dataset converter {name} already exists.")
|
| 381 |
+
|
| 382 |
+
DATASET_CONVERTERS[name] = dataset_converter
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
|
| 386 |
+
r"""Get a dataset converter."""
|
| 387 |
+
if name not in DATASET_CONVERTERS:
|
| 388 |
+
raise ValueError(f"Dataset converter {name} not found.")
|
| 389 |
+
|
| 390 |
+
return DATASET_CONVERTERS[name](dataset_attr, data_args)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def align_dataset(
|
| 394 |
+
dataset: Union["Dataset", "IterableDataset"],
|
| 395 |
+
dataset_attr: "DatasetAttr",
|
| 396 |
+
data_args: "DataArguments",
|
| 397 |
+
training_args: "Seq2SeqTrainingArguments",
|
| 398 |
+
) -> Union["Dataset", "IterableDataset"]:
|
| 399 |
+
r"""Align the dataset to a specific format.
|
| 400 |
+
|
| 401 |
+
Aligned dataset:
|
| 402 |
+
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
| 403 |
+
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
| 404 |
+
_system: "..."
|
| 405 |
+
_tools: "..."
|
| 406 |
+
_images: []
|
| 407 |
+
_videos: []
|
| 408 |
+
_audios: []
|
| 409 |
+
"""
|
| 410 |
+
column_names = list(next(iter(dataset)).keys())
|
| 411 |
+
kwargs = {}
|
| 412 |
+
if not data_args.streaming:
|
| 413 |
+
kwargs = dict(
|
| 414 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 415 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
| 416 |
+
desc="Converting format of dataset",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
|
| 420 |
+
return dataset.map(
|
| 421 |
+
dataset_converter,
|
| 422 |
+
batched=False,
|
| 423 |
+
remove_columns=column_names,
|
| 424 |
+
**kwargs,
|
| 425 |
+
)
|
llamafactory/data/data_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
from enum import Enum, unique
|
| 17 |
+
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
|
| 18 |
+
|
| 19 |
+
import fsspec
|
| 20 |
+
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
| 21 |
+
|
| 22 |
+
from ..extras import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from datasets import Dataset, IterableDataset
|
| 27 |
+
|
| 28 |
+
from ..hparams import DataArguments
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
SLOTS = list[Union[str, set[str], dict[str, str]]]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@unique
|
| 38 |
+
class Role(str, Enum):
|
| 39 |
+
USER = "user"
|
| 40 |
+
ASSISTANT = "assistant"
|
| 41 |
+
SYSTEM = "system"
|
| 42 |
+
FUNCTION = "function"
|
| 43 |
+
OBSERVATION = "observation"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DatasetModule(TypedDict):
|
| 47 |
+
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
| 48 |
+
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def merge_dataset(
|
| 52 |
+
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
| 53 |
+
) -> Union["Dataset", "IterableDataset"]:
|
| 54 |
+
r"""Merge multiple datasets to a unified dataset."""
|
| 55 |
+
if len(all_datasets) == 1:
|
| 56 |
+
return all_datasets[0]
|
| 57 |
+
|
| 58 |
+
elif data_args.mix_strategy == "concat":
|
| 59 |
+
if data_args.streaming:
|
| 60 |
+
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
|
| 61 |
+
|
| 62 |
+
return concatenate_datasets(all_datasets)
|
| 63 |
+
|
| 64 |
+
elif data_args.mix_strategy.startswith("interleave"):
|
| 65 |
+
if not data_args.streaming:
|
| 66 |
+
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
| 67 |
+
|
| 68 |
+
return interleave_datasets(
|
| 69 |
+
datasets=all_datasets,
|
| 70 |
+
probabilities=data_args.interleave_probs,
|
| 71 |
+
seed=seed,
|
| 72 |
+
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def split_dataset(
|
| 80 |
+
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
| 81 |
+
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
| 82 |
+
data_args: "DataArguments",
|
| 83 |
+
seed: int,
|
| 84 |
+
) -> "DatasetDict":
|
| 85 |
+
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
| 86 |
+
|
| 87 |
+
Support both map dataset and iterable dataset.
|
| 88 |
+
"""
|
| 89 |
+
if eval_dataset is not None and data_args.val_size > 1e-6:
|
| 90 |
+
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
| 91 |
+
|
| 92 |
+
dataset_dict = {}
|
| 93 |
+
if dataset is not None:
|
| 94 |
+
if data_args.streaming:
|
| 95 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
| 96 |
+
|
| 97 |
+
if data_args.val_size > 1e-6:
|
| 98 |
+
if data_args.streaming:
|
| 99 |
+
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
|
| 100 |
+
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
|
| 101 |
+
else:
|
| 102 |
+
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
| 103 |
+
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
|
| 104 |
+
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
| 105 |
+
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
|
| 106 |
+
else:
|
| 107 |
+
dataset_dict["train"] = dataset
|
| 108 |
+
|
| 109 |
+
if eval_dataset is not None:
|
| 110 |
+
if isinstance(eval_dataset, dict):
|
| 111 |
+
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
| 112 |
+
else:
|
| 113 |
+
if data_args.streaming:
|
| 114 |
+
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
| 115 |
+
|
| 116 |
+
dataset_dict["validation"] = eval_dataset
|
| 117 |
+
|
| 118 |
+
return DatasetDict(dataset_dict)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
| 122 |
+
r"""Convert dataset or dataset dict to dataset module."""
|
| 123 |
+
dataset_module: DatasetModule = {}
|
| 124 |
+
if isinstance(dataset, DatasetDict): # dataset dict
|
| 125 |
+
if "train" in dataset:
|
| 126 |
+
dataset_module["train_dataset"] = dataset["train"]
|
| 127 |
+
|
| 128 |
+
if "validation" in dataset:
|
| 129 |
+
dataset_module["eval_dataset"] = dataset["validation"]
|
| 130 |
+
else:
|
| 131 |
+
eval_dataset = {}
|
| 132 |
+
for key in dataset.keys():
|
| 133 |
+
if key.startswith("validation_"):
|
| 134 |
+
eval_dataset[key[len("validation_") :]] = dataset[key]
|
| 135 |
+
|
| 136 |
+
if len(eval_dataset):
|
| 137 |
+
dataset_module["eval_dataset"] = eval_dataset
|
| 138 |
+
|
| 139 |
+
else: # single dataset
|
| 140 |
+
dataset_module["train_dataset"] = dataset
|
| 141 |
+
|
| 142 |
+
return dataset_module
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
|
| 146 |
+
r"""Set up a filesystem object based on the path protocol."""
|
| 147 |
+
storage_options = {"anon": anon} if anon else {}
|
| 148 |
+
if path.startswith("s3://"):
|
| 149 |
+
fs = fsspec.filesystem("s3", **storage_options)
|
| 150 |
+
elif path.startswith(("gs://", "gcs://")):
|
| 151 |
+
fs = fsspec.filesystem("gcs", **storage_options)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
|
| 154 |
+
|
| 155 |
+
if not fs.exists(path):
|
| 156 |
+
raise ValueError(f"Path does not exist: {path}.")
|
| 157 |
+
|
| 158 |
+
return fs
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
|
| 162 |
+
r"""Helper function to read JSON/JSONL files using fsspec."""
|
| 163 |
+
with fs.open(path, "r") as f:
|
| 164 |
+
if path.endswith(".jsonl"):
|
| 165 |
+
return [json.loads(line) for line in f if line.strip()]
|
| 166 |
+
else:
|
| 167 |
+
return json.load(f)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def read_cloud_json(cloud_path: str) -> list[Any]:
|
| 171 |
+
r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
cloud_path: str
|
| 175 |
+
Cloud path in the format:
|
| 176 |
+
- 's3://bucket-name/file.json' for AWS S3
|
| 177 |
+
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
|
| 181 |
+
except Exception:
|
| 182 |
+
fs = setup_fs(cloud_path) # try again with credentials
|
| 183 |
+
|
| 184 |
+
# filter out non-JSON files
|
| 185 |
+
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
| 186 |
+
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
| 187 |
+
if not files:
|
| 188 |
+
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
| 189 |
+
|
| 190 |
+
return sum([_read_json_with_fs(fs, file) for file in files], [])
|
llamafactory/data/formatter.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import re
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
from typing_extensions import override
|
| 22 |
+
|
| 23 |
+
from .data_utils import SLOTS
|
| 24 |
+
from .tool_utils import FunctionCall, get_tool_utils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Formatter(ABC):
|
| 29 |
+
slots: SLOTS = field(default_factory=list)
|
| 30 |
+
tool_format: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def apply(self, **kwargs) -> SLOTS:
|
| 34 |
+
r"""Forms a list of slots according to the inputs to encode."""
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
| 38 |
+
r"""Extract a list of tuples from the response message if using tools.
|
| 39 |
+
|
| 40 |
+
Each tuple consists of function name and function arguments.
|
| 41 |
+
"""
|
| 42 |
+
raise NotImplementedError
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class EmptyFormatter(Formatter):
|
| 47 |
+
def __post_init__(self):
|
| 48 |
+
has_placeholder = False
|
| 49 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
| 50 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
| 51 |
+
has_placeholder = True
|
| 52 |
+
|
| 53 |
+
if has_placeholder:
|
| 54 |
+
raise ValueError("Empty formatter should not contain any placeholder.")
|
| 55 |
+
|
| 56 |
+
@override
|
| 57 |
+
def apply(self, **kwargs) -> SLOTS:
|
| 58 |
+
return self.slots
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class StringFormatter(Formatter):
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
has_placeholder = False
|
| 65 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
| 66 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
| 67 |
+
has_placeholder = True
|
| 68 |
+
|
| 69 |
+
if not has_placeholder:
|
| 70 |
+
raise ValueError("A placeholder is required in the string formatter.")
|
| 71 |
+
|
| 72 |
+
@override
|
| 73 |
+
def apply(self, **kwargs) -> SLOTS:
|
| 74 |
+
elements = []
|
| 75 |
+
for slot in self.slots:
|
| 76 |
+
if isinstance(slot, str):
|
| 77 |
+
for name, value in kwargs.items():
|
| 78 |
+
if not isinstance(value, str):
|
| 79 |
+
raise RuntimeError(f"Expected a string, got {value}")
|
| 80 |
+
|
| 81 |
+
slot = slot.replace("{{" + name + "}}", value, 1)
|
| 82 |
+
elements.append(slot)
|
| 83 |
+
elif isinstance(slot, (dict, set)):
|
| 84 |
+
elements.append(slot)
|
| 85 |
+
else:
|
| 86 |
+
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
|
| 87 |
+
|
| 88 |
+
return elements
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class FunctionFormatter(StringFormatter):
|
| 93 |
+
def __post_init__(self):
|
| 94 |
+
super().__post_init__()
|
| 95 |
+
self.tool_utils = get_tool_utils(self.tool_format)
|
| 96 |
+
|
| 97 |
+
@override
|
| 98 |
+
def apply(self, **kwargs) -> SLOTS:
|
| 99 |
+
content: str = kwargs.pop("content")
|
| 100 |
+
thought_words, thought = kwargs.pop("thought_words", None), None
|
| 101 |
+
if thought_words and len(thought_words) == 2:
|
| 102 |
+
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
|
| 103 |
+
thought = re.search(regex, content)
|
| 104 |
+
|
| 105 |
+
if thought:
|
| 106 |
+
content = content.replace(thought.group(0), "")
|
| 107 |
+
|
| 108 |
+
functions: list[FunctionCall] = []
|
| 109 |
+
try:
|
| 110 |
+
tool_calls = json.loads(content)
|
| 111 |
+
if not isinstance(tool_calls, list): # parallel function call
|
| 112 |
+
tool_calls = [tool_calls]
|
| 113 |
+
|
| 114 |
+
for tool_call in tool_calls:
|
| 115 |
+
functions.append(
|
| 116 |
+
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
except json.JSONDecodeError:
|
| 120 |
+
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
|
| 121 |
+
|
| 122 |
+
function_str = self.tool_utils.function_formatter(functions)
|
| 123 |
+
if thought:
|
| 124 |
+
function_str = thought.group(0) + function_str
|
| 125 |
+
|
| 126 |
+
return super().apply(content=function_str)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class ToolFormatter(Formatter):
|
| 131 |
+
def __post_init__(self):
|
| 132 |
+
self.tool_utils = get_tool_utils(self.tool_format)
|
| 133 |
+
|
| 134 |
+
@override
|
| 135 |
+
def apply(self, **kwargs) -> SLOTS:
|
| 136 |
+
content = kwargs.pop("content")
|
| 137 |
+
try:
|
| 138 |
+
tools = json.loads(content)
|
| 139 |
+
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
| 140 |
+
except json.JSONDecodeError:
|
| 141 |
+
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
| 142 |
+
|
| 143 |
+
@override
|
| 144 |
+
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
| 145 |
+
return self.tool_utils.tool_extractor(content)
|
llamafactory/data/loader.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
from datasets import Dataset, load_dataset, load_from_disk
|
| 20 |
+
|
| 21 |
+
from ..extras import logging
|
| 22 |
+
from ..extras.constants import FILEEXT2TYPE
|
| 23 |
+
from ..extras.misc import check_version, has_tokenized_data
|
| 24 |
+
from .converter import align_dataset
|
| 25 |
+
from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
|
| 26 |
+
from .parser import get_dataset_list
|
| 27 |
+
from .processor import (
|
| 28 |
+
FeedbackDatasetProcessor,
|
| 29 |
+
PackedSupervisedDatasetProcessor,
|
| 30 |
+
PairwiseDatasetProcessor,
|
| 31 |
+
PretrainDatasetProcessor,
|
| 32 |
+
SupervisedDatasetProcessor,
|
| 33 |
+
UnsupervisedDatasetProcessor,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from datasets import Dataset, IterableDataset
|
| 39 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
| 40 |
+
|
| 41 |
+
from ..hparams import DataArguments, ModelArguments
|
| 42 |
+
from .data_utils import DatasetModule
|
| 43 |
+
from .parser import DatasetAttr
|
| 44 |
+
from .processor import DatasetProcessor
|
| 45 |
+
from .template import Template
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _load_single_dataset(
|
| 52 |
+
dataset_attr: "DatasetAttr",
|
| 53 |
+
model_args: "ModelArguments",
|
| 54 |
+
data_args: "DataArguments",
|
| 55 |
+
training_args: "Seq2SeqTrainingArguments",
|
| 56 |
+
) -> Union["Dataset", "IterableDataset"]:
|
| 57 |
+
r"""Load a single dataset and aligns it to the standard format."""
|
| 58 |
+
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
| 59 |
+
data_path, data_name, data_dir, data_files = None, None, None, None
|
| 60 |
+
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
| 61 |
+
data_path = dataset_attr.dataset_name
|
| 62 |
+
data_name = dataset_attr.subset
|
| 63 |
+
data_dir = dataset_attr.folder
|
| 64 |
+
|
| 65 |
+
elif dataset_attr.load_from == "script":
|
| 66 |
+
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
| 67 |
+
data_name = dataset_attr.subset
|
| 68 |
+
data_dir = dataset_attr.folder
|
| 69 |
+
|
| 70 |
+
elif dataset_attr.load_from == "cloud_file":
|
| 71 |
+
data_path = dataset_attr.dataset_name
|
| 72 |
+
|
| 73 |
+
elif dataset_attr.load_from == "file":
|
| 74 |
+
data_files = []
|
| 75 |
+
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
| 76 |
+
if os.path.isdir(local_path): # is directory
|
| 77 |
+
for file_name in os.listdir(local_path):
|
| 78 |
+
data_files.append(os.path.join(local_path, file_name))
|
| 79 |
+
elif os.path.isfile(local_path): # is file
|
| 80 |
+
data_files.append(local_path)
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"File {local_path} not found.")
|
| 83 |
+
|
| 84 |
+
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
| 85 |
+
if data_path is None:
|
| 86 |
+
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
| 87 |
+
|
| 88 |
+
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
| 89 |
+
raise ValueError("File types should be identical.")
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
| 92 |
+
|
| 93 |
+
if dataset_attr.load_from == "ms_hub":
|
| 94 |
+
check_version("modelscope>=1.14.0", mandatory=True)
|
| 95 |
+
from modelscope import MsDataset # type: ignore
|
| 96 |
+
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
| 97 |
+
|
| 98 |
+
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
| 99 |
+
dataset = MsDataset.load(
|
| 100 |
+
dataset_name=data_path,
|
| 101 |
+
subset_name=data_name,
|
| 102 |
+
data_dir=data_dir,
|
| 103 |
+
data_files=data_files,
|
| 104 |
+
split=dataset_attr.split,
|
| 105 |
+
cache_dir=cache_dir,
|
| 106 |
+
token=model_args.ms_hub_token,
|
| 107 |
+
use_streaming=data_args.streaming,
|
| 108 |
+
)
|
| 109 |
+
if isinstance(dataset, MsDataset):
|
| 110 |
+
dataset = dataset.to_hf_dataset()
|
| 111 |
+
|
| 112 |
+
elif dataset_attr.load_from == "om_hub":
|
| 113 |
+
check_version("openmind>=0.8.0", mandatory=True)
|
| 114 |
+
from openmind import OmDataset # type: ignore
|
| 115 |
+
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
| 116 |
+
|
| 117 |
+
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
| 118 |
+
dataset = OmDataset.load_dataset(
|
| 119 |
+
path=data_path,
|
| 120 |
+
name=data_name,
|
| 121 |
+
data_dir=data_dir,
|
| 122 |
+
data_files=data_files,
|
| 123 |
+
split=dataset_attr.split,
|
| 124 |
+
cache_dir=cache_dir,
|
| 125 |
+
token=model_args.om_hub_token,
|
| 126 |
+
streaming=data_args.streaming,
|
| 127 |
+
)
|
| 128 |
+
elif dataset_attr.load_from == "cloud_file":
|
| 129 |
+
dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
|
| 130 |
+
else:
|
| 131 |
+
dataset = load_dataset(
|
| 132 |
+
path=data_path,
|
| 133 |
+
name=data_name,
|
| 134 |
+
data_dir=data_dir,
|
| 135 |
+
data_files=data_files,
|
| 136 |
+
split=dataset_attr.split,
|
| 137 |
+
cache_dir=model_args.cache_dir,
|
| 138 |
+
token=model_args.hf_hub_token,
|
| 139 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 140 |
+
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
| 141 |
+
)
|
| 142 |
+
if data_args.streaming and dataset_attr.load_from == "file":
|
| 143 |
+
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
|
| 144 |
+
|
| 145 |
+
if dataset_attr.num_samples is not None and not data_args.streaming:
|
| 146 |
+
target_num = dataset_attr.num_samples
|
| 147 |
+
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
| 148 |
+
target_num -= len(indexes)
|
| 149 |
+
if target_num > 0:
|
| 150 |
+
expand_indexes = np.random.choice(len(dataset), target_num)
|
| 151 |
+
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
| 152 |
+
|
| 153 |
+
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
| 154 |
+
dataset = dataset.select(indexes)
|
| 155 |
+
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
| 156 |
+
|
| 157 |
+
if data_args.max_samples is not None: # truncate dataset
|
| 158 |
+
max_samples = min(data_args.max_samples, len(dataset))
|
| 159 |
+
dataset = dataset.select(range(max_samples))
|
| 160 |
+
|
| 161 |
+
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _get_merged_dataset(
|
| 165 |
+
dataset_names: Optional[list[str]],
|
| 166 |
+
model_args: "ModelArguments",
|
| 167 |
+
data_args: "DataArguments",
|
| 168 |
+
training_args: "Seq2SeqTrainingArguments",
|
| 169 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
| 170 |
+
return_dict: bool = False,
|
| 171 |
+
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
| 172 |
+
r"""Return the merged datasets in the standard format."""
|
| 173 |
+
if dataset_names is None:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
datasets = {}
|
| 177 |
+
for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
|
| 178 |
+
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
| 179 |
+
raise ValueError("The dataset is not applicable in the current training stage.")
|
| 180 |
+
|
| 181 |
+
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
|
| 182 |
+
|
| 183 |
+
if return_dict:
|
| 184 |
+
return datasets
|
| 185 |
+
else:
|
| 186 |
+
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _get_dataset_processor(
|
| 190 |
+
data_args: "DataArguments",
|
| 191 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
| 192 |
+
template: "Template",
|
| 193 |
+
tokenizer: "PreTrainedTokenizer",
|
| 194 |
+
processor: Optional["ProcessorMixin"],
|
| 195 |
+
do_generate: bool = False,
|
| 196 |
+
) -> "DatasetProcessor":
|
| 197 |
+
r"""Return the corresponding dataset processor."""
|
| 198 |
+
if stage == "pt":
|
| 199 |
+
dataset_processor_class = PretrainDatasetProcessor
|
| 200 |
+
elif stage == "sft" and not do_generate:
|
| 201 |
+
if data_args.packing:
|
| 202 |
+
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
| 203 |
+
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
| 204 |
+
|
| 205 |
+
def __init__(self, data, **kwargs):
|
| 206 |
+
return TypedSequence.__init__(
|
| 207 |
+
self,
|
| 208 |
+
data,
|
| 209 |
+
type=kwargs.pop("type", None),
|
| 210 |
+
try_type=kwargs.pop("try_type", None),
|
| 211 |
+
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
OptimizedTypedSequence.__init__ = __init__
|
| 215 |
+
dataset_processor_class = PackedSupervisedDatasetProcessor
|
| 216 |
+
else:
|
| 217 |
+
dataset_processor_class = SupervisedDatasetProcessor
|
| 218 |
+
|
| 219 |
+
elif stage == "rm":
|
| 220 |
+
dataset_processor_class = PairwiseDatasetProcessor
|
| 221 |
+
elif stage == "kto":
|
| 222 |
+
dataset_processor_class = FeedbackDatasetProcessor
|
| 223 |
+
else:
|
| 224 |
+
dataset_processor_class = UnsupervisedDatasetProcessor
|
| 225 |
+
|
| 226 |
+
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _get_preprocessed_dataset(
|
| 230 |
+
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
| 231 |
+
data_args: "DataArguments",
|
| 232 |
+
training_args: "Seq2SeqTrainingArguments",
|
| 233 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
| 234 |
+
template: "Template",
|
| 235 |
+
tokenizer: "PreTrainedTokenizer",
|
| 236 |
+
processor: Optional["ProcessorMixin"] = None,
|
| 237 |
+
is_eval: bool = False,
|
| 238 |
+
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
| 239 |
+
r"""Preprocesses the dataset, including format checking and tokenization."""
|
| 240 |
+
if dataset is None:
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
dataset_processor = _get_dataset_processor(
|
| 244 |
+
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
| 245 |
+
)
|
| 246 |
+
column_names = list(next(iter(dataset)).keys())
|
| 247 |
+
kwargs = {}
|
| 248 |
+
if not data_args.streaming:
|
| 249 |
+
kwargs = dict(
|
| 250 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 251 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
| 252 |
+
desc="Running tokenizer on dataset",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
dataset = dataset.map(
|
| 256 |
+
dataset_processor.preprocess_dataset,
|
| 257 |
+
batched=True,
|
| 258 |
+
batch_size=data_args.preprocessing_batch_size,
|
| 259 |
+
remove_columns=column_names,
|
| 260 |
+
**kwargs,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if training_args.should_log:
|
| 264 |
+
try:
|
| 265 |
+
print("eval example:" if is_eval else "training example:")
|
| 266 |
+
dataset_processor.print_data_example(next(iter(dataset)))
|
| 267 |
+
except StopIteration:
|
| 268 |
+
if stage == "pt":
|
| 269 |
+
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
| 270 |
+
else:
|
| 271 |
+
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
| 272 |
+
|
| 273 |
+
return dataset
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_dataset(
|
| 277 |
+
template: "Template",
|
| 278 |
+
model_args: "ModelArguments",
|
| 279 |
+
data_args: "DataArguments",
|
| 280 |
+
training_args: "Seq2SeqTrainingArguments",
|
| 281 |
+
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
| 282 |
+
tokenizer: "PreTrainedTokenizer",
|
| 283 |
+
processor: Optional["ProcessorMixin"] = None,
|
| 284 |
+
) -> "DatasetModule":
|
| 285 |
+
r"""Get the train dataset and optionally gets the evaluation dataset."""
|
| 286 |
+
# Load tokenized dataset if path exists
|
| 287 |
+
if data_args.tokenized_path is not None:
|
| 288 |
+
if has_tokenized_data(data_args.tokenized_path):
|
| 289 |
+
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
| 290 |
+
tokenized_data = load_from_disk(data_args.tokenized_path)
|
| 291 |
+
dataset_module = get_dataset_module(tokenized_data)
|
| 292 |
+
if data_args.streaming:
|
| 293 |
+
dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset()
|
| 294 |
+
|
| 295 |
+
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
| 296 |
+
return dataset_module
|
| 297 |
+
|
| 298 |
+
if data_args.streaming:
|
| 299 |
+
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
| 300 |
+
|
| 301 |
+
# Load and preprocess dataset
|
| 302 |
+
with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
|
| 303 |
+
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
| 304 |
+
eval_dataset = _get_merged_dataset(
|
| 305 |
+
data_args.eval_dataset,
|
| 306 |
+
model_args,
|
| 307 |
+
data_args,
|
| 308 |
+
training_args,
|
| 309 |
+
stage,
|
| 310 |
+
return_dict=data_args.eval_on_each_dataset,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
| 314 |
+
dataset = _get_preprocessed_dataset(
|
| 315 |
+
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
| 316 |
+
)
|
| 317 |
+
if isinstance(eval_dataset, dict):
|
| 318 |
+
for eval_name, eval_data in eval_dataset.items():
|
| 319 |
+
eval_dataset[eval_name] = _get_preprocessed_dataset(
|
| 320 |
+
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
eval_dataset = _get_preprocessed_dataset(
|
| 324 |
+
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
| 328 |
+
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
| 329 |
+
if training_args.should_save:
|
| 330 |
+
dataset_dict.save_to_disk(data_args.tokenized_path)
|
| 331 |
+
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
|
| 332 |
+
logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")
|
| 333 |
+
|
| 334 |
+
return get_dataset_module(dataset_dict)
|
llamafactory/data/mm_plugin.py
ADDED
|
@@ -0,0 +1,2082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# This code is inspired by the HuggingFace's Transformers library.
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
from copy import deepcopy
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from io import BytesIO
|
| 25 |
+
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
| 30 |
+
from transformers.models.mllama.processing_mllama import (
|
| 31 |
+
convert_sparse_cross_attention_mask_to_dense,
|
| 32 |
+
get_cross_attention_token_mask,
|
| 33 |
+
)
|
| 34 |
+
from typing_extensions import NotRequired, override
|
| 35 |
+
|
| 36 |
+
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
| 37 |
+
from ..extras.packages import (
|
| 38 |
+
is_librosa_available,
|
| 39 |
+
is_pillow_available,
|
| 40 |
+
is_pyav_available,
|
| 41 |
+
is_transformers_version_greater_than,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_librosa_available():
|
| 46 |
+
import librosa
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_pillow_available():
|
| 50 |
+
from PIL import Image
|
| 51 |
+
from PIL.Image import Image as ImageObject
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_pyav_available():
|
| 55 |
+
import av
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if is_transformers_version_greater_than("4.52.0"):
|
| 59 |
+
from transformers.image_utils import make_flat_list_of_images
|
| 60 |
+
from transformers.video_utils import make_batched_videos
|
| 61 |
+
else:
|
| 62 |
+
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if TYPE_CHECKING:
|
| 66 |
+
from av.stream import Stream
|
| 67 |
+
from numpy.typing import NDArray
|
| 68 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
| 69 |
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 70 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 71 |
+
from transformers.video_processing_utils import BaseVideoProcessor
|
| 72 |
+
|
| 73 |
+
class EncodedImage(TypedDict):
|
| 74 |
+
path: Optional[str]
|
| 75 |
+
bytes: Optional[bytes]
|
| 76 |
+
|
| 77 |
+
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
| 78 |
+
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
| 79 |
+
AudioInput = Union[str, BinaryIO, NDArray]
|
| 80 |
+
|
| 81 |
+
class RegularizedImageOutput(TypedDict):
|
| 82 |
+
images: list[ImageObject]
|
| 83 |
+
|
| 84 |
+
class RegularizedVideoOutput(TypedDict):
|
| 85 |
+
videos: list[list[ImageObject]]
|
| 86 |
+
durations: list[float]
|
| 87 |
+
fps_per_video: NotRequired[list[float]]
|
| 88 |
+
|
| 89 |
+
class RegularizedAudioOutput(TypedDict):
|
| 90 |
+
audios: list[NDArray]
|
| 91 |
+
sampling_rates: list[float]
|
| 92 |
+
|
| 93 |
+
class MMProcessor(ProcessorMixin):
|
| 94 |
+
patch_size: int
|
| 95 |
+
image_seq_length: int
|
| 96 |
+
num_additional_image_tokens: int
|
| 97 |
+
vision_feature_select_strategy: Literal["default", "full"]
|
| 98 |
+
|
| 99 |
+
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
| 104 |
+
r"""Get paligemma token type ids for computing loss.
|
| 105 |
+
|
| 106 |
+
It is slightly different with the original token type ids where the prompt part is 0.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
batch_token_type_ids: shape (batch_size, seq_length)
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
batch_token_type_ids = []
|
| 113 |
+
for imglen, seqlen in zip(imglens, seqlens):
|
| 114 |
+
image_seqlen = imglen * processor.image_seq_length
|
| 115 |
+
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
| 116 |
+
|
| 117 |
+
return batch_token_type_ids
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
|
| 121 |
+
r"""Get gemma3 token type ids for computing loss.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
batch_token_type_ids: shape (batch_size, seq_length)
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
image_token_id: int = getattr(processor, "image_token_id")
|
| 128 |
+
batch_token_type_ids = []
|
| 129 |
+
for token_ids in batch_ids:
|
| 130 |
+
token_ids = np.array(token_ids)
|
| 131 |
+
token_type_ids = np.zeros_like(token_ids)
|
| 132 |
+
token_type_ids[token_ids == image_token_id] = 1
|
| 133 |
+
batch_token_type_ids.append(token_type_ids.tolist())
|
| 134 |
+
|
| 135 |
+
return batch_token_type_ids
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
|
| 139 |
+
r"""Make nested list of images."""
|
| 140 |
+
batch_images = []
|
| 141 |
+
for imglen in imglens:
|
| 142 |
+
batch_images.append(images[:imglen])
|
| 143 |
+
images = images[imglen:]
|
| 144 |
+
|
| 145 |
+
return batch_images
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
| 149 |
+
r"""Check if the video is nested images."""
|
| 150 |
+
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@dataclass
|
| 154 |
+
class MMPluginMixin:
|
| 155 |
+
image_token: Optional[str]
|
| 156 |
+
video_token: Optional[str]
|
| 157 |
+
audio_token: Optional[str]
|
| 158 |
+
expand_mm_tokens: bool = True
|
| 159 |
+
|
| 160 |
+
def _validate_input(
|
| 161 |
+
self,
|
| 162 |
+
processor: Optional["MMProcessor"],
|
| 163 |
+
images: list["ImageInput"],
|
| 164 |
+
videos: list["VideoInput"],
|
| 165 |
+
audios: list["AudioInput"],
|
| 166 |
+
) -> None:
|
| 167 |
+
r"""Validate if this model accepts the input modalities."""
|
| 168 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 169 |
+
video_processor: BaseImageProcessor = getattr(
|
| 170 |
+
processor, "video_processor", getattr(processor, "image_processor", None)
|
| 171 |
+
)
|
| 172 |
+
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
| 173 |
+
if len(images) != 0 and self.image_token is None:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
"This model does not support image input. Please check whether the correct `template` is used."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if len(videos) != 0 and self.video_token is None:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"This model does not support video input. Please check whether the correct `template` is used."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if len(audios) != 0 and self.audio_token is None:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"This model does not support audio input. Please check whether the correct `template` is used."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if self.image_token is not None and processor is None:
|
| 189 |
+
raise ValueError("Processor was not found, please check and update your model file.")
|
| 190 |
+
|
| 191 |
+
if self.image_token is not None and image_processor is None:
|
| 192 |
+
raise ValueError("Image processor was not found, please check and update your model file.")
|
| 193 |
+
|
| 194 |
+
if self.video_token is not None and video_processor is None:
|
| 195 |
+
raise ValueError("Video processor was not found, please check and update your model file.")
|
| 196 |
+
|
| 197 |
+
if self.audio_token is not None and feature_extractor is None:
|
| 198 |
+
raise ValueError("Audio feature extractor was not found, please check and update your model file.")
|
| 199 |
+
|
| 200 |
+
def _validate_messages(
|
| 201 |
+
self,
|
| 202 |
+
messages: list[dict[str, str]],
|
| 203 |
+
images: list["ImageInput"],
|
| 204 |
+
videos: list["VideoInput"],
|
| 205 |
+
audios: list["AudioInput"],
|
| 206 |
+
):
|
| 207 |
+
r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
|
| 208 |
+
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
| 209 |
+
for message in messages:
|
| 210 |
+
num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
|
| 211 |
+
num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
|
| 212 |
+
num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
|
| 213 |
+
|
| 214 |
+
if len(images) != num_image_tokens:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if len(videos) != num_video_tokens:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if len(audios) != num_audio_tokens:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _preprocess_image(
|
| 230 |
+
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
| 231 |
+
) -> "ImageObject":
|
| 232 |
+
r"""Pre-process a single image."""
|
| 233 |
+
if (image.width * image.height) > image_max_pixels:
|
| 234 |
+
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
|
| 235 |
+
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
| 236 |
+
image = image.resize((width, height))
|
| 237 |
+
|
| 238 |
+
if (image.width * image.height) < image_min_pixels:
|
| 239 |
+
resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
|
| 240 |
+
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
| 241 |
+
image = image.resize((width, height))
|
| 242 |
+
|
| 243 |
+
if image.mode != "RGB":
|
| 244 |
+
image = image.convert("RGB")
|
| 245 |
+
|
| 246 |
+
return image
|
| 247 |
+
|
| 248 |
+
def _get_video_sample_indices(
|
| 249 |
+
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
|
| 250 |
+
) -> list[int]:
|
| 251 |
+
r"""Compute video sample indices according to fps."""
|
| 252 |
+
total_frames = video_stream.frames
|
| 253 |
+
if total_frames == 0: # infinite video
|
| 254 |
+
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
| 255 |
+
|
| 256 |
+
sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps))
|
| 257 |
+
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
| 258 |
+
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
| 259 |
+
|
| 260 |
+
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
| 261 |
+
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
| 262 |
+
results = []
|
| 263 |
+
for image in images:
|
| 264 |
+
if isinstance(image, (str, BinaryIO)):
|
| 265 |
+
image = Image.open(image)
|
| 266 |
+
elif isinstance(image, bytes):
|
| 267 |
+
image = Image.open(BytesIO(image))
|
| 268 |
+
elif isinstance(image, dict):
|
| 269 |
+
if image["bytes"] is not None:
|
| 270 |
+
image = Image.open(BytesIO(image["bytes"]))
|
| 271 |
+
else:
|
| 272 |
+
image = Image.open(image["path"])
|
| 273 |
+
|
| 274 |
+
if not isinstance(image, ImageObject):
|
| 275 |
+
raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
|
| 276 |
+
|
| 277 |
+
results.append(self._preprocess_image(image, **kwargs))
|
| 278 |
+
|
| 279 |
+
return {"images": results}
|
| 280 |
+
|
| 281 |
+
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
| 282 |
+
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
| 283 |
+
results = []
|
| 284 |
+
durations = []
|
| 285 |
+
for video in videos:
|
| 286 |
+
frames: list[ImageObject] = []
|
| 287 |
+
if _check_video_is_nested_images(video):
|
| 288 |
+
for frame in video:
|
| 289 |
+
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
| 290 |
+
raise ValueError("Invalid image found in video frames.")
|
| 291 |
+
frames = video
|
| 292 |
+
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
| 293 |
+
else:
|
| 294 |
+
container = av.open(video, "r")
|
| 295 |
+
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
| 296 |
+
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
| 297 |
+
container.seek(0)
|
| 298 |
+
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
| 299 |
+
if frame_idx in sample_indices:
|
| 300 |
+
frames.append(frame.to_image())
|
| 301 |
+
|
| 302 |
+
if video_stream.duration is None:
|
| 303 |
+
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
| 304 |
+
else:
|
| 305 |
+
durations.append(float(video_stream.duration * video_stream.time_base))
|
| 306 |
+
|
| 307 |
+
frames = self._regularize_images(frames, **kwargs)["images"]
|
| 308 |
+
results.append(frames)
|
| 309 |
+
|
| 310 |
+
return {"videos": results, "durations": durations}
|
| 311 |
+
|
| 312 |
+
def _regularize_audios(
|
| 313 |
+
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
| 314 |
+
) -> "RegularizedAudioOutput":
|
| 315 |
+
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
| 316 |
+
results, sampling_rates = [], []
|
| 317 |
+
for audio in audios:
|
| 318 |
+
if not isinstance(audio, np.ndarray):
|
| 319 |
+
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
| 320 |
+
|
| 321 |
+
results.append(audio)
|
| 322 |
+
sampling_rates.append(sampling_rate)
|
| 323 |
+
|
| 324 |
+
return {"audios": results, "sampling_rates": sampling_rates}
|
| 325 |
+
|
| 326 |
+
def _get_mm_inputs(
|
| 327 |
+
self,
|
| 328 |
+
images: list["ImageInput"],
|
| 329 |
+
videos: list["VideoInput"],
|
| 330 |
+
audios: list["AudioInput"],
|
| 331 |
+
processor: "MMProcessor",
|
| 332 |
+
imglens: Optional[list[int]] = None,
|
| 333 |
+
) -> dict[str, "torch.Tensor"]:
|
| 334 |
+
r"""Process visual inputs.
|
| 335 |
+
|
| 336 |
+
Returns: (llava and paligemma)
|
| 337 |
+
pixel_values: tensor with shape (B, C, H, W)
|
| 338 |
+
|
| 339 |
+
Returns: (qwen2-vl)
|
| 340 |
+
pixel_values: tensor with shape (num_patches, patch_dim)
|
| 341 |
+
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
| 342 |
+
where num_patches == torch.prod(image_grid_thw)
|
| 343 |
+
|
| 344 |
+
Returns: (mllama)
|
| 345 |
+
pixel_values: tensor with shape
|
| 346 |
+
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
| 347 |
+
For example, (2, 1, 4, 3, 560, 560).
|
| 348 |
+
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
| 349 |
+
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
| 350 |
+
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
| 351 |
+
|
| 352 |
+
"""
|
| 353 |
+
mm_inputs = {}
|
| 354 |
+
if len(images) != 0:
|
| 355 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 356 |
+
images = self._regularize_images(
|
| 357 |
+
images,
|
| 358 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 359 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 360 |
+
)["images"]
|
| 361 |
+
if imglens is not None: # if imglens are provided, make batched images
|
| 362 |
+
images = _make_batched_images(images, imglens)
|
| 363 |
+
|
| 364 |
+
image_processor_kwargs = {}
|
| 365 |
+
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
|
| 366 |
+
image_processor_kwargs.update(
|
| 367 |
+
{
|
| 368 |
+
"do_pan_and_scan": True,
|
| 369 |
+
"pan_and_scan_min_crop_size": 256,
|
| 370 |
+
"pan_and_scan_max_num_crops": 4,
|
| 371 |
+
"pan_and_scan_min_ratio_to_activate": 1.2,
|
| 372 |
+
}
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
| 376 |
+
|
| 377 |
+
if len(videos) != 0:
|
| 378 |
+
video_processor: BaseImageProcessor = getattr(
|
| 379 |
+
processor, "video_processor", getattr(processor, "image_processor", None)
|
| 380 |
+
)
|
| 381 |
+
videos = self._regularize_videos(
|
| 382 |
+
videos,
|
| 383 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 384 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 385 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 386 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 387 |
+
)["videos"]
|
| 388 |
+
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
| 389 |
+
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
| 390 |
+
else: # for llava_next_video
|
| 391 |
+
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
| 392 |
+
|
| 393 |
+
if len(audios) != 0:
|
| 394 |
+
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
| 395 |
+
audios = self._regularize_audios(
|
| 396 |
+
audios,
|
| 397 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 398 |
+
)["audios"]
|
| 399 |
+
mm_inputs.update(
|
| 400 |
+
feature_extractor(
|
| 401 |
+
audios,
|
| 402 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 403 |
+
return_attention_mask=True,
|
| 404 |
+
padding="max_length",
|
| 405 |
+
return_tensors="pt",
|
| 406 |
+
)
|
| 407 |
+
)
|
| 408 |
+
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
|
| 409 |
+
|
| 410 |
+
return mm_inputs
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@dataclass
|
| 414 |
+
class BasePlugin(MMPluginMixin):
|
| 415 |
+
def process_messages(
|
| 416 |
+
self,
|
| 417 |
+
messages: list[dict[str, str]],
|
| 418 |
+
images: list["ImageInput"],
|
| 419 |
+
videos: list["VideoInput"],
|
| 420 |
+
audios: list["AudioInput"],
|
| 421 |
+
processor: Optional["MMProcessor"],
|
| 422 |
+
) -> list[dict[str, str]]:
|
| 423 |
+
r"""Pre-process input messages before tokenization for VLMs."""
|
| 424 |
+
self._validate_input(processor, images, videos, audios)
|
| 425 |
+
return messages
|
| 426 |
+
|
| 427 |
+
def process_token_ids(
|
| 428 |
+
self,
|
| 429 |
+
input_ids: list[int],
|
| 430 |
+
labels: Optional[list[int]],
|
| 431 |
+
images: list["ImageInput"],
|
| 432 |
+
videos: list["VideoInput"],
|
| 433 |
+
audios: list["AudioInput"],
|
| 434 |
+
tokenizer: "PreTrainedTokenizer",
|
| 435 |
+
processor: Optional["MMProcessor"],
|
| 436 |
+
) -> tuple[list[int], Optional[list[int]]]:
|
| 437 |
+
r"""Pre-process token ids after tokenization for VLMs."""
|
| 438 |
+
self._validate_input(processor, images, videos, audios)
|
| 439 |
+
return input_ids, labels
|
| 440 |
+
|
| 441 |
+
def get_mm_inputs(
|
| 442 |
+
self,
|
| 443 |
+
images: list["ImageInput"],
|
| 444 |
+
videos: list["VideoInput"],
|
| 445 |
+
audios: list["AudioInput"],
|
| 446 |
+
imglens: list[int],
|
| 447 |
+
vidlens: list[int],
|
| 448 |
+
audlens: list[int],
|
| 449 |
+
batch_ids: list[list[int]],
|
| 450 |
+
processor: Optional["MMProcessor"],
|
| 451 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 452 |
+
r"""Build batched multimodal inputs for VLMs.
|
| 453 |
+
|
| 454 |
+
Arguments:
|
| 455 |
+
images: a list of image inputs, shape (num_images,)
|
| 456 |
+
videos: a list of video inputs, shape (num_videos,)
|
| 457 |
+
audios: a list of audio inputs, shape (num_audios,)
|
| 458 |
+
imglens: number of images in each sample, shape (batch_size,)
|
| 459 |
+
vidlens: number of videos in each sample, shape (batch_size,)
|
| 460 |
+
audlens: number of audios in each sample, shape (batch_size,)
|
| 461 |
+
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
| 462 |
+
processor: a processor for pre-processing images and videos
|
| 463 |
+
|
| 464 |
+
"""
|
| 465 |
+
self._validate_input(processor, images, videos, audios)
|
| 466 |
+
return self._get_mm_inputs(images, videos, audios, processor)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@dataclass
|
| 470 |
+
class Gemma3Plugin(BasePlugin):
|
| 471 |
+
@override
|
| 472 |
+
def process_messages(
|
| 473 |
+
self,
|
| 474 |
+
messages: list[dict[str, str]],
|
| 475 |
+
images: list["ImageInput"],
|
| 476 |
+
videos: list["VideoInput"],
|
| 477 |
+
audios: list["AudioInput"],
|
| 478 |
+
processor: Optional["MMProcessor"],
|
| 479 |
+
) -> list[dict[str, str]]:
|
| 480 |
+
self._validate_input(processor, images, videos, audios)
|
| 481 |
+
self._validate_messages(messages, images, videos, audios)
|
| 482 |
+
num_image_tokens = 0
|
| 483 |
+
messages = deepcopy(messages)
|
| 484 |
+
boi_token: str = getattr(processor, "boi_token")
|
| 485 |
+
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
| 486 |
+
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
| 487 |
+
|
| 488 |
+
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
|
| 489 |
+
if do_pan_and_scan:
|
| 490 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 491 |
+
|
| 492 |
+
for message in messages:
|
| 493 |
+
content = message["content"]
|
| 494 |
+
while IMAGE_PLACEHOLDER in content:
|
| 495 |
+
if do_pan_and_scan:
|
| 496 |
+
image_placeholder_str = (
|
| 497 |
+
"Here is the original image {{image}} and here are some crops to help you see better "
|
| 498 |
+
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
image_placeholder_str = "{{image}}"
|
| 502 |
+
|
| 503 |
+
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
|
| 504 |
+
num_image_tokens += 1
|
| 505 |
+
|
| 506 |
+
message["content"] = content.replace("{{image}}", image_str)
|
| 507 |
+
|
| 508 |
+
return messages
|
| 509 |
+
|
| 510 |
+
@override
|
| 511 |
+
def get_mm_inputs(
|
| 512 |
+
self,
|
| 513 |
+
images: list["ImageInput"],
|
| 514 |
+
videos: list["VideoInput"],
|
| 515 |
+
audios: list["AudioInput"],
|
| 516 |
+
imglens: list[int],
|
| 517 |
+
vidlens: list[int],
|
| 518 |
+
audlens: list[int],
|
| 519 |
+
batch_ids: list[list[int]],
|
| 520 |
+
processor: Optional["MMProcessor"],
|
| 521 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 522 |
+
self._validate_input(processor, images, videos, audios)
|
| 523 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 524 |
+
mm_inputs.pop("num_crops", None)
|
| 525 |
+
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
|
| 526 |
+
return mm_inputs
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class Gemma3nPlugin(Gemma3Plugin):
|
| 530 |
+
@override
|
| 531 |
+
def process_messages(
|
| 532 |
+
self,
|
| 533 |
+
messages: list[dict[str, str]],
|
| 534 |
+
images: list["ImageInput"],
|
| 535 |
+
videos: list["VideoInput"],
|
| 536 |
+
audios: list["AudioInput"],
|
| 537 |
+
processor: Optional["MMProcessor"],
|
| 538 |
+
) -> list[dict[str, str]]:
|
| 539 |
+
self._validate_input(processor, images, videos, audios)
|
| 540 |
+
self._validate_messages(messages, images, videos, audios)
|
| 541 |
+
messages = deepcopy(messages)
|
| 542 |
+
boi_token: str = getattr(processor, "boi_token")
|
| 543 |
+
boa_token: str = getattr(processor, "boa_token")
|
| 544 |
+
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
| 545 |
+
full_audio_sequence: str = getattr(processor, "full_audio_sequence")
|
| 546 |
+
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
| 547 |
+
audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token
|
| 548 |
+
|
| 549 |
+
for message in messages:
|
| 550 |
+
content = message["content"]
|
| 551 |
+
while IMAGE_PLACEHOLDER in content:
|
| 552 |
+
content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
|
| 553 |
+
|
| 554 |
+
while AUDIO_PLACEHOLDER in content:
|
| 555 |
+
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
| 556 |
+
|
| 557 |
+
message["content"] = content
|
| 558 |
+
|
| 559 |
+
return messages
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
@dataclass
|
| 563 |
+
class InternVLPlugin(BasePlugin):
|
| 564 |
+
@override
|
| 565 |
+
def _get_mm_inputs(
|
| 566 |
+
self,
|
| 567 |
+
images: list["ImageInput"],
|
| 568 |
+
videos: list["VideoInput"],
|
| 569 |
+
audios: list["AudioInput"],
|
| 570 |
+
processor: "ProcessorMixin",
|
| 571 |
+
**kwargs,
|
| 572 |
+
) -> dict[str, "torch.Tensor"]:
|
| 573 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 574 |
+
image_processor_kwargs = {}
|
| 575 |
+
if getattr(processor, "crop_to_patches", False):
|
| 576 |
+
image_processor_kwargs.update(
|
| 577 |
+
{
|
| 578 |
+
"crop_to_patches": True,
|
| 579 |
+
"max_patches": 12,
|
| 580 |
+
"min_patches": 1,
|
| 581 |
+
}
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
mm_inputs = {}
|
| 585 |
+
image_video_patches = []
|
| 586 |
+
|
| 587 |
+
if len(images) != 0:
|
| 588 |
+
images = self._regularize_images(
|
| 589 |
+
images,
|
| 590 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
|
| 591 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 592 |
+
)["images"]
|
| 593 |
+
|
| 594 |
+
if len(videos) != 0:
|
| 595 |
+
videos = self._regularize_videos(
|
| 596 |
+
videos,
|
| 597 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 598 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 599 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 600 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 601 |
+
)["videos"]
|
| 602 |
+
|
| 603 |
+
if len(images) != 0:
|
| 604 |
+
images = make_flat_list_of_images(images)
|
| 605 |
+
image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
|
| 606 |
+
image_num_patches = image_inputs.pop("num_patches")
|
| 607 |
+
image_pixel_values = image_inputs.pop("pixel_values")
|
| 608 |
+
image_num_patches_indices = np.cumsum(image_num_patches)
|
| 609 |
+
|
| 610 |
+
if len(videos) != 0:
|
| 611 |
+
videos = make_batched_videos(videos)
|
| 612 |
+
num_frames_per_video = [len(video) for video in videos]
|
| 613 |
+
patch_indices = np.cumsum(num_frames_per_video)
|
| 614 |
+
image_processor_kwargs["crop_to_patches"] = False
|
| 615 |
+
video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
|
| 616 |
+
video_num_patches = video_inputs.pop("num_patches")
|
| 617 |
+
video_pixel_values = video_inputs.pop("pixel_values")
|
| 618 |
+
video_num_patches_indices = np.cumsum(video_num_patches)
|
| 619 |
+
|
| 620 |
+
# NOT SUPPORT IMAGE VIDEO INTERLEAVED
|
| 621 |
+
if len(images) != 0 and image_pixel_values is not None:
|
| 622 |
+
for i in range(len(images)):
|
| 623 |
+
start_index = image_num_patches_indices[i - 1] if i > 0 else 0
|
| 624 |
+
end_index = image_num_patches_indices[i]
|
| 625 |
+
image_video_patches.append(image_pixel_values[start_index:end_index])
|
| 626 |
+
|
| 627 |
+
if len(videos) != 0 and video_pixel_values is not None:
|
| 628 |
+
patch_indices_with_prefix = [0] + list(patch_indices)
|
| 629 |
+
for i in range(len(videos)):
|
| 630 |
+
current_patch_index = patch_indices_with_prefix[i]
|
| 631 |
+
end_patch_index = patch_indices_with_prefix[i + 1]
|
| 632 |
+
start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
|
| 633 |
+
end_index = video_num_patches_indices[end_patch_index - 1]
|
| 634 |
+
image_video_patches.append(video_pixel_values[start_index:end_index])
|
| 635 |
+
|
| 636 |
+
if len(images) != 0 or len(videos) != 0:
|
| 637 |
+
mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
|
| 638 |
+
|
| 639 |
+
if len(images) != 0:
|
| 640 |
+
mm_inputs.update({"image_num_patches": image_num_patches})
|
| 641 |
+
|
| 642 |
+
if len(videos) != 0:
|
| 643 |
+
mm_inputs.update({"video_patch_indices": patch_indices})
|
| 644 |
+
mm_inputs.update({"video_num_patches": video_num_patches})
|
| 645 |
+
|
| 646 |
+
return mm_inputs
|
| 647 |
+
|
| 648 |
+
@override
|
| 649 |
+
def process_messages(
|
| 650 |
+
self,
|
| 651 |
+
messages: list[dict[str, str]],
|
| 652 |
+
images: list["ImageInput"],
|
| 653 |
+
videos: list["VideoInput"],
|
| 654 |
+
audios: list["AudioInput"],
|
| 655 |
+
processor: Optional["ProcessorMixin"],
|
| 656 |
+
) -> list[dict[str, str]]:
|
| 657 |
+
self._validate_input(processor, images, videos, audios)
|
| 658 |
+
self._validate_messages(messages, images, videos, audios)
|
| 659 |
+
num_image_tokens, num_video_tokens = 0, 0
|
| 660 |
+
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
| 661 |
+
messages = deepcopy(messages)
|
| 662 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 663 |
+
|
| 664 |
+
image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
|
| 665 |
+
video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
|
| 666 |
+
video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
|
| 667 |
+
|
| 668 |
+
for message in messages:
|
| 669 |
+
content = message["content"]
|
| 670 |
+
while IMAGE_PLACEHOLDER in content:
|
| 671 |
+
content = content.replace(
|
| 672 |
+
IMAGE_PLACEHOLDER,
|
| 673 |
+
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
| 674 |
+
1,
|
| 675 |
+
)
|
| 676 |
+
num_image_tokens += 1
|
| 677 |
+
|
| 678 |
+
while VIDEO_PLACEHOLDER in content:
|
| 679 |
+
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
| 680 |
+
end_patch_index = video_patch_indices[num_video_tokens]
|
| 681 |
+
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
| 682 |
+
video_replaced_prompt = "\n".join(
|
| 683 |
+
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
| 684 |
+
for i in range(len(num_patches))
|
| 685 |
+
)
|
| 686 |
+
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
| 687 |
+
num_video_tokens += 1
|
| 688 |
+
|
| 689 |
+
message["content"] = content
|
| 690 |
+
|
| 691 |
+
return messages
|
| 692 |
+
|
| 693 |
+
@override
|
| 694 |
+
def get_mm_inputs(
|
| 695 |
+
self,
|
| 696 |
+
images: list["ImageInput"],
|
| 697 |
+
videos: list["VideoInput"],
|
| 698 |
+
audios: list["AudioInput"],
|
| 699 |
+
imglens: list[int],
|
| 700 |
+
vidlens: list[int],
|
| 701 |
+
audlens: list[int],
|
| 702 |
+
batch_ids: list[list[int]],
|
| 703 |
+
processor: Optional["ProcessorMixin"],
|
| 704 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 705 |
+
self._validate_input(processor, images, videos, audios)
|
| 706 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 707 |
+
mm_inputs.pop("image_num_patches", None)
|
| 708 |
+
mm_inputs.pop("video_patch_indices", None)
|
| 709 |
+
mm_inputs.pop("video_num_patches", None)
|
| 710 |
+
return mm_inputs
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
class KimiVLPlugin(BasePlugin):
|
| 714 |
+
@override
|
| 715 |
+
def process_messages(self, messages, images, videos, audios, processor):
|
| 716 |
+
self._validate_input(processor, images, videos, audios)
|
| 717 |
+
self._validate_messages(messages, images, videos, audios)
|
| 718 |
+
if self.expand_mm_tokens:
|
| 719 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 720 |
+
image_grid_hws = mm_inputs.get("image_grid_hws", [])
|
| 721 |
+
else:
|
| 722 |
+
image_grid_hws = [None] * len(images)
|
| 723 |
+
|
| 724 |
+
num_image_tokens = 0
|
| 725 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 726 |
+
merge_length = math.prod(image_processor.merge_kernel_size)
|
| 727 |
+
messages = deepcopy(messages)
|
| 728 |
+
for message in messages:
|
| 729 |
+
content = message["content"]
|
| 730 |
+
while IMAGE_PLACEHOLDER in content:
|
| 731 |
+
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 732 |
+
content = content.replace(
|
| 733 |
+
IMAGE_PLACEHOLDER,
|
| 734 |
+
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
|
| 735 |
+
1,
|
| 736 |
+
)
|
| 737 |
+
num_image_tokens += 1
|
| 738 |
+
|
| 739 |
+
message["content"] = content
|
| 740 |
+
|
| 741 |
+
return messages
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
@dataclass
|
| 745 |
+
class Llama4Plugin(BasePlugin):
|
| 746 |
+
@override
|
| 747 |
+
def process_messages(
|
| 748 |
+
self,
|
| 749 |
+
messages: list[dict[str, str]],
|
| 750 |
+
images: list["ImageInput"],
|
| 751 |
+
videos: list["VideoInput"],
|
| 752 |
+
audios: list["AudioInput"],
|
| 753 |
+
processor: Optional["MMProcessor"],
|
| 754 |
+
) -> list[dict[str, str]]:
|
| 755 |
+
self._validate_input(processor, images, videos, audios)
|
| 756 |
+
self._validate_messages(messages, images, videos, audios)
|
| 757 |
+
if self.expand_mm_tokens:
|
| 758 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 759 |
+
if "pixel_values" in mm_inputs:
|
| 760 |
+
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
| 761 |
+
num_patches_per_chunk = int(
|
| 762 |
+
(image_height // processor.patch_size)
|
| 763 |
+
* (image_width // processor.patch_size)
|
| 764 |
+
// processor.downsample_ratio
|
| 765 |
+
)
|
| 766 |
+
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
| 767 |
+
|
| 768 |
+
num_image_tokens = 0
|
| 769 |
+
messages = deepcopy(messages)
|
| 770 |
+
for message in messages:
|
| 771 |
+
content = message["content"]
|
| 772 |
+
if self.expand_mm_tokens:
|
| 773 |
+
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
| 774 |
+
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
| 775 |
+
new_content = []
|
| 776 |
+
for local_image_index, split_part in enumerate(prompt_splits):
|
| 777 |
+
new_content.append(split_part)
|
| 778 |
+
if local_image_index < placeholder_count:
|
| 779 |
+
tokens_for_this_image = processor._prompt_split_image(
|
| 780 |
+
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
| 781 |
+
)
|
| 782 |
+
num_image_tokens += 1
|
| 783 |
+
new_content.append(tokens_for_this_image)
|
| 784 |
+
|
| 785 |
+
content = "".join(new_content)
|
| 786 |
+
else:
|
| 787 |
+
content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
| 788 |
+
|
| 789 |
+
message["content"] = content
|
| 790 |
+
|
| 791 |
+
return messages
|
| 792 |
+
|
| 793 |
+
@override
|
| 794 |
+
def get_mm_inputs(
|
| 795 |
+
self,
|
| 796 |
+
images: list["ImageInput"],
|
| 797 |
+
videos: list["VideoInput"],
|
| 798 |
+
audios: list["AudioInput"],
|
| 799 |
+
imglens: list[int],
|
| 800 |
+
vidlens: list[int],
|
| 801 |
+
audlens: list[int],
|
| 802 |
+
batch_ids: list[list[int]],
|
| 803 |
+
processor: Optional["MMProcessor"],
|
| 804 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 805 |
+
self._validate_input(processor, images, videos, audios)
|
| 806 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 807 |
+
mm_inputs.pop("aspect_ratios", None)
|
| 808 |
+
return mm_inputs
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
@dataclass
|
| 812 |
+
class LlavaPlugin(BasePlugin):
|
| 813 |
+
@override
|
| 814 |
+
def process_messages(
|
| 815 |
+
self,
|
| 816 |
+
messages: list[dict[str, str]],
|
| 817 |
+
images: list["ImageInput"],
|
| 818 |
+
videos: list["VideoInput"],
|
| 819 |
+
audios: list["AudioInput"],
|
| 820 |
+
processor: Optional["MMProcessor"],
|
| 821 |
+
) -> list[dict[str, str]]:
|
| 822 |
+
self._validate_input(processor, images, videos, audios)
|
| 823 |
+
self._validate_messages(messages, images, videos, audios)
|
| 824 |
+
messages = deepcopy(messages)
|
| 825 |
+
if self.expand_mm_tokens:
|
| 826 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 827 |
+
if "pixel_values" in mm_inputs:
|
| 828 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
|
| 829 |
+
image_seqlen = (height // processor.patch_size) * (
|
| 830 |
+
width // processor.patch_size
|
| 831 |
+
) + processor.num_additional_image_tokens
|
| 832 |
+
if processor.vision_feature_select_strategy == "default":
|
| 833 |
+
image_seqlen -= 1
|
| 834 |
+
else:
|
| 835 |
+
image_seqlen = 1
|
| 836 |
+
|
| 837 |
+
for message in messages:
|
| 838 |
+
content = message["content"]
|
| 839 |
+
while IMAGE_PLACEHOLDER in content:
|
| 840 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
| 841 |
+
|
| 842 |
+
message["content"] = content.replace("{{image}}", self.image_token)
|
| 843 |
+
|
| 844 |
+
return messages
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
@dataclass
|
| 848 |
+
class LlavaNextPlugin(BasePlugin):
|
| 849 |
+
@override
|
| 850 |
+
def process_messages(
|
| 851 |
+
self,
|
| 852 |
+
messages: list[dict[str, str]],
|
| 853 |
+
images: list["ImageInput"],
|
| 854 |
+
videos: list["VideoInput"],
|
| 855 |
+
audios: list["AudioInput"],
|
| 856 |
+
processor: Optional["MMProcessor"],
|
| 857 |
+
) -> list[dict[str, str]]:
|
| 858 |
+
self._validate_input(processor, images, videos, audios)
|
| 859 |
+
self._validate_messages(messages, images, videos, audios)
|
| 860 |
+
num_image_tokens = 0
|
| 861 |
+
messages = deepcopy(messages)
|
| 862 |
+
if self.expand_mm_tokens:
|
| 863 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 864 |
+
if "pixel_values" in mm_inputs:
|
| 865 |
+
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
| 866 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
| 867 |
+
|
| 868 |
+
for message in messages:
|
| 869 |
+
content = message["content"]
|
| 870 |
+
while IMAGE_PLACEHOLDER in content:
|
| 871 |
+
if self.expand_mm_tokens:
|
| 872 |
+
orig_height, orig_width = next(image_sizes)
|
| 873 |
+
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
| 874 |
+
if processor.vision_feature_select_strategy == "default":
|
| 875 |
+
image_seqlen -= 1
|
| 876 |
+
else:
|
| 877 |
+
image_seqlen = 1
|
| 878 |
+
|
| 879 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
| 880 |
+
num_image_tokens += 1
|
| 881 |
+
|
| 882 |
+
message["content"] = content.replace("{{image}}", self.image_token)
|
| 883 |
+
|
| 884 |
+
return messages
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
@dataclass
|
| 888 |
+
class LlavaNextVideoPlugin(BasePlugin):
|
| 889 |
+
@override
|
| 890 |
+
def process_messages(
|
| 891 |
+
self,
|
| 892 |
+
messages: list[dict[str, str]],
|
| 893 |
+
images: list["ImageInput"],
|
| 894 |
+
videos: list["VideoInput"],
|
| 895 |
+
audios: list["AudioInput"],
|
| 896 |
+
processor: Optional["MMProcessor"],
|
| 897 |
+
) -> list[dict[str, str]]:
|
| 898 |
+
self._validate_input(processor, images, videos, audios)
|
| 899 |
+
self._validate_messages(messages, images, videos, audios)
|
| 900 |
+
messages = deepcopy(messages)
|
| 901 |
+
if self.expand_mm_tokens:
|
| 902 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 903 |
+
if "pixel_values" in mm_inputs:
|
| 904 |
+
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
| 905 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
| 906 |
+
|
| 907 |
+
for message in messages:
|
| 908 |
+
content = message["content"]
|
| 909 |
+
while IMAGE_PLACEHOLDER in content:
|
| 910 |
+
if self.expand_mm_tokens:
|
| 911 |
+
orig_height, orig_width = next(image_sizes)
|
| 912 |
+
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
| 913 |
+
if processor.vision_feature_select_strategy == "default":
|
| 914 |
+
image_seqlen -= 1
|
| 915 |
+
else:
|
| 916 |
+
image_seqlen = 1
|
| 917 |
+
|
| 918 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
| 919 |
+
|
| 920 |
+
message["content"] = content.replace("{{image}}", self.image_token)
|
| 921 |
+
|
| 922 |
+
if self.expand_mm_tokens:
|
| 923 |
+
if "pixel_values_videos" in mm_inputs:
|
| 924 |
+
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
| 925 |
+
height, width = get_image_size(one_video[0])
|
| 926 |
+
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
| 927 |
+
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
| 928 |
+
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
| 929 |
+
else:
|
| 930 |
+
video_seqlen = 1
|
| 931 |
+
|
| 932 |
+
for message in messages:
|
| 933 |
+
content = message["content"]
|
| 934 |
+
while VIDEO_PLACEHOLDER in content:
|
| 935 |
+
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
| 936 |
+
|
| 937 |
+
message["content"] = content.replace("{{video}}", self.video_token)
|
| 938 |
+
|
| 939 |
+
return messages
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
@dataclass
|
| 943 |
+
class MiniCPMVPlugin(BasePlugin):
|
| 944 |
+
@override
|
| 945 |
+
def _get_mm_inputs(
|
| 946 |
+
self,
|
| 947 |
+
images: list["ImageInput"],
|
| 948 |
+
videos: list["VideoInput"],
|
| 949 |
+
audios: list["AudioInput"],
|
| 950 |
+
processor: "MMProcessor",
|
| 951 |
+
**kwargs,
|
| 952 |
+
) -> dict[str, "torch.Tensor"]:
|
| 953 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 954 |
+
mm_inputs = {}
|
| 955 |
+
if len(images) != 0:
|
| 956 |
+
images = self._regularize_images(
|
| 957 |
+
images,
|
| 958 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 959 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 960 |
+
)["images"]
|
| 961 |
+
if "valid_image_nums_ls" in kwargs:
|
| 962 |
+
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
| 963 |
+
new_images = []
|
| 964 |
+
idx = 0
|
| 965 |
+
for valid_image_nums in valid_image_nums_ls:
|
| 966 |
+
new_images.append(images[idx : idx + valid_image_nums])
|
| 967 |
+
idx += valid_image_nums
|
| 968 |
+
|
| 969 |
+
images = new_images
|
| 970 |
+
|
| 971 |
+
image_inputs = image_processor(
|
| 972 |
+
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
| 973 |
+
)
|
| 974 |
+
mm_inputs.update(image_inputs)
|
| 975 |
+
|
| 976 |
+
if len(videos) != 0:
|
| 977 |
+
videos = self._regularize_videos(
|
| 978 |
+
videos,
|
| 979 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 980 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 981 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 982 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 983 |
+
)["videos"]
|
| 984 |
+
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
| 985 |
+
mm_inputs.update(video_inputs)
|
| 986 |
+
|
| 987 |
+
if len(audios) != 0:
|
| 988 |
+
audios = self._regularize_audios(
|
| 989 |
+
audios,
|
| 990 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 991 |
+
)["audios"]
|
| 992 |
+
if "valid_audio_nums_ls" in kwargs:
|
| 993 |
+
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
|
| 994 |
+
audios_ls = []
|
| 995 |
+
idx = 0
|
| 996 |
+
for valid_audio_nums in valid_audio_nums_ls:
|
| 997 |
+
audios_ls.append(audios[idx : idx + valid_audio_nums])
|
| 998 |
+
idx += valid_audio_nums
|
| 999 |
+
else:
|
| 1000 |
+
audios_ls = [audios]
|
| 1001 |
+
|
| 1002 |
+
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
| 1003 |
+
audios_ls,
|
| 1004 |
+
chunk_input=True,
|
| 1005 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 1006 |
+
)
|
| 1007 |
+
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
| 1008 |
+
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
| 1009 |
+
if kwargs.get("ret_phs", False):
|
| 1010 |
+
mm_inputs.update({"audio_phs": audio_phs})
|
| 1011 |
+
|
| 1012 |
+
return mm_inputs
|
| 1013 |
+
|
| 1014 |
+
@override
|
| 1015 |
+
def process_messages(
|
| 1016 |
+
self,
|
| 1017 |
+
messages: list[dict[str, str]],
|
| 1018 |
+
images: list["ImageInput"],
|
| 1019 |
+
videos: list["VideoInput"],
|
| 1020 |
+
audios: list["AudioInput"],
|
| 1021 |
+
processor: Optional["MMProcessor"],
|
| 1022 |
+
) -> list[dict[str, str]]:
|
| 1023 |
+
self._validate_input(processor, images, videos, audios)
|
| 1024 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1025 |
+
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
| 1026 |
+
messages = deepcopy(messages)
|
| 1027 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 1028 |
+
mm_inputs, audio_inputs = {}, {}
|
| 1029 |
+
if len(images) != 0 and len(videos) != 0:
|
| 1030 |
+
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
| 1031 |
+
|
| 1032 |
+
if len(videos) != 0:
|
| 1033 |
+
max_slice_nums = 2
|
| 1034 |
+
use_image_id = False
|
| 1035 |
+
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
| 1036 |
+
else:
|
| 1037 |
+
max_slice_nums = image_processor.max_slice_nums
|
| 1038 |
+
use_image_id = image_processor.use_image_id
|
| 1039 |
+
|
| 1040 |
+
for i, message in enumerate(messages):
|
| 1041 |
+
content = message["content"]
|
| 1042 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1043 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
| 1044 |
+
num_image_tokens += 1
|
| 1045 |
+
|
| 1046 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1047 |
+
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
| 1048 |
+
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
| 1049 |
+
num_video_tokens += 1
|
| 1050 |
+
|
| 1051 |
+
while AUDIO_PLACEHOLDER in content:
|
| 1052 |
+
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
| 1053 |
+
num_audio_tokens += 1
|
| 1054 |
+
|
| 1055 |
+
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
| 1056 |
+
"{{audio}}", "(<audio>./</audio>)"
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
if len(images):
|
| 1060 |
+
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
| 1061 |
+
|
| 1062 |
+
if len(audios):
|
| 1063 |
+
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
| 1064 |
+
|
| 1065 |
+
if self.expand_mm_tokens and mm_inputs:
|
| 1066 |
+
pattern = "(<image>./</image>)"
|
| 1067 |
+
image_sizes = mm_inputs["image_sizes"]
|
| 1068 |
+
idx = 0
|
| 1069 |
+
for index, message in enumerate(messages):
|
| 1070 |
+
text = message["content"]
|
| 1071 |
+
image_tags = re.findall(pattern, text)
|
| 1072 |
+
text_chunks = text.split(pattern)
|
| 1073 |
+
final_text = ""
|
| 1074 |
+
for i in range(len(image_tags)):
|
| 1075 |
+
final_text = (
|
| 1076 |
+
final_text
|
| 1077 |
+
+ text_chunks[i]
|
| 1078 |
+
+ image_processor.get_slice_image_placeholder(
|
| 1079 |
+
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
| 1080 |
+
)
|
| 1081 |
+
)
|
| 1082 |
+
idx += 1
|
| 1083 |
+
|
| 1084 |
+
final_text += text_chunks[-1]
|
| 1085 |
+
messages[index]["content"] = final_text
|
| 1086 |
+
|
| 1087 |
+
if self.expand_mm_tokens and audio_inputs:
|
| 1088 |
+
pattern = "(<audio>./</audio>)"
|
| 1089 |
+
idx = 0
|
| 1090 |
+
for index, message in enumerate(messages):
|
| 1091 |
+
text = message["content"]
|
| 1092 |
+
audio_tags = re.findall(pattern, text)
|
| 1093 |
+
text_chunks = text.split(pattern)
|
| 1094 |
+
final_text = ""
|
| 1095 |
+
for i in range(len(audio_tags)):
|
| 1096 |
+
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
| 1097 |
+
final_text = final_text + text_chunks[i] + audio_placeholder
|
| 1098 |
+
idx += 1
|
| 1099 |
+
|
| 1100 |
+
final_text += text_chunks[-1]
|
| 1101 |
+
messages[index]["content"] = final_text
|
| 1102 |
+
|
| 1103 |
+
return messages
|
| 1104 |
+
|
| 1105 |
+
@override
|
| 1106 |
+
def get_mm_inputs(
|
| 1107 |
+
self,
|
| 1108 |
+
images: list["ImageInput"],
|
| 1109 |
+
videos: list["VideoInput"],
|
| 1110 |
+
audios: list["AudioInput"],
|
| 1111 |
+
imglens: list[int],
|
| 1112 |
+
vidlens: list[int],
|
| 1113 |
+
audlens: list[int],
|
| 1114 |
+
batch_ids: list[list[int]],
|
| 1115 |
+
processor: Optional["MMProcessor"],
|
| 1116 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1117 |
+
self._validate_input(processor, images, videos, audios)
|
| 1118 |
+
# image bound
|
| 1119 |
+
image_bounds_list = []
|
| 1120 |
+
valid_image_nums_ls = []
|
| 1121 |
+
for i, input_ids in enumerate(batch_ids):
|
| 1122 |
+
input_ids_ = torch.tensor(input_ids)
|
| 1123 |
+
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
| 1124 |
+
input_ids_ == processor.tokenizer.slice_start_id
|
| 1125 |
+
)
|
| 1126 |
+
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
| 1127 |
+
image_start_tokens = torch.where(start_cond)[0]
|
| 1128 |
+
image_start_tokens += 1
|
| 1129 |
+
image_end_tokens = torch.where(end_cond)[0]
|
| 1130 |
+
valid_image_nums_ls.append(imglens[i])
|
| 1131 |
+
image_bounds = torch.hstack(
|
| 1132 |
+
[
|
| 1133 |
+
image_start_tokens.unsqueeze(-1),
|
| 1134 |
+
image_end_tokens.unsqueeze(-1),
|
| 1135 |
+
]
|
| 1136 |
+
)
|
| 1137 |
+
image_bounds_list.append(image_bounds)
|
| 1138 |
+
|
| 1139 |
+
mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
|
| 1140 |
+
if "tgt_sizes" not in mm_inputs:
|
| 1141 |
+
dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
|
| 1142 |
+
mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
|
| 1143 |
+
|
| 1144 |
+
mm_inputs.update({"image_bound": image_bounds_list})
|
| 1145 |
+
|
| 1146 |
+
if len(audios) > 0:
|
| 1147 |
+
# audio bound
|
| 1148 |
+
audio_bounds_ls = []
|
| 1149 |
+
spk_bounds_ls = []
|
| 1150 |
+
valid_audio_nums_ls = []
|
| 1151 |
+
|
| 1152 |
+
for input_ids, audiolen in zip(batch_ids, audlens):
|
| 1153 |
+
input_ids_ = torch.tensor(input_ids)
|
| 1154 |
+
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
|
| 1155 |
+
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
|
| 1156 |
+
assert len(audio_start_idx) == len(audio_end_idx)
|
| 1157 |
+
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
| 1158 |
+
audio_bounds_ls.append(audio_bounds)
|
| 1159 |
+
valid_audio_nums_ls.append(audiolen)
|
| 1160 |
+
|
| 1161 |
+
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
| 1162 |
+
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
| 1163 |
+
assert len(spk_start_idx) == len(spk_end_idx)
|
| 1164 |
+
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
| 1165 |
+
spk_bounds_ls.append(spk_bounds)
|
| 1166 |
+
|
| 1167 |
+
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
|
| 1168 |
+
mm_inputs.update(audio_inputs)
|
| 1169 |
+
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
|
| 1170 |
+
|
| 1171 |
+
return mm_inputs
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
@dataclass
|
| 1175 |
+
class MllamaPlugin(BasePlugin):
|
| 1176 |
+
@override
|
| 1177 |
+
def process_messages(
|
| 1178 |
+
self,
|
| 1179 |
+
messages: list[dict[str, str]],
|
| 1180 |
+
images: list["ImageInput"],
|
| 1181 |
+
videos: list["VideoInput"],
|
| 1182 |
+
audios: list["AudioInput"],
|
| 1183 |
+
processor: Optional["MMProcessor"],
|
| 1184 |
+
) -> list[dict[str, str]]:
|
| 1185 |
+
self._validate_input(processor, images, videos, audios)
|
| 1186 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1187 |
+
num_image_tokens = 0
|
| 1188 |
+
messages = deepcopy(messages)
|
| 1189 |
+
for message in messages:
|
| 1190 |
+
content = message["content"]
|
| 1191 |
+
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
| 1192 |
+
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
| 1193 |
+
|
| 1194 |
+
return messages
|
| 1195 |
+
|
| 1196 |
+
@override
|
| 1197 |
+
def get_mm_inputs(
|
| 1198 |
+
self,
|
| 1199 |
+
images: list["ImageInput"],
|
| 1200 |
+
videos: list["VideoInput"],
|
| 1201 |
+
audios: list["AudioInput"],
|
| 1202 |
+
imglens: list[int],
|
| 1203 |
+
vidlens: list[int],
|
| 1204 |
+
audlens: list[int],
|
| 1205 |
+
batch_ids: list[list[int]],
|
| 1206 |
+
processor: Optional["MMProcessor"],
|
| 1207 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1208 |
+
self._validate_input(processor, images, videos, audios)
|
| 1209 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
| 1210 |
+
if mm_inputs:
|
| 1211 |
+
num_tiles = mm_inputs.pop("num_tiles")
|
| 1212 |
+
image_token_id: int = getattr(processor, "image_token_id")
|
| 1213 |
+
max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
|
| 1214 |
+
cross_attention_token_mask = [
|
| 1215 |
+
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
| 1216 |
+
]
|
| 1217 |
+
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
| 1218 |
+
convert_sparse_cross_attention_mask_to_dense(
|
| 1219 |
+
cross_attention_token_mask,
|
| 1220 |
+
num_tiles=num_tiles,
|
| 1221 |
+
max_num_tiles=max_image_tiles,
|
| 1222 |
+
length=max(len(input_ids) for input_ids in batch_ids),
|
| 1223 |
+
)
|
| 1224 |
+
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
| 1225 |
+
|
| 1226 |
+
return mm_inputs
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
@dataclass
|
| 1230 |
+
class PaliGemmaPlugin(BasePlugin):
|
| 1231 |
+
@override
|
| 1232 |
+
def process_messages(
|
| 1233 |
+
self,
|
| 1234 |
+
messages: list[dict[str, str]],
|
| 1235 |
+
images: list["ImageInput"],
|
| 1236 |
+
videos: list["VideoInput"],
|
| 1237 |
+
audios: list["AudioInput"],
|
| 1238 |
+
processor: Optional["MMProcessor"],
|
| 1239 |
+
) -> list[dict[str, str]]:
|
| 1240 |
+
self._validate_input(processor, images, videos, audios)
|
| 1241 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1242 |
+
num_image_tokens = 0
|
| 1243 |
+
messages = deepcopy(messages)
|
| 1244 |
+
for message in messages:
|
| 1245 |
+
content = message["content"]
|
| 1246 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1247 |
+
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
| 1248 |
+
num_image_tokens += 1
|
| 1249 |
+
|
| 1250 |
+
message["content"] = content
|
| 1251 |
+
|
| 1252 |
+
return messages
|
| 1253 |
+
|
| 1254 |
+
@override
|
| 1255 |
+
def process_token_ids(
|
| 1256 |
+
self,
|
| 1257 |
+
input_ids: list[int],
|
| 1258 |
+
labels: Optional[list[int]],
|
| 1259 |
+
images: list["ImageInput"],
|
| 1260 |
+
videos: list["VideoInput"],
|
| 1261 |
+
audios: list["AudioInput"],
|
| 1262 |
+
tokenizer: "PreTrainedTokenizer",
|
| 1263 |
+
processor: Optional["MMProcessor"],
|
| 1264 |
+
) -> tuple[list[int], Optional[list[int]]]:
|
| 1265 |
+
self._validate_input(processor, images, videos, audios)
|
| 1266 |
+
num_images = len(images)
|
| 1267 |
+
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
| 1268 |
+
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
| 1269 |
+
input_ids = [image_token_id] * num_images * image_seqlen + input_ids
|
| 1270 |
+
if labels is not None:
|
| 1271 |
+
labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
|
| 1272 |
+
|
| 1273 |
+
return input_ids, labels
|
| 1274 |
+
|
| 1275 |
+
@override
|
| 1276 |
+
def get_mm_inputs(
|
| 1277 |
+
self,
|
| 1278 |
+
images: list["ImageInput"],
|
| 1279 |
+
videos: list["VideoInput"],
|
| 1280 |
+
audios: list["AudioInput"],
|
| 1281 |
+
imglens: list[int],
|
| 1282 |
+
vidlens: list[int],
|
| 1283 |
+
audlens: list[int],
|
| 1284 |
+
batch_ids: list[list[int]],
|
| 1285 |
+
processor: Optional["MMProcessor"],
|
| 1286 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1287 |
+
self._validate_input(processor, images, videos, audios)
|
| 1288 |
+
seqlens = [len(input_ids) for input_ids in batch_ids]
|
| 1289 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1290 |
+
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
| 1291 |
+
return mm_inputs
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
@dataclass
|
| 1295 |
+
class PixtralPlugin(BasePlugin):
|
| 1296 |
+
@override
|
| 1297 |
+
def process_messages(
|
| 1298 |
+
self,
|
| 1299 |
+
messages: list[dict[str, str]],
|
| 1300 |
+
images: list["ImageInput"],
|
| 1301 |
+
videos: list["VideoInput"],
|
| 1302 |
+
audios: list["AudioInput"],
|
| 1303 |
+
processor: Optional["MMProcessor"],
|
| 1304 |
+
) -> list[dict[str, str]]:
|
| 1305 |
+
self._validate_input(processor, images, videos, audios)
|
| 1306 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1307 |
+
messages = deepcopy(messages)
|
| 1308 |
+
if self.expand_mm_tokens:
|
| 1309 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1310 |
+
if "pixel_values" in mm_inputs:
|
| 1311 |
+
# BC for transformers < 4.49.0
|
| 1312 |
+
if isinstance(mm_inputs["image_sizes"], list):
|
| 1313 |
+
image_sizes = iter(mm_inputs["image_sizes"][0])
|
| 1314 |
+
else:
|
| 1315 |
+
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
| 1316 |
+
|
| 1317 |
+
image_break_token: str = getattr(processor, "image_break_token")
|
| 1318 |
+
image_end_token: str = getattr(processor, "image_end_token")
|
| 1319 |
+
|
| 1320 |
+
for message in messages:
|
| 1321 |
+
content = message["content"]
|
| 1322 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1323 |
+
if self.expand_mm_tokens:
|
| 1324 |
+
patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1)
|
| 1325 |
+
height, width = next(image_sizes)
|
| 1326 |
+
num_height_tokens = height // patch_size
|
| 1327 |
+
num_width_tokens = width // patch_size
|
| 1328 |
+
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
| 1329 |
+
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
| 1330 |
+
replace_tokens[-1] = image_end_token
|
| 1331 |
+
replace_str = "".join(replace_tokens)
|
| 1332 |
+
else:
|
| 1333 |
+
replace_str = self.image_token
|
| 1334 |
+
|
| 1335 |
+
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
| 1336 |
+
|
| 1337 |
+
message["content"] = content
|
| 1338 |
+
|
| 1339 |
+
return messages
|
| 1340 |
+
|
| 1341 |
+
@override
|
| 1342 |
+
def get_mm_inputs(
|
| 1343 |
+
self,
|
| 1344 |
+
images: list["ImageInput"],
|
| 1345 |
+
videos: list["VideoInput"],
|
| 1346 |
+
audios: list["AudioInput"],
|
| 1347 |
+
imglens: list[int],
|
| 1348 |
+
vidlens: list[int],
|
| 1349 |
+
audlens: list[int],
|
| 1350 |
+
batch_ids: list[list[int]],
|
| 1351 |
+
processor: Optional["MMProcessor"],
|
| 1352 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1353 |
+
self._validate_input(processor, images, videos, audios)
|
| 1354 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1355 |
+
# ref to this commit https://github.com/huggingface/transformers/pull/35122
|
| 1356 |
+
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
|
| 1357 |
+
# it can be passed into `LlavaConditionalGeneration` as a parameter.
|
| 1358 |
+
if not is_transformers_version_greater_than("4.49.0"):
|
| 1359 |
+
mm_inputs.pop("image_sizes", None)
|
| 1360 |
+
return mm_inputs
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
@dataclass
|
| 1364 |
+
class Qwen2AudioPlugin(BasePlugin):
|
| 1365 |
+
@override
|
| 1366 |
+
def process_messages(
|
| 1367 |
+
self,
|
| 1368 |
+
messages: list[dict[str, str]],
|
| 1369 |
+
images: list["ImageInput"],
|
| 1370 |
+
videos: list["VideoInput"],
|
| 1371 |
+
audios: list["AudioInput"],
|
| 1372 |
+
processor: Optional["MMProcessor"],
|
| 1373 |
+
) -> list[dict[str, str]]:
|
| 1374 |
+
self._validate_input(processor, images, videos, audios)
|
| 1375 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1376 |
+
bos_token: str = getattr(processor, "audio_bos_token")
|
| 1377 |
+
eos_token: str = getattr(processor, "audio_eos_token")
|
| 1378 |
+
messages = deepcopy(messages)
|
| 1379 |
+
if self.expand_mm_tokens:
|
| 1380 |
+
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
| 1381 |
+
if "feature_attention_mask" in mm_inputs:
|
| 1382 |
+
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
| 1383 |
+
|
| 1384 |
+
for message in messages:
|
| 1385 |
+
content = message["content"]
|
| 1386 |
+
while AUDIO_PLACEHOLDER in content:
|
| 1387 |
+
if self.expand_mm_tokens:
|
| 1388 |
+
audio_length = audio_lengths.pop(0)
|
| 1389 |
+
input_length = (audio_length - 1) // 2 + 1
|
| 1390 |
+
audio_seqlen = (input_length - 2) // 2 + 1
|
| 1391 |
+
else:
|
| 1392 |
+
audio_seqlen = 1
|
| 1393 |
+
|
| 1394 |
+
content = content.replace(
|
| 1395 |
+
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
message["content"] = content
|
| 1399 |
+
|
| 1400 |
+
return messages
|
| 1401 |
+
|
| 1402 |
+
@override
|
| 1403 |
+
def get_mm_inputs(
|
| 1404 |
+
self,
|
| 1405 |
+
images: list["ImageInput"],
|
| 1406 |
+
videos: list["VideoInput"],
|
| 1407 |
+
audios: list["AudioInput"],
|
| 1408 |
+
imglens: list[int],
|
| 1409 |
+
vidlens: list[int],
|
| 1410 |
+
audlens: list[int],
|
| 1411 |
+
batch_ids: list[list[int]],
|
| 1412 |
+
processor: Optional["MMProcessor"],
|
| 1413 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1414 |
+
self._validate_input(processor, images, videos, audios)
|
| 1415 |
+
return self._get_mm_inputs(images, videos, audios, processor)
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
@dataclass
|
| 1419 |
+
class Qwen2VLPlugin(BasePlugin):
|
| 1420 |
+
vision_bos_token: str = "<|vision_start|>"
|
| 1421 |
+
vision_eos_token: str = "<|vision_end|>"
|
| 1422 |
+
|
| 1423 |
+
@override
|
| 1424 |
+
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
| 1425 |
+
image = super()._preprocess_image(image, **kwargs)
|
| 1426 |
+
if min(image.width, image.height) < 28:
|
| 1427 |
+
width, height = max(image.width, 28), max(image.height, 28)
|
| 1428 |
+
image = image.resize((width, height))
|
| 1429 |
+
|
| 1430 |
+
if image.width / image.height > 200:
|
| 1431 |
+
width, height = image.height * 180, image.height
|
| 1432 |
+
image = image.resize((width, height))
|
| 1433 |
+
|
| 1434 |
+
if image.height / image.width > 200:
|
| 1435 |
+
width, height = image.width, image.width * 180
|
| 1436 |
+
image = image.resize((width, height))
|
| 1437 |
+
|
| 1438 |
+
return image
|
| 1439 |
+
|
| 1440 |
+
@override
|
| 1441 |
+
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
| 1442 |
+
results, fps_per_video, durations = [], [], []
|
| 1443 |
+
for video in videos:
|
| 1444 |
+
frames: list[ImageObject] = []
|
| 1445 |
+
if _check_video_is_nested_images(video):
|
| 1446 |
+
for frame in video:
|
| 1447 |
+
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
| 1448 |
+
raise ValueError("Invalid image found in video frames.")
|
| 1449 |
+
|
| 1450 |
+
frames = video
|
| 1451 |
+
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
| 1452 |
+
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
| 1453 |
+
else:
|
| 1454 |
+
container = av.open(video, "r")
|
| 1455 |
+
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
| 1456 |
+
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
| 1457 |
+
container.seek(0)
|
| 1458 |
+
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
| 1459 |
+
if frame_idx in sample_indices:
|
| 1460 |
+
frames.append(frame.to_image())
|
| 1461 |
+
|
| 1462 |
+
if video_stream.duration is None:
|
| 1463 |
+
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
| 1464 |
+
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
| 1465 |
+
else:
|
| 1466 |
+
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
| 1467 |
+
durations.append(float(video_stream.duration * video_stream.time_base))
|
| 1468 |
+
|
| 1469 |
+
if len(frames) % 2 != 0:
|
| 1470 |
+
frames.append(frames[-1])
|
| 1471 |
+
|
| 1472 |
+
frames = self._regularize_images(frames, **kwargs)["images"]
|
| 1473 |
+
results.append(frames)
|
| 1474 |
+
|
| 1475 |
+
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
| 1476 |
+
|
| 1477 |
+
@override
|
| 1478 |
+
def _get_mm_inputs(
|
| 1479 |
+
self,
|
| 1480 |
+
images: list["ImageInput"],
|
| 1481 |
+
videos: list["VideoInput"],
|
| 1482 |
+
audios: list["AudioInput"],
|
| 1483 |
+
processor: "MMProcessor",
|
| 1484 |
+
) -> dict[str, "torch.Tensor"]:
|
| 1485 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 1486 |
+
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
| 1487 |
+
mm_inputs = {}
|
| 1488 |
+
if len(images) != 0:
|
| 1489 |
+
images = self._regularize_images(
|
| 1490 |
+
images,
|
| 1491 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 1492 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 1493 |
+
)["images"]
|
| 1494 |
+
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
| 1495 |
+
|
| 1496 |
+
if len(videos) != 0:
|
| 1497 |
+
video_data = self._regularize_videos(
|
| 1498 |
+
videos,
|
| 1499 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 1500 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 1501 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 1502 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 1503 |
+
)
|
| 1504 |
+
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
|
| 1505 |
+
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
| 1506 |
+
if "second_per_grid_ts" in processor.model_input_names:
|
| 1507 |
+
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
| 1508 |
+
|
| 1509 |
+
return mm_inputs
|
| 1510 |
+
|
| 1511 |
+
@override
|
| 1512 |
+
def process_messages(
|
| 1513 |
+
self,
|
| 1514 |
+
messages: list[dict[str, str]],
|
| 1515 |
+
images: list["ImageInput"],
|
| 1516 |
+
videos: list["VideoInput"],
|
| 1517 |
+
audios: list["AudioInput"],
|
| 1518 |
+
processor: Optional["MMProcessor"],
|
| 1519 |
+
) -> list[dict[str, str]]:
|
| 1520 |
+
self._validate_input(processor, images, videos, audios)
|
| 1521 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1522 |
+
num_image_tokens, num_video_tokens = 0, 0
|
| 1523 |
+
messages = deepcopy(messages)
|
| 1524 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 1525 |
+
|
| 1526 |
+
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
| 1527 |
+
if self.expand_mm_tokens:
|
| 1528 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1529 |
+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
| 1530 |
+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
| 1531 |
+
else:
|
| 1532 |
+
image_grid_thw = [None] * len(images)
|
| 1533 |
+
video_grid_thw = [None] * len(videos)
|
| 1534 |
+
|
| 1535 |
+
for message in messages:
|
| 1536 |
+
content = message["content"]
|
| 1537 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1538 |
+
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1539 |
+
content = content.replace(
|
| 1540 |
+
IMAGE_PLACEHOLDER,
|
| 1541 |
+
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
| 1542 |
+
1,
|
| 1543 |
+
)
|
| 1544 |
+
num_image_tokens += 1
|
| 1545 |
+
|
| 1546 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1547 |
+
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1548 |
+
content = content.replace(
|
| 1549 |
+
VIDEO_PLACEHOLDER,
|
| 1550 |
+
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
| 1551 |
+
1,
|
| 1552 |
+
)
|
| 1553 |
+
num_video_tokens += 1
|
| 1554 |
+
|
| 1555 |
+
message["content"] = content
|
| 1556 |
+
|
| 1557 |
+
return messages
|
| 1558 |
+
|
| 1559 |
+
|
| 1560 |
+
@dataclass
|
| 1561 |
+
class Qwen3VLPlugin(Qwen2VLPlugin):
|
| 1562 |
+
@override
|
| 1563 |
+
def _get_mm_inputs(
|
| 1564 |
+
self,
|
| 1565 |
+
images: list["ImageInput"],
|
| 1566 |
+
videos: list["VideoInput"],
|
| 1567 |
+
audios: list["AudioInput"],
|
| 1568 |
+
processor: "MMProcessor",
|
| 1569 |
+
) -> dict[str, "torch.Tensor"]:
|
| 1570 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 1571 |
+
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
|
| 1572 |
+
mm_inputs = {}
|
| 1573 |
+
if len(images) != 0:
|
| 1574 |
+
images = self._regularize_images(
|
| 1575 |
+
images,
|
| 1576 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 1577 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 1578 |
+
)["images"]
|
| 1579 |
+
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
| 1580 |
+
|
| 1581 |
+
if len(videos) != 0:
|
| 1582 |
+
videos = self._regularize_videos(
|
| 1583 |
+
videos,
|
| 1584 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 1585 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 1586 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 1587 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 1588 |
+
)
|
| 1589 |
+
video_metadata = [
|
| 1590 |
+
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
| 1591 |
+
for video, duration in zip(videos["videos"], videos["durations"])
|
| 1592 |
+
]
|
| 1593 |
+
mm_inputs.update(
|
| 1594 |
+
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
| 1595 |
+
)
|
| 1596 |
+
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
| 1597 |
+
if "second_per_grid_ts" in processor.model_input_names:
|
| 1598 |
+
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
|
| 1599 |
+
|
| 1600 |
+
return mm_inputs
|
| 1601 |
+
|
| 1602 |
+
@override
|
| 1603 |
+
def process_messages(
|
| 1604 |
+
self,
|
| 1605 |
+
messages: list[dict[str, str]],
|
| 1606 |
+
images: list["ImageInput"],
|
| 1607 |
+
videos: list["VideoInput"],
|
| 1608 |
+
audios: list["AudioInput"],
|
| 1609 |
+
processor: Optional["MMProcessor"],
|
| 1610 |
+
) -> list[dict[str, str]]:
|
| 1611 |
+
self._validate_input(processor, images, videos, audios)
|
| 1612 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1613 |
+
num_image_tokens, num_video_tokens = 0, 0
|
| 1614 |
+
messages = deepcopy(messages)
|
| 1615 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 1616 |
+
video_processor: BaseImageProcessor = getattr(processor, "video_processor")
|
| 1617 |
+
|
| 1618 |
+
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
|
| 1619 |
+
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
|
| 1620 |
+
if self.expand_mm_tokens:
|
| 1621 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1622 |
+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
| 1623 |
+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
| 1624 |
+
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
| 1625 |
+
video_metadata = mm_inputs.get("video_metadata", {})
|
| 1626 |
+
|
| 1627 |
+
else:
|
| 1628 |
+
image_grid_thw = [None] * len(images)
|
| 1629 |
+
video_grid_thw = [None] * len(videos)
|
| 1630 |
+
num_frames = 0
|
| 1631 |
+
timestamps = [0]
|
| 1632 |
+
|
| 1633 |
+
for idx, message in enumerate(messages):
|
| 1634 |
+
content = message["content"]
|
| 1635 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1636 |
+
image_seqlen = (
|
| 1637 |
+
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
|
| 1638 |
+
)
|
| 1639 |
+
content = content.replace(
|
| 1640 |
+
IMAGE_PLACEHOLDER,
|
| 1641 |
+
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
| 1642 |
+
1,
|
| 1643 |
+
)
|
| 1644 |
+
num_image_tokens += 1
|
| 1645 |
+
|
| 1646 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1647 |
+
if self.expand_mm_tokens:
|
| 1648 |
+
metadata = video_metadata[idx]
|
| 1649 |
+
timestamps = processor._calculate_timestamps(
|
| 1650 |
+
metadata.frames_indices,
|
| 1651 |
+
metadata.fps,
|
| 1652 |
+
video_processor.merge_size,
|
| 1653 |
+
)
|
| 1654 |
+
video_structure = ""
|
| 1655 |
+
for frame_index in range(num_frames):
|
| 1656 |
+
video_seqlen = (
|
| 1657 |
+
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
| 1658 |
+
if self.expand_mm_tokens
|
| 1659 |
+
else 1
|
| 1660 |
+
)
|
| 1661 |
+
timestamp_sec = timestamps[frame_index]
|
| 1662 |
+
frame_structure = (
|
| 1663 |
+
f"<{timestamp_sec:.1f} seconds>"
|
| 1664 |
+
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
| 1665 |
+
)
|
| 1666 |
+
video_structure += frame_structure
|
| 1667 |
+
else:
|
| 1668 |
+
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
| 1669 |
+
|
| 1670 |
+
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
| 1671 |
+
num_video_tokens += 1
|
| 1672 |
+
|
| 1673 |
+
message["content"] = content
|
| 1674 |
+
|
| 1675 |
+
return messages
|
| 1676 |
+
|
| 1677 |
+
|
| 1678 |
+
@dataclass
|
| 1679 |
+
class GLM4VPlugin(Qwen2VLPlugin):
|
| 1680 |
+
@override
|
| 1681 |
+
def _get_mm_inputs(
|
| 1682 |
+
self,
|
| 1683 |
+
images: list["ImageInput"],
|
| 1684 |
+
videos: list["VideoInput"],
|
| 1685 |
+
audios: list["AudioInput"],
|
| 1686 |
+
processor: "MMProcessor",
|
| 1687 |
+
) -> dict[str, "torch.Tensor"]:
|
| 1688 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 1689 |
+
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
|
| 1690 |
+
mm_inputs = {}
|
| 1691 |
+
if len(images) != 0:
|
| 1692 |
+
images = self._regularize_images(
|
| 1693 |
+
images,
|
| 1694 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 1695 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 1696 |
+
)["images"]
|
| 1697 |
+
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
| 1698 |
+
|
| 1699 |
+
if len(videos) != 0:
|
| 1700 |
+
video_data = self._regularize_videos(
|
| 1701 |
+
videos,
|
| 1702 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 1703 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 1704 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 1705 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 1706 |
+
)
|
| 1707 |
+
# prepare video metadata
|
| 1708 |
+
video_metadata = [
|
| 1709 |
+
{"fps": 2, "duration": duration, "total_frames": len(video)}
|
| 1710 |
+
for video, duration in zip(video_data["videos"], video_data["durations"])
|
| 1711 |
+
]
|
| 1712 |
+
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
| 1713 |
+
|
| 1714 |
+
return mm_inputs
|
| 1715 |
+
|
| 1716 |
+
@override
|
| 1717 |
+
def process_messages(
|
| 1718 |
+
self,
|
| 1719 |
+
messages: list[dict[str, str]],
|
| 1720 |
+
images: list["ImageInput"],
|
| 1721 |
+
videos: list["VideoInput"],
|
| 1722 |
+
audios: list["AudioInput"],
|
| 1723 |
+
processor: Optional["MMProcessor"],
|
| 1724 |
+
) -> list[dict[str, str]]:
|
| 1725 |
+
self._validate_input(processor, images, videos, audios)
|
| 1726 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1727 |
+
num_image_tokens, num_video_tokens = 0, 0
|
| 1728 |
+
messages = deepcopy(messages)
|
| 1729 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
| 1730 |
+
|
| 1731 |
+
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
| 1732 |
+
if self.expand_mm_tokens:
|
| 1733 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1734 |
+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
| 1735 |
+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
| 1736 |
+
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
| 1737 |
+
timestamps = mm_inputs.get("timestamps", [])
|
| 1738 |
+
|
| 1739 |
+
if hasattr(timestamps, "tolist"):
|
| 1740 |
+
timestamps = timestamps.tolist()
|
| 1741 |
+
|
| 1742 |
+
if not timestamps:
|
| 1743 |
+
timestamps_list = []
|
| 1744 |
+
elif isinstance(timestamps[0], list):
|
| 1745 |
+
timestamps_list = timestamps[0]
|
| 1746 |
+
else:
|
| 1747 |
+
timestamps_list = timestamps
|
| 1748 |
+
|
| 1749 |
+
unique_timestamps = timestamps_list.copy()
|
| 1750 |
+
selected_timestamps = unique_timestamps[:num_frames]
|
| 1751 |
+
while len(selected_timestamps) < num_frames:
|
| 1752 |
+
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
| 1753 |
+
|
| 1754 |
+
else:
|
| 1755 |
+
image_grid_thw = [None] * len(images)
|
| 1756 |
+
video_grid_thw = [None] * len(videos)
|
| 1757 |
+
num_frames = 0
|
| 1758 |
+
selected_timestamps = [0]
|
| 1759 |
+
|
| 1760 |
+
for message in messages:
|
| 1761 |
+
content = message["content"]
|
| 1762 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1763 |
+
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1764 |
+
content = content.replace(
|
| 1765 |
+
IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
|
| 1766 |
+
)
|
| 1767 |
+
num_image_tokens += 1
|
| 1768 |
+
|
| 1769 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1770 |
+
video_structure = ""
|
| 1771 |
+
for frame_index in range(num_frames):
|
| 1772 |
+
video_seqlen = (
|
| 1773 |
+
video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1774 |
+
)
|
| 1775 |
+
timestamp_sec = selected_timestamps[frame_index]
|
| 1776 |
+
frame_structure = (
|
| 1777 |
+
f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
|
| 1778 |
+
)
|
| 1779 |
+
video_structure += frame_structure
|
| 1780 |
+
|
| 1781 |
+
if not self.expand_mm_tokens:
|
| 1782 |
+
video_structure = self.video_token
|
| 1783 |
+
|
| 1784 |
+
content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
|
| 1785 |
+
num_video_tokens += 1
|
| 1786 |
+
|
| 1787 |
+
message["content"] = content
|
| 1788 |
+
|
| 1789 |
+
return messages
|
| 1790 |
+
|
| 1791 |
+
@override
|
| 1792 |
+
def get_mm_inputs(
|
| 1793 |
+
self,
|
| 1794 |
+
images: list["ImageInput"],
|
| 1795 |
+
videos: list["VideoInput"],
|
| 1796 |
+
audios: list["AudioInput"],
|
| 1797 |
+
imglens: list[int],
|
| 1798 |
+
vidlens: list[int],
|
| 1799 |
+
audlens: list[int],
|
| 1800 |
+
batch_ids: list[list[int]],
|
| 1801 |
+
processor: Optional["ProcessorMixin"],
|
| 1802 |
+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
| 1803 |
+
self._validate_input(processor, images, videos, audios)
|
| 1804 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1805 |
+
mm_inputs.pop("timestamps", None)
|
| 1806 |
+
return mm_inputs
|
| 1807 |
+
|
| 1808 |
+
|
| 1809 |
+
@dataclass
|
| 1810 |
+
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
| 1811 |
+
audio_bos_token: str = "<|audio_start|>"
|
| 1812 |
+
audio_eos_token: str = "<|audio_end|>"
|
| 1813 |
+
|
| 1814 |
+
@override
|
| 1815 |
+
def _get_mm_inputs(
|
| 1816 |
+
self,
|
| 1817 |
+
images: list["ImageInput"],
|
| 1818 |
+
videos: list["VideoInput"],
|
| 1819 |
+
audios: list["AudioInput"],
|
| 1820 |
+
processor: "MMProcessor",
|
| 1821 |
+
) -> dict[str, "torch.Tensor"]:
|
| 1822 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 1823 |
+
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
| 1824 |
+
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
| 1825 |
+
mm_inputs = {}
|
| 1826 |
+
if len(images) != 0:
|
| 1827 |
+
images = self._regularize_images(
|
| 1828 |
+
images,
|
| 1829 |
+
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
| 1830 |
+
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
| 1831 |
+
)["images"]
|
| 1832 |
+
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
| 1833 |
+
|
| 1834 |
+
if len(videos) != 0:
|
| 1835 |
+
video_dict = self._regularize_videos(
|
| 1836 |
+
videos,
|
| 1837 |
+
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
| 1838 |
+
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
| 1839 |
+
video_fps=getattr(processor, "video_fps", 2.0),
|
| 1840 |
+
video_maxlen=getattr(processor, "video_maxlen", 128),
|
| 1841 |
+
)
|
| 1842 |
+
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
|
| 1843 |
+
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
| 1844 |
+
mm_inputs["video_second_per_grid"] = torch.tensor(
|
| 1845 |
+
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
| 1846 |
+
)
|
| 1847 |
+
|
| 1848 |
+
if len(audios) != 0:
|
| 1849 |
+
audios = self._regularize_audios(
|
| 1850 |
+
audios,
|
| 1851 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 1852 |
+
)["audios"]
|
| 1853 |
+
mm_inputs.update(
|
| 1854 |
+
feature_extractor(
|
| 1855 |
+
audios,
|
| 1856 |
+
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
| 1857 |
+
return_attention_mask=True,
|
| 1858 |
+
padding="max_length",
|
| 1859 |
+
return_tensors="pt",
|
| 1860 |
+
)
|
| 1861 |
+
)
|
| 1862 |
+
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
| 1863 |
+
|
| 1864 |
+
return mm_inputs
|
| 1865 |
+
|
| 1866 |
+
@override
|
| 1867 |
+
def process_messages(
|
| 1868 |
+
self,
|
| 1869 |
+
messages: list[dict[str, str]],
|
| 1870 |
+
images: list["ImageInput"],
|
| 1871 |
+
videos: list["VideoInput"],
|
| 1872 |
+
audios: list["AudioInput"],
|
| 1873 |
+
processor: Optional["MMProcessor"],
|
| 1874 |
+
) -> list[dict[str, str]]:
|
| 1875 |
+
self._validate_input(processor, images, videos, audios)
|
| 1876 |
+
self._validate_messages(messages, images, videos, audios)
|
| 1877 |
+
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
| 1878 |
+
messages = deepcopy(messages)
|
| 1879 |
+
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
| 1880 |
+
|
| 1881 |
+
merge_length = processor.image_processor.merge_size**2
|
| 1882 |
+
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
| 1883 |
+
if self.expand_mm_tokens:
|
| 1884 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 1885 |
+
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
| 1886 |
+
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
| 1887 |
+
if "feature_attention_mask" in mm_inputs:
|
| 1888 |
+
if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni
|
| 1889 |
+
input_lengths = mm_inputs["feature_attention_mask"].sum(-1)
|
| 1890 |
+
input_lengths_leave = input_lengths % 100
|
| 1891 |
+
feature_lengths = (input_lengths_leave - 1) // 2 + 1
|
| 1892 |
+
audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
| 1893 |
+
else:
|
| 1894 |
+
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
| 1895 |
+
audio_lengths = (input_lengths - 2) // 2 + 1
|
| 1896 |
+
else:
|
| 1897 |
+
mm_inputs = {}
|
| 1898 |
+
image_grid_thw = [None] * len(images)
|
| 1899 |
+
video_grid_thw = [None] * len(videos)
|
| 1900 |
+
audio_lengths = [None] * len(audios)
|
| 1901 |
+
|
| 1902 |
+
for message in messages:
|
| 1903 |
+
content = message["content"]
|
| 1904 |
+
while IMAGE_PLACEHOLDER in content:
|
| 1905 |
+
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1906 |
+
content = content.replace(
|
| 1907 |
+
IMAGE_PLACEHOLDER,
|
| 1908 |
+
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
| 1909 |
+
1,
|
| 1910 |
+
)
|
| 1911 |
+
num_image_tokens += 1
|
| 1912 |
+
|
| 1913 |
+
if (
|
| 1914 |
+
use_audio_in_video and len(audios) and len(videos)
|
| 1915 |
+
): # if use the audio of video # deal video token and audio token togather
|
| 1916 |
+
if len(videos) != len(audios):
|
| 1917 |
+
raise ValueError(
|
| 1918 |
+
f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video."
|
| 1919 |
+
)
|
| 1920 |
+
|
| 1921 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1922 |
+
video_pos = content.find(VIDEO_PLACEHOLDER)
|
| 1923 |
+
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
|
| 1924 |
+
if audio_pos == -1 or audio_pos < video_pos:
|
| 1925 |
+
raise ValueError(
|
| 1926 |
+
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
| 1927 |
+
)
|
| 1928 |
+
|
| 1929 |
+
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
| 1930 |
+
video_t_index = (
|
| 1931 |
+
torch.arange(video_grid_thw[num_video_tokens][0])
|
| 1932 |
+
.view(-1, 1, 1)
|
| 1933 |
+
.expand(
|
| 1934 |
+
-1,
|
| 1935 |
+
video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
|
| 1936 |
+
video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
|
| 1937 |
+
)
|
| 1938 |
+
.flatten()
|
| 1939 |
+
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
| 1940 |
+
* 25 # FIXME hardcode of position_id_per_seconds=25
|
| 1941 |
+
).long()
|
| 1942 |
+
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
| 1943 |
+
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
| 1944 |
+
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
| 1945 |
+
placeholder_string = ""
|
| 1946 |
+
placeholder_string += self.vision_bos_token + self.audio_bos_token
|
| 1947 |
+
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
| 1948 |
+
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
| 1949 |
+
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
| 1950 |
+
if video_chunk_index is not None:
|
| 1951 |
+
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
| 1952 |
+
|
| 1953 |
+
if audio_chunk_index is not None:
|
| 1954 |
+
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
| 1955 |
+
|
| 1956 |
+
placeholder_string += self.audio_eos_token + self.vision_eos_token
|
| 1957 |
+
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
| 1958 |
+
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
| 1959 |
+
num_audio_tokens += 1
|
| 1960 |
+
num_video_tokens += 1
|
| 1961 |
+
else:
|
| 1962 |
+
while AUDIO_PLACEHOLDER in content:
|
| 1963 |
+
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
| 1964 |
+
content = content.replace(
|
| 1965 |
+
AUDIO_PLACEHOLDER,
|
| 1966 |
+
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
|
| 1967 |
+
1,
|
| 1968 |
+
)
|
| 1969 |
+
num_audio_tokens += 1
|
| 1970 |
+
|
| 1971 |
+
while VIDEO_PLACEHOLDER in content:
|
| 1972 |
+
video_seqlen = (
|
| 1973 |
+
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
| 1974 |
+
)
|
| 1975 |
+
content = content.replace(
|
| 1976 |
+
VIDEO_PLACEHOLDER,
|
| 1977 |
+
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
| 1978 |
+
1,
|
| 1979 |
+
)
|
| 1980 |
+
num_video_tokens += 1
|
| 1981 |
+
|
| 1982 |
+
message["content"] = content
|
| 1983 |
+
|
| 1984 |
+
return messages
|
| 1985 |
+
|
| 1986 |
+
|
| 1987 |
+
@dataclass
|
| 1988 |
+
class VideoLlavaPlugin(BasePlugin):
|
| 1989 |
+
@override
|
| 1990 |
+
def process_messages(
|
| 1991 |
+
self,
|
| 1992 |
+
messages: list[dict[str, str]],
|
| 1993 |
+
images: list["ImageInput"],
|
| 1994 |
+
videos: list["VideoInput"],
|
| 1995 |
+
audios: list["AudioInput"],
|
| 1996 |
+
processor: Optional["MMProcessor"],
|
| 1997 |
+
) -> list[dict[str, str]]:
|
| 1998 |
+
self._validate_input(processor, images, videos, audios)
|
| 1999 |
+
self._validate_messages(messages, images, videos, audios)
|
| 2000 |
+
num_image_tokens, num_video_tokens = 0, 0
|
| 2001 |
+
messages = deepcopy(messages)
|
| 2002 |
+
num_frames = 0
|
| 2003 |
+
if self.expand_mm_tokens:
|
| 2004 |
+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
| 2005 |
+
if "pixel_values_images" in mm_inputs:
|
| 2006 |
+
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
|
| 2007 |
+
num_frames = 1
|
| 2008 |
+
|
| 2009 |
+
if "pixel_values_videos" in mm_inputs:
|
| 2010 |
+
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
|
| 2011 |
+
height, width = get_image_size(one_video[0])
|
| 2012 |
+
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
| 2013 |
+
|
| 2014 |
+
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
|
| 2015 |
+
image_seqlen = (height // processor.patch_size) * (
|
| 2016 |
+
width // processor.patch_size
|
| 2017 |
+
) + processor.num_additional_image_tokens
|
| 2018 |
+
video_seqlen = image_seqlen * num_frames
|
| 2019 |
+
if processor.vision_feature_select_strategy == "default":
|
| 2020 |
+
image_seqlen -= 1
|
| 2021 |
+
else:
|
| 2022 |
+
image_seqlen, video_seqlen = 1, 1
|
| 2023 |
+
|
| 2024 |
+
for message in messages:
|
| 2025 |
+
content = message["content"]
|
| 2026 |
+
while IMAGE_PLACEHOLDER in content:
|
| 2027 |
+
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
| 2028 |
+
num_image_tokens += 1
|
| 2029 |
+
|
| 2030 |
+
while VIDEO_PLACEHOLDER in content:
|
| 2031 |
+
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
| 2032 |
+
num_video_tokens += 1
|
| 2033 |
+
|
| 2034 |
+
content = content.replace("{{image}}", self.image_token)
|
| 2035 |
+
message["content"] = content.replace("{{video}}", self.video_token)
|
| 2036 |
+
|
| 2037 |
+
return messages
|
| 2038 |
+
|
| 2039 |
+
|
| 2040 |
+
PLUGINS = {
|
| 2041 |
+
"base": BasePlugin,
|
| 2042 |
+
"gemma3": Gemma3Plugin,
|
| 2043 |
+
"glm4v": GLM4VPlugin,
|
| 2044 |
+
"gemma3n": Gemma3nPlugin,
|
| 2045 |
+
"intern_vl": InternVLPlugin,
|
| 2046 |
+
"kimi_vl": KimiVLPlugin,
|
| 2047 |
+
"llama4": Llama4Plugin,
|
| 2048 |
+
"llava": LlavaPlugin,
|
| 2049 |
+
"llava_next": LlavaNextPlugin,
|
| 2050 |
+
"llava_next_video": LlavaNextVideoPlugin,
|
| 2051 |
+
"minicpm_v": MiniCPMVPlugin,
|
| 2052 |
+
"mllama": MllamaPlugin,
|
| 2053 |
+
"paligemma": PaliGemmaPlugin,
|
| 2054 |
+
"pixtral": PixtralPlugin,
|
| 2055 |
+
"qwen2_audio": Qwen2AudioPlugin,
|
| 2056 |
+
"qwen2_omni": Qwen2OmniPlugin,
|
| 2057 |
+
"qwen2_vl": Qwen2VLPlugin,
|
| 2058 |
+
"qwen3_vl": Qwen3VLPlugin,
|
| 2059 |
+
"video_llava": VideoLlavaPlugin,
|
| 2060 |
+
}
|
| 2061 |
+
|
| 2062 |
+
|
| 2063 |
+
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
| 2064 |
+
r"""Register a multimodal plugin."""
|
| 2065 |
+
if name in PLUGINS:
|
| 2066 |
+
raise ValueError(f"Multimodal plugin {name} already exists.")
|
| 2067 |
+
|
| 2068 |
+
PLUGINS[name] = plugin_class
|
| 2069 |
+
|
| 2070 |
+
|
| 2071 |
+
def get_mm_plugin(
|
| 2072 |
+
name: str,
|
| 2073 |
+
image_token: Optional[str] = None,
|
| 2074 |
+
video_token: Optional[str] = None,
|
| 2075 |
+
audio_token: Optional[str] = None,
|
| 2076 |
+
**kwargs,
|
| 2077 |
+
) -> "BasePlugin":
|
| 2078 |
+
r"""Get plugin for multimodal inputs."""
|
| 2079 |
+
if name not in PLUGINS:
|
| 2080 |
+
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
| 2081 |
+
|
| 2082 |
+
return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
|
llamafactory/data/parser.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Literal, Optional, Union
|
| 19 |
+
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
|
| 22 |
+
from ..extras.constants import DATA_CONFIG
|
| 23 |
+
from ..extras.misc import use_modelscope, use_openmind
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class DatasetAttr:
|
| 28 |
+
r"""Dataset attributes."""
|
| 29 |
+
|
| 30 |
+
# basic configs
|
| 31 |
+
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
| 32 |
+
dataset_name: str
|
| 33 |
+
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
| 34 |
+
ranking: bool = False
|
| 35 |
+
# extra configs
|
| 36 |
+
subset: Optional[str] = None
|
| 37 |
+
split: str = "train"
|
| 38 |
+
folder: Optional[str] = None
|
| 39 |
+
num_samples: Optional[int] = None
|
| 40 |
+
# common columns
|
| 41 |
+
system: Optional[str] = None
|
| 42 |
+
tools: Optional[str] = None
|
| 43 |
+
images: Optional[str] = None
|
| 44 |
+
videos: Optional[str] = None
|
| 45 |
+
audios: Optional[str] = None
|
| 46 |
+
# dpo columns
|
| 47 |
+
chosen: Optional[str] = None
|
| 48 |
+
rejected: Optional[str] = None
|
| 49 |
+
kto_tag: Optional[str] = None
|
| 50 |
+
# alpaca columns
|
| 51 |
+
prompt: Optional[str] = "instruction"
|
| 52 |
+
query: Optional[str] = "input"
|
| 53 |
+
response: Optional[str] = "output"
|
| 54 |
+
history: Optional[str] = None
|
| 55 |
+
# sharegpt columns
|
| 56 |
+
messages: Optional[str] = "conversations"
|
| 57 |
+
# sharegpt tags
|
| 58 |
+
role_tag: Optional[str] = "from"
|
| 59 |
+
content_tag: Optional[str] = "value"
|
| 60 |
+
user_tag: Optional[str] = "human"
|
| 61 |
+
assistant_tag: Optional[str] = "gpt"
|
| 62 |
+
observation_tag: Optional[str] = "observation"
|
| 63 |
+
function_tag: Optional[str] = "function_call"
|
| 64 |
+
system_tag: Optional[str] = "system"
|
| 65 |
+
|
| 66 |
+
def __repr__(self) -> str:
|
| 67 |
+
return self.dataset_name
|
| 68 |
+
|
| 69 |
+
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
| 70 |
+
setattr(self, key, obj.get(key, default))
|
| 71 |
+
|
| 72 |
+
def join(self, attr: dict[str, Any]) -> None:
|
| 73 |
+
self.set_attr("formatting", attr, default="alpaca")
|
| 74 |
+
self.set_attr("ranking", attr, default=False)
|
| 75 |
+
self.set_attr("subset", attr)
|
| 76 |
+
self.set_attr("split", attr, default="train")
|
| 77 |
+
self.set_attr("folder", attr)
|
| 78 |
+
self.set_attr("num_samples", attr)
|
| 79 |
+
|
| 80 |
+
if "columns" in attr:
|
| 81 |
+
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
|
| 82 |
+
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
|
| 83 |
+
for column_name in column_names:
|
| 84 |
+
self.set_attr(column_name, attr["columns"])
|
| 85 |
+
|
| 86 |
+
if "tags" in attr:
|
| 87 |
+
tag_names = ["role_tag", "content_tag"]
|
| 88 |
+
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
|
| 89 |
+
for tag in tag_names:
|
| 90 |
+
self.set_attr(tag, attr["tags"])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
| 94 |
+
r"""Get the attributes of the datasets."""
|
| 95 |
+
if dataset_names is None:
|
| 96 |
+
dataset_names = []
|
| 97 |
+
|
| 98 |
+
if isinstance(dataset_dir, dict):
|
| 99 |
+
dataset_info = dataset_dir
|
| 100 |
+
elif dataset_dir == "ONLINE":
|
| 101 |
+
dataset_info = None
|
| 102 |
+
else:
|
| 103 |
+
if dataset_dir.startswith("REMOTE:"):
|
| 104 |
+
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
| 105 |
+
else:
|
| 106 |
+
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
with open(config_path) as f:
|
| 110 |
+
dataset_info = json.load(f)
|
| 111 |
+
except Exception as err:
|
| 112 |
+
if len(dataset_names) != 0:
|
| 113 |
+
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
| 114 |
+
|
| 115 |
+
dataset_info = None
|
| 116 |
+
|
| 117 |
+
dataset_list: list[DatasetAttr] = []
|
| 118 |
+
for name in dataset_names:
|
| 119 |
+
if dataset_info is None: # dataset_dir is ONLINE
|
| 120 |
+
load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
|
| 121 |
+
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
| 122 |
+
dataset_list.append(dataset_attr)
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
if name not in dataset_info:
|
| 126 |
+
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
| 127 |
+
|
| 128 |
+
has_hf_url = "hf_hub_url" in dataset_info[name]
|
| 129 |
+
has_ms_url = "ms_hub_url" in dataset_info[name]
|
| 130 |
+
has_om_url = "om_hub_url" in dataset_info[name]
|
| 131 |
+
|
| 132 |
+
if has_hf_url or has_ms_url or has_om_url:
|
| 133 |
+
if has_ms_url and (use_modelscope() or not has_hf_url):
|
| 134 |
+
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
| 135 |
+
elif has_om_url and (use_openmind() or not has_hf_url):
|
| 136 |
+
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
| 137 |
+
else:
|
| 138 |
+
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
| 139 |
+
elif "script_url" in dataset_info[name]:
|
| 140 |
+
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
| 141 |
+
elif "cloud_file_name" in dataset_info[name]:
|
| 142 |
+
dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
|
| 143 |
+
else:
|
| 144 |
+
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
| 145 |
+
|
| 146 |
+
dataset_attr.join(dataset_info[name])
|
| 147 |
+
dataset_list.append(dataset_attr)
|
| 148 |
+
|
| 149 |
+
return dataset_list
|
llamafactory/data/processor/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the LlamaFactory team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .feedback import FeedbackDatasetProcessor
|
| 16 |
+
from .pairwise import PairwiseDatasetProcessor
|
| 17 |
+
from .pretrain import PretrainDatasetProcessor
|
| 18 |
+
from .processor_utils import DatasetProcessor
|
| 19 |
+
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
|
| 20 |
+
from .unsupervised import UnsupervisedDatasetProcessor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"DatasetProcessor",
|
| 25 |
+
"FeedbackDatasetProcessor",
|
| 26 |
+
"PackedSupervisedDatasetProcessor",
|
| 27 |
+
"PairwiseDatasetProcessor",
|
| 28 |
+
"PretrainDatasetProcessor",
|
| 29 |
+
"SupervisedDatasetProcessor",
|
| 30 |
+
"UnsupervisedDatasetProcessor",
|
| 31 |
+
]
|
llamafactory/data/processor/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (630 Bytes). View file
|
|
|
llamafactory/data/processor/__pycache__/feedback.cpython-312.pyc
ADDED
|
Binary file (7.39 kB). View file
|
|
|