Restore public repo before mistaken code upload
Browse filesRestores files to revision 1a8e9e3ad2130f55c880af8ace85b9af0d0c329f and removes files mistakenly uploaded from unirl_opensource.
- .gitignore +0 -6
- README.md +0 -117
- assets/rl_datasets/README.md +0 -23
- rewards_services/api_services/editreward_scorer_service/README.md +0 -35
- rewards_services/api_services/editreward_scorer_service/app.py +0 -94
- rewards_services/api_services/editreward_scorer_service/editreward_scorer.py +0 -65
- rewards_services/api_services/editreward_scorer_service/gunicorn.conf.py +0 -34
- rewards_services/api_services/editreward_scorer_service/requirements.txt +0 -18
- rewards_services/api_services/editreward_scorer_service/run.sh +0 -13
- scripts/train/deepspeed/zero3.json +0 -39
- scripts/train/edit_grpo.sh +0 -77
- unimodel/qwenkontext/fluxkontext_pipeline.py +1 -565
- unirl/__init__.py +0 -2
- unirl/reward_evaluator/__init__.py +0 -4
- unirl/reward_evaluator/reward_evaluator.py +0 -71
- unirl/train_edit.py +0 -265
- unirl/trainer/__init__.py +0 -4
- unirl/trainer/edit_grpo_trainer.py +0 -623
.gitignore
CHANGED
|
@@ -145,11 +145,5 @@ outputs/
|
|
| 145 |
wandb/
|
| 146 |
|
| 147 |
assets/large_rl_datasets/
|
| 148 |
-
assets/rl_datasets/*.parquet
|
| 149 |
-
assets/rl_datasets/*.jsonl
|
| 150 |
-
!assets/rl_datasets/README.md
|
| 151 |
|
| 152 |
utils/parquet_cache/
|
| 153 |
-
|
| 154 |
-
rewards_services/api_services/editreward_scorer_service/.venv/
|
| 155 |
-
rewards_services/api_services/editreward_scorer_service/EditReward/
|
|
|
|
| 145 |
wandb/
|
| 146 |
|
| 147 |
assets/large_rl_datasets/
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
utils/parquet_cache/
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,17 +1,3 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
library_name: diffusers
|
| 4 |
-
tags:
|
| 5 |
-
- reinforcement-learning
|
| 6 |
-
- image-generation
|
| 7 |
-
- image-editing
|
| 8 |
-
- prompt-optimization
|
| 9 |
-
- flux
|
| 10 |
-
- qwen
|
| 11 |
-
datasets:
|
| 12 |
-
- wangfuyun/PrompRL
|
| 13 |
-
---
|
| 14 |
-
|
| 15 |
<p align="center">
|
| 16 |
<img src="assets/logo.png" width="30%"><br>
|
| 17 |
PromptRL
|
|
@@ -41,109 +27,6 @@ pip install flash-attn==2.7.4.post1 --no-build-isolation
|
|
| 41 |
# bash gen.sh
|
| 42 |
```
|
| 43 |
|
| 44 |
-
<details>
|
| 45 |
-
<summary><b>Training The Edit Model And Running EditReward</b></summary>
|
| 46 |
-
|
| 47 |
-
<br>
|
| 48 |
-
|
| 49 |
-
**Scope**
|
| 50 |
-
|
| 51 |
-
This release keeps only the edit RL path. The trainer is `unirl/trainer/edit_grpo_trainer.py`, which jointly optimizes the Qwen-VL prompt refiner and the FLUX.1-Kontext transformer. The VAE, text encoders, and vision encoder stay frozen.
|
| 52 |
-
|
| 53 |
-
The partial-refinement setting is preserved: with the default `NUM_GENERATIONS=8` and `NUM_SKIP_REFINEMENT=2`, each source image produces six edits from Qwen-refined prompts and two edits from the original prompt.
|
| 54 |
-
|
| 55 |
-
Relevant files:
|
| 56 |
-
|
| 57 |
-
- `unirl/train_edit.py`: CLI entry point for Qwen-Kontext edit GRPO.
|
| 58 |
-
- `unirl/reward_evaluator/reward_evaluator.py`: EditReward HTTP client used by training.
|
| 59 |
-
- `rewards_services/api_services/editreward_scorer_service`: EditReward service wrapper.
|
| 60 |
-
- `scripts/train/edit_grpo.sh`: launch script with environment-variable configuration.
|
| 61 |
-
|
| 62 |
-
**Dataset**
|
| 63 |
-
|
| 64 |
-
By default, `scripts/train/edit_grpo.sh` loads:
|
| 65 |
-
|
| 66 |
-
```text
|
| 67 |
-
https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet
|
| 68 |
-
```
|
| 69 |
-
|
| 70 |
-
You can override it with `PROMPTS_FILE`. The loader also accepts the Hugging Face web URL form with `/blob/main/`; it is converted to the downloadable `/resolve/main/` URL automatically.
|
| 71 |
-
|
| 72 |
-
The dataset should be a `.parquet` or `.jsonl` file with:
|
| 73 |
-
|
| 74 |
-
| Column | Description |
|
| 75 |
-
| --- | --- |
|
| 76 |
-
| `image` | Source image before editing. For jsonl this can be an image path. |
|
| 77 |
-
| `prompt` | Edit instruction. |
|
| 78 |
-
|
| 79 |
-
Optional columns are `caption` and `target_caption`. For other column names, set `IMAGE_COLUMN` and `PROMPT_COLUMN`.
|
| 80 |
-
|
| 81 |
-
**1. Start EditReward**
|
| 82 |
-
|
| 83 |
-
```bash
|
| 84 |
-
cd rewards_services/api_services/editreward_scorer_service
|
| 85 |
-
python -m venv .venv
|
| 86 |
-
source .venv/bin/activate
|
| 87 |
-
pip install --upgrade pip
|
| 88 |
-
pip install torch torchvision torchaudio
|
| 89 |
-
pip install -r requirements.txt
|
| 90 |
-
|
| 91 |
-
git clone https://github.com/TIGER-AI-Lab/EditReward.git
|
| 92 |
-
huggingface-cli download TIGER-Lab/EditReward-MiMo-VL-7B-SFT-2508 \
|
| 93 |
-
--local-dir EditReward/EditReward-MiMo-VL-7B-SFT-2508
|
| 94 |
-
|
| 95 |
-
export EDITREWARD_CUDA_DEVICES=0,1
|
| 96 |
-
export EDITREWARD_WORKERS=2
|
| 97 |
-
export EDITREWARD_PORT=18088
|
| 98 |
-
bash run.sh
|
| 99 |
-
```
|
| 100 |
-
|
| 101 |
-
If the EditReward repo or checkpoint is stored elsewhere:
|
| 102 |
-
|
| 103 |
-
```bash
|
| 104 |
-
export EDITREWARD_REPO_DIR=/path/to/EditReward
|
| 105 |
-
export EDITREWARD_CHECKPOINT_PATH=/path/to/EditReward-MiMo-VL-7B-SFT-2508
|
| 106 |
-
```
|
| 107 |
-
|
| 108 |
-
**2. Launch Training**
|
| 109 |
-
|
| 110 |
-
From the repository root:
|
| 111 |
-
|
| 112 |
-
```bash
|
| 113 |
-
export MODEL_NAME_OR_PATH=/path/to/qwenkontext/checkpoint
|
| 114 |
-
# Optional. Defaults to the PromptRL OmniEdit 50k parquet on Hugging Face.
|
| 115 |
-
export PROMPTS_FILE=https://huggingface.co/wangfuyun/PrompRL/blob/main/data/omni_edit_train_50k.parquet
|
| 116 |
-
export EDITREWARD_URL=http://127.0.0.1:18088/
|
| 117 |
-
export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7
|
| 118 |
-
export NPROC_PER_NODE=6
|
| 119 |
-
export RUN_NAME=qwenkontext-editreward
|
| 120 |
-
|
| 121 |
-
bash scripts/train/edit_grpo.sh
|
| 122 |
-
```
|
| 123 |
-
|
| 124 |
-
Common options:
|
| 125 |
-
|
| 126 |
-
```bash
|
| 127 |
-
export NUM_GENERATIONS=8
|
| 128 |
-
export NUM_SKIP_REFINEMENT=2
|
| 129 |
-
export NUM_SDE=4
|
| 130 |
-
export PER_DEVICE_TRAIN_BATCH_SIZE=1
|
| 131 |
-
export DIT_LEARNING_RATE=2e-7
|
| 132 |
-
export LLM_LEARNING_RATE=3e-7
|
| 133 |
-
export BETA=1e-2
|
| 134 |
-
export IMAGE_COLUMN=image
|
| 135 |
-
export PROMPT_COLUMN=prompt
|
| 136 |
-
export REPORT_TO=wandb
|
| 137 |
-
```
|
| 138 |
-
|
| 139 |
-
Training logs sample source/edited images under:
|
| 140 |
-
|
| 141 |
-
```text
|
| 142 |
-
outputs/rl/kontext/$RUN_NAME/training_samples/
|
| 143 |
-
```
|
| 144 |
-
|
| 145 |
-
</details>
|
| 146 |
-
|
| 147 |
## Qualitative Results
|
| 148 |
|
| 149 |
### Text-to-Image Generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
<p align="center">
|
| 2 |
<img src="assets/logo.png" width="30%"><br>
|
| 3 |
PromptRL
|
|
|
|
| 27 |
# bash gen.sh
|
| 28 |
```
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
## Qualitative Results
|
| 31 |
|
| 32 |
### Text-to-Image Generation
|
assets/rl_datasets/README.md
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
# Edit Training Dataset Schema
|
| 2 |
-
|
| 3 |
-
The training script defaults to:
|
| 4 |
-
|
| 5 |
-
```text
|
| 6 |
-
https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet
|
| 7 |
-
```
|
| 8 |
-
|
| 9 |
-
Use a `.parquet` or `.jsonl` file with at least:
|
| 10 |
-
|
| 11 |
-
| Column | Type | Description |
|
| 12 |
-
| --- | --- | --- |
|
| 13 |
-
| `image` | PIL image, image bytes, or image path | Source image before editing. |
|
| 14 |
-
| `prompt` | string | Edit instruction used by FLUX.1-Kontext and EditReward. |
|
| 15 |
-
|
| 16 |
-
Optional columns:
|
| 17 |
-
|
| 18 |
-
| Column | Type | Description |
|
| 19 |
-
| --- | --- | --- |
|
| 20 |
-
| `caption` | string | Source-image caption, kept for logging or downstream reward extensions. |
|
| 21 |
-
| `target_caption` | string | Target edited-image caption, kept for logging or downstream reward extensions. |
|
| 22 |
-
|
| 23 |
-
If your dataset uses different column names, pass `IMAGE_COLUMN=...` and `PROMPT_COLUMN=...` to `scripts/train/edit_grpo.sh`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/README.md
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
# EditReward Scorer Service
|
| 2 |
-
|
| 3 |
-
This service exposes EditReward over HTTP for edit GRPO training. It accepts a pickled payload with source images, edited images, and edit instructions, then returns `{"scores": [...]}`.
|
| 4 |
-
|
| 5 |
-
## Setup
|
| 6 |
-
|
| 7 |
-
```bash
|
| 8 |
-
cd rewards_services/api_services/editreward_scorer_service
|
| 9 |
-
python -m venv .venv
|
| 10 |
-
source .venv/bin/activate
|
| 11 |
-
pip install --upgrade pip
|
| 12 |
-
pip install torch torchvision torchaudio
|
| 13 |
-
pip install -r requirements.txt
|
| 14 |
-
pip install flash-attn --no-build-isolation # optional, recommended when your CUDA/PyTorch build supports it
|
| 15 |
-
|
| 16 |
-
git clone https://github.com/TIGER-AI-Lab/EditReward.git
|
| 17 |
-
huggingface-cli download TIGER-Lab/EditReward-MiMo-VL-7B-SFT-2508 \
|
| 18 |
-
--local-dir EditReward/EditReward-MiMo-VL-7B-SFT-2508
|
| 19 |
-
```
|
| 20 |
-
|
| 21 |
-
If the repository or checkpoint lives elsewhere, set:
|
| 22 |
-
|
| 23 |
-
```bash
|
| 24 |
-
export EDITREWARD_REPO_DIR=/path/to/EditReward
|
| 25 |
-
export EDITREWARD_CHECKPOINT_PATH=/path/to/EditReward-MiMo-VL-7B-SFT-2508
|
| 26 |
-
```
|
| 27 |
-
|
| 28 |
-
## Run
|
| 29 |
-
|
| 30 |
-
```bash
|
| 31 |
-
export EDITREWARD_PORT=18088
|
| 32 |
-
export EDITREWARD_CUDA_DEVICES=0,1
|
| 33 |
-
export EDITREWARD_WORKERS=2
|
| 34 |
-
bash run.sh
|
| 35 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/app.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import pickle
|
| 3 |
-
import traceback
|
| 4 |
-
from io import BytesIO
|
| 5 |
-
from typing import Any, Dict, List
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from flask import Blueprint, Flask, request
|
| 9 |
-
from PIL import Image
|
| 10 |
-
|
| 11 |
-
from editreward_scorer import EditRewardScorer
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
INFERENCE_FN = None
|
| 15 |
-
root = Blueprint("root", __name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _deserialize_images(images_bytes: List[bytes]) -> List[Image.Image]:
|
| 19 |
-
return [Image.open(BytesIO(data)).convert("RGB") for data in images_bytes]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _service_config() -> Dict[str, Any]:
|
| 23 |
-
repo_dir = os.getenv("EDITREWARD_REPO_DIR", os.path.join(os.path.dirname(__file__), "EditReward"))
|
| 24 |
-
return {
|
| 25 |
-
"repo_dir": repo_dir,
|
| 26 |
-
"config_path": os.getenv(
|
| 27 |
-
"EDITREWARD_CONFIG_PATH",
|
| 28 |
-
os.path.join(repo_dir, "EditReward", "config", "EditReward-MiMo-VL-7B-SFT-2508.yaml"),
|
| 29 |
-
),
|
| 30 |
-
"checkpoint_path": os.getenv(
|
| 31 |
-
"EDITREWARD_CHECKPOINT_PATH",
|
| 32 |
-
os.path.join(repo_dir, "EditReward-MiMo-VL-7B-SFT-2508"),
|
| 33 |
-
),
|
| 34 |
-
"reward_dim": os.getenv("EDITREWARD_DIM", "overall_detail"),
|
| 35 |
-
"rm_head_type": os.getenv("EDITREWARD_HEAD_TYPE", "ranknet_multi_head"),
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def create_app():
|
| 40 |
-
global INFERENCE_FN
|
| 41 |
-
config = _service_config()
|
| 42 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
-
print(f"Loading EditReward scorer on {device} from {config['checkpoint_path']}...")
|
| 44 |
-
INFERENCE_FN = EditRewardScorer(
|
| 45 |
-
repo_dir=config["repo_dir"],
|
| 46 |
-
config_path=config["config_path"],
|
| 47 |
-
checkpoint_path=config["checkpoint_path"],
|
| 48 |
-
reward_dim=config["reward_dim"],
|
| 49 |
-
rm_head_type=config["rm_head_type"],
|
| 50 |
-
device=device,
|
| 51 |
-
)
|
| 52 |
-
INFERENCE_FN.eval()
|
| 53 |
-
print("EditReward scorer loaded.")
|
| 54 |
-
|
| 55 |
-
app = Flask(__name__)
|
| 56 |
-
app.register_blueprint(root)
|
| 57 |
-
return app
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
@root.route("/", methods=["GET"])
|
| 61 |
-
def healthcheck():
|
| 62 |
-
return {"status": "ok", "service": "editreward"}, 200
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
@root.route("/", methods=["POST"])
|
| 66 |
-
def inference():
|
| 67 |
-
try:
|
| 68 |
-
payload = pickle.loads(request.get_data())
|
| 69 |
-
images = payload["images"]
|
| 70 |
-
prompts = payload.get("prompts", [])
|
| 71 |
-
source_images = _deserialize_images(images.get("source", []))
|
| 72 |
-
edited_images = _deserialize_images(images.get("edited", []))
|
| 73 |
-
|
| 74 |
-
if len(source_images) != len(edited_images) or len(source_images) != len(prompts):
|
| 75 |
-
raise ValueError(
|
| 76 |
-
"Mismatched EditReward inputs: "
|
| 77 |
-
f"{len(source_images)} source images, {len(edited_images)} edited images, {len(prompts)} prompts."
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
with torch.no_grad():
|
| 81 |
-
scores = INFERENCE_FN(prompts, source_images, edited_images)
|
| 82 |
-
return pickle.dumps({"scores": [float(score) for score in scores]}), 200
|
| 83 |
-
except Exception:
|
| 84 |
-
error_message = traceback.format_exc()
|
| 85 |
-
print(f"EditReward service error:\n{error_message}")
|
| 86 |
-
return pickle.dumps({"error": error_message}), 500
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
if __name__ == "__main__":
|
| 90 |
-
port = int(os.getenv("EDITREWARD_PORT", "18088"))
|
| 91 |
-
host = os.getenv("EDITREWARD_HOST", "127.0.0.1")
|
| 92 |
-
app = create_app()
|
| 93 |
-
app.run(host=host, port=port, debug=False)
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/editreward_scorer.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import shutil
|
| 3 |
-
import sys
|
| 4 |
-
import tempfile
|
| 5 |
-
from typing import List
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from PIL import Image
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class EditRewardScorer(torch.nn.Module):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
repo_dir: str,
|
| 15 |
-
config_path: str,
|
| 16 |
-
checkpoint_path: str,
|
| 17 |
-
reward_dim: str = "overall_detail",
|
| 18 |
-
rm_head_type: str = "ranknet_multi_head",
|
| 19 |
-
device: str = "cuda",
|
| 20 |
-
):
|
| 21 |
-
super().__init__()
|
| 22 |
-
if not os.path.isdir(repo_dir):
|
| 23 |
-
raise FileNotFoundError(
|
| 24 |
-
f"EditReward repository not found at {repo_dir}. "
|
| 25 |
-
"Clone https://github.com/TIGER-AI-Lab/EditReward.git or set EDITREWARD_REPO_DIR."
|
| 26 |
-
)
|
| 27 |
-
sys.path.insert(0, repo_dir)
|
| 28 |
-
from EditReward import EditRewardInferencer
|
| 29 |
-
|
| 30 |
-
self.inferencer = EditRewardInferencer(
|
| 31 |
-
config_path=config_path,
|
| 32 |
-
checkpoint_path=checkpoint_path,
|
| 33 |
-
device=device,
|
| 34 |
-
reward_dim=reward_dim,
|
| 35 |
-
rm_head_type=rm_head_type,
|
| 36 |
-
)
|
| 37 |
-
self.device = device
|
| 38 |
-
self.eval()
|
| 39 |
-
|
| 40 |
-
@torch.no_grad()
|
| 41 |
-
def __call__(self, prompts: List[str], source_images: List[Image.Image], edited_images: List[Image.Image]):
|
| 42 |
-
if not (len(prompts) == len(source_images) == len(edited_images)):
|
| 43 |
-
raise ValueError("prompts, source_images, and edited_images must have the same length.")
|
| 44 |
-
|
| 45 |
-
temp_dir = tempfile.mkdtemp(prefix="editreward_")
|
| 46 |
-
try:
|
| 47 |
-
source_paths = []
|
| 48 |
-
edited_paths = []
|
| 49 |
-
for index, (source_image, edited_image) in enumerate(zip(source_images, edited_images)):
|
| 50 |
-
source_path = os.path.join(temp_dir, f"source_{index}.png")
|
| 51 |
-
edited_path = os.path.join(temp_dir, f"edited_{index}.png")
|
| 52 |
-
source_image.convert("RGB").save(source_path)
|
| 53 |
-
edited_image.convert("RGB").save(edited_path)
|
| 54 |
-
source_paths.append(source_path)
|
| 55 |
-
edited_paths.append(edited_path)
|
| 56 |
-
|
| 57 |
-
rewards = self.inferencer.reward(
|
| 58 |
-
prompts=prompts,
|
| 59 |
-
image_src=source_paths,
|
| 60 |
-
image_paths=edited_paths,
|
| 61 |
-
)
|
| 62 |
-
return [reward[0].item() if hasattr(reward[0], "item") else float(reward[0]) for reward in rewards]
|
| 63 |
-
finally:
|
| 64 |
-
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/gunicorn.conf.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
bind = f"{os.getenv('EDITREWARD_HOST', '127.0.0.1')}:{os.getenv('EDITREWARD_PORT', '18088')}"
|
| 6 |
-
workers = int(os.getenv("EDITREWARD_WORKERS", os.getenv("EDITREWARD_NUM_DEVICES", "1")))
|
| 7 |
-
worker_class = "sync"
|
| 8 |
-
timeout = int(os.getenv("EDITREWARD_TIMEOUT", "600"))
|
| 9 |
-
|
| 10 |
-
_raw_devices = os.getenv("EDITREWARD_CUDA_DEVICES") or os.getenv("CUDA_VISIBLE_DEVICES") or ""
|
| 11 |
-
CUDA_DEVICES = [device.strip() for device in _raw_devices.split(",") if device.strip()]
|
| 12 |
-
USED_DEVICES = set()
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def pre_fork(server, worker):
|
| 16 |
-
if not CUDA_DEVICES:
|
| 17 |
-
return
|
| 18 |
-
available = [device for device in CUDA_DEVICES if device not in USED_DEVICES]
|
| 19 |
-
worker.cuda_device = available[0] if available else CUDA_DEVICES[len(USED_DEVICES) % len(CUDA_DEVICES)]
|
| 20 |
-
USED_DEVICES.add(worker.cuda_device)
|
| 21 |
-
print(f"Worker {worker.pid} assigned CUDA_VISIBLE_DEVICES={worker.cuda_device}", file=sys.stderr)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def post_fork(server, worker):
|
| 25 |
-
cuda_device = getattr(worker, "cuda_device", None)
|
| 26 |
-
if cuda_device is not None:
|
| 27 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def child_exit(server, worker):
|
| 31 |
-
cuda_device = getattr(worker, "cuda_device", None)
|
| 32 |
-
if cuda_device is not None:
|
| 33 |
-
USED_DEVICES.discard(cuda_device)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/requirements.txt
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
flask
|
| 2 |
-
gunicorn
|
| 3 |
-
datasets
|
| 4 |
-
huggingface_hub
|
| 5 |
-
pillow
|
| 6 |
-
openai
|
| 7 |
-
megfile
|
| 8 |
-
sentencepiece
|
| 9 |
-
deepspeed
|
| 10 |
-
fire
|
| 11 |
-
omegaconf
|
| 12 |
-
matplotlib
|
| 13 |
-
peft
|
| 14 |
-
trl==0.8.6
|
| 15 |
-
tensorboard
|
| 16 |
-
scipy
|
| 17 |
-
transformers==4.56.1
|
| 18 |
-
accelerate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards_services/api_services/editreward_scorer_service/run.sh
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
set -euo pipefail
|
| 3 |
-
|
| 4 |
-
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
-
cd "$SCRIPT_DIR"
|
| 6 |
-
|
| 7 |
-
export EDITREWARD_REPO_DIR="${EDITREWARD_REPO_DIR:-$SCRIPT_DIR/EditReward}"
|
| 8 |
-
export EDITREWARD_PORT="${EDITREWARD_PORT:-18088}"
|
| 9 |
-
export EDITREWARD_HOST="${EDITREWARD_HOST:-127.0.0.1}"
|
| 10 |
-
export EDITREWARD_CUDA_DEVICES="${EDITREWARD_CUDA_DEVICES:-0,1}"
|
| 11 |
-
export EDITREWARD_WORKERS="${EDITREWARD_WORKERS:-${EDITREWARD_NUM_DEVICES:-2}}"
|
| 12 |
-
|
| 13 |
-
python -m gunicorn -c gunicorn.conf.py "app:create_app()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train/deepspeed/zero3.json
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"fp16": {
|
| 3 |
-
"enabled": false,
|
| 4 |
-
"loss_scale": 0,
|
| 5 |
-
"loss_scale_window": 1000,
|
| 6 |
-
"initial_scale_power": 16,
|
| 7 |
-
"hysteresis": 2,
|
| 8 |
-
"min_loss_scale": 1
|
| 9 |
-
},
|
| 10 |
-
"bf16": {
|
| 11 |
-
"enabled": true
|
| 12 |
-
},
|
| 13 |
-
"zero_optimization": {
|
| 14 |
-
"stage": 3,
|
| 15 |
-
"offload_optimizer": {
|
| 16 |
-
"device": "none",
|
| 17 |
-
"pin_memory": true
|
| 18 |
-
},
|
| 19 |
-
"offload_param": {
|
| 20 |
-
"device": "none",
|
| 21 |
-
"pin_memory": true
|
| 22 |
-
},
|
| 23 |
-
"overlap_comm": true,
|
| 24 |
-
"contiguous_gradients": true,
|
| 25 |
-
"sub_group_size": 1000000000.0,
|
| 26 |
-
"reduce_bucket_size": "auto",
|
| 27 |
-
"stage3_prefetch_bucket_size": "auto",
|
| 28 |
-
"stage3_param_persistence_threshold": "auto",
|
| 29 |
-
"stage3_max_live_parameters": 1000000000.0,
|
| 30 |
-
"stage3_max_reuse_distance": 1000000000.0,
|
| 31 |
-
"stage3_gather_16bit_weights_on_model_save": true
|
| 32 |
-
},
|
| 33 |
-
"gradient_accumulation_steps": "auto",
|
| 34 |
-
"steps_per_print": 100,
|
| 35 |
-
"train_batch_size": "auto",
|
| 36 |
-
"train_micro_batch_size_per_gpu": "auto",
|
| 37 |
-
"wall_clock_breakdown": false
|
| 38 |
-
}
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train/edit_grpo.sh
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
set -euo pipefail
|
| 3 |
-
|
| 4 |
-
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
-
REPO_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
| 6 |
-
cd "$REPO_DIR"
|
| 7 |
-
|
| 8 |
-
: "${MODEL_NAME_OR_PATH:?Set MODEL_NAME_OR_PATH to a pretrained Qwen-Kontext checkpoint.}"
|
| 9 |
-
|
| 10 |
-
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-2,3,4,5,6,7}"
|
| 11 |
-
NPROC_PER_NODE="${NPROC_PER_NODE:-6}"
|
| 12 |
-
MASTER_ADDR="${MASTER_ADDR:-localhost}"
|
| 13 |
-
MASTER_PORT="${MASTER_PORT:-25000}"
|
| 14 |
-
RUN_NAME="${RUN_NAME:-qwenkontext-edit-grpo}"
|
| 15 |
-
OUTPUT_DIR="${OUTPUT_DIR:-outputs/rl/kontext/$RUN_NAME}"
|
| 16 |
-
DEEPSPEED_CONFIG="${DEEPSPEED_CONFIG:-scripts/train/deepspeed/zero3.json}"
|
| 17 |
-
EDITREWARD_URL="${EDITREWARD_URL:-http://127.0.0.1:18088/}"
|
| 18 |
-
PROMPTS_FILE="${PROMPTS_FILE:-https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet}"
|
| 19 |
-
REPORT_TO="${REPORT_TO:-none}"
|
| 20 |
-
if [[ -n "${WANDB_PROJECT:-}" && "$REPORT_TO" == "none" ]]; then
|
| 21 |
-
REPORT_TO="wandb"
|
| 22 |
-
fi
|
| 23 |
-
|
| 24 |
-
TORCHRUN_ARGS=(
|
| 25 |
-
--nproc_per_node="$NPROC_PER_NODE"
|
| 26 |
-
--nnodes="${NNODES:-1}"
|
| 27 |
-
--node_rank="${NODE_RANK:-0}"
|
| 28 |
-
--master_addr="$MASTER_ADDR"
|
| 29 |
-
--master_port="$MASTER_PORT"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
TRAIN_ARGS=(
|
| 33 |
-
-m unirl.train_edit
|
| 34 |
-
--reward_funcs editreward format
|
| 35 |
-
--deepspeed "$DEEPSPEED_CONFIG"
|
| 36 |
-
--output_dir "$OUTPUT_DIR"
|
| 37 |
-
--model_name_or_path "$MODEL_NAME_OR_PATH"
|
| 38 |
-
--prompts_file "$PROMPTS_FILE"
|
| 39 |
-
--image_column "${IMAGE_COLUMN:-image}"
|
| 40 |
-
--prompt_column "${PROMPT_COLUMN:-prompt}"
|
| 41 |
-
--editreward_url "$EDITREWARD_URL"
|
| 42 |
-
--max_prompt_length "${MAX_PROMPT_LENGTH:-8192}"
|
| 43 |
-
--max_completion_length "${MAX_COMPLETION_LENGTH:-512}"
|
| 44 |
-
--num_generations "${NUM_GENERATIONS:-8}"
|
| 45 |
-
--num_skip_refinement "${NUM_SKIP_REFINEMENT:-2}"
|
| 46 |
-
--num_sde "${NUM_SDE:-4}"
|
| 47 |
-
--per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-1}"
|
| 48 |
-
--gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-1}"
|
| 49 |
-
--logging_steps "${LOGGING_STEPS:-1}"
|
| 50 |
-
--learning_rate "${LEARNING_RATE:-3e-7}"
|
| 51 |
-
--bf16 "${BF16:-true}"
|
| 52 |
-
--report_to "$REPORT_TO"
|
| 53 |
-
--gradient_checkpointing "${GRADIENT_CHECKPOINTING:-true}"
|
| 54 |
-
--attn_implementation "${ATTN_IMPLEMENTATION:-flash_attention_2}"
|
| 55 |
-
--max_pixels "${MAX_PIXELS:-200704}"
|
| 56 |
-
--min_pixels "${MIN_PIXELS:-200704}"
|
| 57 |
-
--image_size "${IMAGE_SIZE:-512}"
|
| 58 |
-
--save_total_limit "${SAVE_TOTAL_LIMIT:-4}"
|
| 59 |
-
--save_strategy "${SAVE_STRATEGY:-steps}"
|
| 60 |
-
--save_steps "${SAVE_STEPS:-100}"
|
| 61 |
-
--beta "${BETA:-1e-2}"
|
| 62 |
-
--num_train_epochs "${NUM_TRAIN_EPOCHS:-10}"
|
| 63 |
-
--run_name "$RUN_NAME"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
if [[ -n "${DATASET_CACHE_DIR:-}" ]]; then
|
| 67 |
-
TRAIN_ARGS+=(--dataset_cache_dir "$DATASET_CACHE_DIR")
|
| 68 |
-
fi
|
| 69 |
-
|
| 70 |
-
export PROMPTRL_EDIT_GUIDANCE_SCALE="${PROMPTRL_EDIT_GUIDANCE_SCALE:-${EDIT_GUIDANCE_SCALE:-2.5}}"
|
| 71 |
-
export PROMPTRL_EDIT_NUM_INFERENCE_STEPS="${PROMPTRL_EDIT_NUM_INFERENCE_STEPS:-${EDIT_NUM_INFERENCE_STEPS:-8}}"
|
| 72 |
-
export PROMPTRL_EDIT_HEIGHT="${PROMPTRL_EDIT_HEIGHT:-${EDIT_HEIGHT:-1024}}"
|
| 73 |
-
export PROMPTRL_EDIT_WIDTH="${PROMPTRL_EDIT_WIDTH:-${EDIT_WIDTH:-1024}}"
|
| 74 |
-
export DIT_LEARNING_RATE="${DIT_LEARNING_RATE:-2e-7}"
|
| 75 |
-
export LLM_LEARNING_RATE="${LLM_LEARNING_RATE:-3e-7}"
|
| 76 |
-
|
| 77 |
-
torchrun "${TORCHRUN_ARGS[@]}" "${TRAIN_ARGS[@]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unimodel/qwenkontext/fluxkontext_pipeline.py
CHANGED
|
@@ -14,10 +14,9 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
import inspect
|
| 17 |
-
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
-
import math
|
| 21 |
import torch
|
| 22 |
from transformers import (
|
| 23 |
CLIPImageProcessor,
|
|
@@ -1160,566 +1159,3 @@ class FluxKontextPipeline(
|
|
| 1160 |
return (image,)
|
| 1161 |
|
| 1162 |
return FluxPipelineOutput(images=image)
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
# This method should be added to the FluxKontextPipeline class
|
| 1168 |
-
def sde_sampling(
|
| 1169 |
-
self,
|
| 1170 |
-
image: Optional[PipelineImageInput] = None,
|
| 1171 |
-
prompt: Union[str, List[str]] = None,
|
| 1172 |
-
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 1173 |
-
negative_prompt: Union[str, List[str]] = None,
|
| 1174 |
-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 1175 |
-
true_cfg_scale: float = 1.0,
|
| 1176 |
-
height: Optional[int] = None,
|
| 1177 |
-
width: Optional[int] = None,
|
| 1178 |
-
num_inference_steps: int = 28,
|
| 1179 |
-
sigmas: Optional[List[float]] = None,
|
| 1180 |
-
guidance_scale: float = 3.5,
|
| 1181 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 1182 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 1183 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 1184 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1185 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1186 |
-
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 1187 |
-
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 1188 |
-
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 1189 |
-
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 1190 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1191 |
-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1192 |
-
output_type: Optional[str] = "pil",
|
| 1193 |
-
return_dict: bool = True,
|
| 1194 |
-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1195 |
-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 1196 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 1197 |
-
max_sequence_length: int = 512,
|
| 1198 |
-
max_area: int = 1024**2,
|
| 1199 |
-
num_sde: int = None,
|
| 1200 |
-
_auto_resize: bool = True,
|
| 1201 |
-
):
|
| 1202 |
-
r"""
|
| 1203 |
-
SDE sampling function for FLUX Kontext pipeline with log probability tracking.
|
| 1204 |
-
|
| 1205 |
-
This method performs stochastic differential equation (SDE) based sampling while
|
| 1206 |
-
tracking log probabilities at each step. Useful for training and analysis purposes.
|
| 1207 |
-
|
| 1208 |
-
Args:
|
| 1209 |
-
image: Input image for image-to-image generation
|
| 1210 |
-
prompt: Text prompt(s) to guide generation
|
| 1211 |
-
prompt_2: Secondary text prompt for text_encoder_2
|
| 1212 |
-
negative_prompt: Negative text prompt(s)
|
| 1213 |
-
negative_prompt_2: Secondary negative prompt
|
| 1214 |
-
true_cfg_scale: Classifier-free guidance scale (when > 1.0)
|
| 1215 |
-
height: Output height in pixels
|
| 1216 |
-
width: Output width in pixels
|
| 1217 |
-
num_inference_steps: Number of denoising steps
|
| 1218 |
-
sigmas: Custom sigma schedule
|
| 1219 |
-
guidance_scale: Embedded guidance scale
|
| 1220 |
-
num_images_per_prompt: Number of images per prompt
|
| 1221 |
-
generator: Random number generator(s)
|
| 1222 |
-
latents: Pre-generated latents
|
| 1223 |
-
prompt_embeds: Pre-generated prompt embeddings
|
| 1224 |
-
pooled_prompt_embeds: Pre-generated pooled embeddings
|
| 1225 |
-
ip_adapter_image: IP-Adapter input image(s)
|
| 1226 |
-
ip_adapter_image_embeds: Pre-generated IP-Adapter embeddings
|
| 1227 |
-
negative_ip_adapter_image: Negative IP-Adapter image(s)
|
| 1228 |
-
negative_ip_adapter_image_embeds: Negative IP-Adapter embeddings
|
| 1229 |
-
negative_prompt_embeds: Pre-generated negative embeddings
|
| 1230 |
-
negative_pooled_prompt_embeds: Pre-generated negative pooled embeddings
|
| 1231 |
-
output_type: Output format ("pil" or "latent")
|
| 1232 |
-
return_dict: Whether to return dict or tuple
|
| 1233 |
-
joint_attention_kwargs: Additional attention parameters
|
| 1234 |
-
callback_on_step_end: Callback function after each step
|
| 1235 |
-
callback_on_step_end_tensor_inputs: Tensors to pass to callback
|
| 1236 |
-
max_sequence_length: Maximum prompt sequence length
|
| 1237 |
-
max_area: Maximum output area in pixels
|
| 1238 |
-
_auto_resize: Whether to auto-resize to preferred resolutions
|
| 1239 |
-
|
| 1240 |
-
Returns:
|
| 1241 |
-
Tuple of (images, prev_latents, log_probs, pred_latents, timesteps, batched_states)
|
| 1242 |
-
"""
|
| 1243 |
-
|
| 1244 |
-
height = height or self.default_sample_size * self.vae_scale_factor
|
| 1245 |
-
width = width or self.default_sample_size * self.vae_scale_factor
|
| 1246 |
-
|
| 1247 |
-
original_height, original_width = height, width
|
| 1248 |
-
aspect_ratio = width / height
|
| 1249 |
-
|
| 1250 |
-
width = round((max_area * aspect_ratio) ** 0.5)
|
| 1251 |
-
height = round((max_area / aspect_ratio) ** 0.5)
|
| 1252 |
-
|
| 1253 |
-
multiple_of = self.vae_scale_factor * 2
|
| 1254 |
-
width = width // multiple_of * multiple_of
|
| 1255 |
-
height = height // multiple_of * multiple_of
|
| 1256 |
-
|
| 1257 |
-
if height != original_height or width != original_width:
|
| 1258 |
-
logger.warning(
|
| 1259 |
-
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
|
| 1260 |
-
)
|
| 1261 |
-
|
| 1262 |
-
# 1. Check inputs
|
| 1263 |
-
self.check_inputs(
|
| 1264 |
-
prompt,
|
| 1265 |
-
prompt_2,
|
| 1266 |
-
height,
|
| 1267 |
-
width,
|
| 1268 |
-
negative_prompt=negative_prompt,
|
| 1269 |
-
negative_prompt_2=negative_prompt_2,
|
| 1270 |
-
prompt_embeds=prompt_embeds,
|
| 1271 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
| 1272 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1273 |
-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1274 |
-
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1275 |
-
max_sequence_length=max_sequence_length,
|
| 1276 |
-
)
|
| 1277 |
-
|
| 1278 |
-
self._guidance_scale = guidance_scale
|
| 1279 |
-
self._joint_attention_kwargs = joint_attention_kwargs
|
| 1280 |
-
self._current_timestep = None
|
| 1281 |
-
self._interrupt = False
|
| 1282 |
-
|
| 1283 |
-
# 2. Define call parameters
|
| 1284 |
-
if prompt is not None and isinstance(prompt, str):
|
| 1285 |
-
batch_size = 1
|
| 1286 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 1287 |
-
batch_size = len(prompt)
|
| 1288 |
-
else:
|
| 1289 |
-
batch_size = prompt_embeds.shape[0]
|
| 1290 |
-
|
| 1291 |
-
device = self._execution_device
|
| 1292 |
-
|
| 1293 |
-
lora_scale = (
|
| 1294 |
-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 1295 |
-
)
|
| 1296 |
-
has_neg_prompt = negative_prompt is not None or (
|
| 1297 |
-
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 1298 |
-
)
|
| 1299 |
-
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 1300 |
-
|
| 1301 |
-
# Encode prompts
|
| 1302 |
-
(
|
| 1303 |
-
prompt_embeds,
|
| 1304 |
-
pooled_prompt_embeds,
|
| 1305 |
-
text_ids,
|
| 1306 |
-
) = self.encode_prompt(
|
| 1307 |
-
prompt=prompt,
|
| 1308 |
-
prompt_2=prompt_2,
|
| 1309 |
-
prompt_embeds=prompt_embeds,
|
| 1310 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1311 |
-
device=device,
|
| 1312 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 1313 |
-
max_sequence_length=max_sequence_length,
|
| 1314 |
-
lora_scale=lora_scale,
|
| 1315 |
-
)
|
| 1316 |
-
|
| 1317 |
-
if do_true_cfg:
|
| 1318 |
-
(
|
| 1319 |
-
negative_prompt_embeds,
|
| 1320 |
-
negative_pooled_prompt_embeds,
|
| 1321 |
-
negative_text_ids,
|
| 1322 |
-
) = self.encode_prompt(
|
| 1323 |
-
prompt=negative_prompt,
|
| 1324 |
-
prompt_2=negative_prompt_2,
|
| 1325 |
-
prompt_embeds=negative_prompt_embeds,
|
| 1326 |
-
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1327 |
-
device=device,
|
| 1328 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 1329 |
-
max_sequence_length=max_sequence_length,
|
| 1330 |
-
lora_scale=lora_scale,
|
| 1331 |
-
)
|
| 1332 |
-
|
| 1333 |
-
# 3. Preprocess image
|
| 1334 |
-
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 1335 |
-
from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
|
| 1336 |
-
|
| 1337 |
-
img = image[0] if isinstance(image, list) else image
|
| 1338 |
-
image_height, image_width = self.image_processor.get_default_height_width(img)
|
| 1339 |
-
aspect_ratio = image_width / image_height
|
| 1340 |
-
if _auto_resize:
|
| 1341 |
-
_, image_width, image_height = min(
|
| 1342 |
-
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 1343 |
-
)
|
| 1344 |
-
image_width = image_width // multiple_of * multiple_of
|
| 1345 |
-
image_height = image_height // multiple_of * multiple_of
|
| 1346 |
-
image = self.image_processor.resize(image, image_height, image_width)
|
| 1347 |
-
image = self.image_processor.preprocess(image, image_height, image_width)
|
| 1348 |
-
|
| 1349 |
-
# 4. Prepare latent variables
|
| 1350 |
-
num_channels_latents = self.transformer.config.in_channels // 4
|
| 1351 |
-
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
|
| 1352 |
-
image,
|
| 1353 |
-
batch_size * num_images_per_prompt,
|
| 1354 |
-
num_channels_latents,
|
| 1355 |
-
height,
|
| 1356 |
-
width,
|
| 1357 |
-
prompt_embeds.dtype,
|
| 1358 |
-
device,
|
| 1359 |
-
generator,
|
| 1360 |
-
latents,
|
| 1361 |
-
)
|
| 1362 |
-
|
| 1363 |
-
if image_ids is not None:
|
| 1364 |
-
latent_ids = torch.cat([latent_ids, image_ids], dim=0)
|
| 1365 |
-
|
| 1366 |
-
# 5. Prepare timesteps
|
| 1367 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 1368 |
-
image_seq_len = latents.shape[1]
|
| 1369 |
-
|
| 1370 |
-
from diffusers.pipelines.flux.pipeline_flux_kontext import calculate_shift, retrieve_timesteps
|
| 1371 |
-
|
| 1372 |
-
mu = calculate_shift(
|
| 1373 |
-
image_seq_len,
|
| 1374 |
-
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1375 |
-
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1376 |
-
self.scheduler.config.get("base_shift", 0.5),
|
| 1377 |
-
self.scheduler.config.get("max_shift", 1.15),
|
| 1378 |
-
)
|
| 1379 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1380 |
-
self.scheduler,
|
| 1381 |
-
num_inference_steps,
|
| 1382 |
-
device,
|
| 1383 |
-
sigmas=sigmas,
|
| 1384 |
-
mu=mu,
|
| 1385 |
-
)
|
| 1386 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1387 |
-
self._num_timesteps = len(timesteps)
|
| 1388 |
-
|
| 1389 |
-
# Handle guidance
|
| 1390 |
-
if self.transformer.config.guidance_embeds:
|
| 1391 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 1392 |
-
guidance = guidance.expand(latents.shape[0])
|
| 1393 |
-
else:
|
| 1394 |
-
guidance = None
|
| 1395 |
-
|
| 1396 |
-
# Handle IP-Adapter images
|
| 1397 |
-
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 1398 |
-
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 1399 |
-
):
|
| 1400 |
-
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1401 |
-
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1402 |
-
|
| 1403 |
-
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 1404 |
-
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 1405 |
-
):
|
| 1406 |
-
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1407 |
-
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1408 |
-
|
| 1409 |
-
if self.joint_attention_kwargs is None:
|
| 1410 |
-
self._joint_attention_kwargs = {}
|
| 1411 |
-
|
| 1412 |
-
image_embeds = None
|
| 1413 |
-
negative_image_embeds = None
|
| 1414 |
-
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1415 |
-
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1416 |
-
ip_adapter_image,
|
| 1417 |
-
ip_adapter_image_embeds,
|
| 1418 |
-
device,
|
| 1419 |
-
batch_size * num_images_per_prompt,
|
| 1420 |
-
)
|
| 1421 |
-
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 1422 |
-
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1423 |
-
negative_ip_adapter_image,
|
| 1424 |
-
negative_ip_adapter_image_embeds,
|
| 1425 |
-
device,
|
| 1426 |
-
batch_size * num_images_per_prompt,
|
| 1427 |
-
)
|
| 1428 |
-
|
| 1429 |
-
# 6. SDE Denoising loop with state tracking
|
| 1430 |
-
prev_latents = []
|
| 1431 |
-
pred_latents = []
|
| 1432 |
-
states = {
|
| 1433 |
-
"timestep": [],
|
| 1434 |
-
"guidance": [],
|
| 1435 |
-
"pooled_projections": [],
|
| 1436 |
-
"encoder_hidden_states": [],
|
| 1437 |
-
"txt_ids": None,
|
| 1438 |
-
"img_ids": None,
|
| 1439 |
-
}
|
| 1440 |
-
log_probs = []
|
| 1441 |
-
ts = []
|
| 1442 |
-
|
| 1443 |
-
states["txt_ids"] = text_ids if text_ids is not None else None
|
| 1444 |
-
states["img_ids"] = latent_ids if latent_ids is not None else None
|
| 1445 |
-
|
| 1446 |
-
if num_sde is None:
|
| 1447 |
-
num_sde = num_inference_steps
|
| 1448 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1449 |
-
for i, t in enumerate(timesteps):
|
| 1450 |
-
if self.interrupt:
|
| 1451 |
-
continue
|
| 1452 |
-
|
| 1453 |
-
self._current_timestep = t
|
| 1454 |
-
if image_embeds is not None:
|
| 1455 |
-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 1456 |
-
|
| 1457 |
-
# Prepare model input
|
| 1458 |
-
latent_model_input = latents
|
| 1459 |
-
if image_latents is not None:
|
| 1460 |
-
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 1461 |
-
|
| 1462 |
-
timestep = (t.expand(latents.shape[0]) / 1000.).to(latents.dtype)
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
if i < num_sde:
|
| 1466 |
-
# Store states
|
| 1467 |
-
states["timestep"].append(timestep.unsqueeze(1))
|
| 1468 |
-
states["guidance"].append(guidance.unsqueeze(1) if torch.is_tensor(guidance) else guidance)
|
| 1469 |
-
states["pooled_projections"].append(pooled_prompt_embeds.unsqueeze(1) if pooled_prompt_embeds is not None else None)
|
| 1470 |
-
states["encoder_hidden_states"].append(prompt_embeds.unsqueeze(1) if prompt_embeds is not None else None)
|
| 1471 |
-
|
| 1472 |
-
ts.append(t.expand(latents.shape[0]).unsqueeze(1))
|
| 1473 |
-
# prev_latents.append(latents.detach().clone().unsqueeze(1))
|
| 1474 |
-
prev_latents.append(latent_model_input.detach().clone().unsqueeze(1))
|
| 1475 |
-
|
| 1476 |
-
# Forward pass
|
| 1477 |
-
noise_pred = self.transformer(
|
| 1478 |
-
hidden_states=latent_model_input,
|
| 1479 |
-
timestep=timestep,
|
| 1480 |
-
guidance=guidance,
|
| 1481 |
-
pooled_projections=pooled_prompt_embeds,
|
| 1482 |
-
encoder_hidden_states=prompt_embeds,
|
| 1483 |
-
txt_ids=text_ids,
|
| 1484 |
-
img_ids=latent_ids,
|
| 1485 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1486 |
-
return_dict=False,
|
| 1487 |
-
)[0]
|
| 1488 |
-
noise_pred = noise_pred[:, :latents.size(1)]
|
| 1489 |
-
|
| 1490 |
-
# Apply true CFG if needed
|
| 1491 |
-
if do_true_cfg:
|
| 1492 |
-
if negative_image_embeds is not None:
|
| 1493 |
-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 1494 |
-
|
| 1495 |
-
neg_latent_model_input = latents
|
| 1496 |
-
if image_latents is not None:
|
| 1497 |
-
neg_latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 1498 |
-
|
| 1499 |
-
neg_noise_pred = self.transformer(
|
| 1500 |
-
hidden_states=neg_latent_model_input,
|
| 1501 |
-
timestep=timestep,
|
| 1502 |
-
guidance=guidance,
|
| 1503 |
-
pooled_projections=negative_pooled_prompt_embeds,
|
| 1504 |
-
encoder_hidden_states=negative_prompt_embeds,
|
| 1505 |
-
txt_ids=negative_text_ids,
|
| 1506 |
-
img_ids=latent_ids,
|
| 1507 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1508 |
-
return_dict=False,
|
| 1509 |
-
)[0]
|
| 1510 |
-
neg_noise_pred = neg_noise_pred[:, :latents.size(1)]
|
| 1511 |
-
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 1512 |
-
|
| 1513 |
-
if i < num_sde:
|
| 1514 |
-
# SDE step with log probability
|
| 1515 |
-
latents_dtype = latents.dtype
|
| 1516 |
-
latents, log_prob, prev_latents_mean, std_dev = sde_step_with_logprob(
|
| 1517 |
-
self.scheduler,
|
| 1518 |
-
noise_pred.float(),
|
| 1519 |
-
t.expand(latents.shape[0]),
|
| 1520 |
-
latents.float()
|
| 1521 |
-
)
|
| 1522 |
-
|
| 1523 |
-
log_probs.append(log_prob.detach().clone().unsqueeze(1))
|
| 1524 |
-
pred_latents.append(latents.detach().clone().unsqueeze(1))
|
| 1525 |
-
|
| 1526 |
-
else:
|
| 1527 |
-
# Standard scheduler step
|
| 1528 |
-
latents_dtype = latents.dtype
|
| 1529 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1530 |
-
|
| 1531 |
-
|
| 1532 |
-
if latents.dtype != latents_dtype:
|
| 1533 |
-
latents = latents.to(latents_dtype)
|
| 1534 |
-
|
| 1535 |
-
if callback_on_step_end is not None:
|
| 1536 |
-
callback_kwargs = {}
|
| 1537 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 1538 |
-
callback_kwargs[k] = locals()[k]
|
| 1539 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1540 |
-
latents = callback_outputs.pop("latents", latents)
|
| 1541 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1542 |
-
|
| 1543 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1544 |
-
progress_bar.update()
|
| 1545 |
-
|
| 1546 |
-
if XLA_AVAILABLE:
|
| 1547 |
-
xm.mark_step()
|
| 1548 |
-
|
| 1549 |
-
self._current_timestep = None
|
| 1550 |
-
|
| 1551 |
-
# Decode latents to images
|
| 1552 |
-
if output_type == "latent":
|
| 1553 |
-
image = latents
|
| 1554 |
-
else:
|
| 1555 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1556 |
-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1557 |
-
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1558 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1559 |
-
|
| 1560 |
-
# Batch states for output
|
| 1561 |
-
batched_states = {}
|
| 1562 |
-
for key, value_list in states.items():
|
| 1563 |
-
if value_list is None or len(value_list) == 0:
|
| 1564 |
-
batched_states[key] = None
|
| 1565 |
-
continue
|
| 1566 |
-
if isinstance(value_list, list) and value_list[0] is None:
|
| 1567 |
-
batched_states[key] = None
|
| 1568 |
-
continue
|
| 1569 |
-
if isinstance(value_list, list):
|
| 1570 |
-
concatenated = torch.cat(value_list, dim=1)
|
| 1571 |
-
if len(concatenated.shape) <= 2:
|
| 1572 |
-
batched_states[key] = concatenated.view(-1)
|
| 1573 |
-
else:
|
| 1574 |
-
batched_states[key] = concatenated.view(-1, *concatenated.shape[2:])
|
| 1575 |
-
else:
|
| 1576 |
-
batched_states[key] = value_list
|
| 1577 |
-
|
| 1578 |
-
# Reshape outputs
|
| 1579 |
-
prev_latents = torch.cat(prev_latents, dim=1)
|
| 1580 |
-
log_probs = torch.cat(log_probs, dim=1)
|
| 1581 |
-
pred_latents = torch.cat(pred_latents, dim=1)
|
| 1582 |
-
ts = torch.cat(ts, dim=1)
|
| 1583 |
-
|
| 1584 |
-
prev_latents = prev_latents.view(prev_latents.shape[0] * prev_latents.shape[1], *prev_latents.shape[2:])
|
| 1585 |
-
log_probs = log_probs.view(log_probs.shape[0] * log_probs.shape[1], *log_probs.shape[2:])
|
| 1586 |
-
pred_latents = pred_latents.view(pred_latents.shape[0] * pred_latents.shape[1], *pred_latents.shape[2:])
|
| 1587 |
-
ts = ts.view(-1)
|
| 1588 |
-
|
| 1589 |
-
# Offload models
|
| 1590 |
-
self.maybe_free_model_hooks()
|
| 1591 |
-
|
| 1592 |
-
return (image, prev_latents, log_probs, pred_latents, ts, batched_states)
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
def sde_step_with_logprob(
|
| 1596 |
-
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 1597 |
-
model_output: torch.FloatTensor,
|
| 1598 |
-
timestep: Union[float, torch.FloatTensor],
|
| 1599 |
-
sample: torch.FloatTensor,
|
| 1600 |
-
prev_sample: Optional[torch.FloatTensor] = None,
|
| 1601 |
-
generator: Optional[torch.Generator] = None,
|
| 1602 |
-
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1603 |
-
"""
|
| 1604 |
-
Predict the sample from the previous timestep by reversing the SDE with log probability tracking.
|
| 1605 |
-
|
| 1606 |
-
Args:
|
| 1607 |
-
scheduler: The FlowMatchEulerDiscreteScheduler instance
|
| 1608 |
-
model_output: The direct output from learned flow model
|
| 1609 |
-
timestep: The current discrete timestep in the diffusion chain
|
| 1610 |
-
sample: A current instance of a sample created by the diffusion process
|
| 1611 |
-
prev_sample: Optional pre-computed previous sample
|
| 1612 |
-
generator: A random number generator
|
| 1613 |
-
|
| 1614 |
-
Returns:
|
| 1615 |
-
Tuple of (prev_sample, log_prob, prev_sample_mean, std_dev)
|
| 1616 |
-
"""
|
| 1617 |
-
step_index = [scheduler.index_for_timestep(t) for t in timestep]
|
| 1618 |
-
prev_step_index = [step + 1 for step in step_index]
|
| 1619 |
-
sigma = scheduler.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 1620 |
-
sigma_prev = scheduler.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 1621 |
-
sigma_max = scheduler.sigmas[1].item()
|
| 1622 |
-
dt = sigma_prev - sigma
|
| 1623 |
-
|
| 1624 |
-
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * 0.8
|
| 1625 |
-
|
| 1626 |
-
# SDE formulation
|
| 1627 |
-
prev_sample_mean = (
|
| 1628 |
-
sample * (1 + std_dev_t**2 / (2 * sigma) * dt) +
|
| 1629 |
-
model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
|
| 1630 |
-
)
|
| 1631 |
-
|
| 1632 |
-
if prev_sample is not None and generator is not None:
|
| 1633 |
-
raise ValueError(
|
| 1634 |
-
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 1635 |
-
" `prev_sample` stays `None`."
|
| 1636 |
-
)
|
| 1637 |
-
|
| 1638 |
-
if prev_sample is None:
|
| 1639 |
-
variance_noise = randn_tensor(
|
| 1640 |
-
model_output.shape,
|
| 1641 |
-
generator=generator,
|
| 1642 |
-
device=model_output.device,
|
| 1643 |
-
dtype=model_output.dtype,
|
| 1644 |
-
)
|
| 1645 |
-
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
|
| 1646 |
-
|
| 1647 |
-
# Calculate log probability
|
| 1648 |
-
variance = (std_dev_t * torch.sqrt(-1 * dt)) ** 2
|
| 1649 |
-
log_prob = (
|
| 1650 |
-
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance)
|
| 1651 |
-
- torch.log(torch.sqrt(variance))
|
| 1652 |
-
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 1653 |
-
)
|
| 1654 |
-
|
| 1655 |
-
# Mean along all but batch dimension
|
| 1656 |
-
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 1657 |
-
|
| 1658 |
-
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1 * dt)
|
| 1659 |
-
|
| 1660 |
-
|
| 1661 |
-
def sde_step_with_logprob_simple(
|
| 1662 |
-
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 1663 |
-
model_output: torch.FloatTensor,
|
| 1664 |
-
timestep: Union[float, torch.FloatTensor],
|
| 1665 |
-
sample: torch.FloatTensor,
|
| 1666 |
-
prev_sample: Optional[torch.FloatTensor] = None,
|
| 1667 |
-
generator: Optional[torch.Generator] = None,
|
| 1668 |
-
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1669 |
-
"""
|
| 1670 |
-
Simplified SDE step with log probability tracking using eta parameter.
|
| 1671 |
-
|
| 1672 |
-
Args:
|
| 1673 |
-
scheduler: The FlowMatchEulerDiscreteScheduler instance
|
| 1674 |
-
model_output: The direct output from learned flow model
|
| 1675 |
-
timestep: The current discrete timestep in the diffusion chain
|
| 1676 |
-
sample: A current instance of a sample created by the diffusion process
|
| 1677 |
-
prev_sample: Optional pre-computed previous sample
|
| 1678 |
-
generator: A random number generator
|
| 1679 |
-
|
| 1680 |
-
Returns:
|
| 1681 |
-
Tuple of (prev_sample, log_prob, prev_sample_mean, std_dev)
|
| 1682 |
-
"""
|
| 1683 |
-
step_index = [scheduler.index_for_timestep(t) for t in timestep]
|
| 1684 |
-
prev_step_index = [step + 1 for step in step_index]
|
| 1685 |
-
sigma = scheduler.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 1686 |
-
sigma_prev = scheduler.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 1687 |
-
sigma_max = scheduler.sigmas[1].item()
|
| 1688 |
-
dt = sigma_prev - sigma
|
| 1689 |
-
|
| 1690 |
-
eta = 0.5
|
| 1691 |
-
Dt = -dt * eta
|
| 1692 |
-
|
| 1693 |
-
prev_sample_mean = (
|
| 1694 |
-
sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) +
|
| 1695 |
-
model_output * (dt - Dt)
|
| 1696 |
-
)
|
| 1697 |
-
|
| 1698 |
-
std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
|
| 1699 |
-
|
| 1700 |
-
if prev_sample is not None and generator is not None:
|
| 1701 |
-
raise ValueError(
|
| 1702 |
-
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 1703 |
-
" `prev_sample` stays `None`."
|
| 1704 |
-
)
|
| 1705 |
-
|
| 1706 |
-
if prev_sample is None:
|
| 1707 |
-
variance_noise = randn_tensor(
|
| 1708 |
-
model_output.shape,
|
| 1709 |
-
generator=generator,
|
| 1710 |
-
device=model_output.device,
|
| 1711 |
-
dtype=model_output.dtype,
|
| 1712 |
-
)
|
| 1713 |
-
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 1714 |
-
|
| 1715 |
-
# Calculate log probability
|
| 1716 |
-
log_prob = (
|
| 1717 |
-
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
| 1718 |
-
- torch.log(std_dev_t)
|
| 1719 |
-
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 1720 |
-
)
|
| 1721 |
-
|
| 1722 |
-
# Mean along all but batch dimension
|
| 1723 |
-
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 1724 |
-
|
| 1725 |
-
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
|
| 19 |
import numpy as np
|
|
|
|
| 20 |
import torch
|
| 21 |
from transformers import (
|
| 22 |
CLIPImageProcessor,
|
|
|
|
| 1159 |
return (image,)
|
| 1160 |
|
| 1161 |
return FluxPipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unirl/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
"""PromptRL training package."""
|
| 2 |
-
|
|
|
|
|
|
|
|
|
unirl/reward_evaluator/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
from .reward_evaluator import RewardEvaluatorClient
|
| 2 |
-
|
| 3 |
-
__all__ = ["RewardEvaluatorClient"]
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unirl/reward_evaluator/reward_evaluator.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
import pickle
|
| 2 |
-
from io import BytesIO
|
| 3 |
-
from typing import Any, Dict, List, Mapping, Optional, Union
|
| 4 |
-
|
| 5 |
-
import requests
|
| 6 |
-
from PIL import Image
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
DEFAULT_EDITREWARD_URL = "http://127.0.0.1:18088/"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def _serialize_image(image: Image.Image) -> bytes:
|
| 13 |
-
buffer = BytesIO()
|
| 14 |
-
if image.mode != "RGB":
|
| 15 |
-
image = image.convert("RGB")
|
| 16 |
-
image.save(buffer, format="PNG")
|
| 17 |
-
return buffer.getvalue()
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def _serialize_images(
|
| 21 |
-
images: Union[List[Image.Image], Mapping[str, List[Image.Image]]],
|
| 22 |
-
) -> Union[List[bytes], Dict[str, List[bytes]]]:
|
| 23 |
-
if isinstance(images, Mapping):
|
| 24 |
-
return {key: [_serialize_image(image) for image in value] for key, value in images.items()}
|
| 25 |
-
return [_serialize_image(image) for image in images]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _create_payload(
|
| 29 |
-
images: Union[List[Image.Image], Mapping[str, List[Image.Image]]],
|
| 30 |
-
prompts: List[str],
|
| 31 |
-
metadata: Optional[Dict[str, Any]] = None,
|
| 32 |
-
) -> bytes:
|
| 33 |
-
return pickle.dumps(
|
| 34 |
-
{
|
| 35 |
-
"images": _serialize_images(images),
|
| 36 |
-
"prompts": prompts,
|
| 37 |
-
"metadata": metadata or {},
|
| 38 |
-
}
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class RewardEvaluatorClient:
|
| 43 |
-
"""HTTP client for the EditReward scorer service."""
|
| 44 |
-
|
| 45 |
-
def __init__(self, editreward_url: str = DEFAULT_EDITREWARD_URL, timeout: int = 600):
|
| 46 |
-
self.editreward_url = editreward_url
|
| 47 |
-
self.timeout = timeout
|
| 48 |
-
|
| 49 |
-
def evaluate_editreward(
|
| 50 |
-
self,
|
| 51 |
-
source_images: List[Image.Image],
|
| 52 |
-
edited_images: List[Image.Image],
|
| 53 |
-
prompts: List[str],
|
| 54 |
-
) -> Dict[str, Any]:
|
| 55 |
-
if not (len(source_images) == len(edited_images) == len(prompts)):
|
| 56 |
-
raise ValueError(
|
| 57 |
-
"EditReward inputs must have equal lengths: "
|
| 58 |
-
f"{len(source_images)} source images, {len(edited_images)} edited images, {len(prompts)} prompts."
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
payload = _create_payload(
|
| 62 |
-
{"source": source_images, "edited": edited_images},
|
| 63 |
-
prompts,
|
| 64 |
-
)
|
| 65 |
-
response = requests.post(self.editreward_url, data=payload, timeout=self.timeout)
|
| 66 |
-
response.raise_for_status()
|
| 67 |
-
result = pickle.loads(response.content)
|
| 68 |
-
if isinstance(result, dict) and "error" in result:
|
| 69 |
-
raise RuntimeError(f"EditReward service returned an error: {result['error']}")
|
| 70 |
-
return result
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unirl/train_edit.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import re
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from io import BytesIO
|
| 6 |
-
from typing import Any, Dict, List, Optional
|
| 7 |
-
|
| 8 |
-
from datasets import load_dataset
|
| 9 |
-
from PIL import Image, ImageOps
|
| 10 |
-
from torch.utils.data import Dataset
|
| 11 |
-
from transformers.trainer_utils import get_last_checkpoint
|
| 12 |
-
from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 13 |
-
|
| 14 |
-
from .reward_evaluator import RewardEvaluatorClient
|
| 15 |
-
from .trainer import QwenKontextEditGRPOTrainer
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
DEFAULT_EDIT_DATASET = "https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet"
|
| 19 |
-
|
| 20 |
-
EDIT_QUESTION_TEMPLATE = """Please provide an enhanced prompt for the following image editing prompt.
|
| 21 |
-
Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
|
| 22 |
-
Original prompt: {Question}. Directly provide the improved prompt in <answer> </answer> tags."""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class EditGRPOScriptArguments(ScriptArguments):
|
| 27 |
-
reward_funcs: List[str] = field(
|
| 28 |
-
default_factory=lambda: ["editreward", "format"],
|
| 29 |
-
metadata={"help": "Reward functions to use. Edit training supports only: editreward, format."},
|
| 30 |
-
)
|
| 31 |
-
prompts_file: str = field(
|
| 32 |
-
default=DEFAULT_EDIT_DATASET,
|
| 33 |
-
metadata={"help": "Path or URL to a .parquet or .jsonl edit-training dataset."},
|
| 34 |
-
)
|
| 35 |
-
image_column: str = field(default="image", metadata={"help": "Dataset column containing the source image."})
|
| 36 |
-
prompt_column: str = field(default="prompt", metadata={"help": "Dataset column containing the edit instruction."})
|
| 37 |
-
caption_column: Optional[str] = field(default="caption", metadata={"help": "Optional source-caption column."})
|
| 38 |
-
target_caption_column: Optional[str] = field(
|
| 39 |
-
default="target_caption",
|
| 40 |
-
metadata={"help": "Optional target-caption column."},
|
| 41 |
-
)
|
| 42 |
-
image_size: int = field(default=512, metadata={"help": "Center-cropped source image size used for editing."})
|
| 43 |
-
dataset_cache_dir: Optional[str] = field(
|
| 44 |
-
default=None,
|
| 45 |
-
metadata={"help": "Optional Hugging Face datasets cache dir. Defaults to HF_DATASETS_CACHE."},
|
| 46 |
-
)
|
| 47 |
-
editreward_url: str = field(
|
| 48 |
-
default="http://127.0.0.1:18088/",
|
| 49 |
-
metadata={"help": "HTTP URL of the EditReward scorer service."},
|
| 50 |
-
)
|
| 51 |
-
processor_name_or_path: str = field(
|
| 52 |
-
default="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 53 |
-
metadata={"help": "Processor used for Qwen2.5-VL chat formatting and image preprocessing."},
|
| 54 |
-
)
|
| 55 |
-
max_pixels: int = field(default=200704, metadata={"help": "Maximum pixels passed to the Qwen-VL processor."})
|
| 56 |
-
min_pixels: int = field(default=200704, metadata={"help": "Minimum pixels passed to the Qwen-VL processor."})
|
| 57 |
-
num_skip_refinement: int = field(
|
| 58 |
-
default=2,
|
| 59 |
-
metadata={"help": "Generations per input that use the original edit prompt instead of a Qwen-refined prompt."},
|
| 60 |
-
)
|
| 61 |
-
num_sde: int = field(
|
| 62 |
-
default=4,
|
| 63 |
-
metadata={"help": "Number of FLUX denoising steps sampled with SDE log-prob tracking for diffusion GRPO."},
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
class EditPromptDataset(Dataset):
|
| 68 |
-
"""Loads image-edit instructions from parquet or jsonl files."""
|
| 69 |
-
|
| 70 |
-
def __init__(
|
| 71 |
-
self,
|
| 72 |
-
prompts_file: str,
|
| 73 |
-
question_template: str,
|
| 74 |
-
image_column: str = "image",
|
| 75 |
-
prompt_column: str = "prompt",
|
| 76 |
-
caption_column: Optional[str] = "caption",
|
| 77 |
-
target_caption_column: Optional[str] = "target_caption",
|
| 78 |
-
image_size: int = 512,
|
| 79 |
-
cache_dir: Optional[str] = None,
|
| 80 |
-
):
|
| 81 |
-
self.prompts_file = normalize_data_path(prompts_file)
|
| 82 |
-
self.question_template = question_template
|
| 83 |
-
self.image_column = image_column
|
| 84 |
-
self.prompt_column = prompt_column
|
| 85 |
-
self.caption_column = caption_column
|
| 86 |
-
self.target_caption_column = target_caption_column
|
| 87 |
-
self.image_size = image_size
|
| 88 |
-
self.base_dir = (
|
| 89 |
-
os.path.dirname(os.path.abspath(self.prompts_file))
|
| 90 |
-
if not is_remote_path(self.prompts_file)
|
| 91 |
-
else os.getcwd()
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
if self.prompts_file.endswith(".parquet"):
|
| 95 |
-
self.records = load_dataset(
|
| 96 |
-
"parquet",
|
| 97 |
-
data_files={"train": self.prompts_file},
|
| 98 |
-
split="train",
|
| 99 |
-
cache_dir=cache_dir or os.getenv("HF_DATASETS_CACHE"),
|
| 100 |
-
)
|
| 101 |
-
elif self.prompts_file.endswith(".jsonl") or self.prompts_file.endswith(".json"):
|
| 102 |
-
if is_remote_path(self.prompts_file):
|
| 103 |
-
self.records = load_dataset(
|
| 104 |
-
"json",
|
| 105 |
-
data_files={"train": self.prompts_file},
|
| 106 |
-
split="train",
|
| 107 |
-
cache_dir=cache_dir or os.getenv("HF_DATASETS_CACHE"),
|
| 108 |
-
)
|
| 109 |
-
else:
|
| 110 |
-
with open(self.prompts_file, "r", encoding="utf-8") as file:
|
| 111 |
-
self.records = [json.loads(line) for line in file if line.strip()]
|
| 112 |
-
else:
|
| 113 |
-
raise ValueError("Edit training datasets must be .parquet or .jsonl files.")
|
| 114 |
-
|
| 115 |
-
if len(self.records) == 0:
|
| 116 |
-
raise ValueError(f"No training records found in {prompts_file}.")
|
| 117 |
-
|
| 118 |
-
def __len__(self) -> int:
|
| 119 |
-
return len(self.records)
|
| 120 |
-
|
| 121 |
-
def __getitem__(self, index: int) -> Dict[str, Any]:
|
| 122 |
-
item = self.records[index]
|
| 123 |
-
instruction = self._read_text(item, self.prompt_column)
|
| 124 |
-
image = self._read_image(item, self.image_column)
|
| 125 |
-
formatted_prompt = self.question_template.format(Question=instruction)
|
| 126 |
-
|
| 127 |
-
return {
|
| 128 |
-
"image": image,
|
| 129 |
-
"caption": self._read_optional_text(item, self.caption_column),
|
| 130 |
-
"target_caption": self._read_optional_text(item, self.target_caption_column),
|
| 131 |
-
"editing_instruction": instruction,
|
| 132 |
-
"prompt": [
|
| 133 |
-
{
|
| 134 |
-
"role": "user",
|
| 135 |
-
"content": [
|
| 136 |
-
{"type": "image"},
|
| 137 |
-
{"type": "text", "text": formatted_prompt},
|
| 138 |
-
],
|
| 139 |
-
}
|
| 140 |
-
],
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
def _read_text(self, item: Dict[str, Any], column: str) -> str:
|
| 144 |
-
if column not in item:
|
| 145 |
-
raise KeyError(f"Missing required column '{column}' in {self.prompts_file}.")
|
| 146 |
-
value = item[column]
|
| 147 |
-
if isinstance(value, list):
|
| 148 |
-
value = value[-1] if value else ""
|
| 149 |
-
if value is None or not str(value).strip():
|
| 150 |
-
raise ValueError(f"Empty edit instruction in column '{column}'.")
|
| 151 |
-
return str(value).strip()
|
| 152 |
-
|
| 153 |
-
def _read_optional_text(self, item: Dict[str, Any], column: Optional[str]) -> str:
|
| 154 |
-
if not column or column not in item or item[column] is None:
|
| 155 |
-
return ""
|
| 156 |
-
value = item[column]
|
| 157 |
-
if isinstance(value, list):
|
| 158 |
-
value = value[-1] if value else ""
|
| 159 |
-
return str(value).strip()
|
| 160 |
-
|
| 161 |
-
def _read_image(self, item: Dict[str, Any], column: str) -> Image.Image:
|
| 162 |
-
if column not in item:
|
| 163 |
-
raise KeyError(f"Missing required image column '{column}' in {self.prompts_file}.")
|
| 164 |
-
image = self._coerce_image(item[column])
|
| 165 |
-
if self.image_size > 0:
|
| 166 |
-
image = ImageOps.fit(image, (self.image_size, self.image_size), method=Image.Resampling.BICUBIC)
|
| 167 |
-
return image.convert("RGB")
|
| 168 |
-
|
| 169 |
-
def _coerce_image(self, value: Any) -> Image.Image:
|
| 170 |
-
if isinstance(value, Image.Image):
|
| 171 |
-
return value.convert("RGB")
|
| 172 |
-
if isinstance(value, str):
|
| 173 |
-
image_path = value if os.path.isabs(value) else os.path.join(self.base_dir, value)
|
| 174 |
-
return Image.open(image_path).convert("RGB")
|
| 175 |
-
if isinstance(value, bytes):
|
| 176 |
-
return Image.open(BytesIO(value)).convert("RGB")
|
| 177 |
-
if isinstance(value, dict):
|
| 178 |
-
if value.get("bytes") is not None:
|
| 179 |
-
return Image.open(BytesIO(value["bytes"])).convert("RGB")
|
| 180 |
-
if value.get("path") is not None:
|
| 181 |
-
image_path = value["path"] if os.path.isabs(value["path"]) else os.path.join(self.base_dir, value["path"])
|
| 182 |
-
return Image.open(image_path).convert("RGB")
|
| 183 |
-
raise TypeError(f"Unsupported image value type: {type(value)!r}")
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def is_remote_path(path: str) -> bool:
|
| 187 |
-
return path.startswith(("http://", "https://", "hf://"))
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def normalize_data_path(path: str) -> str:
|
| 191 |
-
if path.startswith("hf://"):
|
| 192 |
-
parts = path[len("hf://") :].split("/", 2)
|
| 193 |
-
if len(parts) != 3:
|
| 194 |
-
raise ValueError("hf:// dataset paths must look like hf://owner/repo/path/to/file.parquet")
|
| 195 |
-
repo_id = f"{parts[0]}/{parts[1]}"
|
| 196 |
-
file_path = parts[2]
|
| 197 |
-
return f"https://huggingface.co/{repo_id}/resolve/main/{file_path}"
|
| 198 |
-
|
| 199 |
-
if "huggingface.co/" in path and "/blob/" in path:
|
| 200 |
-
return path.replace("/blob/", "/resolve/", 1)
|
| 201 |
-
return path
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def format_reward(completions: List[str]) -> List[float]:
|
| 205 |
-
pattern = re.compile(r"<answer>.*?</answer>", re.DOTALL)
|
| 206 |
-
return [1.0 if pattern.search(completion) else 0.0 for completion in completions]
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
def build_editreward_func(editreward_url: str):
|
| 210 |
-
reward_client = RewardEvaluatorClient(editreward_url=editreward_url)
|
| 211 |
-
|
| 212 |
-
def editreward(source_images, edited_images, prompts):
|
| 213 |
-
return reward_client.evaluate_editreward(source_images, edited_images, prompts)
|
| 214 |
-
|
| 215 |
-
return editreward
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
def main(script_args: EditGRPOScriptArguments, training_args: GRPOConfig, model_args: ModelConfig) -> None:
|
| 219 |
-
supported_rewards = {"editreward", "format"}
|
| 220 |
-
unsupported_rewards = sorted(set(script_args.reward_funcs) - supported_rewards)
|
| 221 |
-
if unsupported_rewards:
|
| 222 |
-
raise ValueError(f"Edit training supports only {sorted(supported_rewards)}, got {unsupported_rewards}.")
|
| 223 |
-
|
| 224 |
-
reward_registry = {
|
| 225 |
-
"editreward": build_editreward_func(script_args.editreward_url),
|
| 226 |
-
"format": format_reward,
|
| 227 |
-
}
|
| 228 |
-
reward_funcs = [(name, None, reward_registry[name]) for name in script_args.reward_funcs]
|
| 229 |
-
|
| 230 |
-
train_dataset = EditPromptDataset(
|
| 231 |
-
prompts_file=script_args.prompts_file,
|
| 232 |
-
question_template=EDIT_QUESTION_TEMPLATE,
|
| 233 |
-
image_column=script_args.image_column,
|
| 234 |
-
prompt_column=script_args.prompt_column,
|
| 235 |
-
caption_column=script_args.caption_column,
|
| 236 |
-
target_caption_column=script_args.target_caption_column,
|
| 237 |
-
image_size=script_args.image_size,
|
| 238 |
-
cache_dir=script_args.dataset_cache_dir,
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
trainer = QwenKontextEditGRPOTrainer(
|
| 242 |
-
model=model_args.model_name_or_path,
|
| 243 |
-
reward_funcs=reward_funcs,
|
| 244 |
-
args=training_args,
|
| 245 |
-
train_dataset=train_dataset,
|
| 246 |
-
peft_config=get_peft_config(model_args),
|
| 247 |
-
max_pixels=script_args.max_pixels,
|
| 248 |
-
min_pixels=script_args.min_pixels,
|
| 249 |
-
processor_name_or_path=script_args.processor_name_or_path,
|
| 250 |
-
attn_implementation=model_args.attn_implementation,
|
| 251 |
-
num_skip_refinement=script_args.num_skip_refinement,
|
| 252 |
-
num_sde=script_args.num_sde,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
checkpoint = get_last_checkpoint(training_args.output_dir) if os.path.isdir(training_args.output_dir) else None
|
| 256 |
-
trainer.train(resume_from_checkpoint=checkpoint)
|
| 257 |
-
trainer.save_model(training_args.output_dir)
|
| 258 |
-
if training_args.push_to_hub:
|
| 259 |
-
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
if __name__ == "__main__":
|
| 263 |
-
parser = TrlParser((EditGRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 264 |
-
parsed_script_args, parsed_training_args, parsed_model_args = parser.parse_args_and_config()
|
| 265 |
-
main(parsed_script_args, parsed_training_args, parsed_model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unirl/trainer/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
from .edit_grpo_trainer import QwenKontextEditGRPOTrainer
|
| 2 |
-
|
| 3 |
-
__all__ = ["QwenKontextEditGRPOTrainer"]
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unirl/trainer/edit_grpo_trainer.py
DELETED
|
@@ -1,623 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
from datetime import datetime
|
| 4 |
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import transformers
|
| 9 |
-
from accelerate.utils import DistributedType
|
| 10 |
-
from datasets import Dataset, IterableDataset
|
| 11 |
-
from packaging import version
|
| 12 |
-
from PIL import Image
|
| 13 |
-
from transformers import AutoProcessor, GenerationConfig, PreTrainedModel, Trainer, TrainerCallback
|
| 14 |
-
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 15 |
-
from transformers.utils import is_peft_available
|
| 16 |
-
from trl.data_utils import maybe_apply_chat_template
|
| 17 |
-
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
| 18 |
-
from trl.trainer.grpo_config import GRPOConfig
|
| 19 |
-
|
| 20 |
-
from unimodel.qwenkontext.fluxkontext_pipeline import sde_step_with_logprob
|
| 21 |
-
from unimodel.qwenkontext.qwenkontext_inference import QwenKontextForInferenceLM
|
| 22 |
-
|
| 23 |
-
if is_peft_available():
|
| 24 |
-
from peft import PeftConfig, get_peft_model
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
RewardFunc = Callable[..., Union[List[float], Dict[str, Any]]]
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def compute_log_prob(
|
| 31 |
-
model_pred: torch.Tensor,
|
| 32 |
-
scheduler,
|
| 33 |
-
prev_latents: torch.Tensor,
|
| 34 |
-
pred_latents: torch.Tensor,
|
| 35 |
-
timesteps: torch.Tensor,
|
| 36 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 37 |
-
return sde_step_with_logprob(
|
| 38 |
-
scheduler,
|
| 39 |
-
model_pred.float(),
|
| 40 |
-
timesteps,
|
| 41 |
-
prev_latents.float(),
|
| 42 |
-
pred_latents.float(),
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class QwenKontextEditGRPOTrainer(Trainer):
|
| 47 |
-
"""Joint GRPO trainer for Qwen prompt refinement and FLUX.1-Kontext edit generation."""
|
| 48 |
-
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
model: Union[str, PreTrainedModel],
|
| 52 |
-
reward_funcs: List[Tuple[str, Optional[Any], RewardFunc]],
|
| 53 |
-
args: Optional[GRPOConfig] = None,
|
| 54 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 55 |
-
eval_dataset: Optional[Union[Dataset, IterableDataset, Dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 56 |
-
processing_class: Optional[Any] = None,
|
| 57 |
-
callbacks: Optional[List[TrainerCallback]] = None,
|
| 58 |
-
optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 59 |
-
peft_config: Optional["PeftConfig"] = None,
|
| 60 |
-
max_pixels: int = 200704,
|
| 61 |
-
min_pixels: int = 200704,
|
| 62 |
-
processor_name_or_path: str = "Qwen/Qwen2.5-VL-3B-Instruct",
|
| 63 |
-
attn_implementation: str = "flash_attention_2",
|
| 64 |
-
num_skip_refinement: int = 2,
|
| 65 |
-
num_sde: int = 4,
|
| 66 |
-
):
|
| 67 |
-
if args is None:
|
| 68 |
-
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 69 |
-
args = GRPOConfig(f"{os.path.basename(model_name)}-edit-joint-grpo")
|
| 70 |
-
|
| 71 |
-
model_init_kwargs = args.model_init_kwargs or {}
|
| 72 |
-
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 73 |
-
model_init_kwargs["use_cache"] = False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 74 |
-
|
| 75 |
-
if isinstance(model, str):
|
| 76 |
-
self.model_id = model
|
| 77 |
-
model = self._load_model(model, model_init_kwargs)
|
| 78 |
-
else:
|
| 79 |
-
self.model_id = model.config._name_or_path
|
| 80 |
-
if args.model_init_kwargs is not None:
|
| 81 |
-
raise ValueError("model_init_kwargs can only be used when model is a path.")
|
| 82 |
-
|
| 83 |
-
if peft_config is not None:
|
| 84 |
-
model = get_peft_model(model, peft_config)
|
| 85 |
-
|
| 86 |
-
self._configure_trainable_parameters(model)
|
| 87 |
-
self.ref_model = self._create_reference_model(model, model_init_kwargs)
|
| 88 |
-
self.scheduler = model.get_model().diffusion_expert.scheduler
|
| 89 |
-
|
| 90 |
-
if processing_class is None:
|
| 91 |
-
processing_class = self._create_processor(processor_name_or_path, max_pixels, min_pixels)
|
| 92 |
-
self.processing_class = processing_class
|
| 93 |
-
self.reward_funcs = reward_funcs
|
| 94 |
-
self.max_prompt_length = args.max_prompt_length
|
| 95 |
-
self.num_generations = args.num_generations
|
| 96 |
-
self.beta = args.beta
|
| 97 |
-
self.num_sde = num_sde
|
| 98 |
-
|
| 99 |
-
if not 0 <= num_skip_refinement < self.num_generations:
|
| 100 |
-
raise ValueError(
|
| 101 |
-
f"num_skip_refinement must be in [0, num_generations), got {num_skip_refinement} "
|
| 102 |
-
f"for num_generations={self.num_generations}."
|
| 103 |
-
)
|
| 104 |
-
self.num_skip_refinement = num_skip_refinement
|
| 105 |
-
self.num_refined = self.num_generations - num_skip_refinement
|
| 106 |
-
|
| 107 |
-
self.generation_config = GenerationConfig(
|
| 108 |
-
max_new_tokens=args.max_completion_length or 256,
|
| 109 |
-
do_sample=True,
|
| 110 |
-
temperature=1.0,
|
| 111 |
-
num_return_sequences=1,
|
| 112 |
-
pad_token_id=processing_class.pad_token_id,
|
| 113 |
-
eos_token_id=processing_class.eos_token_id,
|
| 114 |
-
)
|
| 115 |
-
model.generation_config = self.generation_config
|
| 116 |
-
self.ref_model.generation_config = self.generation_config
|
| 117 |
-
if hasattr(model, "warnings_issued"):
|
| 118 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 119 |
-
|
| 120 |
-
self._metrics = defaultdict(list)
|
| 121 |
-
|
| 122 |
-
def data_collator(features):
|
| 123 |
-
return features
|
| 124 |
-
|
| 125 |
-
super().__init__(
|
| 126 |
-
model=model,
|
| 127 |
-
args=args,
|
| 128 |
-
data_collator=data_collator,
|
| 129 |
-
train_dataset=train_dataset,
|
| 130 |
-
eval_dataset=eval_dataset,
|
| 131 |
-
processing_class=processing_class,
|
| 132 |
-
callbacks=callbacks,
|
| 133 |
-
optimizers=optimizers,
|
| 134 |
-
)
|
| 135 |
-
self.model_accepts_loss_kwargs = False
|
| 136 |
-
|
| 137 |
-
if self.is_deepspeed_enabled and is_deepspeed_zero3_enabled():
|
| 138 |
-
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 139 |
-
else:
|
| 140 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 141 |
-
|
| 142 |
-
self.diffusion_generation_config = self._get_diffusion_config()
|
| 143 |
-
self.start_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
|
| 144 |
-
self.log_dir = os.path.join(args.output_dir, "training_samples", self.start_time)
|
| 145 |
-
os.makedirs(self.log_dir, exist_ok=True)
|
| 146 |
-
|
| 147 |
-
def _load_model(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> PreTrainedModel:
|
| 148 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 149 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 150 |
-
model_init_kwargs["torch_dtype"] = getattr(torch, torch_dtype)
|
| 151 |
-
if "qwenkontext" not in model_id.lower():
|
| 152 |
-
raise ValueError("Edit joint training expects a Qwen-Kontext checkpoint path.")
|
| 153 |
-
return QwenKontextForInferenceLM.from_pretrained(model_id, **model_init_kwargs)
|
| 154 |
-
|
| 155 |
-
def _create_reference_model(self, model: PreTrainedModel, model_init_kwargs: Dict[str, Any]) -> PreTrainedModel:
|
| 156 |
-
if is_deepspeed_zero3_enabled():
|
| 157 |
-
ref_model = self._load_model(self.model_id, model_init_kwargs)
|
| 158 |
-
else:
|
| 159 |
-
ref_model = create_reference_model(model)
|
| 160 |
-
for parameter in ref_model.parameters():
|
| 161 |
-
parameter.requires_grad = False
|
| 162 |
-
return ref_model
|
| 163 |
-
|
| 164 |
-
def _configure_trainable_parameters(self, model: PreTrainedModel) -> None:
|
| 165 |
-
try:
|
| 166 |
-
model.get_model().diffusion_expert.enable_vae_slicing()
|
| 167 |
-
except AttributeError:
|
| 168 |
-
try:
|
| 169 |
-
model.get_model().diffusion_expert.vae.enable_slicing()
|
| 170 |
-
except AttributeError:
|
| 171 |
-
pass
|
| 172 |
-
|
| 173 |
-
for parameter in model.parameters():
|
| 174 |
-
parameter.requires_grad = False
|
| 175 |
-
for parameter in model.get_model().parameters():
|
| 176 |
-
parameter.requires_grad = True
|
| 177 |
-
for parameter in model.lm_head.parameters():
|
| 178 |
-
parameter.requires_grad = True
|
| 179 |
-
|
| 180 |
-
if hasattr(model, "visual"):
|
| 181 |
-
for parameter in model.visual.parameters():
|
| 182 |
-
parameter.requires_grad = False
|
| 183 |
-
|
| 184 |
-
for component_name in ("visual", "vae", "text_encoder", "text_encoder_2", "text_encoder_3"):
|
| 185 |
-
component = getattr(model.get_model(), component_name, None)
|
| 186 |
-
if component is not None:
|
| 187 |
-
for parameter in component.parameters():
|
| 188 |
-
parameter.requires_grad = False
|
| 189 |
-
|
| 190 |
-
transformer = getattr(model.get_model(), "transformer", None)
|
| 191 |
-
if transformer is None:
|
| 192 |
-
raise ValueError("Qwen-Kontext model does not expose a FLUX transformer.")
|
| 193 |
-
for parameter in transformer.parameters():
|
| 194 |
-
parameter.requires_grad = True
|
| 195 |
-
|
| 196 |
-
def _create_processor(self, processor_name_or_path: str, max_pixels: int, min_pixels: int) -> AutoProcessor:
|
| 197 |
-
processor = AutoProcessor.from_pretrained(processor_name_or_path)
|
| 198 |
-
processor.pad_token_id = processor.tokenizer.pad_token_id
|
| 199 |
-
processor.eos_token_id = processor.tokenizer.eos_token_id
|
| 200 |
-
processor.image_processor.max_pixels = max_pixels
|
| 201 |
-
processor.image_processor.min_pixels = min_pixels
|
| 202 |
-
return processor
|
| 203 |
-
|
| 204 |
-
def _get_diffusion_config(self) -> Dict[str, Any]:
|
| 205 |
-
device_text = str(self.accelerator.device)
|
| 206 |
-
device_id = int(device_text.split(":")[-1]) if ":" in device_text else 0
|
| 207 |
-
return {
|
| 208 |
-
"guidance_scale": float(os.getenv("PROMPTRL_EDIT_GUIDANCE_SCALE", "2.5")),
|
| 209 |
-
"num_inference_steps": int(os.getenv("PROMPTRL_EDIT_NUM_INFERENCE_STEPS", "8")),
|
| 210 |
-
"num_images_per_prompt": 1,
|
| 211 |
-
"generator": torch.manual_seed(42 + device_id),
|
| 212 |
-
"height": int(os.getenv("PROMPTRL_EDIT_HEIGHT", "1024")),
|
| 213 |
-
"width": int(os.getenv("PROMPTRL_EDIT_WIDTH", "1024")),
|
| 214 |
-
"num_sde": self.num_sde,
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
def _set_signature_columns_if_needed(self):
|
| 218 |
-
if self._signature_columns is None:
|
| 219 |
-
self._signature_columns = ["prompt"]
|
| 220 |
-
|
| 221 |
-
def create_optimizer(self):
|
| 222 |
-
if self.optimizer is not None:
|
| 223 |
-
return self.optimizer
|
| 224 |
-
|
| 225 |
-
optimizer_kwargs = {
|
| 226 |
-
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
| 227 |
-
"eps": self.args.adam_epsilon,
|
| 228 |
-
"weight_decay": self.args.weight_decay,
|
| 229 |
-
}
|
| 230 |
-
dit_lr = float(os.getenv("DIT_LEARNING_RATE", os.getenv("PROMPTRL_DIT_LR", "2e-7")))
|
| 231 |
-
llm_lr = float(os.getenv("LLM_LEARNING_RATE", os.getenv("PROMPTRL_LLM_LR", "3e-7")))
|
| 232 |
-
|
| 233 |
-
dit_params = [
|
| 234 |
-
parameter for parameter in self.model.get_model().transformer.parameters() if parameter.requires_grad
|
| 235 |
-
]
|
| 236 |
-
dit_param_ids = {id(parameter) for parameter in dit_params}
|
| 237 |
-
llm_params = [
|
| 238 |
-
parameter
|
| 239 |
-
for parameter in self.model.parameters()
|
| 240 |
-
if parameter.requires_grad and id(parameter) not in dit_param_ids
|
| 241 |
-
]
|
| 242 |
-
|
| 243 |
-
param_groups = []
|
| 244 |
-
if dit_params:
|
| 245 |
-
param_groups.append({"params": dit_params, "lr": dit_lr})
|
| 246 |
-
if llm_params:
|
| 247 |
-
param_groups.append({"params": llm_params, "lr": llm_lr})
|
| 248 |
-
if not param_groups:
|
| 249 |
-
raise ValueError("No trainable parameters were found for edit joint GRPO training.")
|
| 250 |
-
|
| 251 |
-
self.optimizer = torch.optim.AdamW(param_groups, **optimizer_kwargs)
|
| 252 |
-
return self.optimizer
|
| 253 |
-
|
| 254 |
-
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None):
|
| 255 |
-
model.eval()
|
| 256 |
-
self.ref_model.eval()
|
| 257 |
-
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
| 258 |
-
self.optimizer.train()
|
| 259 |
-
|
| 260 |
-
inputs = self._prepare_inputs(inputs)
|
| 261 |
-
|
| 262 |
-
def loss_update(loss: torch.Tensor, scale_factor: float = 1.0) -> None:
|
| 263 |
-
if self.args.n_gpu > 1:
|
| 264 |
-
loss = loss.mean()
|
| 265 |
-
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
| 266 |
-
loss = loss / self.args.gradient_accumulation_steps
|
| 267 |
-
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 268 |
-
loss = loss / scale_factor
|
| 269 |
-
model.backward(loss)
|
| 270 |
-
else:
|
| 271 |
-
self.accelerator.backward(loss / scale_factor)
|
| 272 |
-
|
| 273 |
-
with self.compute_loss_context_manager():
|
| 274 |
-
generations = self.generate_samples(model, inputs)
|
| 275 |
-
torch.cuda.empty_cache()
|
| 276 |
-
|
| 277 |
-
if self.num_refined > 0:
|
| 278 |
-
cot_loss = self.cot_loss_computation(
|
| 279 |
-
model,
|
| 280 |
-
generations["prompt_completion_ids"],
|
| 281 |
-
generations["completion_ids"],
|
| 282 |
-
generations["prompt_length"],
|
| 283 |
-
generations["advantages_refined"],
|
| 284 |
-
generations["prompt_inputs"],
|
| 285 |
-
)
|
| 286 |
-
loss_update(cot_loss, 1.0)
|
| 287 |
-
else:
|
| 288 |
-
cot_loss = torch.tensor(0.0, device=self.accelerator.device)
|
| 289 |
-
|
| 290 |
-
diff_advantages = generations["advantages"].repeat_interleave(self.num_sde, dim=0)
|
| 291 |
-
total_len = diff_advantages.shape[0]
|
| 292 |
-
diff_loss_values = []
|
| 293 |
-
diff_kl_values = []
|
| 294 |
-
diffusion_batch_size = int(os.getenv("PROMPTRL_DIFFUSION_LOSS_BATCH_SIZE", "4"))
|
| 295 |
-
|
| 296 |
-
for idx in range(0, total_len, diffusion_batch_size):
|
| 297 |
-
batched_states_slice = {}
|
| 298 |
-
for key, value in generations["batched_states"].items():
|
| 299 |
-
if key in {"img_ids", "txt_ids"}:
|
| 300 |
-
batched_states_slice[key] = value
|
| 301 |
-
elif value is None:
|
| 302 |
-
batched_states_slice[key] = None
|
| 303 |
-
else:
|
| 304 |
-
batched_states_slice[key] = value[idx : idx + diffusion_batch_size]
|
| 305 |
-
|
| 306 |
-
diff_loss, diff_kl = self.diffusion_loss_computation(
|
| 307 |
-
generations["prev_latents"][idx : idx + diffusion_batch_size],
|
| 308 |
-
generations["diff_sampling_log_probs"][idx : idx + diffusion_batch_size],
|
| 309 |
-
generations["pred_latents"][idx : idx + diffusion_batch_size],
|
| 310 |
-
generations["ts"][idx : idx + diffusion_batch_size],
|
| 311 |
-
batched_states_slice,
|
| 312 |
-
diff_advantages[idx : idx + diffusion_batch_size],
|
| 313 |
-
)
|
| 314 |
-
loss_update(diff_loss, max(1.0, float(total_len / diffusion_batch_size)))
|
| 315 |
-
diff_loss_values.append(diff_loss.detach())
|
| 316 |
-
diff_kl_values.append(diff_kl.detach())
|
| 317 |
-
|
| 318 |
-
diff_loss = torch.stack(diff_loss_values).mean()
|
| 319 |
-
diff_kl = torch.stack(diff_kl_values).mean()
|
| 320 |
-
loss = diff_loss + cot_loss.detach()
|
| 321 |
-
|
| 322 |
-
if self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0:
|
| 323 |
-
torch.cuda.empty_cache()
|
| 324 |
-
|
| 325 |
-
if hasattr(model, "step") and callable(model.step):
|
| 326 |
-
model.step()
|
| 327 |
-
|
| 328 |
-
self._metrics["diff_kl"].append(self.accelerator.gather_for_metrics(diff_kl).mean().item())
|
| 329 |
-
self._metrics["diff_loss"].append(self.accelerator.gather_for_metrics(diff_loss).mean().item())
|
| 330 |
-
torch.cuda.empty_cache()
|
| 331 |
-
return loss.detach()
|
| 332 |
-
|
| 333 |
-
def generate_samples(self, model: nn.Module, inputs: List[Dict]) -> Dict[str, Any]:
|
| 334 |
-
source_images = [example["image"] for example in inputs]
|
| 335 |
-
batch_size = len(inputs)
|
| 336 |
-
prompt_inputs = None
|
| 337 |
-
prompt_completion_ids = None
|
| 338 |
-
completion_ids = None
|
| 339 |
-
prompt_length = 0
|
| 340 |
-
completions_refined: List[str] = []
|
| 341 |
-
refined_prompts: List[str] = []
|
| 342 |
-
|
| 343 |
-
if self.num_refined > 0:
|
| 344 |
-
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
| 345 |
-
prompt_inputs = self.processing_class(
|
| 346 |
-
images=[image for image in source_images for _ in range(self.num_refined)],
|
| 347 |
-
text=[prompt for prompt in prompts_text for _ in range(self.num_refined)],
|
| 348 |
-
return_tensors="pt",
|
| 349 |
-
padding=True,
|
| 350 |
-
padding_side="left",
|
| 351 |
-
add_special_tokens=False,
|
| 352 |
-
)
|
| 353 |
-
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 354 |
-
if self.max_prompt_length is not None:
|
| 355 |
-
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
|
| 356 |
-
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
|
| 357 |
-
|
| 358 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 359 |
-
with torch.no_grad():
|
| 360 |
-
prompt_completion_ids = unwrapped_model.generate(
|
| 361 |
-
**prompt_inputs,
|
| 362 |
-
generation_config=self.generation_config,
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
prompt_length = prompt_inputs["input_ids"].size(1)
|
| 366 |
-
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 367 |
-
completions_refined = self.processing_class.tokenizer.batch_decode(
|
| 368 |
-
completion_ids,
|
| 369 |
-
skip_special_tokens=True,
|
| 370 |
-
)
|
| 371 |
-
refined_prompts = [self.model.extract_thinking_content(completion) for completion in completions_refined]
|
| 372 |
-
|
| 373 |
-
original_prompts = [
|
| 374 |
-
example["editing_instruction"]
|
| 375 |
-
for example in inputs
|
| 376 |
-
for _ in range(self.num_skip_refinement)
|
| 377 |
-
]
|
| 378 |
-
all_prompts: List[str] = []
|
| 379 |
-
for batch_idx in range(batch_size):
|
| 380 |
-
refined_start = batch_idx * self.num_refined
|
| 381 |
-
refined_end = refined_start + self.num_refined
|
| 382 |
-
all_prompts.extend(refined_prompts[refined_start:refined_end])
|
| 383 |
-
|
| 384 |
-
original_start = batch_idx * self.num_skip_refinement
|
| 385 |
-
original_end = original_start + self.num_skip_refinement
|
| 386 |
-
all_prompts.extend(original_prompts[original_start:original_end])
|
| 387 |
-
|
| 388 |
-
all_source_images = [image for image in source_images for _ in range(self.num_generations)]
|
| 389 |
-
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
| 390 |
-
with torch.no_grad():
|
| 391 |
-
(
|
| 392 |
-
edited_images,
|
| 393 |
-
prev_latents,
|
| 394 |
-
diff_sampling_log_probs,
|
| 395 |
-
pred_latents,
|
| 396 |
-
timesteps,
|
| 397 |
-
batched_states,
|
| 398 |
-
) = unwrapped_model.generate_image(
|
| 399 |
-
images=all_source_images,
|
| 400 |
-
texts=all_prompts,
|
| 401 |
-
diffusion_kwargs=self.diffusion_generation_config,
|
| 402 |
-
sde_sampling=True,
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
rewards, rewards_per_func = self.compute_rewards(inputs, edited_images, completions_refined)
|
| 406 |
-
advantages = self.compute_advantages(rewards)
|
| 407 |
-
advantages_refined = (
|
| 408 |
-
advantages.view(batch_size, self.num_generations)[:, : self.num_refined].flatten()
|
| 409 |
-
if self.num_refined > 0
|
| 410 |
-
else torch.tensor([], device=advantages.device)
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 414 |
-
for index, (func_name, _, _) in enumerate(self.reward_funcs):
|
| 415 |
-
self._metrics[f"reward/{func_name}"].append(
|
| 416 |
-
self.accelerator.gather_for_metrics(rewards_per_func[:, index]).mean().item()
|
| 417 |
-
)
|
| 418 |
-
self._log_samples(source_images, edited_images, all_prompts, advantages)
|
| 419 |
-
|
| 420 |
-
return {
|
| 421 |
-
"images": edited_images,
|
| 422 |
-
"prev_latents": prev_latents,
|
| 423 |
-
"diff_sampling_log_probs": diff_sampling_log_probs,
|
| 424 |
-
"pred_latents": pred_latents,
|
| 425 |
-
"batched_states": batched_states,
|
| 426 |
-
"prompt_length": prompt_length,
|
| 427 |
-
"completion_ids": completion_ids,
|
| 428 |
-
"prompt_completion_ids": prompt_completion_ids,
|
| 429 |
-
"prompt_inputs": prompt_inputs,
|
| 430 |
-
"advantages": advantages,
|
| 431 |
-
"advantages_refined": advantages_refined,
|
| 432 |
-
"ts": timesteps,
|
| 433 |
-
}
|
| 434 |
-
|
| 435 |
-
def compute_rewards(
|
| 436 |
-
self,
|
| 437 |
-
inputs: List[Dict],
|
| 438 |
-
edited_images: List[Image.Image],
|
| 439 |
-
completions_refined: List[str],
|
| 440 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 441 |
-
device = self.accelerator.device
|
| 442 |
-
rewards_per_func = torch.zeros(len(edited_images), len(self.reward_funcs), device=device)
|
| 443 |
-
batch_size = len(inputs)
|
| 444 |
-
|
| 445 |
-
for index, (func_name, _, reward_func) in enumerate(self.reward_funcs):
|
| 446 |
-
if func_name == "format":
|
| 447 |
-
refined_scores = torch.tensor(reward_func(completions_refined), device=device, dtype=torch.float32)
|
| 448 |
-
for batch_idx in range(batch_size):
|
| 449 |
-
start = batch_idx * self.num_generations
|
| 450 |
-
refined_start = batch_idx * self.num_refined
|
| 451 |
-
refined_end = refined_start + self.num_refined
|
| 452 |
-
batch_refined_scores = refined_scores[refined_start:refined_end]
|
| 453 |
-
rewards_per_func[start : start + self.num_refined, index] = batch_refined_scores
|
| 454 |
-
rewards_per_func[start + self.num_refined : start + self.num_generations, index] = (
|
| 455 |
-
batch_refined_scores.mean() if len(batch_refined_scores) else 0.0
|
| 456 |
-
)
|
| 457 |
-
elif func_name == "editreward":
|
| 458 |
-
source_images = [example["image"] for example in inputs for _ in range(self.num_generations)]
|
| 459 |
-
prompts = [example["editing_instruction"] for example in inputs for _ in range(self.num_generations)]
|
| 460 |
-
rewards_per_func[:, index] = torch.tensor(
|
| 461 |
-
reward_func(source_images, edited_images, prompts)["scores"],
|
| 462 |
-
device=device,
|
| 463 |
-
dtype=torch.float32,
|
| 464 |
-
)
|
| 465 |
-
else:
|
| 466 |
-
raise ValueError(f"Unsupported reward function for edit joint training: {func_name}")
|
| 467 |
-
|
| 468 |
-
return rewards_per_func.sum(dim=1), rewards_per_func
|
| 469 |
-
|
| 470 |
-
def compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
|
| 471 |
-
grouped_rewards = rewards.view(-1, self.num_generations)
|
| 472 |
-
mean = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
| 473 |
-
std = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(self.num_generations, dim=0)
|
| 474 |
-
return torch.clamp((rewards - mean) / (std + 1e-4), -5, 5)
|
| 475 |
-
|
| 476 |
-
def cot_loss_computation(
|
| 477 |
-
self,
|
| 478 |
-
model: nn.Module,
|
| 479 |
-
input_ids: torch.Tensor,
|
| 480 |
-
completion_ids: torch.Tensor,
|
| 481 |
-
prompt_length: int,
|
| 482 |
-
advantages: torch.Tensor,
|
| 483 |
-
prompt_inputs: Dict[str, torch.Tensor],
|
| 484 |
-
) -> torch.Tensor:
|
| 485 |
-
image_kwargs = {
|
| 486 |
-
key: value for key, value in prompt_inputs.items() if key not in {"input_ids", "attention_mask"}
|
| 487 |
-
}
|
| 488 |
-
per_token_logps = self._get_per_token_logps(model, input_ids, image_kwargs)[:, prompt_length - 1 :]
|
| 489 |
-
with torch.inference_mode():
|
| 490 |
-
ref_per_token_logps = self._get_per_token_logps(self.ref_model, input_ids, image_kwargs)[:, prompt_length - 1 :]
|
| 491 |
-
|
| 492 |
-
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 493 |
-
completion_mask = self._completion_mask(completion_ids)
|
| 494 |
-
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
| 495 |
-
per_token_loss = -(per_token_loss - 0.01 * per_token_kl)
|
| 496 |
-
cot_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp_min(1)).mean()
|
| 497 |
-
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp_min(1)).mean()
|
| 498 |
-
|
| 499 |
-
self._metrics["completion_length"].append(
|
| 500 |
-
self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 501 |
-
)
|
| 502 |
-
self._metrics["cot_kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 503 |
-
self._metrics["cot_loss"].append(self.accelerator.gather_for_metrics(cot_loss).mean().item())
|
| 504 |
-
return cot_loss
|
| 505 |
-
|
| 506 |
-
def _get_per_token_logps(
|
| 507 |
-
self,
|
| 508 |
-
model: nn.Module,
|
| 509 |
-
input_ids: torch.Tensor,
|
| 510 |
-
image_kwargs: Dict[str, torch.Tensor],
|
| 511 |
-
) -> torch.Tensor:
|
| 512 |
-
logits = model(input_ids, **image_kwargs).logits[:, :-1, :]
|
| 513 |
-
target_ids = input_ids[:, 1:]
|
| 514 |
-
per_token_logps = []
|
| 515 |
-
for logits_row, target_ids_row in zip(logits, target_ids):
|
| 516 |
-
log_probs = logits_row.log_softmax(dim=-1)
|
| 517 |
-
per_token_logps.append(torch.gather(log_probs, dim=1, index=target_ids_row.unsqueeze(1)).squeeze(1))
|
| 518 |
-
return torch.stack(per_token_logps)
|
| 519 |
-
|
| 520 |
-
def _completion_mask(self, completion_ids: torch.Tensor) -> torch.Tensor:
|
| 521 |
-
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 522 |
-
device = completion_ids.device
|
| 523 |
-
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 524 |
-
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 525 |
-
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 526 |
-
return (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 527 |
-
|
| 528 |
-
def diffusion_loss_computation(
|
| 529 |
-
self,
|
| 530 |
-
prev_latents: torch.Tensor,
|
| 531 |
-
diff_sampling_log_probs: torch.Tensor,
|
| 532 |
-
pred_latents: torch.Tensor,
|
| 533 |
-
timesteps: torch.Tensor,
|
| 534 |
-
batched_states: Dict[str, torch.Tensor],
|
| 535 |
-
advantages: torch.Tensor,
|
| 536 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 537 |
-
model_pred = self.model.get_model().transformer(
|
| 538 |
-
hidden_states=prev_latents.to(self.model.device),
|
| 539 |
-
**batched_states,
|
| 540 |
-
joint_attention_kwargs={},
|
| 541 |
-
return_dict=False,
|
| 542 |
-
)[0][:, : pred_latents.size(1)]
|
| 543 |
-
|
| 544 |
-
with torch.no_grad():
|
| 545 |
-
ref_model_pred = self.ref_model.get_model().transformer(
|
| 546 |
-
hidden_states=prev_latents.to(self.model.device),
|
| 547 |
-
**batched_states,
|
| 548 |
-
joint_attention_kwargs={},
|
| 549 |
-
return_dict=False,
|
| 550 |
-
)[0][:, : pred_latents.size(1)]
|
| 551 |
-
|
| 552 |
-
_, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(
|
| 553 |
-
model_pred,
|
| 554 |
-
self.scheduler,
|
| 555 |
-
prev_latents[:, : pred_latents.size(1)],
|
| 556 |
-
pred_latents,
|
| 557 |
-
timesteps,
|
| 558 |
-
)
|
| 559 |
-
_, _, ref_prev_sample_mean, ref_std_dev_t = compute_log_prob(
|
| 560 |
-
ref_model_pred,
|
| 561 |
-
self.scheduler,
|
| 562 |
-
prev_latents[:, : pred_latents.size(1)],
|
| 563 |
-
pred_latents,
|
| 564 |
-
timesteps,
|
| 565 |
-
)
|
| 566 |
-
if not torch.equal(std_dev_t, ref_std_dev_t):
|
| 567 |
-
raise RuntimeError("Current and reference SDE std-dev tensors diverged.")
|
| 568 |
-
|
| 569 |
-
kl = ((prev_sample_mean - ref_prev_sample_mean) ** 2 / (2 * std_dev_t**2)).mean(
|
| 570 |
-
dim=tuple(range(1, prev_sample_mean.ndim))
|
| 571 |
-
)
|
| 572 |
-
ratio = torch.exp(log_prob - diff_sampling_log_probs)
|
| 573 |
-
unclipped_loss = -advantages * ratio
|
| 574 |
-
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - 1e-4, 1.0 + 1e-4)
|
| 575 |
-
diff_loss = torch.maximum(unclipped_loss, clipped_loss).mean() + self.beta * kl.mean()
|
| 576 |
-
return diff_loss, kl
|
| 577 |
-
|
| 578 |
-
def _log_samples(
|
| 579 |
-
self,
|
| 580 |
-
source_images: List[Image.Image],
|
| 581 |
-
edited_images: List[Image.Image],
|
| 582 |
-
prompts: List[str],
|
| 583 |
-
advantages: torch.Tensor,
|
| 584 |
-
) -> None:
|
| 585 |
-
global_step = self.state.global_step
|
| 586 |
-
if global_step % 10 != 0 or not edited_images:
|
| 587 |
-
return
|
| 588 |
-
|
| 589 |
-
device_id = str(self.accelerator.device).replace(":", "")
|
| 590 |
-
text_content = []
|
| 591 |
-
for batch_idx in range(len(source_images)):
|
| 592 |
-
for gen_idx in range(self.num_generations):
|
| 593 |
-
overall_idx = batch_idx * self.num_generations + gen_idx
|
| 594 |
-
status = "REFINED" if gen_idx < self.num_refined else "ORIGINAL"
|
| 595 |
-
text_content.append(f"[{status}] Generation {gen_idx}: {prompts[overall_idx]}")
|
| 596 |
-
text_content.append("")
|
| 597 |
-
|
| 598 |
-
txt_path = os.path.join(self.log_dir, f"step_{global_step}_{device_id}.txt")
|
| 599 |
-
if not os.path.exists(txt_path):
|
| 600 |
-
with open(txt_path, "w", encoding="utf-8") as file:
|
| 601 |
-
file.write("\n".join(text_content))
|
| 602 |
-
|
| 603 |
-
for batch_idx, source_image in enumerate(source_images):
|
| 604 |
-
source_image.save(os.path.join(self.log_dir, f"step_{global_step}_{device_id}_batch{batch_idx}_source.jpg"))
|
| 605 |
-
for gen_idx in range(self.num_generations):
|
| 606 |
-
overall_idx = batch_idx * self.num_generations + gen_idx
|
| 607 |
-
prefix = "refined" if gen_idx < self.num_refined else "original"
|
| 608 |
-
edited_images[overall_idx].save(
|
| 609 |
-
os.path.join(
|
| 610 |
-
self.log_dir,
|
| 611 |
-
f"step_{global_step}_{device_id}_batch{batch_idx}_{prefix}_gen{gen_idx}_{advantages[overall_idx].item():.5f}.jpg",
|
| 612 |
-
)
|
| 613 |
-
)
|
| 614 |
-
|
| 615 |
-
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
| 616 |
-
metrics = {key: sum(value) / len(value) for key, value in self._metrics.items() if value}
|
| 617 |
-
logs = {**logs, **metrics}
|
| 618 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 619 |
-
super().log(logs, start_time)
|
| 620 |
-
else:
|
| 621 |
-
super().log(logs)
|
| 622 |
-
self._metrics.clear()
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|