Safetensors
wangfuyun commited on
Commit
ff175fa
·
verified ·
1 Parent(s): e478a8d

Restore public repo before mistaken code upload

Browse files

Restores files to revision 1a8e9e3ad2130f55c880af8ace85b9af0d0c329f and removes files mistakenly uploaded from unirl_opensource.

.gitignore CHANGED
@@ -145,11 +145,5 @@ outputs/
145
  wandb/
146
 
147
  assets/large_rl_datasets/
148
- assets/rl_datasets/*.parquet
149
- assets/rl_datasets/*.jsonl
150
- !assets/rl_datasets/README.md
151
 
152
  utils/parquet_cache/
153
-
154
- rewards_services/api_services/editreward_scorer_service/.venv/
155
- rewards_services/api_services/editreward_scorer_service/EditReward/
 
145
  wandb/
146
 
147
  assets/large_rl_datasets/
 
 
 
148
 
149
  utils/parquet_cache/
 
 
 
README.md CHANGED
@@ -1,17 +1,3 @@
1
- ---
2
- license: apache-2.0
3
- library_name: diffusers
4
- tags:
5
- - reinforcement-learning
6
- - image-generation
7
- - image-editing
8
- - prompt-optimization
9
- - flux
10
- - qwen
11
- datasets:
12
- - wangfuyun/PrompRL
13
- ---
14
-
15
  <p align="center">
16
  <img src="assets/logo.png" width="30%"><br>
17
  PromptRL
@@ -41,109 +27,6 @@ pip install flash-attn==2.7.4.post1 --no-build-isolation
41
  # bash gen.sh
42
  ```
43
 
44
- <details>
45
- <summary><b>Training The Edit Model And Running EditReward</b></summary>
46
-
47
- <br>
48
-
49
- **Scope**
50
-
51
- This release keeps only the edit RL path. The trainer is `unirl/trainer/edit_grpo_trainer.py`, which jointly optimizes the Qwen-VL prompt refiner and the FLUX.1-Kontext transformer. The VAE, text encoders, and vision encoder stay frozen.
52
-
53
- The partial-refinement setting is preserved: with the default `NUM_GENERATIONS=8` and `NUM_SKIP_REFINEMENT=2`, each source image produces six edits from Qwen-refined prompts and two edits from the original prompt.
54
-
55
- Relevant files:
56
-
57
- - `unirl/train_edit.py`: CLI entry point for Qwen-Kontext edit GRPO.
58
- - `unirl/reward_evaluator/reward_evaluator.py`: EditReward HTTP client used by training.
59
- - `rewards_services/api_services/editreward_scorer_service`: EditReward service wrapper.
60
- - `scripts/train/edit_grpo.sh`: launch script with environment-variable configuration.
61
-
62
- **Dataset**
63
-
64
- By default, `scripts/train/edit_grpo.sh` loads:
65
-
66
- ```text
67
- https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet
68
- ```
69
-
70
- You can override it with `PROMPTS_FILE`. The loader also accepts the Hugging Face web URL form with `/blob/main/`; it is converted to the downloadable `/resolve/main/` URL automatically.
71
-
72
- The dataset should be a `.parquet` or `.jsonl` file with:
73
-
74
- | Column | Description |
75
- | --- | --- |
76
- | `image` | Source image before editing. For jsonl this can be an image path. |
77
- | `prompt` | Edit instruction. |
78
-
79
- Optional columns are `caption` and `target_caption`. For other column names, set `IMAGE_COLUMN` and `PROMPT_COLUMN`.
80
-
81
- **1. Start EditReward**
82
-
83
- ```bash
84
- cd rewards_services/api_services/editreward_scorer_service
85
- python -m venv .venv
86
- source .venv/bin/activate
87
- pip install --upgrade pip
88
- pip install torch torchvision torchaudio
89
- pip install -r requirements.txt
90
-
91
- git clone https://github.com/TIGER-AI-Lab/EditReward.git
92
- huggingface-cli download TIGER-Lab/EditReward-MiMo-VL-7B-SFT-2508 \
93
- --local-dir EditReward/EditReward-MiMo-VL-7B-SFT-2508
94
-
95
- export EDITREWARD_CUDA_DEVICES=0,1
96
- export EDITREWARD_WORKERS=2
97
- export EDITREWARD_PORT=18088
98
- bash run.sh
99
- ```
100
-
101
- If the EditReward repo or checkpoint is stored elsewhere:
102
-
103
- ```bash
104
- export EDITREWARD_REPO_DIR=/path/to/EditReward
105
- export EDITREWARD_CHECKPOINT_PATH=/path/to/EditReward-MiMo-VL-7B-SFT-2508
106
- ```
107
-
108
- **2. Launch Training**
109
-
110
- From the repository root:
111
-
112
- ```bash
113
- export MODEL_NAME_OR_PATH=/path/to/qwenkontext/checkpoint
114
- # Optional. Defaults to the PromptRL OmniEdit 50k parquet on Hugging Face.
115
- export PROMPTS_FILE=https://huggingface.co/wangfuyun/PrompRL/blob/main/data/omni_edit_train_50k.parquet
116
- export EDITREWARD_URL=http://127.0.0.1:18088/
117
- export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7
118
- export NPROC_PER_NODE=6
119
- export RUN_NAME=qwenkontext-editreward
120
-
121
- bash scripts/train/edit_grpo.sh
122
- ```
123
-
124
- Common options:
125
-
126
- ```bash
127
- export NUM_GENERATIONS=8
128
- export NUM_SKIP_REFINEMENT=2
129
- export NUM_SDE=4
130
- export PER_DEVICE_TRAIN_BATCH_SIZE=1
131
- export DIT_LEARNING_RATE=2e-7
132
- export LLM_LEARNING_RATE=3e-7
133
- export BETA=1e-2
134
- export IMAGE_COLUMN=image
135
- export PROMPT_COLUMN=prompt
136
- export REPORT_TO=wandb
137
- ```
138
-
139
- Training logs sample source/edited images under:
140
-
141
- ```text
142
- outputs/rl/kontext/$RUN_NAME/training_samples/
143
- ```
144
-
145
- </details>
146
-
147
  ## Qualitative Results
148
 
149
  ### Text-to-Image Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  <p align="center">
2
  <img src="assets/logo.png" width="30%"><br>
3
  PromptRL
 
27
  # bash gen.sh
28
  ```
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## Qualitative Results
31
 
32
  ### Text-to-Image Generation
assets/rl_datasets/README.md DELETED
@@ -1,23 +0,0 @@
1
- # Edit Training Dataset Schema
2
-
3
- The training script defaults to:
4
-
5
- ```text
6
- https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet
7
- ```
8
-
9
- Use a `.parquet` or `.jsonl` file with at least:
10
-
11
- | Column | Type | Description |
12
- | --- | --- | --- |
13
- | `image` | PIL image, image bytes, or image path | Source image before editing. |
14
- | `prompt` | string | Edit instruction used by FLUX.1-Kontext and EditReward. |
15
-
16
- Optional columns:
17
-
18
- | Column | Type | Description |
19
- | --- | --- | --- |
20
- | `caption` | string | Source-image caption, kept for logging or downstream reward extensions. |
21
- | `target_caption` | string | Target edited-image caption, kept for logging or downstream reward extensions. |
22
-
23
- If your dataset uses different column names, pass `IMAGE_COLUMN=...` and `PROMPT_COLUMN=...` to `scripts/train/edit_grpo.sh`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/README.md DELETED
@@ -1,35 +0,0 @@
1
- # EditReward Scorer Service
2
-
3
- This service exposes EditReward over HTTP for edit GRPO training. It accepts a pickled payload with source images, edited images, and edit instructions, then returns `{"scores": [...]}`.
4
-
5
- ## Setup
6
-
7
- ```bash
8
- cd rewards_services/api_services/editreward_scorer_service
9
- python -m venv .venv
10
- source .venv/bin/activate
11
- pip install --upgrade pip
12
- pip install torch torchvision torchaudio
13
- pip install -r requirements.txt
14
- pip install flash-attn --no-build-isolation # optional, recommended when your CUDA/PyTorch build supports it
15
-
16
- git clone https://github.com/TIGER-AI-Lab/EditReward.git
17
- huggingface-cli download TIGER-Lab/EditReward-MiMo-VL-7B-SFT-2508 \
18
- --local-dir EditReward/EditReward-MiMo-VL-7B-SFT-2508
19
- ```
20
-
21
- If the repository or checkpoint lives elsewhere, set:
22
-
23
- ```bash
24
- export EDITREWARD_REPO_DIR=/path/to/EditReward
25
- export EDITREWARD_CHECKPOINT_PATH=/path/to/EditReward-MiMo-VL-7B-SFT-2508
26
- ```
27
-
28
- ## Run
29
-
30
- ```bash
31
- export EDITREWARD_PORT=18088
32
- export EDITREWARD_CUDA_DEVICES=0,1
33
- export EDITREWARD_WORKERS=2
34
- bash run.sh
35
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/app.py DELETED
@@ -1,94 +0,0 @@
1
- import os
2
- import pickle
3
- import traceback
4
- from io import BytesIO
5
- from typing import Any, Dict, List
6
-
7
- import torch
8
- from flask import Blueprint, Flask, request
9
- from PIL import Image
10
-
11
- from editreward_scorer import EditRewardScorer
12
-
13
-
14
- INFERENCE_FN = None
15
- root = Blueprint("root", __name__)
16
-
17
-
18
- def _deserialize_images(images_bytes: List[bytes]) -> List[Image.Image]:
19
- return [Image.open(BytesIO(data)).convert("RGB") for data in images_bytes]
20
-
21
-
22
- def _service_config() -> Dict[str, Any]:
23
- repo_dir = os.getenv("EDITREWARD_REPO_DIR", os.path.join(os.path.dirname(__file__), "EditReward"))
24
- return {
25
- "repo_dir": repo_dir,
26
- "config_path": os.getenv(
27
- "EDITREWARD_CONFIG_PATH",
28
- os.path.join(repo_dir, "EditReward", "config", "EditReward-MiMo-VL-7B-SFT-2508.yaml"),
29
- ),
30
- "checkpoint_path": os.getenv(
31
- "EDITREWARD_CHECKPOINT_PATH",
32
- os.path.join(repo_dir, "EditReward-MiMo-VL-7B-SFT-2508"),
33
- ),
34
- "reward_dim": os.getenv("EDITREWARD_DIM", "overall_detail"),
35
- "rm_head_type": os.getenv("EDITREWARD_HEAD_TYPE", "ranknet_multi_head"),
36
- }
37
-
38
-
39
- def create_app():
40
- global INFERENCE_FN
41
- config = _service_config()
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
- print(f"Loading EditReward scorer on {device} from {config['checkpoint_path']}...")
44
- INFERENCE_FN = EditRewardScorer(
45
- repo_dir=config["repo_dir"],
46
- config_path=config["config_path"],
47
- checkpoint_path=config["checkpoint_path"],
48
- reward_dim=config["reward_dim"],
49
- rm_head_type=config["rm_head_type"],
50
- device=device,
51
- )
52
- INFERENCE_FN.eval()
53
- print("EditReward scorer loaded.")
54
-
55
- app = Flask(__name__)
56
- app.register_blueprint(root)
57
- return app
58
-
59
-
60
- @root.route("/", methods=["GET"])
61
- def healthcheck():
62
- return {"status": "ok", "service": "editreward"}, 200
63
-
64
-
65
- @root.route("/", methods=["POST"])
66
- def inference():
67
- try:
68
- payload = pickle.loads(request.get_data())
69
- images = payload["images"]
70
- prompts = payload.get("prompts", [])
71
- source_images = _deserialize_images(images.get("source", []))
72
- edited_images = _deserialize_images(images.get("edited", []))
73
-
74
- if len(source_images) != len(edited_images) or len(source_images) != len(prompts):
75
- raise ValueError(
76
- "Mismatched EditReward inputs: "
77
- f"{len(source_images)} source images, {len(edited_images)} edited images, {len(prompts)} prompts."
78
- )
79
-
80
- with torch.no_grad():
81
- scores = INFERENCE_FN(prompts, source_images, edited_images)
82
- return pickle.dumps({"scores": [float(score) for score in scores]}), 200
83
- except Exception:
84
- error_message = traceback.format_exc()
85
- print(f"EditReward service error:\n{error_message}")
86
- return pickle.dumps({"error": error_message}), 500
87
-
88
-
89
- if __name__ == "__main__":
90
- port = int(os.getenv("EDITREWARD_PORT", "18088"))
91
- host = os.getenv("EDITREWARD_HOST", "127.0.0.1")
92
- app = create_app()
93
- app.run(host=host, port=port, debug=False)
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/editreward_scorer.py DELETED
@@ -1,65 +0,0 @@
1
- import os
2
- import shutil
3
- import sys
4
- import tempfile
5
- from typing import List
6
-
7
- import torch
8
- from PIL import Image
9
-
10
-
11
- class EditRewardScorer(torch.nn.Module):
12
- def __init__(
13
- self,
14
- repo_dir: str,
15
- config_path: str,
16
- checkpoint_path: str,
17
- reward_dim: str = "overall_detail",
18
- rm_head_type: str = "ranknet_multi_head",
19
- device: str = "cuda",
20
- ):
21
- super().__init__()
22
- if not os.path.isdir(repo_dir):
23
- raise FileNotFoundError(
24
- f"EditReward repository not found at {repo_dir}. "
25
- "Clone https://github.com/TIGER-AI-Lab/EditReward.git or set EDITREWARD_REPO_DIR."
26
- )
27
- sys.path.insert(0, repo_dir)
28
- from EditReward import EditRewardInferencer
29
-
30
- self.inferencer = EditRewardInferencer(
31
- config_path=config_path,
32
- checkpoint_path=checkpoint_path,
33
- device=device,
34
- reward_dim=reward_dim,
35
- rm_head_type=rm_head_type,
36
- )
37
- self.device = device
38
- self.eval()
39
-
40
- @torch.no_grad()
41
- def __call__(self, prompts: List[str], source_images: List[Image.Image], edited_images: List[Image.Image]):
42
- if not (len(prompts) == len(source_images) == len(edited_images)):
43
- raise ValueError("prompts, source_images, and edited_images must have the same length.")
44
-
45
- temp_dir = tempfile.mkdtemp(prefix="editreward_")
46
- try:
47
- source_paths = []
48
- edited_paths = []
49
- for index, (source_image, edited_image) in enumerate(zip(source_images, edited_images)):
50
- source_path = os.path.join(temp_dir, f"source_{index}.png")
51
- edited_path = os.path.join(temp_dir, f"edited_{index}.png")
52
- source_image.convert("RGB").save(source_path)
53
- edited_image.convert("RGB").save(edited_path)
54
- source_paths.append(source_path)
55
- edited_paths.append(edited_path)
56
-
57
- rewards = self.inferencer.reward(
58
- prompts=prompts,
59
- image_src=source_paths,
60
- image_paths=edited_paths,
61
- )
62
- return [reward[0].item() if hasattr(reward[0], "item") else float(reward[0]) for reward in rewards]
63
- finally:
64
- shutil.rmtree(temp_dir, ignore_errors=True)
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/gunicorn.conf.py DELETED
@@ -1,34 +0,0 @@
1
- import os
2
- import sys
3
-
4
-
5
- bind = f"{os.getenv('EDITREWARD_HOST', '127.0.0.1')}:{os.getenv('EDITREWARD_PORT', '18088')}"
6
- workers = int(os.getenv("EDITREWARD_WORKERS", os.getenv("EDITREWARD_NUM_DEVICES", "1")))
7
- worker_class = "sync"
8
- timeout = int(os.getenv("EDITREWARD_TIMEOUT", "600"))
9
-
10
- _raw_devices = os.getenv("EDITREWARD_CUDA_DEVICES") or os.getenv("CUDA_VISIBLE_DEVICES") or ""
11
- CUDA_DEVICES = [device.strip() for device in _raw_devices.split(",") if device.strip()]
12
- USED_DEVICES = set()
13
-
14
-
15
- def pre_fork(server, worker):
16
- if not CUDA_DEVICES:
17
- return
18
- available = [device for device in CUDA_DEVICES if device not in USED_DEVICES]
19
- worker.cuda_device = available[0] if available else CUDA_DEVICES[len(USED_DEVICES) % len(CUDA_DEVICES)]
20
- USED_DEVICES.add(worker.cuda_device)
21
- print(f"Worker {worker.pid} assigned CUDA_VISIBLE_DEVICES={worker.cuda_device}", file=sys.stderr)
22
-
23
-
24
- def post_fork(server, worker):
25
- cuda_device = getattr(worker, "cuda_device", None)
26
- if cuda_device is not None:
27
- os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
28
-
29
-
30
- def child_exit(server, worker):
31
- cuda_device = getattr(worker, "cuda_device", None)
32
- if cuda_device is not None:
33
- USED_DEVICES.discard(cuda_device)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/requirements.txt DELETED
@@ -1,18 +0,0 @@
1
- flask
2
- gunicorn
3
- datasets
4
- huggingface_hub
5
- pillow
6
- openai
7
- megfile
8
- sentencepiece
9
- deepspeed
10
- fire
11
- omegaconf
12
- matplotlib
13
- peft
14
- trl==0.8.6
15
- tensorboard
16
- scipy
17
- transformers==4.56.1
18
- accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewards_services/api_services/editreward_scorer_service/run.sh DELETED
@@ -1,13 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -euo pipefail
3
-
4
- SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
- cd "$SCRIPT_DIR"
6
-
7
- export EDITREWARD_REPO_DIR="${EDITREWARD_REPO_DIR:-$SCRIPT_DIR/EditReward}"
8
- export EDITREWARD_PORT="${EDITREWARD_PORT:-18088}"
9
- export EDITREWARD_HOST="${EDITREWARD_HOST:-127.0.0.1}"
10
- export EDITREWARD_CUDA_DEVICES="${EDITREWARD_CUDA_DEVICES:-0,1}"
11
- export EDITREWARD_WORKERS="${EDITREWARD_WORKERS:-${EDITREWARD_NUM_DEVICES:-2}}"
12
-
13
- python -m gunicorn -c gunicorn.conf.py "app:create_app()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train/deepspeed/zero3.json DELETED
@@ -1,39 +0,0 @@
1
- {
2
- "fp16": {
3
- "enabled": false,
4
- "loss_scale": 0,
5
- "loss_scale_window": 1000,
6
- "initial_scale_power": 16,
7
- "hysteresis": 2,
8
- "min_loss_scale": 1
9
- },
10
- "bf16": {
11
- "enabled": true
12
- },
13
- "zero_optimization": {
14
- "stage": 3,
15
- "offload_optimizer": {
16
- "device": "none",
17
- "pin_memory": true
18
- },
19
- "offload_param": {
20
- "device": "none",
21
- "pin_memory": true
22
- },
23
- "overlap_comm": true,
24
- "contiguous_gradients": true,
25
- "sub_group_size": 1000000000.0,
26
- "reduce_bucket_size": "auto",
27
- "stage3_prefetch_bucket_size": "auto",
28
- "stage3_param_persistence_threshold": "auto",
29
- "stage3_max_live_parameters": 1000000000.0,
30
- "stage3_max_reuse_distance": 1000000000.0,
31
- "stage3_gather_16bit_weights_on_model_save": true
32
- },
33
- "gradient_accumulation_steps": "auto",
34
- "steps_per_print": 100,
35
- "train_batch_size": "auto",
36
- "train_micro_batch_size_per_gpu": "auto",
37
- "wall_clock_breakdown": false
38
- }
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train/edit_grpo.sh DELETED
@@ -1,77 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -euo pipefail
3
-
4
- SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
- REPO_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)"
6
- cd "$REPO_DIR"
7
-
8
- : "${MODEL_NAME_OR_PATH:?Set MODEL_NAME_OR_PATH to a pretrained Qwen-Kontext checkpoint.}"
9
-
10
- export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-2,3,4,5,6,7}"
11
- NPROC_PER_NODE="${NPROC_PER_NODE:-6}"
12
- MASTER_ADDR="${MASTER_ADDR:-localhost}"
13
- MASTER_PORT="${MASTER_PORT:-25000}"
14
- RUN_NAME="${RUN_NAME:-qwenkontext-edit-grpo}"
15
- OUTPUT_DIR="${OUTPUT_DIR:-outputs/rl/kontext/$RUN_NAME}"
16
- DEEPSPEED_CONFIG="${DEEPSPEED_CONFIG:-scripts/train/deepspeed/zero3.json}"
17
- EDITREWARD_URL="${EDITREWARD_URL:-http://127.0.0.1:18088/}"
18
- PROMPTS_FILE="${PROMPTS_FILE:-https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet}"
19
- REPORT_TO="${REPORT_TO:-none}"
20
- if [[ -n "${WANDB_PROJECT:-}" && "$REPORT_TO" == "none" ]]; then
21
- REPORT_TO="wandb"
22
- fi
23
-
24
- TORCHRUN_ARGS=(
25
- --nproc_per_node="$NPROC_PER_NODE"
26
- --nnodes="${NNODES:-1}"
27
- --node_rank="${NODE_RANK:-0}"
28
- --master_addr="$MASTER_ADDR"
29
- --master_port="$MASTER_PORT"
30
- )
31
-
32
- TRAIN_ARGS=(
33
- -m unirl.train_edit
34
- --reward_funcs editreward format
35
- --deepspeed "$DEEPSPEED_CONFIG"
36
- --output_dir "$OUTPUT_DIR"
37
- --model_name_or_path "$MODEL_NAME_OR_PATH"
38
- --prompts_file "$PROMPTS_FILE"
39
- --image_column "${IMAGE_COLUMN:-image}"
40
- --prompt_column "${PROMPT_COLUMN:-prompt}"
41
- --editreward_url "$EDITREWARD_URL"
42
- --max_prompt_length "${MAX_PROMPT_LENGTH:-8192}"
43
- --max_completion_length "${MAX_COMPLETION_LENGTH:-512}"
44
- --num_generations "${NUM_GENERATIONS:-8}"
45
- --num_skip_refinement "${NUM_SKIP_REFINEMENT:-2}"
46
- --num_sde "${NUM_SDE:-4}"
47
- --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-1}"
48
- --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-1}"
49
- --logging_steps "${LOGGING_STEPS:-1}"
50
- --learning_rate "${LEARNING_RATE:-3e-7}"
51
- --bf16 "${BF16:-true}"
52
- --report_to "$REPORT_TO"
53
- --gradient_checkpointing "${GRADIENT_CHECKPOINTING:-true}"
54
- --attn_implementation "${ATTN_IMPLEMENTATION:-flash_attention_2}"
55
- --max_pixels "${MAX_PIXELS:-200704}"
56
- --min_pixels "${MIN_PIXELS:-200704}"
57
- --image_size "${IMAGE_SIZE:-512}"
58
- --save_total_limit "${SAVE_TOTAL_LIMIT:-4}"
59
- --save_strategy "${SAVE_STRATEGY:-steps}"
60
- --save_steps "${SAVE_STEPS:-100}"
61
- --beta "${BETA:-1e-2}"
62
- --num_train_epochs "${NUM_TRAIN_EPOCHS:-10}"
63
- --run_name "$RUN_NAME"
64
- )
65
-
66
- if [[ -n "${DATASET_CACHE_DIR:-}" ]]; then
67
- TRAIN_ARGS+=(--dataset_cache_dir "$DATASET_CACHE_DIR")
68
- fi
69
-
70
- export PROMPTRL_EDIT_GUIDANCE_SCALE="${PROMPTRL_EDIT_GUIDANCE_SCALE:-${EDIT_GUIDANCE_SCALE:-2.5}}"
71
- export PROMPTRL_EDIT_NUM_INFERENCE_STEPS="${PROMPTRL_EDIT_NUM_INFERENCE_STEPS:-${EDIT_NUM_INFERENCE_STEPS:-8}}"
72
- export PROMPTRL_EDIT_HEIGHT="${PROMPTRL_EDIT_HEIGHT:-${EDIT_HEIGHT:-1024}}"
73
- export PROMPTRL_EDIT_WIDTH="${PROMPTRL_EDIT_WIDTH:-${EDIT_WIDTH:-1024}}"
74
- export DIT_LEARNING_RATE="${DIT_LEARNING_RATE:-2e-7}"
75
- export LLM_LEARNING_RATE="${LLM_LEARNING_RATE:-3e-7}"
76
-
77
- torchrun "${TORCHRUN_ARGS[@]}" "${TRAIN_ARGS[@]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unimodel/qwenkontext/fluxkontext_pipeline.py CHANGED
@@ -14,10 +14,9 @@
14
  # limitations under the License.
15
 
16
  import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Union, Tuple
18
 
19
  import numpy as np
20
- import math
21
  import torch
22
  from transformers import (
23
  CLIPImageProcessor,
@@ -1160,566 +1159,3 @@ class FluxKontextPipeline(
1160
  return (image,)
1161
 
1162
  return FluxPipelineOutput(images=image)
1163
-
1164
-
1165
-
1166
-
1167
- # This method should be added to the FluxKontextPipeline class
1168
- def sde_sampling(
1169
- self,
1170
- image: Optional[PipelineImageInput] = None,
1171
- prompt: Union[str, List[str]] = None,
1172
- prompt_2: Optional[Union[str, List[str]]] = None,
1173
- negative_prompt: Union[str, List[str]] = None,
1174
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1175
- true_cfg_scale: float = 1.0,
1176
- height: Optional[int] = None,
1177
- width: Optional[int] = None,
1178
- num_inference_steps: int = 28,
1179
- sigmas: Optional[List[float]] = None,
1180
- guidance_scale: float = 3.5,
1181
- num_images_per_prompt: Optional[int] = 1,
1182
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1183
- latents: Optional[torch.FloatTensor] = None,
1184
- prompt_embeds: Optional[torch.FloatTensor] = None,
1185
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1186
- ip_adapter_image: Optional[PipelineImageInput] = None,
1187
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1188
- negative_ip_adapter_image: Optional[PipelineImageInput] = None,
1189
- negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1190
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1191
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1192
- output_type: Optional[str] = "pil",
1193
- return_dict: bool = True,
1194
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1195
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1196
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1197
- max_sequence_length: int = 512,
1198
- max_area: int = 1024**2,
1199
- num_sde: int = None,
1200
- _auto_resize: bool = True,
1201
- ):
1202
- r"""
1203
- SDE sampling function for FLUX Kontext pipeline with log probability tracking.
1204
-
1205
- This method performs stochastic differential equation (SDE) based sampling while
1206
- tracking log probabilities at each step. Useful for training and analysis purposes.
1207
-
1208
- Args:
1209
- image: Input image for image-to-image generation
1210
- prompt: Text prompt(s) to guide generation
1211
- prompt_2: Secondary text prompt for text_encoder_2
1212
- negative_prompt: Negative text prompt(s)
1213
- negative_prompt_2: Secondary negative prompt
1214
- true_cfg_scale: Classifier-free guidance scale (when > 1.0)
1215
- height: Output height in pixels
1216
- width: Output width in pixels
1217
- num_inference_steps: Number of denoising steps
1218
- sigmas: Custom sigma schedule
1219
- guidance_scale: Embedded guidance scale
1220
- num_images_per_prompt: Number of images per prompt
1221
- generator: Random number generator(s)
1222
- latents: Pre-generated latents
1223
- prompt_embeds: Pre-generated prompt embeddings
1224
- pooled_prompt_embeds: Pre-generated pooled embeddings
1225
- ip_adapter_image: IP-Adapter input image(s)
1226
- ip_adapter_image_embeds: Pre-generated IP-Adapter embeddings
1227
- negative_ip_adapter_image: Negative IP-Adapter image(s)
1228
- negative_ip_adapter_image_embeds: Negative IP-Adapter embeddings
1229
- negative_prompt_embeds: Pre-generated negative embeddings
1230
- negative_pooled_prompt_embeds: Pre-generated negative pooled embeddings
1231
- output_type: Output format ("pil" or "latent")
1232
- return_dict: Whether to return dict or tuple
1233
- joint_attention_kwargs: Additional attention parameters
1234
- callback_on_step_end: Callback function after each step
1235
- callback_on_step_end_tensor_inputs: Tensors to pass to callback
1236
- max_sequence_length: Maximum prompt sequence length
1237
- max_area: Maximum output area in pixels
1238
- _auto_resize: Whether to auto-resize to preferred resolutions
1239
-
1240
- Returns:
1241
- Tuple of (images, prev_latents, log_probs, pred_latents, timesteps, batched_states)
1242
- """
1243
-
1244
- height = height or self.default_sample_size * self.vae_scale_factor
1245
- width = width or self.default_sample_size * self.vae_scale_factor
1246
-
1247
- original_height, original_width = height, width
1248
- aspect_ratio = width / height
1249
-
1250
- width = round((max_area * aspect_ratio) ** 0.5)
1251
- height = round((max_area / aspect_ratio) ** 0.5)
1252
-
1253
- multiple_of = self.vae_scale_factor * 2
1254
- width = width // multiple_of * multiple_of
1255
- height = height // multiple_of * multiple_of
1256
-
1257
- if height != original_height or width != original_width:
1258
- logger.warning(
1259
- f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
1260
- )
1261
-
1262
- # 1. Check inputs
1263
- self.check_inputs(
1264
- prompt,
1265
- prompt_2,
1266
- height,
1267
- width,
1268
- negative_prompt=negative_prompt,
1269
- negative_prompt_2=negative_prompt_2,
1270
- prompt_embeds=prompt_embeds,
1271
- negative_prompt_embeds=negative_prompt_embeds,
1272
- pooled_prompt_embeds=pooled_prompt_embeds,
1273
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1274
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1275
- max_sequence_length=max_sequence_length,
1276
- )
1277
-
1278
- self._guidance_scale = guidance_scale
1279
- self._joint_attention_kwargs = joint_attention_kwargs
1280
- self._current_timestep = None
1281
- self._interrupt = False
1282
-
1283
- # 2. Define call parameters
1284
- if prompt is not None and isinstance(prompt, str):
1285
- batch_size = 1
1286
- elif prompt is not None and isinstance(prompt, list):
1287
- batch_size = len(prompt)
1288
- else:
1289
- batch_size = prompt_embeds.shape[0]
1290
-
1291
- device = self._execution_device
1292
-
1293
- lora_scale = (
1294
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1295
- )
1296
- has_neg_prompt = negative_prompt is not None or (
1297
- negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
1298
- )
1299
- do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
1300
-
1301
- # Encode prompts
1302
- (
1303
- prompt_embeds,
1304
- pooled_prompt_embeds,
1305
- text_ids,
1306
- ) = self.encode_prompt(
1307
- prompt=prompt,
1308
- prompt_2=prompt_2,
1309
- prompt_embeds=prompt_embeds,
1310
- pooled_prompt_embeds=pooled_prompt_embeds,
1311
- device=device,
1312
- num_images_per_prompt=num_images_per_prompt,
1313
- max_sequence_length=max_sequence_length,
1314
- lora_scale=lora_scale,
1315
- )
1316
-
1317
- if do_true_cfg:
1318
- (
1319
- negative_prompt_embeds,
1320
- negative_pooled_prompt_embeds,
1321
- negative_text_ids,
1322
- ) = self.encode_prompt(
1323
- prompt=negative_prompt,
1324
- prompt_2=negative_prompt_2,
1325
- prompt_embeds=negative_prompt_embeds,
1326
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
1327
- device=device,
1328
- num_images_per_prompt=num_images_per_prompt,
1329
- max_sequence_length=max_sequence_length,
1330
- lora_scale=lora_scale,
1331
- )
1332
-
1333
- # 3. Preprocess image
1334
- if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
1335
- from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
1336
-
1337
- img = image[0] if isinstance(image, list) else image
1338
- image_height, image_width = self.image_processor.get_default_height_width(img)
1339
- aspect_ratio = image_width / image_height
1340
- if _auto_resize:
1341
- _, image_width, image_height = min(
1342
- (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
1343
- )
1344
- image_width = image_width // multiple_of * multiple_of
1345
- image_height = image_height // multiple_of * multiple_of
1346
- image = self.image_processor.resize(image, image_height, image_width)
1347
- image = self.image_processor.preprocess(image, image_height, image_width)
1348
-
1349
- # 4. Prepare latent variables
1350
- num_channels_latents = self.transformer.config.in_channels // 4
1351
- latents, image_latents, latent_ids, image_ids = self.prepare_latents(
1352
- image,
1353
- batch_size * num_images_per_prompt,
1354
- num_channels_latents,
1355
- height,
1356
- width,
1357
- prompt_embeds.dtype,
1358
- device,
1359
- generator,
1360
- latents,
1361
- )
1362
-
1363
- if image_ids is not None:
1364
- latent_ids = torch.cat([latent_ids, image_ids], dim=0)
1365
-
1366
- # 5. Prepare timesteps
1367
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1368
- image_seq_len = latents.shape[1]
1369
-
1370
- from diffusers.pipelines.flux.pipeline_flux_kontext import calculate_shift, retrieve_timesteps
1371
-
1372
- mu = calculate_shift(
1373
- image_seq_len,
1374
- self.scheduler.config.get("base_image_seq_len", 256),
1375
- self.scheduler.config.get("max_image_seq_len", 4096),
1376
- self.scheduler.config.get("base_shift", 0.5),
1377
- self.scheduler.config.get("max_shift", 1.15),
1378
- )
1379
- timesteps, num_inference_steps = retrieve_timesteps(
1380
- self.scheduler,
1381
- num_inference_steps,
1382
- device,
1383
- sigmas=sigmas,
1384
- mu=mu,
1385
- )
1386
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1387
- self._num_timesteps = len(timesteps)
1388
-
1389
- # Handle guidance
1390
- if self.transformer.config.guidance_embeds:
1391
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1392
- guidance = guidance.expand(latents.shape[0])
1393
- else:
1394
- guidance = None
1395
-
1396
- # Handle IP-Adapter images
1397
- if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1398
- negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1399
- ):
1400
- negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1401
- negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1402
-
1403
- elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1404
- negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1405
- ):
1406
- ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1407
- ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1408
-
1409
- if self.joint_attention_kwargs is None:
1410
- self._joint_attention_kwargs = {}
1411
-
1412
- image_embeds = None
1413
- negative_image_embeds = None
1414
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1415
- image_embeds = self.prepare_ip_adapter_image_embeds(
1416
- ip_adapter_image,
1417
- ip_adapter_image_embeds,
1418
- device,
1419
- batch_size * num_images_per_prompt,
1420
- )
1421
- if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1422
- negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1423
- negative_ip_adapter_image,
1424
- negative_ip_adapter_image_embeds,
1425
- device,
1426
- batch_size * num_images_per_prompt,
1427
- )
1428
-
1429
- # 6. SDE Denoising loop with state tracking
1430
- prev_latents = []
1431
- pred_latents = []
1432
- states = {
1433
- "timestep": [],
1434
- "guidance": [],
1435
- "pooled_projections": [],
1436
- "encoder_hidden_states": [],
1437
- "txt_ids": None,
1438
- "img_ids": None,
1439
- }
1440
- log_probs = []
1441
- ts = []
1442
-
1443
- states["txt_ids"] = text_ids if text_ids is not None else None
1444
- states["img_ids"] = latent_ids if latent_ids is not None else None
1445
-
1446
- if num_sde is None:
1447
- num_sde = num_inference_steps
1448
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1449
- for i, t in enumerate(timesteps):
1450
- if self.interrupt:
1451
- continue
1452
-
1453
- self._current_timestep = t
1454
- if image_embeds is not None:
1455
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1456
-
1457
- # Prepare model input
1458
- latent_model_input = latents
1459
- if image_latents is not None:
1460
- latent_model_input = torch.cat([latents, image_latents], dim=1)
1461
-
1462
- timestep = (t.expand(latents.shape[0]) / 1000.).to(latents.dtype)
1463
-
1464
-
1465
- if i < num_sde:
1466
- # Store states
1467
- states["timestep"].append(timestep.unsqueeze(1))
1468
- states["guidance"].append(guidance.unsqueeze(1) if torch.is_tensor(guidance) else guidance)
1469
- states["pooled_projections"].append(pooled_prompt_embeds.unsqueeze(1) if pooled_prompt_embeds is not None else None)
1470
- states["encoder_hidden_states"].append(prompt_embeds.unsqueeze(1) if prompt_embeds is not None else None)
1471
-
1472
- ts.append(t.expand(latents.shape[0]).unsqueeze(1))
1473
- # prev_latents.append(latents.detach().clone().unsqueeze(1))
1474
- prev_latents.append(latent_model_input.detach().clone().unsqueeze(1))
1475
-
1476
- # Forward pass
1477
- noise_pred = self.transformer(
1478
- hidden_states=latent_model_input,
1479
- timestep=timestep,
1480
- guidance=guidance,
1481
- pooled_projections=pooled_prompt_embeds,
1482
- encoder_hidden_states=prompt_embeds,
1483
- txt_ids=text_ids,
1484
- img_ids=latent_ids,
1485
- joint_attention_kwargs=self.joint_attention_kwargs,
1486
- return_dict=False,
1487
- )[0]
1488
- noise_pred = noise_pred[:, :latents.size(1)]
1489
-
1490
- # Apply true CFG if needed
1491
- if do_true_cfg:
1492
- if negative_image_embeds is not None:
1493
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1494
-
1495
- neg_latent_model_input = latents
1496
- if image_latents is not None:
1497
- neg_latent_model_input = torch.cat([latents, image_latents], dim=1)
1498
-
1499
- neg_noise_pred = self.transformer(
1500
- hidden_states=neg_latent_model_input,
1501
- timestep=timestep,
1502
- guidance=guidance,
1503
- pooled_projections=negative_pooled_prompt_embeds,
1504
- encoder_hidden_states=negative_prompt_embeds,
1505
- txt_ids=negative_text_ids,
1506
- img_ids=latent_ids,
1507
- joint_attention_kwargs=self.joint_attention_kwargs,
1508
- return_dict=False,
1509
- )[0]
1510
- neg_noise_pred = neg_noise_pred[:, :latents.size(1)]
1511
- noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1512
-
1513
- if i < num_sde:
1514
- # SDE step with log probability
1515
- latents_dtype = latents.dtype
1516
- latents, log_prob, prev_latents_mean, std_dev = sde_step_with_logprob(
1517
- self.scheduler,
1518
- noise_pred.float(),
1519
- t.expand(latents.shape[0]),
1520
- latents.float()
1521
- )
1522
-
1523
- log_probs.append(log_prob.detach().clone().unsqueeze(1))
1524
- pred_latents.append(latents.detach().clone().unsqueeze(1))
1525
-
1526
- else:
1527
- # Standard scheduler step
1528
- latents_dtype = latents.dtype
1529
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1530
-
1531
-
1532
- if latents.dtype != latents_dtype:
1533
- latents = latents.to(latents_dtype)
1534
-
1535
- if callback_on_step_end is not None:
1536
- callback_kwargs = {}
1537
- for k in callback_on_step_end_tensor_inputs:
1538
- callback_kwargs[k] = locals()[k]
1539
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1540
- latents = callback_outputs.pop("latents", latents)
1541
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1542
-
1543
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1544
- progress_bar.update()
1545
-
1546
- if XLA_AVAILABLE:
1547
- xm.mark_step()
1548
-
1549
- self._current_timestep = None
1550
-
1551
- # Decode latents to images
1552
- if output_type == "latent":
1553
- image = latents
1554
- else:
1555
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1556
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1557
- image = self.vae.decode(latents, return_dict=False)[0]
1558
- image = self.image_processor.postprocess(image, output_type=output_type)
1559
-
1560
- # Batch states for output
1561
- batched_states = {}
1562
- for key, value_list in states.items():
1563
- if value_list is None or len(value_list) == 0:
1564
- batched_states[key] = None
1565
- continue
1566
- if isinstance(value_list, list) and value_list[0] is None:
1567
- batched_states[key] = None
1568
- continue
1569
- if isinstance(value_list, list):
1570
- concatenated = torch.cat(value_list, dim=1)
1571
- if len(concatenated.shape) <= 2:
1572
- batched_states[key] = concatenated.view(-1)
1573
- else:
1574
- batched_states[key] = concatenated.view(-1, *concatenated.shape[2:])
1575
- else:
1576
- batched_states[key] = value_list
1577
-
1578
- # Reshape outputs
1579
- prev_latents = torch.cat(prev_latents, dim=1)
1580
- log_probs = torch.cat(log_probs, dim=1)
1581
- pred_latents = torch.cat(pred_latents, dim=1)
1582
- ts = torch.cat(ts, dim=1)
1583
-
1584
- prev_latents = prev_latents.view(prev_latents.shape[0] * prev_latents.shape[1], *prev_latents.shape[2:])
1585
- log_probs = log_probs.view(log_probs.shape[0] * log_probs.shape[1], *log_probs.shape[2:])
1586
- pred_latents = pred_latents.view(pred_latents.shape[0] * pred_latents.shape[1], *pred_latents.shape[2:])
1587
- ts = ts.view(-1)
1588
-
1589
- # Offload models
1590
- self.maybe_free_model_hooks()
1591
-
1592
- return (image, prev_latents, log_probs, pred_latents, ts, batched_states)
1593
-
1594
-
1595
- def sde_step_with_logprob(
1596
- scheduler: FlowMatchEulerDiscreteScheduler,
1597
- model_output: torch.FloatTensor,
1598
- timestep: Union[float, torch.FloatTensor],
1599
- sample: torch.FloatTensor,
1600
- prev_sample: Optional[torch.FloatTensor] = None,
1601
- generator: Optional[torch.Generator] = None,
1602
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1603
- """
1604
- Predict the sample from the previous timestep by reversing the SDE with log probability tracking.
1605
-
1606
- Args:
1607
- scheduler: The FlowMatchEulerDiscreteScheduler instance
1608
- model_output: The direct output from learned flow model
1609
- timestep: The current discrete timestep in the diffusion chain
1610
- sample: A current instance of a sample created by the diffusion process
1611
- prev_sample: Optional pre-computed previous sample
1612
- generator: A random number generator
1613
-
1614
- Returns:
1615
- Tuple of (prev_sample, log_prob, prev_sample_mean, std_dev)
1616
- """
1617
- step_index = [scheduler.index_for_timestep(t) for t in timestep]
1618
- prev_step_index = [step + 1 for step in step_index]
1619
- sigma = scheduler.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
1620
- sigma_prev = scheduler.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
1621
- sigma_max = scheduler.sigmas[1].item()
1622
- dt = sigma_prev - sigma
1623
-
1624
- std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * 0.8
1625
-
1626
- # SDE formulation
1627
- prev_sample_mean = (
1628
- sample * (1 + std_dev_t**2 / (2 * sigma) * dt) +
1629
- model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
1630
- )
1631
-
1632
- if prev_sample is not None and generator is not None:
1633
- raise ValueError(
1634
- "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
1635
- " `prev_sample` stays `None`."
1636
- )
1637
-
1638
- if prev_sample is None:
1639
- variance_noise = randn_tensor(
1640
- model_output.shape,
1641
- generator=generator,
1642
- device=model_output.device,
1643
- dtype=model_output.dtype,
1644
- )
1645
- prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
1646
-
1647
- # Calculate log probability
1648
- variance = (std_dev_t * torch.sqrt(-1 * dt)) ** 2
1649
- log_prob = (
1650
- -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance)
1651
- - torch.log(torch.sqrt(variance))
1652
- - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
1653
- )
1654
-
1655
- # Mean along all but batch dimension
1656
- log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
1657
-
1658
- return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1 * dt)
1659
-
1660
-
1661
- def sde_step_with_logprob_simple(
1662
- scheduler: FlowMatchEulerDiscreteScheduler,
1663
- model_output: torch.FloatTensor,
1664
- timestep: Union[float, torch.FloatTensor],
1665
- sample: torch.FloatTensor,
1666
- prev_sample: Optional[torch.FloatTensor] = None,
1667
- generator: Optional[torch.Generator] = None,
1668
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1669
- """
1670
- Simplified SDE step with log probability tracking using eta parameter.
1671
-
1672
- Args:
1673
- scheduler: The FlowMatchEulerDiscreteScheduler instance
1674
- model_output: The direct output from learned flow model
1675
- timestep: The current discrete timestep in the diffusion chain
1676
- sample: A current instance of a sample created by the diffusion process
1677
- prev_sample: Optional pre-computed previous sample
1678
- generator: A random number generator
1679
-
1680
- Returns:
1681
- Tuple of (prev_sample, log_prob, prev_sample_mean, std_dev)
1682
- """
1683
- step_index = [scheduler.index_for_timestep(t) for t in timestep]
1684
- prev_step_index = [step + 1 for step in step_index]
1685
- sigma = scheduler.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
1686
- sigma_prev = scheduler.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
1687
- sigma_max = scheduler.sigmas[1].item()
1688
- dt = sigma_prev - sigma
1689
-
1690
- eta = 0.5
1691
- Dt = -dt * eta
1692
-
1693
- prev_sample_mean = (
1694
- sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) +
1695
- model_output * (dt - Dt)
1696
- )
1697
-
1698
- std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
1699
-
1700
- if prev_sample is not None and generator is not None:
1701
- raise ValueError(
1702
- "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
1703
- " `prev_sample` stays `None`."
1704
- )
1705
-
1706
- if prev_sample is None:
1707
- variance_noise = randn_tensor(
1708
- model_output.shape,
1709
- generator=generator,
1710
- device=model_output.device,
1711
- dtype=model_output.dtype,
1712
- )
1713
- prev_sample = prev_sample_mean + std_dev_t * variance_noise
1714
-
1715
- # Calculate log probability
1716
- log_prob = (
1717
- -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
1718
- - torch.log(std_dev_t)
1719
- - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
1720
- )
1721
-
1722
- # Mean along all but batch dimension
1723
- log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
1724
-
1725
- return prev_sample, log_prob, prev_sample_mean, std_dev_t
 
14
  # limitations under the License.
15
 
16
  import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
 
19
  import numpy as np
 
20
  import torch
21
  from transformers import (
22
  CLIPImageProcessor,
 
1159
  return (image,)
1160
 
1161
  return FluxPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unirl/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- """PromptRL training package."""
2
-
 
 
 
unirl/reward_evaluator/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .reward_evaluator import RewardEvaluatorClient
2
-
3
- __all__ = ["RewardEvaluatorClient"]
4
-
 
 
 
 
 
unirl/reward_evaluator/reward_evaluator.py DELETED
@@ -1,71 +0,0 @@
1
- import pickle
2
- from io import BytesIO
3
- from typing import Any, Dict, List, Mapping, Optional, Union
4
-
5
- import requests
6
- from PIL import Image
7
-
8
-
9
- DEFAULT_EDITREWARD_URL = "http://127.0.0.1:18088/"
10
-
11
-
12
- def _serialize_image(image: Image.Image) -> bytes:
13
- buffer = BytesIO()
14
- if image.mode != "RGB":
15
- image = image.convert("RGB")
16
- image.save(buffer, format="PNG")
17
- return buffer.getvalue()
18
-
19
-
20
- def _serialize_images(
21
- images: Union[List[Image.Image], Mapping[str, List[Image.Image]]],
22
- ) -> Union[List[bytes], Dict[str, List[bytes]]]:
23
- if isinstance(images, Mapping):
24
- return {key: [_serialize_image(image) for image in value] for key, value in images.items()}
25
- return [_serialize_image(image) for image in images]
26
-
27
-
28
- def _create_payload(
29
- images: Union[List[Image.Image], Mapping[str, List[Image.Image]]],
30
- prompts: List[str],
31
- metadata: Optional[Dict[str, Any]] = None,
32
- ) -> bytes:
33
- return pickle.dumps(
34
- {
35
- "images": _serialize_images(images),
36
- "prompts": prompts,
37
- "metadata": metadata or {},
38
- }
39
- )
40
-
41
-
42
- class RewardEvaluatorClient:
43
- """HTTP client for the EditReward scorer service."""
44
-
45
- def __init__(self, editreward_url: str = DEFAULT_EDITREWARD_URL, timeout: int = 600):
46
- self.editreward_url = editreward_url
47
- self.timeout = timeout
48
-
49
- def evaluate_editreward(
50
- self,
51
- source_images: List[Image.Image],
52
- edited_images: List[Image.Image],
53
- prompts: List[str],
54
- ) -> Dict[str, Any]:
55
- if not (len(source_images) == len(edited_images) == len(prompts)):
56
- raise ValueError(
57
- "EditReward inputs must have equal lengths: "
58
- f"{len(source_images)} source images, {len(edited_images)} edited images, {len(prompts)} prompts."
59
- )
60
-
61
- payload = _create_payload(
62
- {"source": source_images, "edited": edited_images},
63
- prompts,
64
- )
65
- response = requests.post(self.editreward_url, data=payload, timeout=self.timeout)
66
- response.raise_for_status()
67
- result = pickle.loads(response.content)
68
- if isinstance(result, dict) and "error" in result:
69
- raise RuntimeError(f"EditReward service returned an error: {result['error']}")
70
- return result
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unirl/train_edit.py DELETED
@@ -1,265 +0,0 @@
1
- import json
2
- import os
3
- import re
4
- from dataclasses import dataclass, field
5
- from io import BytesIO
6
- from typing import Any, Dict, List, Optional
7
-
8
- from datasets import load_dataset
9
- from PIL import Image, ImageOps
10
- from torch.utils.data import Dataset
11
- from transformers.trainer_utils import get_last_checkpoint
12
- from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config
13
-
14
- from .reward_evaluator import RewardEvaluatorClient
15
- from .trainer import QwenKontextEditGRPOTrainer
16
-
17
-
18
- DEFAULT_EDIT_DATASET = "https://huggingface.co/wangfuyun/PrompRL/resolve/main/data/omni_edit_train_50k.parquet"
19
-
20
- EDIT_QUESTION_TEMPLATE = """Please provide an enhanced prompt for the following image editing prompt.
21
- Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
22
- Original prompt: {Question}. Directly provide the improved prompt in <answer> </answer> tags."""
23
-
24
-
25
- @dataclass
26
- class EditGRPOScriptArguments(ScriptArguments):
27
- reward_funcs: List[str] = field(
28
- default_factory=lambda: ["editreward", "format"],
29
- metadata={"help": "Reward functions to use. Edit training supports only: editreward, format."},
30
- )
31
- prompts_file: str = field(
32
- default=DEFAULT_EDIT_DATASET,
33
- metadata={"help": "Path or URL to a .parquet or .jsonl edit-training dataset."},
34
- )
35
- image_column: str = field(default="image", metadata={"help": "Dataset column containing the source image."})
36
- prompt_column: str = field(default="prompt", metadata={"help": "Dataset column containing the edit instruction."})
37
- caption_column: Optional[str] = field(default="caption", metadata={"help": "Optional source-caption column."})
38
- target_caption_column: Optional[str] = field(
39
- default="target_caption",
40
- metadata={"help": "Optional target-caption column."},
41
- )
42
- image_size: int = field(default=512, metadata={"help": "Center-cropped source image size used for editing."})
43
- dataset_cache_dir: Optional[str] = field(
44
- default=None,
45
- metadata={"help": "Optional Hugging Face datasets cache dir. Defaults to HF_DATASETS_CACHE."},
46
- )
47
- editreward_url: str = field(
48
- default="http://127.0.0.1:18088/",
49
- metadata={"help": "HTTP URL of the EditReward scorer service."},
50
- )
51
- processor_name_or_path: str = field(
52
- default="Qwen/Qwen2.5-VL-3B-Instruct",
53
- metadata={"help": "Processor used for Qwen2.5-VL chat formatting and image preprocessing."},
54
- )
55
- max_pixels: int = field(default=200704, metadata={"help": "Maximum pixels passed to the Qwen-VL processor."})
56
- min_pixels: int = field(default=200704, metadata={"help": "Minimum pixels passed to the Qwen-VL processor."})
57
- num_skip_refinement: int = field(
58
- default=2,
59
- metadata={"help": "Generations per input that use the original edit prompt instead of a Qwen-refined prompt."},
60
- )
61
- num_sde: int = field(
62
- default=4,
63
- metadata={"help": "Number of FLUX denoising steps sampled with SDE log-prob tracking for diffusion GRPO."},
64
- )
65
-
66
-
67
- class EditPromptDataset(Dataset):
68
- """Loads image-edit instructions from parquet or jsonl files."""
69
-
70
- def __init__(
71
- self,
72
- prompts_file: str,
73
- question_template: str,
74
- image_column: str = "image",
75
- prompt_column: str = "prompt",
76
- caption_column: Optional[str] = "caption",
77
- target_caption_column: Optional[str] = "target_caption",
78
- image_size: int = 512,
79
- cache_dir: Optional[str] = None,
80
- ):
81
- self.prompts_file = normalize_data_path(prompts_file)
82
- self.question_template = question_template
83
- self.image_column = image_column
84
- self.prompt_column = prompt_column
85
- self.caption_column = caption_column
86
- self.target_caption_column = target_caption_column
87
- self.image_size = image_size
88
- self.base_dir = (
89
- os.path.dirname(os.path.abspath(self.prompts_file))
90
- if not is_remote_path(self.prompts_file)
91
- else os.getcwd()
92
- )
93
-
94
- if self.prompts_file.endswith(".parquet"):
95
- self.records = load_dataset(
96
- "parquet",
97
- data_files={"train": self.prompts_file},
98
- split="train",
99
- cache_dir=cache_dir or os.getenv("HF_DATASETS_CACHE"),
100
- )
101
- elif self.prompts_file.endswith(".jsonl") or self.prompts_file.endswith(".json"):
102
- if is_remote_path(self.prompts_file):
103
- self.records = load_dataset(
104
- "json",
105
- data_files={"train": self.prompts_file},
106
- split="train",
107
- cache_dir=cache_dir or os.getenv("HF_DATASETS_CACHE"),
108
- )
109
- else:
110
- with open(self.prompts_file, "r", encoding="utf-8") as file:
111
- self.records = [json.loads(line) for line in file if line.strip()]
112
- else:
113
- raise ValueError("Edit training datasets must be .parquet or .jsonl files.")
114
-
115
- if len(self.records) == 0:
116
- raise ValueError(f"No training records found in {prompts_file}.")
117
-
118
- def __len__(self) -> int:
119
- return len(self.records)
120
-
121
- def __getitem__(self, index: int) -> Dict[str, Any]:
122
- item = self.records[index]
123
- instruction = self._read_text(item, self.prompt_column)
124
- image = self._read_image(item, self.image_column)
125
- formatted_prompt = self.question_template.format(Question=instruction)
126
-
127
- return {
128
- "image": image,
129
- "caption": self._read_optional_text(item, self.caption_column),
130
- "target_caption": self._read_optional_text(item, self.target_caption_column),
131
- "editing_instruction": instruction,
132
- "prompt": [
133
- {
134
- "role": "user",
135
- "content": [
136
- {"type": "image"},
137
- {"type": "text", "text": formatted_prompt},
138
- ],
139
- }
140
- ],
141
- }
142
-
143
- def _read_text(self, item: Dict[str, Any], column: str) -> str:
144
- if column not in item:
145
- raise KeyError(f"Missing required column '{column}' in {self.prompts_file}.")
146
- value = item[column]
147
- if isinstance(value, list):
148
- value = value[-1] if value else ""
149
- if value is None or not str(value).strip():
150
- raise ValueError(f"Empty edit instruction in column '{column}'.")
151
- return str(value).strip()
152
-
153
- def _read_optional_text(self, item: Dict[str, Any], column: Optional[str]) -> str:
154
- if not column or column not in item or item[column] is None:
155
- return ""
156
- value = item[column]
157
- if isinstance(value, list):
158
- value = value[-1] if value else ""
159
- return str(value).strip()
160
-
161
- def _read_image(self, item: Dict[str, Any], column: str) -> Image.Image:
162
- if column not in item:
163
- raise KeyError(f"Missing required image column '{column}' in {self.prompts_file}.")
164
- image = self._coerce_image(item[column])
165
- if self.image_size > 0:
166
- image = ImageOps.fit(image, (self.image_size, self.image_size), method=Image.Resampling.BICUBIC)
167
- return image.convert("RGB")
168
-
169
- def _coerce_image(self, value: Any) -> Image.Image:
170
- if isinstance(value, Image.Image):
171
- return value.convert("RGB")
172
- if isinstance(value, str):
173
- image_path = value if os.path.isabs(value) else os.path.join(self.base_dir, value)
174
- return Image.open(image_path).convert("RGB")
175
- if isinstance(value, bytes):
176
- return Image.open(BytesIO(value)).convert("RGB")
177
- if isinstance(value, dict):
178
- if value.get("bytes") is not None:
179
- return Image.open(BytesIO(value["bytes"])).convert("RGB")
180
- if value.get("path") is not None:
181
- image_path = value["path"] if os.path.isabs(value["path"]) else os.path.join(self.base_dir, value["path"])
182
- return Image.open(image_path).convert("RGB")
183
- raise TypeError(f"Unsupported image value type: {type(value)!r}")
184
-
185
-
186
- def is_remote_path(path: str) -> bool:
187
- return path.startswith(("http://", "https://", "hf://"))
188
-
189
-
190
- def normalize_data_path(path: str) -> str:
191
- if path.startswith("hf://"):
192
- parts = path[len("hf://") :].split("/", 2)
193
- if len(parts) != 3:
194
- raise ValueError("hf:// dataset paths must look like hf://owner/repo/path/to/file.parquet")
195
- repo_id = f"{parts[0]}/{parts[1]}"
196
- file_path = parts[2]
197
- return f"https://huggingface.co/{repo_id}/resolve/main/{file_path}"
198
-
199
- if "huggingface.co/" in path and "/blob/" in path:
200
- return path.replace("/blob/", "/resolve/", 1)
201
- return path
202
-
203
-
204
- def format_reward(completions: List[str]) -> List[float]:
205
- pattern = re.compile(r"<answer>.*?</answer>", re.DOTALL)
206
- return [1.0 if pattern.search(completion) else 0.0 for completion in completions]
207
-
208
-
209
- def build_editreward_func(editreward_url: str):
210
- reward_client = RewardEvaluatorClient(editreward_url=editreward_url)
211
-
212
- def editreward(source_images, edited_images, prompts):
213
- return reward_client.evaluate_editreward(source_images, edited_images, prompts)
214
-
215
- return editreward
216
-
217
-
218
- def main(script_args: EditGRPOScriptArguments, training_args: GRPOConfig, model_args: ModelConfig) -> None:
219
- supported_rewards = {"editreward", "format"}
220
- unsupported_rewards = sorted(set(script_args.reward_funcs) - supported_rewards)
221
- if unsupported_rewards:
222
- raise ValueError(f"Edit training supports only {sorted(supported_rewards)}, got {unsupported_rewards}.")
223
-
224
- reward_registry = {
225
- "editreward": build_editreward_func(script_args.editreward_url),
226
- "format": format_reward,
227
- }
228
- reward_funcs = [(name, None, reward_registry[name]) for name in script_args.reward_funcs]
229
-
230
- train_dataset = EditPromptDataset(
231
- prompts_file=script_args.prompts_file,
232
- question_template=EDIT_QUESTION_TEMPLATE,
233
- image_column=script_args.image_column,
234
- prompt_column=script_args.prompt_column,
235
- caption_column=script_args.caption_column,
236
- target_caption_column=script_args.target_caption_column,
237
- image_size=script_args.image_size,
238
- cache_dir=script_args.dataset_cache_dir,
239
- )
240
-
241
- trainer = QwenKontextEditGRPOTrainer(
242
- model=model_args.model_name_or_path,
243
- reward_funcs=reward_funcs,
244
- args=training_args,
245
- train_dataset=train_dataset,
246
- peft_config=get_peft_config(model_args),
247
- max_pixels=script_args.max_pixels,
248
- min_pixels=script_args.min_pixels,
249
- processor_name_or_path=script_args.processor_name_or_path,
250
- attn_implementation=model_args.attn_implementation,
251
- num_skip_refinement=script_args.num_skip_refinement,
252
- num_sde=script_args.num_sde,
253
- )
254
-
255
- checkpoint = get_last_checkpoint(training_args.output_dir) if os.path.isdir(training_args.output_dir) else None
256
- trainer.train(resume_from_checkpoint=checkpoint)
257
- trainer.save_model(training_args.output_dir)
258
- if training_args.push_to_hub:
259
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
260
-
261
-
262
- if __name__ == "__main__":
263
- parser = TrlParser((EditGRPOScriptArguments, GRPOConfig, ModelConfig))
264
- parsed_script_args, parsed_training_args, parsed_model_args = parser.parse_args_and_config()
265
- main(parsed_script_args, parsed_training_args, parsed_model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unirl/trainer/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .edit_grpo_trainer import QwenKontextEditGRPOTrainer
2
-
3
- __all__ = ["QwenKontextEditGRPOTrainer"]
4
-
 
 
 
 
 
unirl/trainer/edit_grpo_trainer.py DELETED
@@ -1,623 +0,0 @@
1
- import os
2
- from collections import defaultdict
3
- from datetime import datetime
4
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import transformers
9
- from accelerate.utils import DistributedType
10
- from datasets import Dataset, IterableDataset
11
- from packaging import version
12
- from PIL import Image
13
- from transformers import AutoProcessor, GenerationConfig, PreTrainedModel, Trainer, TrainerCallback
14
- from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
15
- from transformers.utils import is_peft_available
16
- from trl.data_utils import maybe_apply_chat_template
17
- from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
18
- from trl.trainer.grpo_config import GRPOConfig
19
-
20
- from unimodel.qwenkontext.fluxkontext_pipeline import sde_step_with_logprob
21
- from unimodel.qwenkontext.qwenkontext_inference import QwenKontextForInferenceLM
22
-
23
- if is_peft_available():
24
- from peft import PeftConfig, get_peft_model
25
-
26
-
27
- RewardFunc = Callable[..., Union[List[float], Dict[str, Any]]]
28
-
29
-
30
- def compute_log_prob(
31
- model_pred: torch.Tensor,
32
- scheduler,
33
- prev_latents: torch.Tensor,
34
- pred_latents: torch.Tensor,
35
- timesteps: torch.Tensor,
36
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
37
- return sde_step_with_logprob(
38
- scheduler,
39
- model_pred.float(),
40
- timesteps,
41
- prev_latents.float(),
42
- pred_latents.float(),
43
- )
44
-
45
-
46
- class QwenKontextEditGRPOTrainer(Trainer):
47
- """Joint GRPO trainer for Qwen prompt refinement and FLUX.1-Kontext edit generation."""
48
-
49
- def __init__(
50
- self,
51
- model: Union[str, PreTrainedModel],
52
- reward_funcs: List[Tuple[str, Optional[Any], RewardFunc]],
53
- args: Optional[GRPOConfig] = None,
54
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
55
- eval_dataset: Optional[Union[Dataset, IterableDataset, Dict[str, Union[Dataset, IterableDataset]]]] = None,
56
- processing_class: Optional[Any] = None,
57
- callbacks: Optional[List[TrainerCallback]] = None,
58
- optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
59
- peft_config: Optional["PeftConfig"] = None,
60
- max_pixels: int = 200704,
61
- min_pixels: int = 200704,
62
- processor_name_or_path: str = "Qwen/Qwen2.5-VL-3B-Instruct",
63
- attn_implementation: str = "flash_attention_2",
64
- num_skip_refinement: int = 2,
65
- num_sde: int = 4,
66
- ):
67
- if args is None:
68
- model_name = model if isinstance(model, str) else model.config._name_or_path
69
- args = GRPOConfig(f"{os.path.basename(model_name)}-edit-joint-grpo")
70
-
71
- model_init_kwargs = args.model_init_kwargs or {}
72
- model_init_kwargs["attn_implementation"] = attn_implementation
73
- model_init_kwargs["use_cache"] = False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
74
-
75
- if isinstance(model, str):
76
- self.model_id = model
77
- model = self._load_model(model, model_init_kwargs)
78
- else:
79
- self.model_id = model.config._name_or_path
80
- if args.model_init_kwargs is not None:
81
- raise ValueError("model_init_kwargs can only be used when model is a path.")
82
-
83
- if peft_config is not None:
84
- model = get_peft_model(model, peft_config)
85
-
86
- self._configure_trainable_parameters(model)
87
- self.ref_model = self._create_reference_model(model, model_init_kwargs)
88
- self.scheduler = model.get_model().diffusion_expert.scheduler
89
-
90
- if processing_class is None:
91
- processing_class = self._create_processor(processor_name_or_path, max_pixels, min_pixels)
92
- self.processing_class = processing_class
93
- self.reward_funcs = reward_funcs
94
- self.max_prompt_length = args.max_prompt_length
95
- self.num_generations = args.num_generations
96
- self.beta = args.beta
97
- self.num_sde = num_sde
98
-
99
- if not 0 <= num_skip_refinement < self.num_generations:
100
- raise ValueError(
101
- f"num_skip_refinement must be in [0, num_generations), got {num_skip_refinement} "
102
- f"for num_generations={self.num_generations}."
103
- )
104
- self.num_skip_refinement = num_skip_refinement
105
- self.num_refined = self.num_generations - num_skip_refinement
106
-
107
- self.generation_config = GenerationConfig(
108
- max_new_tokens=args.max_completion_length or 256,
109
- do_sample=True,
110
- temperature=1.0,
111
- num_return_sequences=1,
112
- pad_token_id=processing_class.pad_token_id,
113
- eos_token_id=processing_class.eos_token_id,
114
- )
115
- model.generation_config = self.generation_config
116
- self.ref_model.generation_config = self.generation_config
117
- if hasattr(model, "warnings_issued"):
118
- model.warnings_issued["estimate_tokens"] = True
119
-
120
- self._metrics = defaultdict(list)
121
-
122
- def data_collator(features):
123
- return features
124
-
125
- super().__init__(
126
- model=model,
127
- args=args,
128
- data_collator=data_collator,
129
- train_dataset=train_dataset,
130
- eval_dataset=eval_dataset,
131
- processing_class=processing_class,
132
- callbacks=callbacks,
133
- optimizers=optimizers,
134
- )
135
- self.model_accepts_loss_kwargs = False
136
-
137
- if self.is_deepspeed_enabled and is_deepspeed_zero3_enabled():
138
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
139
- else:
140
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
141
-
142
- self.diffusion_generation_config = self._get_diffusion_config()
143
- self.start_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
144
- self.log_dir = os.path.join(args.output_dir, "training_samples", self.start_time)
145
- os.makedirs(self.log_dir, exist_ok=True)
146
-
147
- def _load_model(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> PreTrainedModel:
148
- torch_dtype = model_init_kwargs.get("torch_dtype")
149
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
150
- model_init_kwargs["torch_dtype"] = getattr(torch, torch_dtype)
151
- if "qwenkontext" not in model_id.lower():
152
- raise ValueError("Edit joint training expects a Qwen-Kontext checkpoint path.")
153
- return QwenKontextForInferenceLM.from_pretrained(model_id, **model_init_kwargs)
154
-
155
- def _create_reference_model(self, model: PreTrainedModel, model_init_kwargs: Dict[str, Any]) -> PreTrainedModel:
156
- if is_deepspeed_zero3_enabled():
157
- ref_model = self._load_model(self.model_id, model_init_kwargs)
158
- else:
159
- ref_model = create_reference_model(model)
160
- for parameter in ref_model.parameters():
161
- parameter.requires_grad = False
162
- return ref_model
163
-
164
- def _configure_trainable_parameters(self, model: PreTrainedModel) -> None:
165
- try:
166
- model.get_model().diffusion_expert.enable_vae_slicing()
167
- except AttributeError:
168
- try:
169
- model.get_model().diffusion_expert.vae.enable_slicing()
170
- except AttributeError:
171
- pass
172
-
173
- for parameter in model.parameters():
174
- parameter.requires_grad = False
175
- for parameter in model.get_model().parameters():
176
- parameter.requires_grad = True
177
- for parameter in model.lm_head.parameters():
178
- parameter.requires_grad = True
179
-
180
- if hasattr(model, "visual"):
181
- for parameter in model.visual.parameters():
182
- parameter.requires_grad = False
183
-
184
- for component_name in ("visual", "vae", "text_encoder", "text_encoder_2", "text_encoder_3"):
185
- component = getattr(model.get_model(), component_name, None)
186
- if component is not None:
187
- for parameter in component.parameters():
188
- parameter.requires_grad = False
189
-
190
- transformer = getattr(model.get_model(), "transformer", None)
191
- if transformer is None:
192
- raise ValueError("Qwen-Kontext model does not expose a FLUX transformer.")
193
- for parameter in transformer.parameters():
194
- parameter.requires_grad = True
195
-
196
- def _create_processor(self, processor_name_or_path: str, max_pixels: int, min_pixels: int) -> AutoProcessor:
197
- processor = AutoProcessor.from_pretrained(processor_name_or_path)
198
- processor.pad_token_id = processor.tokenizer.pad_token_id
199
- processor.eos_token_id = processor.tokenizer.eos_token_id
200
- processor.image_processor.max_pixels = max_pixels
201
- processor.image_processor.min_pixels = min_pixels
202
- return processor
203
-
204
- def _get_diffusion_config(self) -> Dict[str, Any]:
205
- device_text = str(self.accelerator.device)
206
- device_id = int(device_text.split(":")[-1]) if ":" in device_text else 0
207
- return {
208
- "guidance_scale": float(os.getenv("PROMPTRL_EDIT_GUIDANCE_SCALE", "2.5")),
209
- "num_inference_steps": int(os.getenv("PROMPTRL_EDIT_NUM_INFERENCE_STEPS", "8")),
210
- "num_images_per_prompt": 1,
211
- "generator": torch.manual_seed(42 + device_id),
212
- "height": int(os.getenv("PROMPTRL_EDIT_HEIGHT", "1024")),
213
- "width": int(os.getenv("PROMPTRL_EDIT_WIDTH", "1024")),
214
- "num_sde": self.num_sde,
215
- }
216
-
217
- def _set_signature_columns_if_needed(self):
218
- if self._signature_columns is None:
219
- self._signature_columns = ["prompt"]
220
-
221
- def create_optimizer(self):
222
- if self.optimizer is not None:
223
- return self.optimizer
224
-
225
- optimizer_kwargs = {
226
- "betas": (self.args.adam_beta1, self.args.adam_beta2),
227
- "eps": self.args.adam_epsilon,
228
- "weight_decay": self.args.weight_decay,
229
- }
230
- dit_lr = float(os.getenv("DIT_LEARNING_RATE", os.getenv("PROMPTRL_DIT_LR", "2e-7")))
231
- llm_lr = float(os.getenv("LLM_LEARNING_RATE", os.getenv("PROMPTRL_LLM_LR", "3e-7")))
232
-
233
- dit_params = [
234
- parameter for parameter in self.model.get_model().transformer.parameters() if parameter.requires_grad
235
- ]
236
- dit_param_ids = {id(parameter) for parameter in dit_params}
237
- llm_params = [
238
- parameter
239
- for parameter in self.model.parameters()
240
- if parameter.requires_grad and id(parameter) not in dit_param_ids
241
- ]
242
-
243
- param_groups = []
244
- if dit_params:
245
- param_groups.append({"params": dit_params, "lr": dit_lr})
246
- if llm_params:
247
- param_groups.append({"params": llm_params, "lr": llm_lr})
248
- if not param_groups:
249
- raise ValueError("No trainable parameters were found for edit joint GRPO training.")
250
-
251
- self.optimizer = torch.optim.AdamW(param_groups, **optimizer_kwargs)
252
- return self.optimizer
253
-
254
- def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None):
255
- model.eval()
256
- self.ref_model.eval()
257
- if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
258
- self.optimizer.train()
259
-
260
- inputs = self._prepare_inputs(inputs)
261
-
262
- def loss_update(loss: torch.Tensor, scale_factor: float = 1.0) -> None:
263
- if self.args.n_gpu > 1:
264
- loss = loss.mean()
265
- if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
266
- loss = loss / self.args.gradient_accumulation_steps
267
- if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
268
- loss = loss / scale_factor
269
- model.backward(loss)
270
- else:
271
- self.accelerator.backward(loss / scale_factor)
272
-
273
- with self.compute_loss_context_manager():
274
- generations = self.generate_samples(model, inputs)
275
- torch.cuda.empty_cache()
276
-
277
- if self.num_refined > 0:
278
- cot_loss = self.cot_loss_computation(
279
- model,
280
- generations["prompt_completion_ids"],
281
- generations["completion_ids"],
282
- generations["prompt_length"],
283
- generations["advantages_refined"],
284
- generations["prompt_inputs"],
285
- )
286
- loss_update(cot_loss, 1.0)
287
- else:
288
- cot_loss = torch.tensor(0.0, device=self.accelerator.device)
289
-
290
- diff_advantages = generations["advantages"].repeat_interleave(self.num_sde, dim=0)
291
- total_len = diff_advantages.shape[0]
292
- diff_loss_values = []
293
- diff_kl_values = []
294
- diffusion_batch_size = int(os.getenv("PROMPTRL_DIFFUSION_LOSS_BATCH_SIZE", "4"))
295
-
296
- for idx in range(0, total_len, diffusion_batch_size):
297
- batched_states_slice = {}
298
- for key, value in generations["batched_states"].items():
299
- if key in {"img_ids", "txt_ids"}:
300
- batched_states_slice[key] = value
301
- elif value is None:
302
- batched_states_slice[key] = None
303
- else:
304
- batched_states_slice[key] = value[idx : idx + diffusion_batch_size]
305
-
306
- diff_loss, diff_kl = self.diffusion_loss_computation(
307
- generations["prev_latents"][idx : idx + diffusion_batch_size],
308
- generations["diff_sampling_log_probs"][idx : idx + diffusion_batch_size],
309
- generations["pred_latents"][idx : idx + diffusion_batch_size],
310
- generations["ts"][idx : idx + diffusion_batch_size],
311
- batched_states_slice,
312
- diff_advantages[idx : idx + diffusion_batch_size],
313
- )
314
- loss_update(diff_loss, max(1.0, float(total_len / diffusion_batch_size)))
315
- diff_loss_values.append(diff_loss.detach())
316
- diff_kl_values.append(diff_kl.detach())
317
-
318
- diff_loss = torch.stack(diff_loss_values).mean()
319
- diff_kl = torch.stack(diff_kl_values).mean()
320
- loss = diff_loss + cot_loss.detach()
321
-
322
- if self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0:
323
- torch.cuda.empty_cache()
324
-
325
- if hasattr(model, "step") and callable(model.step):
326
- model.step()
327
-
328
- self._metrics["diff_kl"].append(self.accelerator.gather_for_metrics(diff_kl).mean().item())
329
- self._metrics["diff_loss"].append(self.accelerator.gather_for_metrics(diff_loss).mean().item())
330
- torch.cuda.empty_cache()
331
- return loss.detach()
332
-
333
- def generate_samples(self, model: nn.Module, inputs: List[Dict]) -> Dict[str, Any]:
334
- source_images = [example["image"] for example in inputs]
335
- batch_size = len(inputs)
336
- prompt_inputs = None
337
- prompt_completion_ids = None
338
- completion_ids = None
339
- prompt_length = 0
340
- completions_refined: List[str] = []
341
- refined_prompts: List[str] = []
342
-
343
- if self.num_refined > 0:
344
- prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
345
- prompt_inputs = self.processing_class(
346
- images=[image for image in source_images for _ in range(self.num_refined)],
347
- text=[prompt for prompt in prompts_text for _ in range(self.num_refined)],
348
- return_tensors="pt",
349
- padding=True,
350
- padding_side="left",
351
- add_special_tokens=False,
352
- )
353
- prompt_inputs = super()._prepare_inputs(prompt_inputs)
354
- if self.max_prompt_length is not None:
355
- prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
356
- prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
357
-
358
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
359
- with torch.no_grad():
360
- prompt_completion_ids = unwrapped_model.generate(
361
- **prompt_inputs,
362
- generation_config=self.generation_config,
363
- )
364
-
365
- prompt_length = prompt_inputs["input_ids"].size(1)
366
- completion_ids = prompt_completion_ids[:, prompt_length:]
367
- completions_refined = self.processing_class.tokenizer.batch_decode(
368
- completion_ids,
369
- skip_special_tokens=True,
370
- )
371
- refined_prompts = [self.model.extract_thinking_content(completion) for completion in completions_refined]
372
-
373
- original_prompts = [
374
- example["editing_instruction"]
375
- for example in inputs
376
- for _ in range(self.num_skip_refinement)
377
- ]
378
- all_prompts: List[str] = []
379
- for batch_idx in range(batch_size):
380
- refined_start = batch_idx * self.num_refined
381
- refined_end = refined_start + self.num_refined
382
- all_prompts.extend(refined_prompts[refined_start:refined_end])
383
-
384
- original_start = batch_idx * self.num_skip_refinement
385
- original_end = original_start + self.num_skip_refinement
386
- all_prompts.extend(original_prompts[original_start:original_end])
387
-
388
- all_source_images = [image for image in source_images for _ in range(self.num_generations)]
389
- with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
390
- with torch.no_grad():
391
- (
392
- edited_images,
393
- prev_latents,
394
- diff_sampling_log_probs,
395
- pred_latents,
396
- timesteps,
397
- batched_states,
398
- ) = unwrapped_model.generate_image(
399
- images=all_source_images,
400
- texts=all_prompts,
401
- diffusion_kwargs=self.diffusion_generation_config,
402
- sde_sampling=True,
403
- )
404
-
405
- rewards, rewards_per_func = self.compute_rewards(inputs, edited_images, completions_refined)
406
- advantages = self.compute_advantages(rewards)
407
- advantages_refined = (
408
- advantages.view(batch_size, self.num_generations)[:, : self.num_refined].flatten()
409
- if self.num_refined > 0
410
- else torch.tensor([], device=advantages.device)
411
- )
412
-
413
- self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
414
- for index, (func_name, _, _) in enumerate(self.reward_funcs):
415
- self._metrics[f"reward/{func_name}"].append(
416
- self.accelerator.gather_for_metrics(rewards_per_func[:, index]).mean().item()
417
- )
418
- self._log_samples(source_images, edited_images, all_prompts, advantages)
419
-
420
- return {
421
- "images": edited_images,
422
- "prev_latents": prev_latents,
423
- "diff_sampling_log_probs": diff_sampling_log_probs,
424
- "pred_latents": pred_latents,
425
- "batched_states": batched_states,
426
- "prompt_length": prompt_length,
427
- "completion_ids": completion_ids,
428
- "prompt_completion_ids": prompt_completion_ids,
429
- "prompt_inputs": prompt_inputs,
430
- "advantages": advantages,
431
- "advantages_refined": advantages_refined,
432
- "ts": timesteps,
433
- }
434
-
435
- def compute_rewards(
436
- self,
437
- inputs: List[Dict],
438
- edited_images: List[Image.Image],
439
- completions_refined: List[str],
440
- ) -> Tuple[torch.Tensor, torch.Tensor]:
441
- device = self.accelerator.device
442
- rewards_per_func = torch.zeros(len(edited_images), len(self.reward_funcs), device=device)
443
- batch_size = len(inputs)
444
-
445
- for index, (func_name, _, reward_func) in enumerate(self.reward_funcs):
446
- if func_name == "format":
447
- refined_scores = torch.tensor(reward_func(completions_refined), device=device, dtype=torch.float32)
448
- for batch_idx in range(batch_size):
449
- start = batch_idx * self.num_generations
450
- refined_start = batch_idx * self.num_refined
451
- refined_end = refined_start + self.num_refined
452
- batch_refined_scores = refined_scores[refined_start:refined_end]
453
- rewards_per_func[start : start + self.num_refined, index] = batch_refined_scores
454
- rewards_per_func[start + self.num_refined : start + self.num_generations, index] = (
455
- batch_refined_scores.mean() if len(batch_refined_scores) else 0.0
456
- )
457
- elif func_name == "editreward":
458
- source_images = [example["image"] for example in inputs for _ in range(self.num_generations)]
459
- prompts = [example["editing_instruction"] for example in inputs for _ in range(self.num_generations)]
460
- rewards_per_func[:, index] = torch.tensor(
461
- reward_func(source_images, edited_images, prompts)["scores"],
462
- device=device,
463
- dtype=torch.float32,
464
- )
465
- else:
466
- raise ValueError(f"Unsupported reward function for edit joint training: {func_name}")
467
-
468
- return rewards_per_func.sum(dim=1), rewards_per_func
469
-
470
- def compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
471
- grouped_rewards = rewards.view(-1, self.num_generations)
472
- mean = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
473
- std = grouped_rewards.std(dim=1, unbiased=False).repeat_interleave(self.num_generations, dim=0)
474
- return torch.clamp((rewards - mean) / (std + 1e-4), -5, 5)
475
-
476
- def cot_loss_computation(
477
- self,
478
- model: nn.Module,
479
- input_ids: torch.Tensor,
480
- completion_ids: torch.Tensor,
481
- prompt_length: int,
482
- advantages: torch.Tensor,
483
- prompt_inputs: Dict[str, torch.Tensor],
484
- ) -> torch.Tensor:
485
- image_kwargs = {
486
- key: value for key, value in prompt_inputs.items() if key not in {"input_ids", "attention_mask"}
487
- }
488
- per_token_logps = self._get_per_token_logps(model, input_ids, image_kwargs)[:, prompt_length - 1 :]
489
- with torch.inference_mode():
490
- ref_per_token_logps = self._get_per_token_logps(self.ref_model, input_ids, image_kwargs)[:, prompt_length - 1 :]
491
-
492
- per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
493
- completion_mask = self._completion_mask(completion_ids)
494
- per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
495
- per_token_loss = -(per_token_loss - 0.01 * per_token_kl)
496
- cot_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp_min(1)).mean()
497
- mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1).clamp_min(1)).mean()
498
-
499
- self._metrics["completion_length"].append(
500
- self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
501
- )
502
- self._metrics["cot_kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
503
- self._metrics["cot_loss"].append(self.accelerator.gather_for_metrics(cot_loss).mean().item())
504
- return cot_loss
505
-
506
- def _get_per_token_logps(
507
- self,
508
- model: nn.Module,
509
- input_ids: torch.Tensor,
510
- image_kwargs: Dict[str, torch.Tensor],
511
- ) -> torch.Tensor:
512
- logits = model(input_ids, **image_kwargs).logits[:, :-1, :]
513
- target_ids = input_ids[:, 1:]
514
- per_token_logps = []
515
- for logits_row, target_ids_row in zip(logits, target_ids):
516
- log_probs = logits_row.log_softmax(dim=-1)
517
- per_token_logps.append(torch.gather(log_probs, dim=1, index=target_ids_row.unsqueeze(1)).squeeze(1))
518
- return torch.stack(per_token_logps)
519
-
520
- def _completion_mask(self, completion_ids: torch.Tensor) -> torch.Tensor:
521
- is_eos = completion_ids == self.processing_class.eos_token_id
522
- device = completion_ids.device
523
- eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
524
- eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
525
- sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
526
- return (sequence_indices <= eos_idx.unsqueeze(1)).int()
527
-
528
- def diffusion_loss_computation(
529
- self,
530
- prev_latents: torch.Tensor,
531
- diff_sampling_log_probs: torch.Tensor,
532
- pred_latents: torch.Tensor,
533
- timesteps: torch.Tensor,
534
- batched_states: Dict[str, torch.Tensor],
535
- advantages: torch.Tensor,
536
- ) -> Tuple[torch.Tensor, torch.Tensor]:
537
- model_pred = self.model.get_model().transformer(
538
- hidden_states=prev_latents.to(self.model.device),
539
- **batched_states,
540
- joint_attention_kwargs={},
541
- return_dict=False,
542
- )[0][:, : pred_latents.size(1)]
543
-
544
- with torch.no_grad():
545
- ref_model_pred = self.ref_model.get_model().transformer(
546
- hidden_states=prev_latents.to(self.model.device),
547
- **batched_states,
548
- joint_attention_kwargs={},
549
- return_dict=False,
550
- )[0][:, : pred_latents.size(1)]
551
-
552
- _, log_prob, prev_sample_mean, std_dev_t = compute_log_prob(
553
- model_pred,
554
- self.scheduler,
555
- prev_latents[:, : pred_latents.size(1)],
556
- pred_latents,
557
- timesteps,
558
- )
559
- _, _, ref_prev_sample_mean, ref_std_dev_t = compute_log_prob(
560
- ref_model_pred,
561
- self.scheduler,
562
- prev_latents[:, : pred_latents.size(1)],
563
- pred_latents,
564
- timesteps,
565
- )
566
- if not torch.equal(std_dev_t, ref_std_dev_t):
567
- raise RuntimeError("Current and reference SDE std-dev tensors diverged.")
568
-
569
- kl = ((prev_sample_mean - ref_prev_sample_mean) ** 2 / (2 * std_dev_t**2)).mean(
570
- dim=tuple(range(1, prev_sample_mean.ndim))
571
- )
572
- ratio = torch.exp(log_prob - diff_sampling_log_probs)
573
- unclipped_loss = -advantages * ratio
574
- clipped_loss = -advantages * torch.clamp(ratio, 1.0 - 1e-4, 1.0 + 1e-4)
575
- diff_loss = torch.maximum(unclipped_loss, clipped_loss).mean() + self.beta * kl.mean()
576
- return diff_loss, kl
577
-
578
- def _log_samples(
579
- self,
580
- source_images: List[Image.Image],
581
- edited_images: List[Image.Image],
582
- prompts: List[str],
583
- advantages: torch.Tensor,
584
- ) -> None:
585
- global_step = self.state.global_step
586
- if global_step % 10 != 0 or not edited_images:
587
- return
588
-
589
- device_id = str(self.accelerator.device).replace(":", "")
590
- text_content = []
591
- for batch_idx in range(len(source_images)):
592
- for gen_idx in range(self.num_generations):
593
- overall_idx = batch_idx * self.num_generations + gen_idx
594
- status = "REFINED" if gen_idx < self.num_refined else "ORIGINAL"
595
- text_content.append(f"[{status}] Generation {gen_idx}: {prompts[overall_idx]}")
596
- text_content.append("")
597
-
598
- txt_path = os.path.join(self.log_dir, f"step_{global_step}_{device_id}.txt")
599
- if not os.path.exists(txt_path):
600
- with open(txt_path, "w", encoding="utf-8") as file:
601
- file.write("\n".join(text_content))
602
-
603
- for batch_idx, source_image in enumerate(source_images):
604
- source_image.save(os.path.join(self.log_dir, f"step_{global_step}_{device_id}_batch{batch_idx}_source.jpg"))
605
- for gen_idx in range(self.num_generations):
606
- overall_idx = batch_idx * self.num_generations + gen_idx
607
- prefix = "refined" if gen_idx < self.num_refined else "original"
608
- edited_images[overall_idx].save(
609
- os.path.join(
610
- self.log_dir,
611
- f"step_{global_step}_{device_id}_batch{batch_idx}_{prefix}_gen{gen_idx}_{advantages[overall_idx].item():.5f}.jpg",
612
- )
613
- )
614
-
615
- def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
616
- metrics = {key: sum(value) / len(value) for key, value in self._metrics.items() if value}
617
- logs = {**logs, **metrics}
618
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
619
- super().log(logs, start_time)
620
- else:
621
- super().log(logs)
622
- self._metrics.clear()
623
-