Werli commited on
Commit
b0b0867
·
verified ·
1 Parent(s): 83782c8

Upload 6 files

Browse files
Files changed (5) hide show
  1. app.py +851 -823
  2. modules/classifyTags.py +429 -256
  3. modules/media_handler.py +212 -0
  4. modules/pixai.py +801 -810
  5. modules/video_processor.py +206 -0
app.py CHANGED
@@ -1,823 +1,851 @@
1
- import os, io, json, requests, spaces, argparse, traceback, tempfile, zipfile, re, ast, time
2
- import gradio as gr
3
- import numpy as np
4
- import huggingface_hub
5
- import onnxruntime as ort
6
- import pandas as pd
7
- from datetime import datetime, timezone
8
- from collections import defaultdict
9
- from PIL import Image, ImageOps
10
- from apscheduler.schedulers.background import BackgroundScheduler
11
- from modules.classifyTags import categorize_tags_output, generate_tags_json
12
- from modules.pixai import create_pixai_interface
13
- from modules.booru import create_booru_interface
14
-
15
- """ For GPU install all the requirements.txt and the following:
16
- pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
17
- pip install onnxruntime-gpu
18
- """
19
-
20
- """ It's recommended to create a venv:
21
- python -m venv venv
22
- venv\Scripts\activate
23
- pip install ...
24
- python app.py
25
- """
26
-
27
- TITLE = 'Multi-Tagger v1.4'
28
- DESCRIPTION = '\nMulti-Tagger is a versatile application for advanced image analysis and captioning. Supports <b>CUDA</b> and <b>CPU</b>.\n'
29
-
30
- SWINV2_MODEL_DSV3_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
31
- CONV_MODEL_DSV3_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
32
- VIT_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-tagger-v3'
33
- VIT_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
34
- EVA02_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-eva02-large-tagger-v3'
35
- MOAT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-moat-tagger-v2'
36
- SWIN_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-swinv2-tagger-v2'
37
- CONV_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
38
- CONV2_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2'
39
- VIT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-vit-tagger-v2'
40
- EVA02_LARGE_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-eva02-large-tagger-v1'
41
- SWINV2_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-swinv2-tagger-v1'
42
-
43
- # Global variables for model components (for memory management)
44
- CURRENT_MODEL = None
45
- CURRENT_MODEL_NAME = None
46
- CURRENT_TAGS_DF = None
47
- CURRENT_TAG_NAMES = None
48
- CURRENT_RATING_INDEXES = None
49
- CURRENT_GENERAL_INDEXES = None
50
- CURRENT_CHARACTER_INDEXES = None
51
- CURRENT_MODEL_TARGET_SIZE = None
52
-
53
- # Custom CSS for gallery styling
54
- css = """
55
- #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
56
- #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
57
- #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
58
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
59
- #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
60
- #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
61
- .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
62
- .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
63
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
64
- """
65
-
66
- MODEL_FILENAME = 'model.onnx'
67
- LABEL_FILENAME = 'selected_tags.csv'
68
-
69
- class Timer:
70
- """Utility class for measuring execution time of different operations"""
71
-
72
- def __init__(self):
73
- self.start_time = time.perf_counter()
74
- self.checkpoints = [('Start', self.start_time)]
75
-
76
- def checkpoint(self, label='Checkpoint'):
77
- """Add a checkpoint with a label"""
78
- now = time.perf_counter()
79
- self.checkpoints.append((label, now))
80
-
81
- def report(self, is_clear_checkpoints=True):
82
- """Report time elapsed since last checkpoint"""
83
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
84
- prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
85
-
86
- for (label, curr_time) in self.checkpoints[1:]:
87
- elapsed = curr_time - prev_time
88
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
89
- prev_time = curr_time
90
-
91
- if is_clear_checkpoints:
92
- self.checkpoints.clear()
93
- self.checkpoint()
94
-
95
- def report_all(self):
96
- """Report all checkpoint times including total execution time"""
97
- print('\n> Execution Time Report:')
98
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
99
- prev_time = self.start_time
100
-
101
- for (label, curr_time) in self.checkpoints[1:]:
102
- elapsed = curr_time - prev_time
103
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
104
- prev_time = curr_time
105
-
106
- total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
107
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
108
- self.checkpoints.clear()
109
-
110
- def restart(self):
111
- """Restart the timer"""
112
- self.start_time = time.perf_counter()
113
- self.checkpoints = [('Start', self.start_time)]
114
-
115
- def parse_args() -> argparse.Namespace:
116
- """Parse command line arguments"""
117
- parser = argparse.ArgumentParser()
118
- parser.add_argument('--score-slider-step', type=float, default=0.05)
119
- parser.add_argument('--score-general-threshold', type=float, default=0.35)
120
- parser.add_argument('--score-character-threshold', type=float, default=0.85)
121
- parser.add_argument('--share', action='store_true')
122
- return parser.parse_args()
123
-
124
- def load_labels(dataframe) -> tuple:
125
- """Load tag names and their category indexes from the dataframe"""
126
- name_series = dataframe['name']
127
- tag_names = name_series.tolist()
128
-
129
- # Find indexes for different tag categories
130
- rating_indexes = list(np.where(dataframe['category'] == 9)[0])
131
- general_indexes = list(np.where(dataframe['category'] == 0)[0])
132
- character_indexes = list(np.where(dataframe['category'] == 4)[0])
133
-
134
- return tag_names, rating_indexes, general_indexes, character_indexes
135
-
136
- def mcut_threshold(probs):
137
- """Calculate threshold using Maximum Change in second derivative (MCut) method"""
138
- sorted_probs = probs[probs.argsort()[::-1]]
139
- difs = sorted_probs[:-1] - sorted_probs[1:]
140
- t = difs.argmax()
141
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
142
- return thresh
143
-
144
- def _download_model_files(model_repo):
145
- """Download model files from HuggingFace Hub"""
146
- csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
147
- model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
148
- return csv_path, model_path
149
-
150
- def create_optimized_ort_session(model_path):
151
- """Create an optimized ONNX Runtime session with GPU support"""
152
- # Configure session options for better performance
153
- sess_options = ort.SessionOptions()
154
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
155
- sess_options.intra_op_num_threads = 0 # Use all available cores
156
- sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
157
- sess_options.enable_mem_pattern = True
158
- sess_options.enable_cpu_mem_arena = True
159
-
160
- # Check available providers
161
- available_providers = ort.get_available_providers()
162
- print(f"Available ONNX Runtime providers: {available_providers}")
163
-
164
- # Configure execution providers (prefer CUDA if available)
165
- providers = []
166
-
167
- # Use CUDA if available
168
- if 'CUDAExecutionProvider' in available_providers:
169
- providers.append('CUDAExecutionProvider')
170
- print("Using CUDA provider for ONNX inference")
171
- else:
172
- print("CUDA provider not available, falling back to CPU")
173
-
174
- # Always include CPU as fallback
175
- providers.append('CPUExecutionProvider')
176
-
177
- try:
178
- session = ort.InferenceSession(model_path, sess_options, providers=providers)
179
- print(f"Model loaded with providers: {session.get_providers()}")
180
- return session
181
- except Exception as e:
182
- print(f"Failed to create ONNX session: {e}")
183
- raise
184
-
185
- def _load_model_components_optimized(model_repo):
186
- """Load and optimize model components"""
187
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
188
- global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
189
-
190
- # Only reload if model changed
191
- if model_repo == CURRENT_MODEL_NAME and CURRENT_MODEL is not None:
192
- return
193
-
194
- # Download files
195
- csv_path, model_path = _download_model_files(model_repo)
196
-
197
- # Load optimized ONNX model
198
- CURRENT_MODEL = create_optimized_ort_session(model_path)
199
-
200
- # Load tags
201
- tags_df = pd.read_csv(csv_path)
202
- tag_names, rating_indexes, general_indexes, character_indexes = load_labels(tags_df)
203
-
204
- # Store in global variables
205
- CURRENT_TAGS_DF = tags_df
206
- CURRENT_TAG_NAMES = tag_names
207
- CURRENT_RATING_INDEXES = rating_indexes
208
- CURRENT_GENERAL_INDEXES = general_indexes
209
- CURRENT_CHARACTER_INDEXES = character_indexes
210
-
211
- # Get model input size
212
- _, height, width, _ = CURRENT_MODEL.get_inputs()[0].shape
213
- CURRENT_MODEL_TARGET_SIZE = height
214
- CURRENT_MODEL_NAME = model_repo
215
-
216
- def _raw_predict(image_array, model_session):
217
- """Run raw prediction using the model session"""
218
- input_name = model_session.get_inputs()[0].name
219
- label_name = model_session.get_outputs()[0].name
220
- preds = model_session.run([label_name], {input_name: image_array})[0]
221
- return preds[0].astype(float)
222
-
223
- def unload_model():
224
- """Explicitly unload the current model from memory"""
225
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
226
- global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
227
-
228
- # Delete the model session
229
- if CURRENT_MODEL is not None:
230
- del CURRENT_MODEL
231
- CURRENT_MODEL = None
232
-
233
- # Clear other large objects
234
- CURRENT_TAGS_DF = None
235
- CURRENT_TAG_NAMES = None
236
- CURRENT_RATING_INDEXES = None
237
- CURRENT_GENERAL_INDEXES = None
238
- CURRENT_CHARACTER_INDEXES = None
239
- CURRENT_MODEL_TARGET_SIZE = None
240
- CURRENT_MODEL_NAME = None
241
-
242
- # Force garbage collection
243
- import gc
244
- gc.collect()
245
-
246
- # Clear CUDA cache if using GPU
247
- try:
248
- import torch
249
- if torch.cuda.is_available():
250
- torch.cuda.empty_cache()
251
- except ImportError:
252
- pass
253
-
254
- def cleanup_after_processing():
255
- """Cleanup resources after processing"""
256
- unload_model()
257
-
258
- class Predictor:
259
- """Main predictor class for handling image tagging"""
260
-
261
- def __init__(self):
262
- self.model_components = None
263
- self.last_loaded_repo = None
264
-
265
- def load_model(self, model_repo):
266
- """Load model if not already loaded"""
267
- if model_repo == self.last_loaded_repo and self.model_components is not None:
268
- return
269
- _load_model_components_optimized(model_repo)
270
- self.last_loaded_repo = model_repo
271
-
272
- def prepare_image(self, path):
273
- """Prepare image for model input"""
274
- image = Image.open(path)
275
- image = image.convert('RGBA')
276
- target_size = CURRENT_MODEL_TARGET_SIZE
277
-
278
- # Create white background and composite
279
- canvas = Image.new('RGBA', image.size, (255, 255, 255))
280
- canvas.alpha_composite(image)
281
- image = canvas.convert('RGB')
282
-
283
- # Pad to square
284
- image_shape = image.size
285
- max_dim = max(image_shape)
286
- pad_left = (max_dim - image_shape[0]) // 2
287
- pad_top = (max_dim - image_shape[1]) // 2
288
- padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
289
- padded_image.paste(image, (pad_left, pad_top))
290
-
291
- # Resize if needed
292
- if max_dim != target_size:
293
- padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
294
-
295
- # Convert to array and preprocess
296
- image_array = np.asarray(padded_image, dtype=np.float32)
297
- image_array = image_array[:, :, ::-1] # BGR to RGB
298
- return np.expand_dims(image_array, axis=0)
299
-
300
- def create_file(self, content: str, directory: str, fileName: str) -> str:
301
- """Create a file with the given content"""
302
- file_path = os.path.join(directory, fileName)
303
- if fileName.endswith('.json'):
304
- with open(file_path, 'w', encoding='utf-8') as file:
305
- file.write(content)
306
- else:
307
- with open(file_path, 'w+', encoding='utf-8') as file:
308
- file.write(content)
309
- return file_path
310
-
311
- def predict(self, gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
312
- character_thresh, character_mcut_enabled, characters_merge_enabled,
313
- additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()):
314
- """Main prediction function for processing images"""
315
- tag_results.clear()
316
- gallery_len = len(gallery)
317
- print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
318
-
319
- timer = Timer()
320
- progressRatio = 1
321
- progressTotal = gallery_len + 1
322
- current_progress = 0
323
- txt_infos = []
324
- output_dir = tempfile.mkdtemp()
325
-
326
- if not os.path.exists(output_dir):
327
- os.makedirs(output_dir)
328
-
329
- # Load initial model
330
- self.load_model(model_repo)
331
- current_progress += progressRatio / progressTotal
332
- progress(current_progress, desc='Initialize wd model finished')
333
- timer.checkpoint("Initialize wd model")
334
- timer.report()
335
-
336
- name_counters = defaultdict(int)
337
-
338
- for (idx, value) in enumerate(gallery):
339
- try:
340
- # Handle duplicate filenames
341
- image_path = value[0]
342
- image_name = os.path.splitext(os.path.basename(image_path))[0]
343
- name_counters[image_name] += 1
344
- if name_counters[image_name] > 1:
345
- image_name = f"{image_name}_{name_counters[image_name]:02d}"
346
-
347
- # Prepare image
348
- image = self.prepare_image(image_path)
349
- print(f"Gallery {idx:02d}: Starting run first model ({model_repo})...")
350
-
351
- # Load and run first model
352
- self.load_model(model_repo)
353
- preds = _raw_predict(image, CURRENT_MODEL)
354
- labels = list(zip(CURRENT_TAG_NAMES, preds))
355
-
356
- # Process ratings
357
- ratings_names = [labels[i] for i in CURRENT_RATING_INDEXES]
358
- rating = dict(ratings_names)
359
-
360
- # Process general tags
361
- general_names = [labels[i] for i in CURRENT_GENERAL_INDEXES]
362
- if general_mcut_enabled:
363
- general_probs = np.array([x[1] for x in general_names])
364
- general_thresh_temp = mcut_threshold(general_probs)
365
- else:
366
- general_thresh_temp = general_thresh
367
-
368
- general_res = [x for x in general_names if x[1] > general_thresh_temp]
369
- general_res = dict(general_res)
370
-
371
- # Process character tags
372
- character_names = [labels[i] for i in CURRENT_CHARACTER_INDEXES]
373
- if character_mcut_enabled:
374
- character_probs = np.array([x[1] for x in character_names])
375
- character_thresh_temp = mcut_threshold(character_probs)
376
- character_thresh_temp = max(0.15, character_thresh_temp)
377
- else:
378
- character_thresh_temp = character_thresh
379
-
380
- character_res = [x for x in character_names if x[1] > character_thresh_temp]
381
- character_res = dict(character_res)
382
- character_list_1 = list(character_res.keys())
383
-
384
- # Sort general tags by confidence
385
- sorted_general_list_1 = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
386
- sorted_general_list_1 = [x[0] for x in sorted_general_list_1]
387
-
388
- # Handle second model if provided
389
- if model_repo_2 and model_repo_2 != model_repo:
390
- print(f"Gallery {idx:02d}: Starting run second model ({model_repo_2})...")
391
- self.load_model(model_repo_2)
392
- preds_2 = _raw_predict(image, CURRENT_MODEL)
393
- labels_2 = list(zip(CURRENT_TAG_NAMES, preds_2))
394
-
395
- # Process general tags from second model
396
- general_names_2 = [labels_2[i] for i in CURRENT_GENERAL_INDEXES]
397
- if general_mcut_enabled:
398
- general_probs_2 = np.array([x[1] for x in general_names_2])
399
- general_thresh_temp_2 = mcut_threshold(general_probs_2)
400
- else:
401
- general_thresh_temp_2 = general_thresh
402
-
403
- general_res_2 = [x for x in general_names_2 if x[1] > general_thresh_temp_2]
404
- general_res_2 = dict(general_res_2)
405
-
406
- # Process character tags from second model
407
- character_names_2 = [labels_2[i] for i in CURRENT_CHARACTER_INDEXES]
408
- if character_mcut_enabled:
409
- character_probs_2 = np.array([x[1] for x in character_names_2])
410
- character_thresh_temp_2 = mcut_threshold(character_probs_2)
411
- character_thresh_temp_2 = max(0.15, character_thresh_temp_2)
412
- else:
413
- character_thresh_temp_2 = character_thresh
414
-
415
- character_res_2 = [x for x in character_names_2 if x[1] > character_thresh_temp_2]
416
- character_res_2 = dict(character_res_2)
417
- character_list_2 = list(character_res_2.keys())
418
-
419
- # Sort general tags from second model
420
- sorted_general_list_2 = sorted(general_res_2.items(), key=lambda x: x[1], reverse=True)
421
- sorted_general_list_2 = [x[0] for x in sorted_general_list_2]
422
-
423
- # Combine results from both models
424
- combined_character_list = list(set(character_list_1 + character_list_2))
425
- combined_general_list = list(set(sorted_general_list_1 + sorted_general_list_2))
426
- else:
427
- combined_character_list = character_list_1
428
- combined_general_list = sorted_general_list_1
429
-
430
- # Remove characters from general tags if merging is disabled
431
- if not characters_merge_enabled:
432
- combined_character_list = [item for item in combined_character_list
433
- if item not in combined_general_list]
434
-
435
- # Handle additional tags
436
- prepend_list = [tag.strip() for tag in additional_tags_prepend.split(',') if tag.strip()]
437
- append_list = [tag.strip() for tag in additional_tags_append.split(',') if tag.strip()]
438
-
439
- # Avoid duplicates in prepend/append lists
440
- if prepend_list and append_list:
441
- append_list = [item for item in append_list if item not in prepend_list]
442
-
443
- # Remove prepended tags from main list
444
- if prepend_list:
445
- combined_general_list = [item for item in combined_general_list if item not in prepend_list]
446
-
447
- # Remove appended tags from main list
448
- if append_list:
449
- combined_general_list = [item for item in combined_general_list if item not in append_list]
450
-
451
- # Combine all tags
452
- combined_general_list = prepend_list + combined_general_list + append_list
453
-
454
- # Format output string
455
- sorted_general_strings = ', '.join(
456
- (combined_character_list if characters_merge_enabled else []) +
457
- combined_general_list
458
- ).replace('(', '\\(').replace(')', '\\)').replace('_', ' ')
459
-
460
- # Generate categorized output
461
- categorized_strings = categorize_tags_output(sorted_general_strings, character_res).replace('(', '\\(').replace(')', '\\)')
462
- categorized_json = generate_tags_json(sorted_general_strings, character_res)
463
-
464
- # Create output files
465
- txt_content = f"Output (string): {sorted_general_strings}\n\nCategorized Output: {categorized_strings}"
466
- txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
467
- txt_infos.append({'path': txt_file, 'name': f"{image_name}_output.txt"})
468
-
469
- # Save image copy
470
- image_path = value[0]
471
- image = Image.open(image_path)
472
- image.save(os.path.join(output_dir, f"{image_name}.png"), format='PNG')
473
- txt_infos.append({'path': os.path.join(output_dir, f"{image_name}.png"), 'name': f"{image_name}.png"})
474
-
475
- # Create tags text file
476
- txt_file = self.create_file(sorted_general_strings, output_dir, image_name + '.txt')
477
- # Create categorized tags file
478
- categorized_file = self.create_file(categorized_strings, output_dir, f"{image_name}_categorized.txt")
479
- txt_infos.append({'path': categorized_file, 'name': f"{image_name}_categorized.txt"})
480
- txt_infos.append({'path': txt_file, 'name': image_name + '.txt'})
481
-
482
- # Create JSON file
483
- json_content = json.dumps(categorized_json, indent=2, ensure_ascii=False)
484
- json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized.json")
485
- txt_infos.append({'path': json_file, 'name': f"{image_name}_categorized.json"})
486
-
487
- # Store results
488
- tag_results[image_path] = {
489
- 'strings': sorted_general_strings,
490
- 'categorized_strings': categorized_strings,
491
- 'categorized_json': categorized_json,
492
- 'rating': rating,
493
- 'character_res': character_res,
494
- 'general_res': general_res
495
- }
496
-
497
- # Update progress
498
- current_progress += progressRatio / progressTotal
499
- progress(current_progress, desc=f"image{idx:02d}, predict finished")
500
- timer.checkpoint(f"image{idx:02d}, predict finished")
501
- timer.report()
502
-
503
- except Exception as e:
504
- print(traceback.format_exc())
505
- print('Error predict: ' + str(e))
506
-
507
- # Create download zip
508
- download = []
509
- if txt_infos is not None and len(txt_infos) > 0:
510
- downloadZipPath = os.path.join(
511
- output_dir,
512
- 'Multi-Tagger-' + datetime.now().strftime('%Y%m%d-%H%M%S') + '.zip'
513
- )
514
- with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
515
- for info in txt_infos:
516
- taggers_zip.write(info['path'], arcname=info['name'])
517
- # If using GPU, model will auto unload after zip file creation
518
- cleanup_after_processing() # Comment here to turn off this behavior
519
- download.append(downloadZipPath)
520
-
521
- progress(1, desc=f"Predict completed")
522
- timer.report_all()
523
- print('Predict is complete.')
524
-
525
- # Return first image results as default
526
- first_image_results = '', {}, {}, {}, '', {}
527
- if gallery and len(gallery) > 0:
528
- first_image_path = gallery[0][0]
529
- if first_image_path in tag_results:
530
- first_result = tag_results[first_image_path]
531
- character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
532
- for name in first_result['character_res'].keys()])
533
- first_image_results = (
534
- first_result['strings'],
535
- first_result['rating'],
536
- character_tags_formatted,
537
- first_result['general_res'],
538
- first_result.get('categorized_strings', ''),
539
- first_result.get('categorized_json', {})
540
- )
541
-
542
-
543
- return (
544
- download,
545
- first_image_results[0],
546
- first_image_results[1],
547
- first_image_results[2],
548
- first_image_results[3],
549
- first_image_results[4],
550
- first_image_results[5],
551
- tag_results
552
- )
553
-
554
- def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
555
- # Return first image results if no selection
556
- if not selected_state and gallery and len(gallery) > 0:
557
- first_image_path = gallery[0][0]
558
- if first_image_path in tag_results:
559
- first_result = tag_results[first_image_path]
560
- character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
561
- for name in first_result['character_res'].keys()])
562
- return (
563
- first_result['strings'],
564
- first_result['rating'],
565
- character_tags_formatted,
566
- first_result['general_res'],
567
- first_result.get('categorized_strings', ''),
568
- first_result.get('categorized_json', {})
569
- )
570
-
571
- if not selected_state:
572
- return '', {}, '', {}, '', {}
573
-
574
- # Get selected image path
575
- selected_value = selected_state.value
576
- image_path = None
577
-
578
- if isinstance(selected_value, dict) and 'image' in selected_value:
579
- image_path = selected_value['image']['path']
580
- elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
581
- image_path = selected_value[0]
582
- else:
583
- image_path = str(selected_value)
584
-
585
- # Return stored results
586
- if image_path in tag_results:
587
- result = tag_results[image_path]
588
-
589
- character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
590
- for name in result['character_res'].keys()])
591
- return (
592
- result['strings'],
593
- result['rating'],
594
- character_tags_formatted,
595
- result['general_res'],
596
- result.get('categorized_strings', ''),
597
- result.get('categorized_json', {})
598
- )
599
-
600
- return '', {}, '', {}, '', {}
601
-
602
- def append_gallery(gallery: list, image: str):
603
- """Add a single image to the gallery"""
604
- if gallery is None:
605
- gallery = []
606
- if not image:
607
- return gallery, None
608
- gallery.append(image)
609
- return gallery, None
610
-
611
- def extend_gallery(gallery: list, images):
612
- """Add multiple images to the gallery"""
613
- if gallery is None:
614
- gallery = []
615
- if not images:
616
- return gallery
617
- gallery.extend(images)
618
- return gallery
619
-
620
- # Parse arguments and initialize predictor
621
- args = parse_args()
622
- predictor = Predictor()
623
- dropdown_list = [
624
- EVA02_LARGE_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO,
625
- CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO,
626
- SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO,
627
- VIT_MODEL_DSV2_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, SWINV2_MODEL_IS_DSV1_REPO
628
- ]
629
-
630
- def _restart_space():
631
- """Restart the HuggingFace Space periodically for stability"""
632
- HF_TOKEN = os.getenv('HF_TOKEN')
633
- if not HF_TOKEN:
634
- raise ValueError('HF_TOKEN environment variable is not set.')
635
- huggingface_hub.HfApi().restart_space(
636
- repo_id='Werli/Multi-Tagger',
637
- token=HF_TOKEN,
638
- factory_reboot=False
639
- )
640
-
641
- # Setup scheduler for periodic restarts
642
- scheduler = BackgroundScheduler()
643
- restart_space_job = scheduler.add_job(_restart_space, 'interval', seconds=172800)
644
- scheduler.start()
645
- next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
646
- NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
647
-
648
- with gr.Blocks(title=TITLE, css=css, theme='Werli/Purple-Crimson-Gradio-Theme', fill_width=True) as demo:
649
- gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
650
- gr.Markdown(value=f"<p style='text-align: center;'>{DESCRIPTION}</p>")
651
-
652
- with gr.Tab(label='Waifu Diffusion'):
653
- with gr.Row():
654
- with gr.Column():
655
-
656
- with gr.Column(variant='panel'):
657
- image_input = gr.Image(
658
- label='Upload an Image or clicking paste from clipboard button',
659
- type='filepath',
660
- sources=['upload', 'clipboard'],
661
- height=150
662
- )
663
- with gr.Row():
664
- upload_button = gr.UploadButton(
665
- 'Upload multiple images',
666
- file_types=['image'],
667
- file_count='multiple',
668
- size='sm'
669
- )
670
- gallery = gr.Gallery(
671
- columns=2,
672
- show_share_button=False,
673
- interactive=True,
674
- height='auto',
675
- label='Grid of images',
676
- preview=False,
677
- elem_id='custom-gallery'
678
- )
679
- submit = gr.Button(value='Analyze Images', variant='primary', size='lg')
680
- with gr.Column(variant='panel'):
681
- model_repo = gr.Dropdown(
682
- dropdown_list,
683
- value=EVA02_LARGE_MODEL_DSV3_REPO,
684
- label='1st Model'
685
- )
686
- PLUS = '+?'
687
- gr.Markdown(value=f"<p style='text-align: center;'>{PLUS}</p>")
688
- model_repo_2 = gr.Dropdown(
689
- [None] + dropdown_list,
690
- value=None,
691
- label='2nd Model (Optional)',
692
- info='Select another model for diversified results.'
693
- )
694
-
695
- with gr.Row():
696
- general_thresh = gr.Slider(
697
- 0, 1,
698
- step=args.score_slider_step,
699
- value=args.score_general_threshold,
700
- label='General Tags Threshold',
701
- scale=3
702
- )
703
- general_mcut_enabled = gr.Checkbox(
704
- value=False,
705
- label='Use MCut threshold',
706
- scale=1
707
- )
708
-
709
- with gr.Row():
710
- character_thresh = gr.Slider(
711
- 0, 1,
712
- step=args.score_slider_step,
713
- value=args.score_character_threshold,
714
- label='Character Tags Threshold',
715
- scale=3
716
- )
717
- character_mcut_enabled = gr.Checkbox(
718
- value=False,
719
- label='Use MCut threshold',
720
- scale=1
721
- )
722
-
723
- with gr.Row():
724
- characters_merge_enabled = gr.Checkbox(
725
- value=False,
726
- label='Merge characters into the string output',
727
- scale=1
728
- )
729
-
730
- with gr.Row():
731
- additional_tags_prepend = gr.Text(
732
- label='Prepend Additional tags (comma split)'
733
- )
734
- additional_tags_append = gr.Text(
735
- label='Append Additional tags (comma split)'
736
- )
737
-
738
- with gr.Row():
739
- clear = gr.ClearButton(
740
- components=[
741
- gallery, model_repo, general_thresh, general_mcut_enabled,
742
- character_thresh, character_mcut_enabled, characters_merge_enabled,
743
- additional_tags_prepend, additional_tags_append
744
- ],
745
- variant='secondary',
746
- size='lg'
747
- )
748
-
749
- with gr.Column(variant='panel'):
750
- download_file = gr.File(label='Download')
751
- character_res = gr.Textbox(
752
- label="Character tags",
753
- show_copy_button=True,
754
- lines=3
755
- )
756
- sorted_general_strings = gr.Textbox(
757
- label='Output',
758
- show_label=True,
759
- show_copy_button=True,
760
- lines=5
761
- )
762
- categorized_strings = gr.Textbox(
763
- label='Categorized',
764
- show_label=True,
765
- show_copy_button=True,
766
- lines=5
767
- )
768
- tags_json = gr.JSON(
769
- label='Categorized Tags (JSON)',
770
- visible=True
771
- )
772
- rating = gr.Label(label='Rating')
773
- general_res = gr.Textbox(
774
- label="General tags",
775
- show_copy_button=True,
776
- lines=3,
777
- visible=False # Temp
778
- )
779
- # State to store results
780
- tag_results = gr.State({})
781
-
782
- # Event handlers
783
- image_input.change(
784
- append_gallery,
785
- inputs=[gallery, image_input],
786
- outputs=[gallery, image_input]
787
- )
788
-
789
- upload_button.upload(
790
- extend_gallery,
791
- inputs=[gallery, upload_button],
792
- outputs=gallery
793
- )
794
-
795
- gallery.select(
796
- get_selection_from_gallery,
797
- inputs=[gallery, tag_results],
798
- outputs=[sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json]
799
- )
800
-
801
- submit.click(
802
- predictor.predict,
803
- inputs=[
804
- gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
805
- character_thresh, character_mcut_enabled, characters_merge_enabled,
806
- additional_tags_prepend, additional_tags_append, tag_results
807
- ],
808
- outputs=[download_file, sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json, tag_results]
809
- )
810
-
811
- gr.Examples(
812
- [['images/1girl.png', EVA02_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
813
- inputs=[image_input, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled]
814
- )
815
- gr.Markdown('[Based on SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) <p style="text-align:right"><a href="https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger-b">Prompt Enhancer</a></p>')
816
- with gr.Tab("PixAI"):
817
- pixai_interface = create_pixai_interface()
818
- with gr.Tab("Booru Image Fetcher"):
819
- booru_interface = create_booru_interface()
820
-
821
- gr.Markdown(NEXT_RESTART)
822
-
823
- demo.queue(max_size=5).launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, requests, spaces, argparse, traceback, tempfile, zipfile, re, ast, time
2
+ import gradio as gr
3
+ import numpy as np
4
+ import huggingface_hub
5
+ import onnxruntime as ort
6
+ import pandas as pd
7
+ from datetime import datetime, timezone
8
+ from collections import defaultdict
9
+ from PIL import Image, ImageOps
10
+ from apscheduler.schedulers.background import BackgroundScheduler
11
+ from modules.classifyTags import categorize_tags_output, generate_tags_json, process_tags_for_misc
12
+ from modules.pixai import create_pixai_interface
13
+ from modules.booru import create_booru_interface
14
+ from modules.multi_comfy import create_multi_comfy
15
+ from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads
16
+
17
+ """ For GPU install all the requirements.txt and the following:
18
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
19
+ pip install onnxruntime-gpu
20
+ """
21
+
22
+ """ It's recommended to create a venv:
23
+ python -m venv venv
24
+ venv\Scripts\activate
25
+ pip install ...
26
+ python app.py
27
+ """
28
+
29
+ TITLE = 'Multi-Tagger v1.4'
30
+ DESCRIPTION = '\nMulti-Tagger is a versatile application for advanced image analysis and captioning. Supports <b>CUDA</b> and <b>CPU</b>.\n'
31
+
32
+ SWINV2_MODEL_DSV3_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
33
+ CONV_MODEL_DSV3_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
34
+ VIT_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-tagger-v3'
35
+ VIT_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
36
+ EVA02_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-eva02-large-tagger-v3'
37
+ MOAT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-moat-tagger-v2'
38
+ SWIN_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-swinv2-tagger-v2'
39
+ CONV_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
40
+ CONV2_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2'
41
+ VIT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-vit-tagger-v2'
42
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-eva02-large-tagger-v1'
43
+ SWINV2_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-swinv2-tagger-v1'
44
+
45
+ # Global variables for model components (for memory management)
46
+ CURRENT_MODEL = None
47
+ CURRENT_MODEL_NAME = None
48
+ CURRENT_TAGS_DF = None
49
+ CURRENT_TAG_NAMES = None
50
+ CURRENT_RATING_INDEXES = None
51
+ CURRENT_GENERAL_INDEXES = None
52
+ CURRENT_CHARACTER_INDEXES = None
53
+ CURRENT_MODEL_TARGET_SIZE = None
54
+
55
+ # Custom CSS for gallery styling
56
+ css = """
57
+ #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
58
+ #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
59
+ #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
60
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
61
+ #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
62
+ #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
63
+ .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
64
+ .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
65
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
66
+ """
67
+
68
+ MODEL_FILENAME = 'model.onnx'
69
+ LABEL_FILENAME = 'selected_tags.csv'
70
+
71
+ class Timer:
72
+ """Utility class for measuring execution time of different operations"""
73
+
74
+ def __init__(self):
75
+ self.start_time = time.perf_counter()
76
+ self.checkpoints = [('Start', self.start_time)]
77
+
78
+ def checkpoint(self, label='Checkpoint'):
79
+ """Add a checkpoint with a label"""
80
+ now = time.perf_counter()
81
+ self.checkpoints.append((label, now))
82
+
83
+ def report(self, is_clear_checkpoints=True):
84
+ """Report time elapsed since last checkpoint"""
85
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
86
+ prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
87
+
88
+ for (label, curr_time) in self.checkpoints[1:]:
89
+ elapsed = curr_time - prev_time
90
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
91
+ prev_time = curr_time
92
+
93
+ if is_clear_checkpoints:
94
+ self.checkpoints.clear()
95
+ self.checkpoint()
96
+
97
+ def report_all(self):
98
+ """Report all checkpoint times including total execution time"""
99
+ print('\n> Execution Time Report:')
100
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
101
+ prev_time = self.start_time
102
+
103
+ for (label, curr_time) in self.checkpoints[1:]:
104
+ elapsed = curr_time - prev_time
105
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
106
+ prev_time = curr_time
107
+
108
+ total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
109
+ print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
110
+ self.checkpoints.clear()
111
+
112
+ def restart(self):
113
+ """Restart the timer"""
114
+ self.start_time = time.perf_counter()
115
+ self.checkpoints = [('Start', self.start_time)]
116
+
117
+ def parse_args() -> argparse.Namespace:
118
+ """Parse command line arguments"""
119
+ parser = argparse.ArgumentParser()
120
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
121
+ parser.add_argument('--score-general-threshold', type=float, default=0.35)
122
+ parser.add_argument('--score-character-threshold', type=float, default=0.85)
123
+ parser.add_argument('--share', action='store_true')
124
+ return parser.parse_args()
125
+
126
+ def load_labels(dataframe) -> tuple:
127
+ """Load tag names and their category indexes from the dataframe"""
128
+ name_series = dataframe['name']
129
+ tag_names = name_series.tolist()
130
+
131
+ # Find indexes for different tag categories
132
+ rating_indexes = list(np.where(dataframe['category'] == 9)[0])
133
+ general_indexes = list(np.where(dataframe['category'] == 0)[0])
134
+ character_indexes = list(np.where(dataframe['category'] == 4)[0])
135
+
136
+ return tag_names, rating_indexes, general_indexes, character_indexes
137
+
138
+ def mcut_threshold(probs):
139
+ """Calculate threshold using Maximum Change in second derivative (MCut) method"""
140
+ sorted_probs = probs[probs.argsort()[::-1]]
141
+ difs = sorted_probs[:-1] - sorted_probs[1:]
142
+ t = difs.argmax()
143
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
144
+ return thresh
145
+
146
+ def _download_model_files(model_repo):
147
+ """Download model files from HuggingFace Hub"""
148
+ csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
149
+ model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
150
+ return csv_path, model_path
151
+
152
+ def create_optimized_ort_session(model_path):
153
+ """Create an optimized ONNX Runtime session with GPU support"""
154
+ # Configure session options for better performance
155
+ sess_options = ort.SessionOptions()
156
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
157
+ sess_options.intra_op_num_threads = 0 # Use all available cores
158
+ sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
159
+ sess_options.enable_mem_pattern = True
160
+ sess_options.enable_cpu_mem_arena = True
161
+
162
+ # Check available providers
163
+ available_providers = ort.get_available_providers()
164
+ print(f"Available ONNX Runtime providers: {available_providers}")
165
+
166
+ # Configure execution providers (prefer CUDA if available)
167
+ providers = []
168
+
169
+ # Use CUDA if available
170
+ if 'CUDAExecutionProvider' in available_providers:
171
+ providers.append('CUDAExecutionProvider')
172
+ print("Using CUDA provider for ONNX inference")
173
+ else:
174
+ print("CUDA provider not available, falling back to CPU")
175
+
176
+ # Always include CPU as fallback
177
+ providers.append('CPUExecutionProvider')
178
+
179
+ try:
180
+ session = ort.InferenceSession(model_path, sess_options, providers=providers)
181
+ print(f"Model loaded with providers: {session.get_providers()}")
182
+ return session
183
+ except Exception as e:
184
+ print(f"Failed to create ONNX session: {e}")
185
+ raise
186
+
187
+ def _load_model_components_optimized(model_repo):
188
+ """Load and optimize model components"""
189
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
190
+ global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
191
+
192
+ # Only reload if model changed
193
+ if model_repo == CURRENT_MODEL_NAME and CURRENT_MODEL is not None:
194
+ return
195
+
196
+ # Download files
197
+ csv_path, model_path = _download_model_files(model_repo)
198
+
199
+ # Load optimized ONNX model
200
+ CURRENT_MODEL = create_optimized_ort_session(model_path)
201
+
202
+ # Load tags
203
+ tags_df = pd.read_csv(csv_path)
204
+ tag_names, rating_indexes, general_indexes, character_indexes = load_labels(tags_df)
205
+
206
+ # Store in global variables
207
+ CURRENT_TAGS_DF = tags_df
208
+ CURRENT_TAG_NAMES = tag_names
209
+ CURRENT_RATING_INDEXES = rating_indexes
210
+ CURRENT_GENERAL_INDEXES = general_indexes
211
+ CURRENT_CHARACTER_INDEXES = character_indexes
212
+
213
+ # Get model input size
214
+ _, height, width, _ = CURRENT_MODEL.get_inputs()[0].shape
215
+ CURRENT_MODEL_TARGET_SIZE = height
216
+ CURRENT_MODEL_NAME = model_repo
217
+
218
+ def _raw_predict(image_array, model_session):
219
+ """Run raw prediction using the model session"""
220
+ input_name = model_session.get_inputs()[0].name
221
+ label_name = model_session.get_outputs()[0].name
222
+ preds = model_session.run([label_name], {input_name: image_array})[0]
223
+ return preds[0].astype(float)
224
+
225
+ def unload_model():
226
+ """Explicitly unload the current model from memory"""
227
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
228
+ global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
229
+
230
+ # Delete the model session
231
+ if CURRENT_MODEL is not None:
232
+ del CURRENT_MODEL
233
+ CURRENT_MODEL = None
234
+
235
+ # Clear other large objects
236
+ CURRENT_TAGS_DF = None
237
+ CURRENT_TAG_NAMES = None
238
+ CURRENT_RATING_INDEXES = None
239
+ CURRENT_GENERAL_INDEXES = None
240
+ CURRENT_CHARACTER_INDEXES = None
241
+ CURRENT_MODEL_TARGET_SIZE = None
242
+ CURRENT_MODEL_NAME = None
243
+
244
+ # Force garbage collection
245
+ import gc
246
+ gc.collect()
247
+
248
+ # Clear CUDA cache if using GPU
249
+ try:
250
+ import torch
251
+ if torch.cuda.is_available():
252
+ torch.cuda.empty_cache()
253
+ except ImportError:
254
+ pass
255
+
256
+ def cleanup_after_processing():
257
+ """Cleanup resources after processing"""
258
+ unload_model()
259
+
260
+ class Predictor:
261
+ """Main predictor class for handling image tagging"""
262
+
263
+ def __init__(self):
264
+ self.model_components = None
265
+ self.last_loaded_repo = None
266
+
267
+ def load_model(self, model_repo):
268
+ """Load model if not already loaded"""
269
+ if model_repo == self.last_loaded_repo and self.model_components is not None:
270
+ return
271
+ _load_model_components_optimized(model_repo)
272
+ self.last_loaded_repo = model_repo
273
+
274
+ def prepare_image(self, path):
275
+ """Prepare image for model input"""
276
+ image = Image.open(path)
277
+ image = image.convert('RGBA')
278
+ target_size = CURRENT_MODEL_TARGET_SIZE
279
+
280
+ # Create white background and composite
281
+ canvas = Image.new('RGBA', image.size, (255, 255, 255))
282
+ canvas.alpha_composite(image)
283
+ image = canvas.convert('RGB')
284
+
285
+ # Pad to square
286
+ image_shape = image.size
287
+ max_dim = max(image_shape)
288
+ pad_left = (max_dim - image_shape[0]) // 2
289
+ pad_top = (max_dim - image_shape[1]) // 2
290
+ padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
291
+ padded_image.paste(image, (pad_left, pad_top))
292
+
293
+ # Resize if needed
294
+ if max_dim != target_size:
295
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
296
+
297
+ # Convert to array and preprocess
298
+ image_array = np.asarray(padded_image, dtype=np.float32)
299
+ image_array = image_array[:, :, ::-1] # BGR to RGB
300
+ return np.expand_dims(image_array, axis=0)
301
+
302
+ def create_file(self, content: str, directory: str, fileName: str) -> str:
303
+ """Create a file with the given content"""
304
+ file_path = os.path.join(directory, fileName)
305
+ if fileName.endswith('.json'):
306
+ with open(file_path, 'w', encoding='utf-8') as file:
307
+ file.write(content)
308
+ else:
309
+ with open(file_path, 'w+', encoding='utf-8') as file:
310
+ file.write(content)
311
+ return file_path
312
+
313
+ def predict(self, gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
314
+ character_thresh, character_mcut_enabled, characters_merge_enabled,
315
+ additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()):
316
+ """Main prediction function for processing images"""
317
+ tag_results.clear()
318
+ gallery_len = len(gallery)
319
+ print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
320
+
321
+ timer = Timer()
322
+ progressRatio = 1
323
+ progressTotal = gallery_len + 1
324
+ current_progress = 0
325
+ txt_infos = []
326
+ output_dir = tempfile.mkdtemp()
327
+
328
+ if not os.path.exists(output_dir):
329
+ os.makedirs(output_dir)
330
+
331
+ # Load initial model
332
+ self.load_model(model_repo)
333
+ current_progress += progressRatio / progressTotal
334
+ progress(current_progress, desc='Initialize wd model finished')
335
+ timer.checkpoint("Initialize wd model")
336
+ timer.report()
337
+
338
+ name_counters = defaultdict(int)
339
+
340
+ for (idx, value) in enumerate(gallery):
341
+ try:
342
+ # Handle duplicate filenames
343
+ image_path = value[0]
344
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
345
+ name_counters[image_name] += 1
346
+ if name_counters[image_name] > 1:
347
+ image_name = f"{image_name}_{name_counters[image_name]:02d}"
348
+
349
+ # Prepare image
350
+ image = self.prepare_image(image_path)
351
+ print(f"Gallery {idx:02d}: Starting run first model ({model_repo})...")
352
+
353
+ # Load and run first model
354
+ self.load_model(model_repo)
355
+ preds = _raw_predict(image, CURRENT_MODEL)
356
+ labels = list(zip(CURRENT_TAG_NAMES, preds))
357
+
358
+ # Process ratings
359
+ ratings_names = [labels[i] for i in CURRENT_RATING_INDEXES]
360
+ rating = dict(ratings_names)
361
+
362
+ # Process general tags
363
+ general_names = [labels[i] for i in CURRENT_GENERAL_INDEXES]
364
+ if general_mcut_enabled:
365
+ general_probs = np.array([x[1] for x in general_names])
366
+ general_thresh_temp = mcut_threshold(general_probs)
367
+ else:
368
+ general_thresh_temp = general_thresh
369
+
370
+ general_res = [x for x in general_names if x[1] > general_thresh_temp]
371
+ general_res = dict(general_res)
372
+
373
+ # Process character tags
374
+ character_names = [labels[i] for i in CURRENT_CHARACTER_INDEXES]
375
+ if character_mcut_enabled:
376
+ character_probs = np.array([x[1] for x in character_names])
377
+ character_thresh_temp = mcut_threshold(character_probs)
378
+ character_thresh_temp = max(0.15, character_thresh_temp)
379
+ else:
380
+ character_thresh_temp = character_thresh
381
+
382
+ character_res = [x for x in character_names if x[1] > character_thresh_temp]
383
+ character_res = dict(character_res)
384
+ character_list_1 = list(character_res.keys())
385
+
386
+ # Sort general tags by confidence
387
+ sorted_general_list_1 = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
388
+ sorted_general_list_1 = [x[0] for x in sorted_general_list_1]
389
+
390
+ # Handle second model if provided
391
+ if model_repo_2 and model_repo_2 != model_repo:
392
+ print(f"Gallery {idx:02d}: Starting run second model ({model_repo_2})...")
393
+ self.load_model(model_repo_2)
394
+ preds_2 = _raw_predict(image, CURRENT_MODEL)
395
+ labels_2 = list(zip(CURRENT_TAG_NAMES, preds_2))
396
+
397
+ # Process general tags from second model
398
+ general_names_2 = [labels_2[i] for i in CURRENT_GENERAL_INDEXES]
399
+ if general_mcut_enabled:
400
+ general_probs_2 = np.array([x[1] for x in general_names_2])
401
+ general_thresh_temp_2 = mcut_threshold(general_probs_2)
402
+ else:
403
+ general_thresh_temp_2 = general_thresh
404
+
405
+ general_res_2 = [x for x in general_names_2 if x[1] > general_thresh_temp_2]
406
+ general_res_2 = dict(general_res_2)
407
+
408
+ # Process character tags from second model
409
+ character_names_2 = [labels_2[i] for i in CURRENT_CHARACTER_INDEXES]
410
+ if character_mcut_enabled:
411
+ character_probs_2 = np.array([x[1] for x in character_names_2])
412
+ character_thresh_temp_2 = mcut_threshold(character_probs_2)
413
+ character_thresh_temp_2 = max(0.15, character_thresh_temp_2)
414
+ else:
415
+ character_thresh_temp_2 = character_thresh
416
+
417
+ character_res_2 = [x for x in character_names_2 if x[1] > character_thresh_temp_2]
418
+ character_res_2 = dict(character_res_2)
419
+ character_list_2 = list(character_res_2.keys())
420
+
421
+ # Sort general tags from second model
422
+ sorted_general_list_2 = sorted(general_res_2.items(), key=lambda x: x[1], reverse=True)
423
+ sorted_general_list_2 = [x[0] for x in sorted_general_list_2]
424
+
425
+ # Combine results from both models
426
+ combined_character_list = list(set(character_list_1 + character_list_2))
427
+ combined_general_list = list(set(sorted_general_list_1 + sorted_general_list_2))
428
+ else:
429
+ combined_character_list = character_list_1
430
+ combined_general_list = sorted_general_list_1
431
+
432
+ # Remove characters from general tags if merging is disabled
433
+ if not characters_merge_enabled:
434
+ combined_character_list = [item for item in combined_character_list
435
+ if item not in combined_general_list]
436
+
437
+ # Handle additional tags
438
+ prepend_list = [tag.strip() for tag in additional_tags_prepend.split(',') if tag.strip()]
439
+ append_list = [tag.strip() for tag in additional_tags_append.split(',') if tag.strip()]
440
+
441
+ # Avoid duplicates in prepend/append lists
442
+ if prepend_list and append_list:
443
+ append_list = [item for item in append_list if item not in prepend_list]
444
+
445
+ # Remove prepended tags from main list
446
+ if prepend_list:
447
+ combined_general_list = [item for item in combined_general_list if item not in prepend_list]
448
+
449
+ # Remove appended tags from main list
450
+ if append_list:
451
+ combined_general_list = [item for item in combined_general_list if item not in append_list]
452
+
453
+ # Combine all tags
454
+ combined_general_list = prepend_list + combined_general_list + append_list
455
+
456
+ # Format output string
457
+ sorted_general_strings = ', '.join(
458
+ (combined_character_list if characters_merge_enabled else []) +
459
+ combined_general_list
460
+ ).replace('(', '\\(').replace(')', '\\)').replace('_', ' ')
461
+
462
+ # Generate categorized output
463
+ categorized_strings = categorize_tags_output(sorted_general_strings, character_res).replace('(', '\\(').replace(')', '\\)')
464
+ categorized_json = generate_tags_json(sorted_general_strings, character_res)
465
+
466
+ # Create output files
467
+ txt_content = f"Output (string): {sorted_general_strings}\n\nCategorized Output: {categorized_strings}"
468
+ txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
469
+ txt_infos.append({'path': txt_file, 'name': f"{image_name}_output.txt"})
470
+
471
+ # Save image copy
472
+ image_path = value[0]
473
+ image = Image.open(image_path)
474
+ image.save(os.path.join(output_dir, f"{image_name}.png"), format='PNG')
475
+ txt_infos.append({'path': os.path.join(output_dir, f"{image_name}.png"), 'name': f"{image_name}.png"})
476
+
477
+ # Create tags text file
478
+ txt_file = self.create_file(sorted_general_strings, output_dir, image_name + '.txt')
479
+ # Create categorized tags file
480
+ categorized_file = self.create_file(categorized_strings, output_dir, f"{image_name}_categorized.txt")
481
+ txt_infos.append({'path': categorized_file, 'name': f"{image_name}_categorized.txt"})
482
+ txt_infos.append({'path': txt_file, 'name': image_name + '.txt'})
483
+
484
+ # Create JSON file
485
+ json_content = json.dumps(categorized_json, indent=2, ensure_ascii=False)
486
+ json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized.json")
487
+ txt_infos.append({'path': json_file, 'name': f"{image_name}_categorized.json"})
488
+
489
+ # Store results
490
+ tag_results[image_path] = {
491
+ 'strings': sorted_general_strings,
492
+ 'categorized_strings': categorized_strings,
493
+ 'categorized_json': categorized_json,
494
+ 'rating': rating,
495
+ 'character_res': character_res,
496
+ 'general_res': general_res
497
+ }
498
+
499
+ # Update progress
500
+ current_progress += progressRatio / progressTotal
501
+ progress(current_progress, desc=f"image{idx:02d}, predict finished")
502
+ timer.checkpoint(f"image{idx:02d}, predict finished")
503
+ timer.report()
504
+
505
+ except Exception as e:
506
+ print(traceback.format_exc())
507
+ print('Error predict: ' + str(e))
508
+
509
+ # Create download zip
510
+ download = []
511
+ if txt_infos is not None and len(txt_infos) > 0:
512
+ downloadZipPath = os.path.join(
513
+ output_dir,
514
+ 'Multi-Tagger-' + datetime.now().strftime('%Y%m%d-%H%M%S') + '.zip'
515
+ )
516
+ with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
517
+ for info in txt_infos:
518
+ taggers_zip.write(info['path'], arcname=info['name'])
519
+ # If using GPU, model will auto unload after zip file creation
520
+ cleanup_after_processing() # Comment here to turn off this behavior
521
+ download.append(downloadZipPath)
522
+
523
+ progress(1, desc=f"Predict completed")
524
+ timer.report_all()
525
+ print('Predict is complete.')
526
+
527
+ # Return first image results as default
528
+ first_image_results = '', {}, {}, {}, '', {}
529
+ if gallery and len(gallery) > 0:
530
+ first_image_path = gallery[0][0]
531
+ if first_image_path in tag_results:
532
+ first_result = tag_results[first_image_path]
533
+ character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
534
+ for name in first_result['character_res'].keys()])
535
+ first_image_results = (
536
+ first_result['strings'],
537
+ first_result['rating'],
538
+ character_tags_formatted,
539
+ first_result['general_res'],
540
+ first_result.get('categorized_strings', ''),
541
+ first_result.get('categorized_json', {})
542
+ )
543
+
544
+
545
+ return (
546
+ download,
547
+ first_image_results[0],
548
+ first_image_results[1],
549
+ first_image_results[2],
550
+ first_image_results[3],
551
+ first_image_results[4],
552
+ first_image_results[5],
553
+ tag_results
554
+ )
555
+
556
+ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
557
+ # Return first image results if no selection
558
+ if not selected_state and gallery and len(gallery) > 0:
559
+ first_image_path = gallery[0][0]
560
+ if first_image_path in tag_results:
561
+ first_result = tag_results[first_image_path]
562
+ character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
563
+ for name in first_result['character_res'].keys()])
564
+ return (
565
+ first_result['strings'],
566
+ first_result['rating'],
567
+ character_tags_formatted,
568
+ first_result['general_res'],
569
+ first_result.get('categorized_strings', ''),
570
+ first_result.get('categorized_json', {})
571
+ )
572
+
573
+ if not selected_state:
574
+ return '', {}, '', {}, '', {}
575
+
576
+ # Get selected image path
577
+ selected_value = selected_state.value
578
+ image_path = None
579
+
580
+ if isinstance(selected_value, dict) and 'image' in selected_value:
581
+ image_path = selected_value['image']['path']
582
+ elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
583
+ image_path = selected_value[0]
584
+ else:
585
+ image_path = str(selected_value)
586
+
587
+ # Return stored results
588
+ if image_path in tag_results:
589
+ result = tag_results[image_path]
590
+
591
+ character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
592
+ for name in result['character_res'].keys()])
593
+ return (
594
+ result['strings'],
595
+ result['rating'],
596
+ character_tags_formatted,
597
+ result['general_res'],
598
+ result.get('categorized_strings', ''),
599
+ result.get('categorized_json', {})
600
+ )
601
+
602
+ return '', {}, '', {}, '', {}
603
+
604
+ def append_gallery(gallery: list, image: str):
605
+ """Add a single media file (image or video) to the gallery"""
606
+ return handle_single_media_upload(image, gallery)
607
+
608
+ def extend_gallery(gallery: list, images):
609
+ """Add multiple media files (images or videos) to the gallery"""
610
+ return handle_multiple_media_uploads(images, gallery)
611
+
612
+ # Parse arguments and initialize predictor
613
+ args = parse_args()
614
+ predictor = Predictor()
615
+ dropdown_list = [
616
+ EVA02_LARGE_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO,
617
+ CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO,
618
+ SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO,
619
+ VIT_MODEL_DSV2_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, SWINV2_MODEL_IS_DSV1_REPO
620
+ ]
621
+
622
+ def _restart_space():
623
+ """Restart the HuggingFace Space periodically for stability"""
624
+ HF_TOKEN = os.getenv('HF_TOKEN')
625
+ if not HF_TOKEN:
626
+ raise ValueError('HF_TOKEN environment variable is not set.')
627
+ huggingface_hub.HfApi().restart_space(
628
+ repo_id='Werli/Multi-Tagger',
629
+ token=HF_TOKEN,
630
+ factory_reboot=False
631
+ )
632
+
633
+ # Setup scheduler for periodic restarts
634
+ scheduler = BackgroundScheduler()
635
+ restart_space_job = scheduler.add_job(_restart_space, 'interval', seconds=172800)
636
+ scheduler.start()
637
+ next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
638
+ NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
639
+
640
+
641
+ with gr.Blocks(title=TITLE, css=css, theme="Werli/Purple-Crimson-Gradio-Theme", fill_width=True) as demo:
642
+ gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
643
+ gr.Markdown(value=f"<p style='text-align: center;'>{DESCRIPTION}</p>")
644
+
645
+ with gr.Tab(label='Waifu Diffusion'):
646
+ with gr.Row():
647
+ with gr.Column():
648
+
649
+ with gr.Column(variant='panel'):
650
+ image_input = gr.Image(
651
+ label='Upload an Image (or paste from clipboard)',
652
+ type='filepath',
653
+ sources=['upload', 'clipboard'],
654
+ height=150
655
+ )
656
+ with gr.Row():
657
+ upload_button = gr.UploadButton(
658
+ 'Upload multiple images or videos',
659
+ file_types=['image', 'video'],
660
+ file_count='multiple',
661
+ size='sm'
662
+ )
663
+ gallery = gr.Gallery(
664
+ columns=2,
665
+ show_share_button=False,
666
+ interactive=True,
667
+ height='auto',
668
+ label='Grid of images',
669
+ preview=False,
670
+ elem_id='custom-gallery'
671
+ )
672
+ submit = gr.Button(value='Analyze Images', variant='primary', size='lg')
673
+ with gr.Column(variant='panel'):
674
+ model_repo = gr.Dropdown(
675
+ dropdown_list,
676
+ value=EVA02_LARGE_MODEL_DSV3_REPO,
677
+ label='1st Model'
678
+ )
679
+ PLUS = '+?'
680
+ gr.Markdown(value=f"<p style='text-align: center;'>{PLUS}</p>")
681
+ model_repo_2 = gr.Dropdown(
682
+ [None] + dropdown_list,
683
+ value=None,
684
+ label='2nd Model (Optional)',
685
+ info='Select another model for diversified results.'
686
+ )
687
+
688
+ with gr.Row():
689
+ general_thresh = gr.Slider(
690
+ 0, 1,
691
+ step=args.score_slider_step,
692
+ value=args.score_general_threshold,
693
+ label='General Tags Threshold',
694
+ scale=3
695
+ )
696
+ general_mcut_enabled = gr.Checkbox(
697
+ value=False,
698
+ label='Use MCut threshold',
699
+ scale=1
700
+ )
701
+
702
+ with gr.Row():
703
+ character_thresh = gr.Slider(
704
+ 0, 1,
705
+ step=args.score_slider_step,
706
+ value=args.score_character_threshold,
707
+ label='Character Tags Threshold',
708
+ scale=3
709
+ )
710
+ character_mcut_enabled = gr.Checkbox(
711
+ value=False,
712
+ label='Use MCut threshold',
713
+ scale=1
714
+ )
715
+
716
+ with gr.Row():
717
+ characters_merge_enabled = gr.Checkbox(
718
+ value=False,
719
+ label='Merge characters into the string output',
720
+ scale=1
721
+ )
722
+
723
+ with gr.Row():
724
+ additional_tags_prepend = gr.Text(
725
+ label='Prepend Additional tags (comma split)'
726
+ )
727
+ additional_tags_append = gr.Text(
728
+ label='Append Additional tags (comma split)'
729
+ )
730
+
731
+ with gr.Row():
732
+ clear = gr.ClearButton(
733
+ components=[
734
+ gallery, model_repo, general_thresh, general_mcut_enabled,
735
+ character_thresh, character_mcut_enabled, characters_merge_enabled,
736
+ additional_tags_prepend, additional_tags_append
737
+ ],
738
+ variant='secondary',
739
+ size='lg'
740
+ )
741
+
742
+ with gr.Column(variant='panel'):
743
+ download_file = gr.File(label='Download')
744
+ character_res = gr.Textbox(
745
+ label="Character tags",
746
+ show_copy_button=True,
747
+ lines=3
748
+ )
749
+ sorted_general_strings = gr.Textbox(
750
+ label='Output',
751
+ show_label=True,
752
+ show_copy_button=True,
753
+ lines=5
754
+ )
755
+ categorized_strings = gr.Textbox(
756
+ label='Categorized',
757
+ show_label=True,
758
+ show_copy_button=True,
759
+ lines=5
760
+ )
761
+ tags_json = gr.JSON(
762
+ label='Categorized Tags (JSON)',
763
+ visible=True
764
+ )
765
+ rating = gr.Label(label='Rating')
766
+ general_res = gr.Textbox(
767
+ label="General tags",
768
+ show_copy_button=True,
769
+ lines=3,
770
+ visible=False # Temp
771
+ )
772
+ # State to store results
773
+ tag_results = gr.State({})
774
+
775
+ # Event handlers
776
+ image_input.change(
777
+ append_gallery,
778
+ inputs=[gallery, image_input],
779
+ outputs=[gallery, image_input]
780
+ )
781
+
782
+ upload_button.upload(
783
+ extend_gallery,
784
+ inputs=[gallery, upload_button],
785
+ outputs=gallery
786
+ )
787
+
788
+ gallery.select(
789
+ get_selection_from_gallery,
790
+ inputs=[gallery, tag_results],
791
+ outputs=[sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json]
792
+ )
793
+
794
+ submit.click(
795
+ predictor.predict,
796
+ inputs=[
797
+ gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
798
+ character_thresh, character_mcut_enabled, characters_merge_enabled,
799
+ additional_tags_prepend, additional_tags_append, tag_results
800
+ ],
801
+ outputs=[download_file, sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json, tag_results]
802
+ )
803
+
804
+ gr.Examples(
805
+ [['images/1girl.png', EVA02_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
806
+ inputs=[image_input, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled]
807
+ )
808
+ gr.Markdown('[Based on SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) <p style="text-align:right"><a href="https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger-b">Prompt Enhancer</a></p>')
809
+ with gr.Tab("PixAI"):
810
+ pixai_interface = create_pixai_interface()
811
+ with gr.Tab("Booru Image Fetcher"):
812
+ booru_interface = create_booru_interface()
813
+ with gr.Tab("ComfyUI Extractor"):
814
+ comfy_interface = create_multi_comfy()
815
+ with gr.Tab(label="Misc"):
816
+ with gr.Row():
817
+ with gr.Column(variant="panel"):
818
+ tag_string = gr.Textbox(
819
+ label="Input Tags",
820
+ placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...",
821
+ lines=4
822
+ )
823
+ submit_button = gr.Button(value="START", variant="primary", size="lg")
824
+ with gr.Column(variant="panel"):
825
+ cleaned_tags_output = gr.Textbox(
826
+ label="Cleaned Tags",
827
+ show_label=True,
828
+ show_copy_button=True,
829
+ lines=4,
830
+ info="Tags with ? and numbers removed, formatted with commas. Useful for clearing tags from Booru sites."
831
+ )
832
+ classify_tags_for_display = gr.Textbox(
833
+ label="Categorized (string)",
834
+ show_label=True,
835
+ show_copy_button=True,
836
+ lines=8,
837
+ info="Tags organized by categories"
838
+ )
839
+ generate_categorized_json = gr.JSON(
840
+ label="Categorized JSON (tags)"
841
+ )
842
+
843
+ # Fix the event handler to properly call the function
844
+ submit_button.click(
845
+ process_tags_for_misc,
846
+ inputs=[tag_string],
847
+ outputs=[cleaned_tags_output, classify_tags_for_display, generate_categorized_json]
848
+ )
849
+ gr.Markdown(NEXT_RESTART)
850
+
851
+ demo.queue(max_size=5).launch(show_error=True)
modules/classifyTags.py CHANGED
@@ -1,9 +1,10 @@
1
  import re
2
  from collections import defaultdict
 
3
 
4
  # Test: Define priority tags that should always come first
5
  PRIORITY_TAGS = [
6
- '1girl', '2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls',
7
  '1boy', '2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
8
  'male_focus', 'female_focus', 'other_focus'
9
  ]
@@ -45,293 +46,465 @@ categories = {
45
  'Others':['2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022', '2023', '2024', 'artist', 'artist_name', 'artistic_error', 'asian', '(company)', 'character_name', 'content_rating', 'copyright', 'cover_page', 'dated', 'english_text', 'japan', 'layer', 'logo', 'name', 'numbered', 'page_number', 'pixiv_id', 'language', 'reference_sheet', 'signature', 'speech_bubble', 'subtitled', 'text', 'thank_you', 'typo', 'username', 'wallpaper', 'watermark', 'web_address', 'screwdriver', 'translated'],
46
  'Quality Tags':['masterpiece', '_quality', 'highres', 'absurdres', 'ultra-detailed', 'lowres']}
47
 
48
- # Build a trie for efficient prefix matching
49
- class TrieNode:
50
- def __init__(self):
51
- self.children = {}
52
- self.category = None
53
 
54
- class TagTrie:
55
- def __init__(self):
56
- self.root = TrieNode()
57
- self._build_trie()
58
-
59
- def _build_trie(self):
60
- for category, tags in categories.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  for tag in tags:
62
- node = self.root
63
- for char in tag:
64
- if char not in node.children:
65
- node.children[char] = TrieNode()
66
- node = node.children[char]
67
- node.category = category
68
-
69
- def find_category(self, tag):
70
- node = self.root
71
- matched_category = None
72
-
 
 
 
 
 
73
  # Try exact match first
74
- for char in tag:
75
- if char in node.children:
76
- node = node.children[char]
77
- if node.category:
78
- matched_category = node.category
79
- else:
80
- break
81
-
82
- # If exact match found, return it
83
- if matched_category and node.children == {}:
84
- return matched_category
85
-
86
- # If partial match found, check if it's a valid prefix
87
- if matched_category:
88
- return matched_category
89
-
90
- # Try substring matching for longer than 3 characters
91
  for i in range(len(tag)):
92
- for j in range(i+4, len(tag)+1): # Only check substrings longer than 3 chars
93
  substring = tag[i:j]
94
- node = self.root
95
- valid = True
96
- for char in substring:
97
- if char in node.children:
98
- node = node.children[char]
99
- else:
100
- valid = False
101
- break
102
- if valid and node.category:
103
- return node.category
104
-
105
  return None
106
 
107
- tag_trie = TagTrie()
108
-
109
- def normalize_tag(tag):
110
- """Normalize tag by converting spaces/hyphens to underscores"""
111
- return re.sub(r'[-\s]+', '_', tag.strip())
112
-
113
- def classify_single_tag(tag):
114
- """Classify a single tag into its category"""
115
- normalized_tag = normalize_tag(tag)
116
-
117
- # Try exact match through Trie lookup first
118
- category = tag_trie.find_category(normalized_tag)
119
-
120
- # If no match and has underscores, try parts
121
- if not category and '_' in normalized_tag:
122
- parts = normalized_tag.split('_')
123
- for part in parts:
124
- if len(part) > 3: # Only check parts longer than 3 characters
125
- category = tag_trie.find_category(part)
126
- if category:
127
- break
128
-
129
- # Special handling for escaped parentheses
130
- if not category and ('\\(' in normalized_tag or '\\)' in normalized_tag):
131
- unescaped = normalized_tag.replace('\\(', '(').replace('\\)', ')')
132
- category = tag_trie.find_category(unescaped)
133
 
134
- if not category and '_' in unescaped:
135
- parts = unescaped.split('_')
136
- for part in parts:
137
- if len(part) > 3:
138
- category = tag_trie.find_category(part)
139
- if category:
140
- break
141
-
142
- return category if category else 'Uncategorized'
143
-
144
- def extract_priority_and_character_tags(tags_list, character_tags):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
- Extract priority tags and character tags from the tags list
147
-
148
  Args:
149
- tags_list (list): List of all tags
150
- character_tags (dict): Dictionary of character tags with confidence scores
151
-
152
  Returns:
153
- tuple: (priority_tags, character_tag_names, remaining_tags)
154
  """
155
- priority_tags_found = []
156
- character_tag_names = list(character_tags.keys()) if character_tags else []
157
- remaining_tags = []
158
-
159
- # Convert priority tags to set for faster lookup
160
- priority_set = set(PRIORITY_TAGS)
161
-
162
- for tag in tags_list:
163
- if tag in priority_set:
164
- priority_tags_found.append(tag)
165
- elif tag in character_tag_names:
166
- # Character tags are already handled separately
167
- remaining_tags.append(tag)
168
- else:
169
- remaining_tags.append(tag)
170
 
171
- return priority_tags_found, character_tag_names, remaining_tags
172
 
173
- def classify_tags_for_display(tag_string, character_tags=None):
174
  """
175
- Classify a string of tags and organize them by categories with priority ordering for display
176
-
177
  Args:
178
- tag_string (str): Comma-separated tags string
179
- character_tags (dict): Dictionary of character tags with confidence scores
180
-
181
  Returns:
182
- str: Categorized and organized tags as a comma-separated string
183
  """
184
- if not tag_string:
185
- return ""
186
-
187
- # Split tags by common delimiters
188
- delimiters = r'[,\n\r\.!?]+'
189
- raw_tags = re.split(delimiters, tag_string)
190
 
191
- # Clean and normalize tags
192
- cleaned_tags = []
193
- for tag in raw_tags:
194
- tag = tag.strip()
195
- if tag:
196
- cleaned_tags.append(tag)
197
 
198
- # Extract priority and character tags
199
- priority_tags_found, character_tag_names, remaining_tags = extract_priority_and_character_tags(cleaned_tags, character_tags)
200
-
201
- # Classify remaining tags
202
- categorized = defaultdict(list)
203
- uncategorized = []
204
-
205
- for tag in remaining_tags:
206
- # Skip character tags as they're already in their own list
207
- if tag in character_tag_names:
208
- continue
209
-
210
- category = classify_single_tag(tag)
211
- if category == 'Uncategorized':
212
- uncategorized.append(tag)
213
- else:
214
- categorized[category].append(tag)
215
 
216
- # Build result string with priority ordering
217
- result_parts = []
218
 
219
- # 1. Add priority subject tags first
220
- result_parts.extend(priority_tags_found)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # 2. Add character tags next
223
- result_parts.extend(character_tag_names)
224
 
225
- # 3. Add categorized tags in category order
226
- for category in categories.keys():
227
- if category in categorized and categorized[category]:
228
- result_parts.extend(categorized[category])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # 4. Add uncategorized tags at the end
231
- result_parts.extend(uncategorized)
232
 
233
- # Process tags: replace underscores with spaces and handle escaped characters
234
- processed_tags = []
235
- for tag in result_parts:
236
- processed_tag = tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')')
237
- processed_tags.append(processed_tag)
238
 
239
- return ', '.join(processed_tags)
240
 
241
- def generate_categorized_json(tag_string, character_tags=None):
242
- """
243
- Generate JSON object organizing tags by categories
244
 
245
- Args:
246
- tag_string (str): Comma-separated tags string
247
- character_tags (dict): Dictionary of character tags with confidence scores
248
 
249
- Returns:
250
- dict: JSON-compatible dictionary with categories as keys and tag lists as values
251
- """
252
- if not tag_string:
253
- return {}
254
-
255
- # Split tags by common delimiters
256
- delimiters = r'[,\n\r\.!?]+'
257
- raw_tags = re.split(delimiters, tag_string)
258
-
259
- # Clean and normalize tags
260
- cleaned_tags = []
261
- for tag in raw_tags:
262
- tag = tag.strip()
263
- if tag:
264
- cleaned_tags.append(tag)
265
-
266
- # Extract priority and character tags
267
- priority_tags_found, character_tag_names, remaining_tags = extract_priority_and_character_tags(cleaned_tags, character_tags)
268
-
269
- # Classify remaining tags
270
- categorized = defaultdict(list)
271
- uncategorized = []
272
-
273
- for tag in remaining_tags:
274
- # Skip character tags as they're already in their own list
275
- if tag in character_tag_names:
276
- continue
277
-
278
- category = classify_single_tag(tag)
279
- if category == 'Uncategorized':
280
- uncategorized.append(tag)
281
- else:
282
- # Store the original tag (with underscores) for JSON
283
- categorized[category].append(tag)
284
-
285
- # Build JSON result
286
- json_result = {}
287
-
288
- # Add special categories if they have content
289
- if priority_tags_found:
290
- # Process priority tags for display (replace underscores with spaces) # Replacement is not 100% necessary, but will do anyway
291
- processed_priority = [tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')') for tag in priority_tags_found]
292
- json_result['Subject'] = processed_priority
293
-
294
- if character_tag_names:
295
- # Process character tags for display
296
- processed_characters = [tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')') for tag in character_tag_names]
297
- json_result['Characters'] = processed_characters
298
-
299
- # Add categorized tags (process for display)
300
- for category, tags in categorized.items():
301
- if tags:
302
- processed_tags = [tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')') for tag in tags]
303
- json_result[category] = processed_tags
304
-
305
- # Add uncategorized tags if any
306
- if uncategorized:
307
- processed_uncategorized = [tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')') for tag in uncategorized]
308
- json_result['Uncategorized'] = processed_uncategorized
309
-
310
- return json_result
311
-
312
-
313
- def categorize_tags_output(tag_string, character_tags=None):
314
- """
315
- Main function to categorize tags output for display
316
 
317
- Args:
318
- tag_string (str): Raw tags string from the model
319
- character_tags (dict): Dictionary of character tags with confidence scores
320
 
321
- Returns:
322
- str: Organized, categorized tags string
323
- """
324
- return classify_tags_for_display(tag_string, character_tags)
325
 
326
- def generate_tags_json(tag_string, character_tags=None):
327
- """
328
- Main function to generate categorized JSON
 
329
 
330
- Args:
331
- tag_string (str): Raw tags string from the model
332
- character_tags (dict): Dictionary of character tags with confidence scores
333
 
334
- Returns:
335
- dict: JSON object with categorized tags
336
- """
337
- return generate_categorized_json(tag_string, character_tags)
 
 
1
  import re
2
  from collections import defaultdict
3
+ from typing import List, Dict, Tuple, Optional, Set
4
 
5
  # Test: Define priority tags that should always come first
6
  PRIORITY_TAGS = [
7
+ '1girl', '2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls', '1other',
8
  '1boy', '2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
9
  'male_focus', 'female_focus', 'other_focus'
10
  ]
 
46
  'Others':['2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022', '2023', '2024', 'artist', 'artist_name', 'artistic_error', 'asian', '(company)', 'character_name', 'content_rating', 'copyright', 'cover_page', 'dated', 'english_text', 'japan', 'layer', 'logo', 'name', 'numbered', 'page_number', 'pixiv_id', 'language', 'reference_sheet', 'signature', 'speech_bubble', 'subtitled', 'text', 'thank_you', 'typo', 'username', 'wallpaper', 'watermark', 'web_address', 'screwdriver', 'translated'],
47
  'Quality Tags':['masterpiece', '_quality', 'highres', 'absurdres', 'ultra-detailed', 'lowres']}
48
 
 
 
 
 
 
49
 
50
+ class TagCleaner:
51
+ """Handles tag cleaning and normalization operations."""
52
+
53
+ @staticmethod
54
+ def clean_raw_tags(tag_string: str) -> List[str]:
55
+ """
56
+ Clean raw tags by removing question marks, numbers, and normalizing format.
57
+
58
+ Args:
59
+ tag_string: Raw tags string with potential ? and numbers
60
+
61
+ Returns:
62
+ List of cleaned tags
63
+ """
64
+ if not tag_string:
65
+ return []
66
+
67
+ # Split by common delimiters
68
+ delimiters = r'[,\n\r\.!?]+'
69
+ raw_tags = re.split(delimiters, tag_string)
70
+
71
+ cleaned_tags = []
72
+
73
+ for tag in raw_tags:
74
+ tag = tag.strip()
75
+ if not tag:
76
+ continue
77
+
78
+ # Remove question marks
79
+ tag = tag.replace('?', '')
80
+ # Test
81
+ # Remove standalone numbers (4+ digit IDs) but keep numbers that are part of tags
82
+ # This removes "123456" but keeps "3girls", "1boy", and years like "2025"..
83
+ # First, remove patterns like "tag 123456" or "123456 tag" (excluding years)
84
+ tag = re.sub(r'\s+(?!19\d{2}|20\d{2})\d{4,}\b|\b(?!19\d{2}|20\d{2})\d{4,}\s+', ' ', tag)
85
+ # Then remove any remaining standalone numbers (excluding years and 3-digit numbers)
86
+ tag = re.sub(r'\b(?!19\d{2}|20\d{2})\d{4,}\b', '', tag)
87
+ # Finally, remove any remaining 5+ digit numbers that might be attached to tags
88
+ tag = re.sub(r'\b\w+\d{5,}\b|\b\d{5,}\w+\b', lambda m: re.sub(r'\d{5,}', '', m.group()), tag)
89
+
90
+ # Clean up extra spaces
91
+ tag = re.sub(r'\s+', ' ', tag).strip()
92
+
93
+ if tag: # Only add if tag is not empty after cleaning
94
+ cleaned_tags.append(tag)
95
+
96
+ return cleaned_tags
97
+
98
+ @staticmethod
99
+ def normalize_tag(tag: str) -> str:
100
+ """
101
+ Normalize tag by converting spaces/hyphens to underscores.
102
+
103
+ Args:
104
+ tag: Raw tag string
105
+
106
+ Returns:
107
+ Normalized tag string
108
+ """
109
+ return re.sub(r'[-\s]+', '_', tag.strip())
110
+
111
+ @staticmethod
112
+ def format_tags_for_display(tags: List[str]) -> str:
113
+ """
114
+ Format tags as a comma-separated string for display.
115
+
116
+ Args:
117
+ tags: List of tags
118
+
119
+ Returns:
120
+ Comma-separated string
121
+ """
122
+ return ', '.join(tags)
123
+
124
+
125
+ class CategoryMatcher:
126
+ """Optimized category matching using trie data structure."""
127
+
128
+ def __init__(self, categories_dict: Dict[str, List[str]]):
129
+ """Initialize with categories dictionary."""
130
+ self.categories = categories_dict
131
+ self._build_lookup_tables()
132
+
133
+ def _build_lookup_tables(self):
134
+ """Build efficient lookup tables for category matching."""
135
+ self.tag_to_category = {}
136
+ self.priority_set = set(PRIORITY_TAGS)
137
+
138
+ # Build direct lookup table
139
+ for category, tags in self.categories.items():
140
  for tag in tags:
141
+ self.tag_to_category[tag] = category
142
+ # Also add escaped version for matching
143
+ if '(' in tag or ')' in tag:
144
+ escaped_tag = tag.replace('(', '\\(').replace(')', '\\)')
145
+ self.tag_to_category[escaped_tag] = category
146
+
147
+ def find_category(self, tag: str) -> Optional[str]:
148
+ """
149
+ Find the category for a given tag.
150
+
151
+ Args:
152
+ tag: Tag to categorize
153
+
154
+ Returns:
155
+ Category name or None if not found
156
+ """
157
  # Try exact match first
158
+ if tag in self.tag_to_category:
159
+ return self.tag_to_category[tag]
160
+
161
+ # Try normalized version
162
+ normalized_tag = TagCleaner.normalize_tag(tag)
163
+ if normalized_tag in self.tag_to_category:
164
+ return self.tag_to_category[normalized_tag]
165
+
166
+ # Try partial matching for compound tags
167
+ if '_' in normalized_tag:
168
+ parts = normalized_tag.split('_')
169
+ for part in parts:
170
+ if len(part) > 3 and part in self.tag_to_category:
171
+ return self.tag_to_category[part]
172
+
173
+ # Try substring matching for longer tags
 
174
  for i in range(len(tag)):
175
+ for j in range(i + 4, len(tag) + 1): # Only check substrings longer than 3 chars
176
  substring = tag[i:j]
177
+ if substring in self.tag_to_category:
178
+ return self.tag_to_category[substring]
179
+
 
 
 
 
 
 
 
 
180
  return None
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ class TagClassifier:
184
+ """Main tag classification engine."""
185
+
186
+ def __init__(self, categories_dict: Dict[str, List[str]] = None):
187
+ """
188
+ Initialize the classifier.
189
+
190
+ Args:
191
+ categories_dict: Dictionary of categories and their tags
192
+ """
193
+ self.categories = categories_dict or categories
194
+ self.matcher = CategoryMatcher(self.categories)
195
+ self.priority_set = set(PRIORITY_TAGS)
196
+
197
+ def extract_special_tags(self, tags: List[str], character_tags: Optional[Dict] = None) -> Tuple[List[str], List[str], List[str]]:
198
+ """
199
+ Extract priority and character tags from the tags list.
200
+
201
+ Args:
202
+ tags: List of all tags
203
+ character_tags: Dictionary of character tags with confidence scores
204
+
205
+ Returns:
206
+ Tuple of (priority_tags, character_tag_names, remaining_tags)
207
+ """
208
+ priority_tags_found = []
209
+ character_tag_names = list(character_tags.keys()) if character_tags else []
210
+ remaining_tags = []
211
+
212
+ for tag in tags:
213
+ if tag in self.priority_set:
214
+ priority_tags_found.append(tag)
215
+ elif tag in character_tag_names:
216
+ remaining_tags.append(tag) # Character tags are handled separately
217
+ else:
218
+ remaining_tags.append(tag)
219
+
220
+ return priority_tags_found, character_tag_names, remaining_tags
221
+
222
+ def classify_tags(self, tags: List[str], character_tags: Optional[Dict] = None) -> Dict[str, List[str]]:
223
+ """
224
+ Classify tags into categories.
225
+
226
+ Args:
227
+ tags: List of tags to classify
228
+ character_tags: Dictionary of character tags with confidence scores
229
+
230
+ Returns:
231
+ Dictionary with categories as keys and tag lists as values
232
+ """
233
+ # Extract special tags first
234
+ priority_tags, character_tag_names, remaining_tags = self.extract_special_tags(tags, character_tags)
235
+
236
+ # Classify remaining tags
237
+ categorized = defaultdict(list)
238
+ uncategorized = []
239
+
240
+ for tag in remaining_tags:
241
+ # Skip character tags as they're handled separately
242
+ if tag in character_tag_names:
243
+ continue
244
+
245
+ category = self.matcher.find_category(tag)
246
+ if category:
247
+ categorized[category].append(tag)
248
+ else:
249
+ uncategorized.append(tag)
250
+
251
+ # Build result dictionary
252
+ result = {}
253
+
254
+ # Add special categories if they have content
255
+ if priority_tags:
256
+ result['Subject'] = priority_tags
257
+
258
+ if character_tag_names:
259
+ result['Characters'] = character_tag_names
260
+
261
+ # Add categorized tags
262
+ for category in self.categories.keys():
263
+ if category in categorized and categorized[category]:
264
+ result[category] = categorized[category]
265
+
266
+ # Add uncategorized tags if any
267
+ if uncategorized:
268
+ result['Uncategorized'] = uncategorized
269
+
270
+ return result
271
+
272
+ def get_ordered_tags_string(self, tags: List[str], character_tags: Optional[Dict] = None) -> str:
273
+ """
274
+ Get tags ordered by priority and categories as a string.
275
+
276
+ Args:
277
+ tags: List of tags to order
278
+ character_tags: Dictionary of character tags with confidence scores
279
+
280
+ Returns:
281
+ Ordered comma-separated string
282
+ """
283
+ # Extract special tags
284
+ priority_tags, character_tag_names, remaining_tags = self.extract_special_tags(tags, character_tags)
285
+
286
+ # Classify remaining tags
287
+ categorized = defaultdict(list)
288
+ uncategorized = []
289
+
290
+ for tag in remaining_tags:
291
+ if tag in character_tag_names:
292
+ continue
293
+
294
+ category = self.matcher.find_category(tag)
295
+ if category and category != 'Uncategorized':
296
+ categorized[category].append(tag)
297
+ else:
298
+ uncategorized.append(tag)
299
+
300
+ # Build ordered result
301
+ result_parts = []
302
+
303
+ # 1. Add priority subject tags first
304
+ result_parts.extend(priority_tags)
305
+
306
+ # 2. Add character tags next
307
+ result_parts.extend(character_tag_names)
308
+
309
+ # 3. Add categorized tags in category order
310
+ for category in self.categories.keys():
311
+ if category in categorized and categorized[category]:
312
+ result_parts.extend(categorized[category])
313
+
314
+ # 4. Add uncategorized tags at the end
315
+ result_parts.extend(uncategorized)
316
+
317
+ # Process tags for display
318
+ processed_tags = []
319
+ for tag in result_parts:
320
+ processed_tag = tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')')
321
+ processed_tags.append(processed_tag)
322
+
323
+ return ', '.join(processed_tags)
324
+
325
+
326
+ class TagFormatter:
327
+ """Handles output formatting for different display types."""
328
+
329
+ @staticmethod
330
+ def format_for_display(categorized_tags: Dict[str, List[str]]) -> str:
331
+ """
332
+ Format categorized tags as a display string.
333
+
334
+ Args:
335
+ categorized_tags: Dictionary of categorized tags
336
+
337
+ Returns:
338
+ Formatted string for display
339
+ """
340
+ result_parts = []
341
+
342
+ # Order categories for display
343
+ display_order = ['Subject', 'Characters'] + [cat for cat in categories.keys() if cat not in ['Subject', 'Characters']] + ['Uncategorized']
344
+
345
+ for category in display_order:
346
+ if category in categorized_tags and categorized_tags[category]:
347
+ # Process tags for display
348
+ processed_tags = []
349
+ for tag in categorized_tags[category]:
350
+ processed_tag = tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')')
351
+ processed_tags.append(processed_tag)
352
+
353
+ result_parts.extend(processed_tags)
354
+
355
+ return ', '.join(result_parts)
356
+
357
+ @staticmethod
358
+ def format_for_json(categorized_tags: Dict[str, List[str]]) -> Dict[str, List[str]]:
359
+ """
360
+ Format categorized tags as JSON-compatible dictionary.
361
+
362
+ Args:
363
+ categorized_tags: Dictionary of categorized tags
364
+
365
+ Returns:
366
+ JSON-compatible dictionary
367
+ """
368
+ json_result = {}
369
+
370
+ for category, tags in categorized_tags.items():
371
+ if tags:
372
+ # Process tags for display
373
+ processed_tags = []
374
+ for tag in tags:
375
+ processed_tag = tag.replace('_', ' ').replace('\\(', '(').replace('\\)', ')')
376
+ processed_tags.append(processed_tag)
377
+
378
+ json_result[category] = processed_tags
379
+
380
+ return json_result
381
+
382
+
383
+ # Global classifier instance
384
+ _classifier = TagClassifier()
385
+ _cleaner = TagCleaner()
386
+ _formatter = TagFormatter()
387
+
388
+
389
+ # Public API Functions
390
+ def clean_tags(tag_string: str) -> List[str]:
391
  """
392
+ Clean tags by removing question marks and numbers.
393
+
394
  Args:
395
+ tag_string: Raw tags string with potential ? and numbers
396
+
 
397
  Returns:
398
+ List of cleaned tags
399
  """
400
+ return _cleaner.clean_raw_tags(tag_string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
 
402
 
403
+ def clean_and_format_tags(tag_string: str) -> str:
404
  """
405
+ Clean tags and format them as a comma-separated string.
406
+
407
  Args:
408
+ tag_string: Raw tags string with potential ? and numbers
409
+
 
410
  Returns:
411
+ Comma-separated cleaned tags
412
  """
413
+ cleaned_tags = clean_tags(tag_string)
414
+ return _cleaner.format_tags_for_display(cleaned_tags)
 
 
 
 
415
 
 
 
 
 
 
 
416
 
417
+ def categorize_tags_output(tag_string: str, character_tags: Optional[Dict] = None) -> str:
418
+ """
419
+ Main function to categorize tags output for display.
420
+
421
+ Args:
422
+ tag_string: Raw tags string from the model
423
+ character_tags: Dictionary of character tags with confidence scores
424
+
425
+ Returns:
426
+ Organized, categorized tags string
427
+ """
428
+ # Clean tags first
429
+ cleaned_tags = clean_tags(tag_string)
430
+
431
+ # Get ordered string
432
+ return _classifier.get_ordered_tags_string(cleaned_tags, character_tags)
 
433
 
 
 
434
 
435
+ def generate_tags_json(tag_string: str, character_tags: Optional[Dict] = None) -> Dict[str, List[str]]:
436
+ """
437
+ Main function to generate categorized JSON.
438
+
439
+ Args:
440
+ tag_string: Raw tags string from the model
441
+ character_tags: Dictionary of character tags with confidence scores
442
+
443
+ Returns:
444
+ JSON object with categorized tags
445
+ """
446
+ # Clean tags first
447
+ cleaned_tags = clean_tags(tag_string)
448
+
449
+ # Classify tags
450
+ categorized = _classifier.classify_tags(cleaned_tags, character_tags)
451
+
452
+ # Format for JSON
453
+ return _formatter.format_for_json(categorized)
454
 
 
 
455
 
456
+ def process_tags_for_misc(tag_string: str) -> Tuple[str, str, Dict[str, List[str]]]:
457
+ """
458
+ Process tags for the Misc tab - clean and categorize them.
459
+
460
+ Args:
461
+ tag_string: Raw tags string with potential ? and numbers
462
+
463
+ Returns:
464
+ Tuple of (cleaned_tags_string, categorized_string, categorized_json)
465
+ """
466
+ # Clean the tags first
467
+ cleaned_tags_string = clean_and_format_tags(tag_string)
468
+
469
+ # Then categorize the cleaned tags
470
+ categorized_string = categorize_tags_output(tag_string)
471
+ categorized_json = generate_tags_json(tag_string)
472
+
473
+ return cleaned_tags_string, categorized_string, categorized_json
474
 
 
 
475
 
476
+ # Legacy compatibility functions
477
+ def classify_tags_for_display(tag_string: str, character_tags: Optional[Dict] = None) -> str:
478
+ """Legacy function - use categorize_tags_output instead."""
479
+ return categorize_tags_output(tag_string, character_tags)
 
480
 
 
481
 
482
+ def generate_categorized_json(tag_string: str, character_tags: Optional[Dict] = None) -> Dict[str, List[str]]:
483
+ """Legacy function - use generate_tags_json instead."""
484
+ return generate_tags_json(tag_string, character_tags)
485
 
 
 
 
486
 
487
+ """
488
+ How to test:
489
+ python -c "
490
+ from modules.classifyTags import process_tags_for_misc, clean_tags
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
+ # Test example
493
+ test_input = 'tags here'
 
494
 
495
+ print('Input:', test_input)
496
+ print()
 
 
497
 
498
+ # Test cleaning
499
+ cleaned = clean_tags(test_input)
500
+ print('Cleaned tags:', cleaned)
501
+ print()
502
 
503
+ # Test full processing
504
+ cleaned_str, categorized_str, categorized_json = process_tags_for_misc(test_input)
 
505
 
506
+ print('Cleaned output:', cleaned_str)
507
+ print('Categorized output:', categorized_str)
508
+ print('Categorized JSON:', categorized_json)
509
+ "
510
+ """
modules/media_handler.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import List, Union, Tuple, Optional
4
+ from modules.video_processor import is_video_file, process_video_upload, SUPPORTED_VIDEO_FORMATS
5
+
6
+ # Supported image formats
7
+ SUPPORTED_IMAGE_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif']
8
+
9
+ def get_media_type(file_path: str) -> str:
10
+ """
11
+ Determine if a file is an image, video, or unsupported.
12
+
13
+ Args:
14
+ file_path: Path to the file
15
+
16
+ Returns:
17
+ 'image', 'video', or 'unsupported'
18
+ """
19
+ if not file_path:
20
+ return 'unsupported'
21
+
22
+ _, ext = os.path.splitext(file_path.lower())
23
+
24
+ if ext in SUPPORTED_IMAGE_FORMATS:
25
+ return 'image'
26
+ elif ext in SUPPORTED_VIDEO_FORMATS:
27
+ return 'video'
28
+ else:
29
+ return 'unsupported'
30
+
31
+ def is_supported_media(file_path: str) -> bool:
32
+ """Check if the file is a supported image or video format."""
33
+ return get_media_type(file_path) in ['image', 'video']
34
+
35
+ def create_gallery_item(file_path: str) -> Optional[Tuple[str, str]]:
36
+ """
37
+ Create a gallery-compatible item from a media file path.
38
+
39
+ Args:
40
+ file_path: Path to the media file
41
+
42
+ Returns:
43
+ Tuple suitable for gallery (file_path, filename) or None if unsupported
44
+ """
45
+ if not os.path.exists(file_path):
46
+ return None
47
+
48
+ if not is_supported_media(file_path):
49
+ return None
50
+
51
+ filename = os.path.basename(file_path)
52
+ return (file_path, filename)
53
+
54
+ def process_single_media_upload(file_path: str, max_video_duration: int = 30, frame_interval: int = 1) -> List[str]:
55
+ """
56
+ Process a single media file upload (image or video).
57
+
58
+ Args:
59
+ file_path: Path to the uploaded file
60
+ max_video_duration: Maximum duration to process for videos (seconds)
61
+ frame_interval: Interval between frames for videos (seconds)
62
+
63
+ Returns:
64
+ List of file paths to be added to gallery (images or extracted frames)
65
+ """
66
+ if not file_path or not os.path.exists(file_path):
67
+ return []
68
+
69
+ media_type = get_media_type(file_path)
70
+
71
+ if media_type == 'image':
72
+ # For images, just return the original path
73
+ return [file_path]
74
+ elif media_type == 'video':
75
+ # For videos, extract frames
76
+ frame_paths, _ = process_video_upload(file_path, max_video_duration, frame_interval)
77
+ return frame_paths
78
+ else:
79
+ # Unsupported format
80
+ return []
81
+
82
+ def process_multiple_media_uploads(
83
+ file_paths: List[str],
84
+ max_video_duration: int = 30,
85
+ frame_interval: int = 1
86
+ ) -> List[str]:
87
+ """
88
+ Process multiple media file uploads.
89
+
90
+ Args:
91
+ file_paths: List of paths to uploaded files
92
+ max_video_duration: Maximum duration to process for videos (seconds)
93
+ frame_interval: Interval between frames for videos (seconds)
94
+
95
+ Returns:
96
+ List of file paths to be added to gallery (images and extracted frames)
97
+ """
98
+ all_paths = []
99
+
100
+ for file_path in file_paths:
101
+ processed_paths = process_single_media_upload(file_path, max_video_duration, frame_interval)
102
+ all_paths.extend(processed_paths)
103
+
104
+ return all_paths
105
+
106
+ def handle_single_media_upload(file_path: str, gallery: List, max_video_duration: int = 30, frame_interval: int = 1) -> Tuple[List, Optional[str]]:
107
+ """
108
+ Handle a single media file upload and update gallery.
109
+
110
+ Args:
111
+ file_path: Path to the uploaded file
112
+ gallery: Current gallery list
113
+ max_video_duration: Maximum duration to process for videos (seconds)
114
+ frame_interval: Interval between frames for videos (seconds)
115
+
116
+ Returns:
117
+ Tuple of (updated_gallery, None) for Gradio compatibility
118
+ """
119
+ if gallery is None:
120
+ gallery = []
121
+
122
+ if not file_path:
123
+ return gallery, None
124
+
125
+ # Process the media file
126
+ processed_paths = process_single_media_upload(file_path, max_video_duration, frame_interval)
127
+
128
+ # Create gallery items and add to gallery
129
+ for path in processed_paths:
130
+ gallery_item = create_gallery_item(path)
131
+ if gallery_item:
132
+ gallery.append(gallery_item)
133
+
134
+ return gallery, None
135
+
136
+ def handle_multiple_media_uploads(
137
+ file_paths: List,
138
+ gallery: List,
139
+ max_video_duration: int = 30,
140
+ frame_interval: int = 1
141
+ ) -> List:
142
+ """
143
+ Handle multiple media file uploads and update gallery.
144
+
145
+ Args:
146
+ file_paths: List of uploaded file paths
147
+ gallery: Current gallery list
148
+ max_video_duration: Maximum duration to process for videos (seconds)
149
+ frame_interval: Interval between frames for videos (seconds)
150
+
151
+ Returns:
152
+ Updated gallery list
153
+ """
154
+ if gallery is None:
155
+ gallery = []
156
+
157
+ if not file_paths:
158
+ return gallery
159
+
160
+ # Process all media files
161
+ processed_paths = process_multiple_media_uploads(file_paths, max_video_duration, frame_interval)
162
+
163
+ # Create gallery items and add to gallery
164
+ for path in processed_paths:
165
+ gallery_item = create_gallery_item(path)
166
+ if gallery_item:
167
+ gallery.append(gallery_item)
168
+
169
+ return gallery
170
+
171
+ def get_supported_formats() -> dict:
172
+ """Get dictionary of supported file formats."""
173
+ return {
174
+ 'images': SUPPORTED_IMAGE_FORMATS,
175
+ 'videos': SUPPORTED_VIDEO_FORMATS,
176
+ 'all': SUPPORTED_IMAGE_FORMATS + SUPPORTED_VIDEO_FORMATS
177
+ }
178
+
179
+ def validate_media_files(file_paths: List[str]) -> Tuple[List[str], List[str]]:
180
+ """
181
+ Validate a list of media files.
182
+
183
+ Args:
184
+ file_paths: List of file paths to validate
185
+
186
+ Returns:
187
+ Tuple of (valid_files, invalid_files)
188
+ """
189
+ valid_files = []
190
+ invalid_files = []
191
+
192
+ for file_path in file_paths:
193
+ if is_supported_media(file_path):
194
+ valid_files.append(file_path)
195
+ else:
196
+ invalid_files.append(file_path)
197
+
198
+ return valid_files, invalid_files
199
+
200
+ # Export functions
201
+ __all__ = [
202
+ 'get_media_type',
203
+ 'is_supported_media',
204
+ 'create_gallery_item',
205
+ 'process_single_media_upload',
206
+ 'process_multiple_media_uploads',
207
+ 'handle_single_media_upload',
208
+ 'handle_multiple_media_uploads',
209
+ 'get_supported_formats',
210
+ 'validate_media_files',
211
+ 'SUPPORTED_IMAGE_FORMATS'
212
+ ]
modules/pixai.py CHANGED
@@ -1,810 +1,801 @@
1
- import os, json, zipfile, tempfile, time, traceback
2
- import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import onnxruntime as ort
6
- from collections import defaultdict
7
- from typing import Union, Dict, Any, Tuple, List
8
- from PIL import Image
9
- from huggingface_hub import hf_hub_download
10
- from huggingface_hub.errors import EntryNotFoundError
11
- from datetime import datetime
12
-
13
- # Global variables for model components (for memory management)
14
- CURRENT_MODEL = None
15
- CURRENT_MODEL_NAME = None
16
- CURRENT_TAGS_DF = None
17
- CURRENT_D_IPS = None
18
- CURRENT_PREPROCESS_FUNC = None
19
- CURRENT_THRESHOLDS = None
20
- CURRENT_CATEGORY_NAMES = None
21
-
22
- css = """
23
- #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
24
- #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
25
- #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
26
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
27
- #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
28
- #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
29
- .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
30
- .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
31
- #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
32
- """
33
-
34
- def preprocess_on_gpu(img, device='cuda'):
35
- """Preprocess image on GPU using PyTorch"""
36
- import torch
37
- import torchvision.transforms as transforms
38
- # Convert PIL to tensor and move to GPU
39
- transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
40
- # Move to GPU if available
41
- tensor_img = transform(img).unsqueeze(0)
42
- if torch.cuda.is_available():
43
- tensor_img = tensor_img.to(device)
44
- return tensor_img.cpu().numpy()
45
-
46
- class Timer: # Report the execution time & process
47
- def __init__(self):
48
- self.start_time = time.perf_counter()
49
- self.checkpoints = [('Start', self.start_time)]
50
-
51
- def checkpoint(self, label='Checkpoint'):
52
- now = time.perf_counter()
53
- self.checkpoints.append((label, now))
54
-
55
- def report(self, is_clear_checkpoints=True):
56
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
57
- prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
58
-
59
- for (label, curr_time) in self.checkpoints[1:]:
60
- elapsed = curr_time - prev_time
61
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
62
- prev_time = curr_time
63
-
64
- if is_clear_checkpoints:
65
- self.checkpoints.clear()
66
- self.checkpoint()
67
-
68
- def report_all(self):
69
- print('\n> Execution Time Report:')
70
- max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
71
- prev_time = self.start_time
72
-
73
- for (label, curr_time) in self.checkpoints[1:]:
74
- elapsed = curr_time - prev_time
75
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
76
- prev_time = curr_time
77
-
78
- total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
79
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests
80
- self.checkpoints.clear()
81
-
82
- def restart(self):
83
- self.start_time = time.perf_counter()
84
- self.checkpoints = [('Start', self.start_time)]
85
-
86
- def _get_repo_id(model_name: str) -> str:
87
- """Get the repository ID for the specified model name."""
88
- if '/' in model_name:
89
- return model_name
90
- else:
91
- return f'deepghs/pixai-tagger-{model_name}-onnx'
92
-
93
- def _download_model_files(model_name: str):
94
- """Download all required model files."""
95
- repo_id = _get_repo_id(model_name)
96
-
97
- # Download the necessary files using hf_hub_download instead of local cache...
98
- model_path = hf_hub_download(
99
- repo_id=repo_id,
100
- filename='model.onnx',
101
- library_name="pixai-tagger"
102
- )
103
- tags_path = hf_hub_download(
104
- repo_id=repo_id,
105
- filename='selected_tags.csv',
106
- library_name="pixai-tagger"
107
- )
108
- preprocess_path = hf_hub_download(
109
- repo_id=repo_id,
110
- filename='preprocess.json',
111
- library_name="pixai-tagger"
112
- )
113
- try:
114
- thresholds_path = hf_hub_download(
115
- repo_id=repo_id,
116
- filename='thresholds.csv',
117
- library_name="pixai-tagger"
118
- )
119
- except EntryNotFoundError:
120
- thresholds_path = None
121
-
122
- return model_path, tags_path, preprocess_path, thresholds_path
123
-
124
- def create_optimized_ort_session(model_path):
125
- """Create an optimized ONNX Runtime session with GPU support"""
126
- # Test: Session options for better performance
127
- sess_options = ort.SessionOptions()
128
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
129
- sess_options.intra_op_num_threads = 0 # Use all available cores
130
- sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
131
- sess_options.enable_mem_pattern = True
132
- sess_options.enable_cpu_mem_arena = True
133
-
134
- # Check available providers
135
- available_providers = ort.get_available_providers()
136
- print(f"Available ONNX Runtime providers: {available_providers}")
137
-
138
- # Use appropriate execution providers (in order of preference)
139
- providers = []
140
-
141
- # Use CUDA if available
142
- if 'CUDAExecutionProvider' in available_providers:
143
- cuda_provider = ('CUDAExecutionProvider', {
144
- 'device_id': 0,
145
- 'arena_extend_strategy': 'kNextPowerOfTwo',
146
- 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM
147
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
148
- 'do_copy_in_default_stream': True,
149
- })
150
- providers.append(cuda_provider)
151
- print("Using CUDA provider for ONNX inference")
152
- else:
153
- print("CUDA provider not available, falling back to CPU")
154
-
155
- # Always include CPU as fallback (FOR HF)
156
- providers.append('CPUExecutionProvider')
157
-
158
- try:
159
- session = ort.InferenceSession(model_path, sess_options, providers=providers)
160
- print(f"Model loaded with providers: {session.get_providers()}")
161
- return session
162
- except Exception as e:
163
- print(f"Failed to create ONNX session: {e}")
164
- raise
165
-
166
- def _load_model_components_optimized(model_name: str):
167
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
168
- global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
169
-
170
- # Only reload if model changed
171
- if CURRENT_MODEL_NAME != model_name:
172
- # Download files
173
- model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name)
174
-
175
- # Load optimized ONNX model
176
- CURRENT_MODEL = create_optimized_ort_session(model_path)
177
-
178
- # Load tags
179
- CURRENT_TAGS_DF = pd.read_csv(tags_path)
180
- CURRENT_D_IPS = {}
181
-
182
- if 'ips' in CURRENT_TAGS_DF.columns:
183
- CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads)
184
- for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']):
185
- if ips:
186
- CURRENT_D_IPS[name] = ips
187
-
188
- # Load preprocessing
189
- with open(preprocess_path, 'r') as f:
190
- data_ = json.load(f)
191
- # Simple preprocessing function
192
- def transform(img):
193
- # Ensure image is in RGB mode
194
- if img.mode != 'RGB':
195
- img = img.convert('RGB')
196
-
197
- # Resize to 448x448 <- Very important.
198
- img = img.resize((448, 448), Image.Resampling.LANCZOS)
199
-
200
- # Convert to numpy array and normalize
201
- img_array = np.array(img).astype(np.float32)
202
-
203
- # Normalize pixel values to [0, 1]
204
- img_array = img_array / 255.0
205
-
206
- # Normalize with ImageNet mean and std
207
- mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32)
208
- std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32)
209
- img_array = (img_array - mean) / std
210
-
211
- # Transpose to (C, H, W)
212
- img_array = np.transpose(img_array, (2, 0, 1))
213
- return img_array
214
-
215
- CURRENT_PREPROCESS_FUNC = transform
216
-
217
- # Load thresholds
218
- CURRENT_THRESHOLDS = {}
219
- CURRENT_CATEGORY_NAMES = {}
220
-
221
- if thresholds_path and os.path.exists(thresholds_path):
222
- df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False)
223
- for item in df_category_thresholds.to_dict('records'):
224
- if item['category'] not in CURRENT_THRESHOLDS:
225
- CURRENT_THRESHOLDS[item['category']] = item['threshold']
226
- CURRENT_CATEGORY_NAMES[item['category']] = item['name']
227
- else:
228
- # Default thresholds if file doesn't exist
229
- CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85}
230
- CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'}
231
-
232
- CURRENT_MODEL_NAME = model_name
233
-
234
- return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC,
235
- CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES)
236
-
237
- def _raw_predict(image: Image.Image, model_name: str):
238
- """Make a raw prediction with the PixAI tagger model."""
239
- try:
240
- # Ensure we have a PIL Image
241
- if not isinstance(image, Image.Image):
242
- raise ValueError("Input must be a PIL Image") # <-
243
-
244
- # Load model components
245
- model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name)
246
-
247
- # Preprocess image
248
- input_tensor = preprocess_func(image)
249
-
250
- # Add batch dimension
251
- if len(input_tensor.shape) == 3:
252
- input_tensor = np.expand_dims(input_tensor, axis=0)
253
-
254
- # Run inference
255
- output_names = [output.name for output in model.get_outputs()]
256
- output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)})
257
-
258
- return {name: value[0] for name, value in zip(output_names, output_values)}
259
-
260
- except Exception as e:
261
- raise RuntimeError(f"Error processing image: {str(e)}")
262
-
263
- def get_pixai_tags(
264
- image: Union[str, Image.Image],
265
- model_name: str = 'deepghs/pixai-tagger-v0.9-onnx',
266
- thresholds: Union[float, Dict[Any, float]] = None,
267
- fmt='all'
268
- ):
269
- try:
270
- # Load image if it's a path
271
- if isinstance(image, str):
272
- pil_image = Image.open(image)
273
- elif isinstance(image, Image.Image):
274
- pil_image = image
275
- else:
276
- raise ValueError("Image must be a file path or PIL Image")
277
-
278
- # Load model components
279
- _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name)
280
-
281
- values = _raw_predict(pil_image, model_name)
282
- prediction = values.get('prediction', np.array([]))
283
-
284
- if prediction.size == 0:
285
- raise RuntimeError("Model did not return valid predictions")
286
-
287
- tags = {}
288
-
289
- # Process tags by category
290
- for category in sorted(set(df_tags['category'].tolist())):
291
- mask = df_tags['category'] == category
292
- tag_names = df_tags.loc[mask, 'name']
293
- category_pred = prediction[mask]
294
-
295
- # Determine threshold for this category
296
- if isinstance(thresholds, float):
297
- category_threshold = thresholds
298
- elif isinstance(thresholds, dict) and \
299
- (category in thresholds or category_names.get(category, '') in thresholds):
300
- if category in thresholds:
301
- category_threshold = thresholds[category]
302
- elif category_names.get(category, '') in thresholds:
303
- category_threshold = thresholds[category_names[category]]
304
- else:
305
- category_threshold = 0.85
306
- else:
307
- category_threshold = default_thresholds.get(category, 0.85)
308
-
309
- # Apply threshold
310
- pred_mask = category_pred >= category_threshold
311
- filtered_tag_names = tag_names[pred_mask].tolist()
312
- filtered_predictions = category_pred[pred_mask].tolist()
313
-
314
- # Sort by confidence
315
- cate_tags = dict(sorted(
316
- zip(filtered_tag_names, filtered_predictions),
317
- key=lambda x: (-x[1], x[0])
318
- ))
319
-
320
- category_name = category_names.get(category, f"category_{category}")
321
- values[category_name] = cate_tags
322
- tags.update(cate_tags)
323
-
324
- values['tag'] = tags
325
-
326
- # Handle IPs if available
327
- if 'ips' in df_tags.columns:
328
- ips_mapping, ips_counts = {}, defaultdict(int)
329
- for tag, _ in tags.items():
330
- if tag in d_ips:
331
- ips_mapping[tag] = d_ips[tag]
332
- for ip_name in d_ips[tag]:
333
- ips_counts[ip_name] += 1
334
- values['ips_mapping'] = ips_mapping
335
- values['ips_count'] = dict(ips_counts)
336
- values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))]
337
-
338
- # Return based on format
339
- if fmt == 'all':
340
- # Return all available categories
341
- available_categories = [category_names.get(cat, f"category_{cat}")
342
- for cat in sorted(set(df_tags['category'].tolist()))]
343
- return tuple(values.get(cat, {}) for cat in available_categories)
344
- elif fmt in values:
345
- return values[fmt]
346
- else:
347
- return values
348
-
349
- except Exception as e:
350
- raise RuntimeError(f"Error processing image: {str(e)}")
351
-
352
- def format_ips_output(ips_result, ips_mapping):
353
- """Format IP detection output as a single string with proper escaping."""
354
- if not ips_result and not ips_mapping:
355
- return ""
356
-
357
- # Format detected IPs
358
- ips_list = []
359
- if ips_result:
360
- ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
361
- for ip in ips_result]
362
-
363
- # Format character-to-IP mapping
364
- mapping_list = []
365
- if ips_mapping:
366
- for char, ips in ips_mapping.items():
367
- formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
368
- formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
369
- for ip in ips]
370
- mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}")
371
-
372
- # Combine all into a single string
373
- result_parts = []
374
- if ips_list:
375
- result_parts.append(", ".join(ips_list))
376
- if mapping_list:
377
- result_parts.extend(mapping_list)
378
-
379
- return ", ".join(result_parts)
380
-
381
- def process_single_image(
382
- image_path,
383
- model_name="deepghs/pixai-tagger-v0.9-onnx", ###
384
- general_threshold=0.3,
385
- character_threshold=0.85,
386
- progress=None,
387
- idx=0,
388
- total_images=1
389
- ):
390
- """Process a single image and return all formatted outputs."""
391
- try:
392
- if image_path is None:
393
- return "", "", "", "", {}, {}
394
-
395
- if progress:
396
- progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}")
397
-
398
- # Load image from path
399
- pil_image = Image.open(image_path)
400
-
401
- # Set thresholds
402
- thresholds = {
403
- 'general': general_threshold,
404
- 'character': character_threshold
405
- }
406
-
407
- # Get all tag categories
408
- all_categories = get_pixai_tags(
409
- pil_image, model_name, thresholds, fmt='all'
410
- )
411
-
412
- # Ensure we have at least 3 categories (general, character, rating)
413
- while len(all_categories) < 3:
414
- all_categories += ({},)
415
-
416
- general_tags = all_categories[0] if len(all_categories) > 0 else {}
417
- character_tags = all_categories[1] if len(all_categories) > 1 else {}
418
- rating_tags = all_categories[2] if len(all_categories) > 2 else {}
419
-
420
- # Get IP detection data
421
- ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or []
422
- ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {}
423
-
424
- # Format character tags (names only)
425
- character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway
426
- for name in character_tags.keys()]
427
- character_output = ", ".join(character_names)
428
-
429
- # Format general tags (names only)
430
- general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
431
- for name in general_tags.keys()]
432
- general_output = ", ".join(general_names)
433
-
434
- # Format IP detection output
435
- ips_output = format_ips_output(ips_result, ips_mapping)
436
-
437
- # Format combined tags (Character tags first, then General tags, then IP tags)
438
- combined_parts = []
439
- if character_names:
440
- combined_parts.append(", ".join(character_names))
441
- if general_names:
442
- combined_parts.append(", ".join(general_names))
443
- if ips_output:
444
- combined_parts.append(ips_output)
445
-
446
- combined_output = ", ".join(combined_parts)
447
-
448
- # Get detailed JSON data
449
- json_data = {
450
- "character_tags": character_tags,
451
- "general_tags": general_tags,
452
- "rating_tags": rating_tags,
453
- "ips_result": ips_result,
454
- "ips_mapping": ips_mapping
455
- }
456
-
457
- # Format rating as label-compatible dict
458
- rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v
459
- for k, v in rating_tags.items()}
460
-
461
- return (
462
- character_output, # Character tags
463
- general_output, # General tags
464
- ips_output, # IP Detection
465
- combined_output, # Combined tags
466
- json_data, # Detailed JSON
467
- rating_output # Rating <- Not working atm
468
- )
469
- except Exception as e:
470
- error_msg = f"Error: {str(e)}"
471
- # Return error message for all 6 outputs
472
- return error_msg, error_msg, error_msg, error_msg, {}, {} # 6
473
-
474
- """GPU"""
475
- def unload_model():
476
- """Explicitly unload the current model from memory"""
477
- global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
478
- global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
479
- # Delete the model session
480
- if CURRENT_MODEL is not None:
481
- del CURRENT_MODEL
482
- CURRENT_MODEL = None
483
- # Clear other large objects
484
- CURRENT_TAGS_DF = None
485
- CURRENT_D_IPS = None
486
- CURRENT_PREPROCESS_FUNC = None
487
- CURRENT_THRESHOLDS = None
488
- CURRENT_CATEGORY_NAMES = None
489
- CURRENT_MODEL_NAME = None
490
- # Force garbage collection
491
- import gc
492
- gc.collect()
493
- # Clear CUDA cache if using GPU
494
- try:
495
- import torch
496
- if torch.cuda.is_available():
497
- torch.cuda.empty_cache()
498
- except ImportError:
499
- pass
500
- # print("Model unloaded and memory cleared")
501
- def cleanup_after_processing():
502
- unload_model()
503
-
504
- def process_gallery_images(
505
- gallery,
506
- model_name,
507
- general_threshold,
508
- character_threshold,
509
- progress=gr.Progress()
510
- ):
511
- """Process all images in the gallery and return results with download file."""
512
- if not gallery:
513
- return [], "", "", "", {}, {}, {}, None
514
-
515
- tag_results = {}
516
- txt_infos = []
517
- output_dir = tempfile.mkdtemp()
518
-
519
- if not os.path.exists(output_dir):
520
- os.makedirs(output_dir)
521
-
522
- total_images = len(gallery)
523
- timer = Timer()
524
-
525
- try:
526
- for idx, image_data in enumerate(gallery):
527
- try:
528
- image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data
529
-
530
- # Process image
531
- results = process_single_image(
532
- image_path, model_name, general_threshold, character_threshold,
533
- progress, idx, total_images
534
- )
535
-
536
- # Store results
537
- tag_results[image_path] = {
538
- 'character_tags': results[0],
539
- 'general_tags': results[1],
540
- 'ips_detection': results[2],
541
- 'combined_tags': results[3],
542
- 'json_data': results[4],
543
- 'rating': results[5]
544
- }
545
-
546
- # Create output files with descriptive names
547
- image_name = os.path.splitext(os.path.basename(image_path))[0]
548
-
549
- # Save all output files with descriptive prefixes
550
- files_to_create = [
551
- (f"character_tags-{image_name}.txt", results[0]),
552
- (f"general_tags-{image_name}.txt", results[1]),
553
- (f"ips_detection-{image_name}.txt", results[2]),
554
- (f"combined_tags-{image_name}.txt", results[3]),
555
- (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False))
556
- ]
557
-
558
- for file_name, content in files_to_create:
559
- file_path = os.path.join(output_dir, file_name)
560
- with open(file_path, 'w', encoding='utf-8') as f:
561
- f.write(content if isinstance(content, str) else content)
562
- txt_infos.append({'path': file_path, 'name': file_name})
563
-
564
- # Copy original image
565
- original_image = Image.open(image_path)
566
- image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}")
567
- original_image.save(image_copy_path)
568
- txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"})
569
-
570
- timer.checkpoint(f"image{idx:02d}, processed")
571
-
572
- except Exception as e:
573
- print(f"Error processing image {image_path}: {str(e)}")
574
- print(traceback.format_exc())
575
- continue
576
-
577
- # Create zip file
578
- download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip")
579
- with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
580
- for info in txt_infos:
581
- zipf.write(info['path'], arcname=info['name'])
582
- # If using GPU, model will auto unload after zip file creation
583
- cleanup_after_processing() # Comment here to turn off this behavior
584
-
585
- progress(1.0, desc="Processing complete")
586
- timer.report_all()
587
- print('Processing is complete.')
588
-
589
- # Return first image results as default if available even if we are tagging 1000+ images.
590
- first_image_results = ("", "", "", {}, {}, "") # 6
591
- if gallery and len(gallery) > 0:
592
- first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0]
593
- if first_image_path in tag_results:
594
- result = tag_results[first_image_path]
595
- first_image_results = (
596
- result['character_tags'],
597
- result['general_tags'],
598
- result['combined_tags'],
599
- result['json_data'],
600
- result['rating'],
601
- result['ips_detection']
602
- )
603
-
604
- return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path
605
-
606
- except Exception as e:
607
- print(f"Error in process_gallery_images: {str(e)}")
608
- print(traceback.format_exc())
609
- progress(1.0, desc="Processing failed")
610
- return {}, "", "", "", {}, {}, "", None
611
-
612
- def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData):
613
- """Handle gallery image selection and update UI with stored results."""
614
- if not selected_state or not tag_results:
615
- return "", "", "", {}, {}, ""
616
-
617
- # Get selected image path
618
- selected_value = selected_state.value
619
- if isinstance(selected_value, dict) and 'image' in selected_value:
620
- image_path = selected_value['image']['path']
621
- elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
622
- image_path = selected_value[0]
623
- else:
624
- image_path = str(selected_value)
625
-
626
- # Retrieve stored results
627
- if image_path in tag_results:
628
- result = tag_results[image_path]
629
- return (
630
- result['character_tags'],
631
- result['general_tags'],
632
- result['combined_tags'],
633
- result['json_data'],
634
- result['rating'],
635
- result['ips_detection']
636
- )
637
-
638
- # Return empty if not found
639
- return "", "", "", {}, {}, ""
640
-
641
- def append_gallery(gallery, image):
642
- """Add a single image to the gallery."""
643
- if gallery is None:
644
- gallery = []
645
- if not image:
646
- return gallery, None
647
- gallery.append(image)
648
- return gallery, None
649
-
650
- def extend_gallery(gallery, images):
651
- """Add multiple images to the gallery."""
652
- if gallery is None:
653
- gallery = []
654
- if not images:
655
- return gallery
656
- gallery.extend(images)
657
- return gallery
658
-
659
- def create_pixai_interface():
660
- """Create the PixAI Gradio interface"""
661
- with gr.Blocks(css=css, fill_width=True) as demo:
662
- # gr.Markdown("Upload anime-style images to extract tags using PixAI")
663
- # State to store results
664
- tag_results = gr.State({})
665
- selected_image = gr.Textbox(label='Selected Image', visible=False)
666
-
667
- with gr.Row():
668
- with gr.Column():
669
- # Image upload section
670
- with gr.Column(variant='panel'):
671
- image_input = gr.Image(
672
- label='Upload an Image or clicking paste from clipboard button',
673
- type='filepath',
674
- sources=['upload', 'clipboard'],
675
- height=150
676
- )
677
- with gr.Row():
678
- upload_button = gr.UploadButton(
679
- 'Upload multiple images',
680
- file_types=['image'],
681
- file_count='multiple',
682
- size='sm'
683
- )
684
- gallery = gr.Gallery(
685
- columns=2,
686
- show_share_button=False,
687
- interactive=True,
688
- height='auto',
689
- label='Grid of images',
690
- preview=False,
691
- elem_id='custom-gallery'
692
- )
693
- run_button = gr.Button("Analyze Images", variant="primary", size='lg')
694
- model_dropdown = gr.Dropdown(
695
- choices=["deepghs/pixai-tagger-v0.9-onnx"],
696
- value="deepghs/pixai-tagger-v0.9-onnx",
697
- label="Model"
698
- )
699
- # Threshold controls
700
- with gr.Row():
701
- general_threshold = gr.Slider(
702
- minimum=0.0, maximum=1.0, value=0.30, step=0.05,
703
- label="General Tags Threshold", scale=3
704
- )
705
- character_threshold = gr.Slider(
706
- minimum=0.0, maximum=1.0, value=0.85, step=0.05,
707
- label="Character Tags Threshold", scale=3
708
- )
709
-
710
- with gr.Row():
711
- clear = gr.ClearButton(
712
- components=[gallery, model_dropdown, general_threshold, character_threshold],
713
- variant='secondary',
714
- size='lg'
715
- )
716
- clear.add([tag_results])
717
- detailed_json_output = gr.JSON(label="Detailed JSON")
718
-
719
- with gr.Column(variant='panel'):
720
-
721
- download_file = gr.File(label="Download")
722
-
723
- # Output blocks
724
- character_tags_output = gr.Textbox(
725
- label="Character tags",
726
- show_copy_button=True,
727
- lines=3
728
- )
729
- general_tags_output = gr.Textbox(
730
- label="General tags",
731
- show_copy_button=True,
732
- lines=3
733
- )
734
- ips_detection_output = gr.Textbox(
735
- label="IPs Detection",
736
- show_copy_button=True,
737
- lines=5
738
- )
739
- combined_tags_output = gr.Textbox(
740
- label="Combined tags",
741
- show_copy_button=True,
742
- lines=6
743
- )
744
- rating_output = gr.Label(label="Rating")
745
-
746
- # Clear button targets
747
- clear.add([
748
- download_file,
749
- character_tags_output,
750
- general_tags_output,
751
- ips_detection_output,
752
- combined_tags_output,
753
- rating_output,
754
- detailed_json_output
755
- ])
756
-
757
- # Event handlers
758
- image_input.change(
759
- append_gallery,
760
- inputs=[gallery, image_input],
761
- outputs=[gallery, image_input]
762
- )
763
-
764
- upload_button.upload(
765
- extend_gallery,
766
- inputs=[gallery, upload_button],
767
- outputs=gallery
768
- )
769
-
770
- gallery.select(
771
- get_selection_from_gallery,
772
- inputs=[gallery, tag_results],
773
- outputs=[
774
- character_tags_output,
775
- general_tags_output,
776
- combined_tags_output,
777
- detailed_json_output,
778
- rating_output,
779
- ips_detection_output
780
- ]
781
- )
782
-
783
- run_button.click(
784
- process_gallery_images,
785
- inputs=[gallery, model_dropdown, general_threshold, character_threshold],
786
- outputs=[
787
- tag_results,
788
- character_tags_output,
789
- general_tags_output,
790
- combined_tags_output,
791
- detailed_json_output,
792
- rating_output,
793
- ips_detection_output,
794
- download_file
795
- ]
796
- )
797
-
798
- gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)')
799
-
800
- return demo
801
-
802
- # Export public API
803
- __all__ = [
804
- 'get_pixai_tags',
805
- 'process_single_image',
806
- 'process_gallery_images',
807
- 'create_pixai_interface',
808
- 'unload_model',
809
- 'cleanup_after_processing'
810
- ]
 
1
+ import os, json, zipfile, tempfile, time, traceback
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from collections import defaultdict
7
+ from typing import Union, Dict, Any, Tuple, List
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.errors import EntryNotFoundError
11
+ from datetime import datetime
12
+ from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads
13
+
14
+ # Global variables for model components (for memory management)
15
+ CURRENT_MODEL = None
16
+ CURRENT_MODEL_NAME = None
17
+ CURRENT_TAGS_DF = None
18
+ CURRENT_D_IPS = None
19
+ CURRENT_PREPROCESS_FUNC = None
20
+ CURRENT_THRESHOLDS = None
21
+ CURRENT_CATEGORY_NAMES = None
22
+
23
+ css = """
24
+ #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
25
+ #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
26
+ #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
27
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
28
+ #custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
29
+ #custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
30
+ .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
31
+ .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
32
+ #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
33
+ """
34
+
35
+ def preprocess_on_gpu(img, device='cuda'):
36
+ """Preprocess image on GPU using PyTorch"""
37
+ import torch
38
+ import torchvision.transforms as transforms
39
+ # Convert PIL to tensor and move to GPU
40
+ transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
41
+ # Move to GPU if available
42
+ tensor_img = transform(img).unsqueeze(0)
43
+ if torch.cuda.is_available():
44
+ tensor_img = tensor_img.to(device)
45
+ return tensor_img.cpu().numpy()
46
+
47
+ class Timer: # Report the execution time & process
48
+ def __init__(self):
49
+ self.start_time = time.perf_counter()
50
+ self.checkpoints = [('Start', self.start_time)]
51
+
52
+ def checkpoint(self, label='Checkpoint'):
53
+ now = time.perf_counter()
54
+ self.checkpoints.append((label, now))
55
+
56
+ def report(self, is_clear_checkpoints=True):
57
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
58
+ prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
59
+
60
+ for (label, curr_time) in self.checkpoints[1:]:
61
+ elapsed = curr_time - prev_time
62
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
63
+ prev_time = curr_time
64
+
65
+ if is_clear_checkpoints:
66
+ self.checkpoints.clear()
67
+ self.checkpoint()
68
+
69
+ def report_all(self):
70
+ print('\n> Execution Time Report:')
71
+ max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
72
+ prev_time = self.start_time
73
+
74
+ for (label, curr_time) in self.checkpoints[1:]:
75
+ elapsed = curr_time - prev_time
76
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
77
+ prev_time = curr_time
78
+
79
+ total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
80
+ print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests
81
+ self.checkpoints.clear()
82
+
83
+ def restart(self):
84
+ self.start_time = time.perf_counter()
85
+ self.checkpoints = [('Start', self.start_time)]
86
+
87
+ def _get_repo_id(model_name: str) -> str:
88
+ """Get the repository ID for the specified model name."""
89
+ if '/' in model_name:
90
+ return model_name
91
+ else:
92
+ return f'deepghs/pixai-tagger-{model_name}-onnx'
93
+
94
+ def _download_model_files(model_name: str):
95
+ """Download all required model files."""
96
+ repo_id = _get_repo_id(model_name)
97
+
98
+ # Download the necessary files using hf_hub_download instead of local cache...
99
+ model_path = hf_hub_download(
100
+ repo_id=repo_id,
101
+ filename='model.onnx',
102
+ library_name="pixai-tagger"
103
+ )
104
+ tags_path = hf_hub_download(
105
+ repo_id=repo_id,
106
+ filename='selected_tags.csv',
107
+ library_name="pixai-tagger"
108
+ )
109
+ preprocess_path = hf_hub_download(
110
+ repo_id=repo_id,
111
+ filename='preprocess.json',
112
+ library_name="pixai-tagger"
113
+ )
114
+ try:
115
+ thresholds_path = hf_hub_download(
116
+ repo_id=repo_id,
117
+ filename='thresholds.csv',
118
+ library_name="pixai-tagger"
119
+ )
120
+ except EntryNotFoundError:
121
+ thresholds_path = None
122
+
123
+ return model_path, tags_path, preprocess_path, thresholds_path
124
+
125
+ def create_optimized_ort_session(model_path):
126
+ """Create an optimized ONNX Runtime session with GPU support"""
127
+ # Test: Session options for better performance
128
+ sess_options = ort.SessionOptions()
129
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
130
+ sess_options.intra_op_num_threads = 0 # Use all available cores
131
+ sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
132
+ sess_options.enable_mem_pattern = True
133
+ sess_options.enable_cpu_mem_arena = True
134
+
135
+ # Check available providers
136
+ available_providers = ort.get_available_providers()
137
+ print(f"Available ONNX Runtime providers: {available_providers}")
138
+
139
+ # Use appropriate execution providers (in order of preference)
140
+ providers = []
141
+
142
+ # Use CUDA if available
143
+ if 'CUDAExecutionProvider' in available_providers:
144
+ cuda_provider = ('CUDAExecutionProvider', {
145
+ 'device_id': 0,
146
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
147
+ 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM
148
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
149
+ 'do_copy_in_default_stream': True,
150
+ })
151
+ providers.append(cuda_provider)
152
+ print("Using CUDA provider for ONNX inference")
153
+ else:
154
+ print("CUDA provider not available, falling back to CPU")
155
+
156
+ # Always include CPU as fallback (FOR HF)
157
+ providers.append('CPUExecutionProvider')
158
+
159
+ try:
160
+ session = ort.InferenceSession(model_path, sess_options, providers=providers)
161
+ print(f"Model loaded with providers: {session.get_providers()}")
162
+ return session
163
+ except Exception as e:
164
+ print(f"Failed to create ONNX session: {e}")
165
+ raise
166
+
167
+ def _load_model_components_optimized(model_name: str):
168
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
169
+ global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
170
+
171
+ # Only reload if model changed
172
+ if CURRENT_MODEL_NAME != model_name:
173
+ # Download files
174
+ model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name)
175
+
176
+ # Load optimized ONNX model
177
+ CURRENT_MODEL = create_optimized_ort_session(model_path)
178
+
179
+ # Load tags
180
+ CURRENT_TAGS_DF = pd.read_csv(tags_path)
181
+ CURRENT_D_IPS = {}
182
+
183
+ if 'ips' in CURRENT_TAGS_DF.columns:
184
+ CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads)
185
+ for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']):
186
+ if ips:
187
+ CURRENT_D_IPS[name] = ips
188
+
189
+ # Load preprocessing
190
+ with open(preprocess_path, 'r') as f:
191
+ data_ = json.load(f)
192
+ # Simple preprocessing function
193
+ def transform(img):
194
+ # Ensure image is in RGB mode
195
+ if img.mode != 'RGB':
196
+ img = img.convert('RGB')
197
+
198
+ # Resize to 448x448 <- Very important.
199
+ img = img.resize((448, 448), Image.Resampling.LANCZOS)
200
+
201
+ # Convert to numpy array and normalize
202
+ img_array = np.array(img).astype(np.float32)
203
+
204
+ # Normalize pixel values to [0, 1]
205
+ img_array = img_array / 255.0
206
+
207
+ # Normalize with ImageNet mean and std
208
+ mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32)
209
+ std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32)
210
+ img_array = (img_array - mean) / std
211
+
212
+ # Transpose to (C, H, W)
213
+ img_array = np.transpose(img_array, (2, 0, 1))
214
+ return img_array
215
+
216
+ CURRENT_PREPROCESS_FUNC = transform
217
+
218
+ # Load thresholds
219
+ CURRENT_THRESHOLDS = {}
220
+ CURRENT_CATEGORY_NAMES = {}
221
+
222
+ if thresholds_path and os.path.exists(thresholds_path):
223
+ df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False)
224
+ for item in df_category_thresholds.to_dict('records'):
225
+ if item['category'] not in CURRENT_THRESHOLDS:
226
+ CURRENT_THRESHOLDS[item['category']] = item['threshold']
227
+ CURRENT_CATEGORY_NAMES[item['category']] = item['name']
228
+ else:
229
+ # Default thresholds if file doesn't exist
230
+ CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85}
231
+ CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'}
232
+
233
+ CURRENT_MODEL_NAME = model_name
234
+
235
+ return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC,
236
+ CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES)
237
+
238
+ def _raw_predict(image: Image.Image, model_name: str):
239
+ """Make a raw prediction with the PixAI tagger model."""
240
+ try:
241
+ # Ensure we have a PIL Image
242
+ if not isinstance(image, Image.Image):
243
+ raise ValueError("Input must be a PIL Image") # <-
244
+
245
+ # Load model components
246
+ model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name)
247
+
248
+ # Preprocess image
249
+ input_tensor = preprocess_func(image)
250
+
251
+ # Add batch dimension
252
+ if len(input_tensor.shape) == 3:
253
+ input_tensor = np.expand_dims(input_tensor, axis=0)
254
+
255
+ # Run inference
256
+ output_names = [output.name for output in model.get_outputs()]
257
+ output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)})
258
+
259
+ return {name: value[0] for name, value in zip(output_names, output_values)}
260
+
261
+ except Exception as e:
262
+ raise RuntimeError(f"Error processing image: {str(e)}")
263
+
264
+ def get_pixai_tags(
265
+ image: Union[str, Image.Image],
266
+ model_name: str = 'deepghs/pixai-tagger-v0.9-onnx',
267
+ thresholds: Union[float, Dict[Any, float]] = None,
268
+ fmt='all'
269
+ ):
270
+ try:
271
+ # Load image if it's a path
272
+ if isinstance(image, str):
273
+ pil_image = Image.open(image)
274
+ elif isinstance(image, Image.Image):
275
+ pil_image = image
276
+ else:
277
+ raise ValueError("Image must be a file path or PIL Image")
278
+
279
+ # Load model components
280
+ _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name)
281
+
282
+ values = _raw_predict(pil_image, model_name)
283
+ prediction = values.get('prediction', np.array([]))
284
+
285
+ if prediction.size == 0:
286
+ raise RuntimeError("Model did not return valid predictions")
287
+
288
+ tags = {}
289
+
290
+ # Process tags by category
291
+ for category in sorted(set(df_tags['category'].tolist())):
292
+ mask = df_tags['category'] == category
293
+ tag_names = df_tags.loc[mask, 'name']
294
+ category_pred = prediction[mask]
295
+
296
+ # Determine threshold for this category
297
+ if isinstance(thresholds, float):
298
+ category_threshold = thresholds
299
+ elif isinstance(thresholds, dict) and \
300
+ (category in thresholds or category_names.get(category, '') in thresholds):
301
+ if category in thresholds:
302
+ category_threshold = thresholds[category]
303
+ elif category_names.get(category, '') in thresholds:
304
+ category_threshold = thresholds[category_names[category]]
305
+ else:
306
+ category_threshold = 0.85
307
+ else:
308
+ category_threshold = default_thresholds.get(category, 0.85)
309
+
310
+ # Apply threshold
311
+ pred_mask = category_pred >= category_threshold
312
+ filtered_tag_names = tag_names[pred_mask].tolist()
313
+ filtered_predictions = category_pred[pred_mask].tolist()
314
+
315
+ # Sort by confidence
316
+ cate_tags = dict(sorted(
317
+ zip(filtered_tag_names, filtered_predictions),
318
+ key=lambda x: (-x[1], x[0])
319
+ ))
320
+
321
+ category_name = category_names.get(category, f"category_{category}")
322
+ values[category_name] = cate_tags
323
+ tags.update(cate_tags)
324
+
325
+ values['tag'] = tags
326
+
327
+ # Handle IPs if available
328
+ if 'ips' in df_tags.columns:
329
+ ips_mapping, ips_counts = {}, defaultdict(int)
330
+ for tag, _ in tags.items():
331
+ if tag in d_ips:
332
+ ips_mapping[tag] = d_ips[tag]
333
+ for ip_name in d_ips[tag]:
334
+ ips_counts[ip_name] += 1
335
+ values['ips_mapping'] = ips_mapping
336
+ values['ips_count'] = dict(ips_counts)
337
+ values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))]
338
+
339
+ # Return based on format
340
+ if fmt == 'all':
341
+ # Return all available categories
342
+ available_categories = [category_names.get(cat, f"category_{cat}")
343
+ for cat in sorted(set(df_tags['category'].tolist()))]
344
+ return tuple(values.get(cat, {}) for cat in available_categories)
345
+ elif fmt in values:
346
+ return values[fmt]
347
+ else:
348
+ return values
349
+
350
+ except Exception as e:
351
+ raise RuntimeError(f"Error processing image: {str(e)}")
352
+
353
+ def format_ips_output(ips_result, ips_mapping):
354
+ """Format IP detection output as a single string with proper escaping."""
355
+ if not ips_result and not ips_mapping:
356
+ return ""
357
+
358
+ # Format detected IPs
359
+ ips_list = []
360
+ if ips_result:
361
+ ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
362
+ for ip in ips_result]
363
+
364
+ # Format character-to-IP mapping
365
+ mapping_list = []
366
+ if ips_mapping:
367
+ for char, ips in ips_mapping.items():
368
+ formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
369
+ formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
370
+ for ip in ips]
371
+ mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}")
372
+
373
+ # Combine all into a single string
374
+ result_parts = []
375
+ if ips_list:
376
+ result_parts.append(", ".join(ips_list))
377
+ if mapping_list:
378
+ result_parts.extend(mapping_list)
379
+
380
+ return ", ".join(result_parts)
381
+
382
+ def process_single_image(
383
+ image_path,
384
+ model_name="deepghs/pixai-tagger-v0.9-onnx", ###
385
+ general_threshold=0.3,
386
+ character_threshold=0.85,
387
+ progress=None,
388
+ idx=0,
389
+ total_images=1
390
+ ):
391
+ """Process a single image and return all formatted outputs."""
392
+ try:
393
+ if image_path is None:
394
+ return "", "", "", "", {}, {}
395
+
396
+ if progress:
397
+ progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}")
398
+
399
+ # Load image from path
400
+ pil_image = Image.open(image_path)
401
+
402
+ # Set thresholds
403
+ thresholds = {
404
+ 'general': general_threshold,
405
+ 'character': character_threshold
406
+ }
407
+
408
+ # Get all tag categories
409
+ all_categories = get_pixai_tags(
410
+ pil_image, model_name, thresholds, fmt='all'
411
+ )
412
+
413
+ # Ensure we have at least 3 categories (general, character, rating)
414
+ while len(all_categories) < 3:
415
+ all_categories += ({},)
416
+
417
+ general_tags = all_categories[0] if len(all_categories) > 0 else {}
418
+ character_tags = all_categories[1] if len(all_categories) > 1 else {}
419
+ rating_tags = all_categories[2] if len(all_categories) > 2 else {}
420
+
421
+ # Get IP detection data
422
+ ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or []
423
+ ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {}
424
+
425
+ # Format character tags (names only)
426
+ character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway
427
+ for name in character_tags.keys()]
428
+ character_output = ", ".join(character_names)
429
+
430
+ # Format general tags (names only)
431
+ general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
432
+ for name in general_tags.keys()]
433
+ general_output = ", ".join(general_names)
434
+
435
+ # Format IP detection output
436
+ ips_output = format_ips_output(ips_result, ips_mapping)
437
+
438
+ # Format combined tags (Character tags first, then General tags, then IP tags)
439
+ combined_parts = []
440
+ if character_names:
441
+ combined_parts.append(", ".join(character_names))
442
+ if general_names:
443
+ combined_parts.append(", ".join(general_names))
444
+ if ips_output:
445
+ combined_parts.append(ips_output)
446
+
447
+ combined_output = ", ".join(combined_parts)
448
+
449
+ # Get detailed JSON data
450
+ json_data = {
451
+ "character_tags": character_tags,
452
+ "general_tags": general_tags,
453
+ "rating_tags": rating_tags,
454
+ "ips_result": ips_result,
455
+ "ips_mapping": ips_mapping
456
+ }
457
+
458
+ # Format rating as label-compatible dict
459
+ rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v
460
+ for k, v in rating_tags.items()}
461
+
462
+ return (
463
+ character_output, # Character tags
464
+ general_output, # General tags
465
+ ips_output, # IP Detection
466
+ combined_output, # Combined tags
467
+ json_data, # Detailed JSON
468
+ rating_output # Rating <- Not working atm
469
+ )
470
+ except Exception as e:
471
+ error_msg = f"Error: {str(e)}"
472
+ # Return error message for all 6 outputs
473
+ return error_msg, error_msg, error_msg, error_msg, {}, {} # 6
474
+
475
+ """GPU"""
476
+ def unload_model():
477
+ """Explicitly unload the current model from memory"""
478
+ global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS
479
+ global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES
480
+ # Delete the model session
481
+ if CURRENT_MODEL is not None:
482
+ del CURRENT_MODEL
483
+ CURRENT_MODEL = None
484
+ # Clear other large objects
485
+ CURRENT_TAGS_DF = None
486
+ CURRENT_D_IPS = None
487
+ CURRENT_PREPROCESS_FUNC = None
488
+ CURRENT_THRESHOLDS = None
489
+ CURRENT_CATEGORY_NAMES = None
490
+ CURRENT_MODEL_NAME = None
491
+ # Force garbage collection
492
+ import gc
493
+ gc.collect()
494
+ # Clear CUDA cache if using GPU
495
+ try:
496
+ import torch
497
+ if torch.cuda.is_available():
498
+ torch.cuda.empty_cache()
499
+ except ImportError:
500
+ pass
501
+ # print("Model unloaded and memory cleared")
502
+ def cleanup_after_processing():
503
+ unload_model()
504
+
505
+ def process_gallery_images(
506
+ gallery,
507
+ model_name,
508
+ general_threshold,
509
+ character_threshold,
510
+ progress=gr.Progress()
511
+ ):
512
+ """Process all images in the gallery and return results with download file."""
513
+ if not gallery:
514
+ return [], "", "", "", {}, {}, {}, None
515
+
516
+ tag_results = {}
517
+ txt_infos = []
518
+ output_dir = tempfile.mkdtemp()
519
+
520
+ if not os.path.exists(output_dir):
521
+ os.makedirs(output_dir)
522
+
523
+ total_images = len(gallery)
524
+ timer = Timer()
525
+
526
+ try:
527
+ for idx, image_data in enumerate(gallery):
528
+ try:
529
+ image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data
530
+
531
+ # Process image
532
+ results = process_single_image(
533
+ image_path, model_name, general_threshold, character_threshold,
534
+ progress, idx, total_images
535
+ )
536
+
537
+ # Store results
538
+ tag_results[image_path] = {
539
+ 'character_tags': results[0],
540
+ 'general_tags': results[1],
541
+ 'ips_detection': results[2],
542
+ 'combined_tags': results[3],
543
+ 'json_data': results[4],
544
+ 'rating': results[5]
545
+ }
546
+
547
+ # Create output files with descriptive names
548
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
549
+
550
+ # Save all output files with descriptive prefixes
551
+ files_to_create = [
552
+ (f"character_tags-{image_name}.txt", results[0]),
553
+ (f"general_tags-{image_name}.txt", results[1]),
554
+ (f"ips_detection-{image_name}.txt", results[2]),
555
+ (f"combined_tags-{image_name}.txt", results[3]),
556
+ (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False))
557
+ ]
558
+
559
+ for file_name, content in files_to_create:
560
+ file_path = os.path.join(output_dir, file_name)
561
+ with open(file_path, 'w', encoding='utf-8') as f:
562
+ f.write(content if isinstance(content, str) else content)
563
+ txt_infos.append({'path': file_path, 'name': file_name})
564
+
565
+ # Copy original image
566
+ original_image = Image.open(image_path)
567
+ image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}")
568
+ original_image.save(image_copy_path)
569
+ txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"})
570
+
571
+ timer.checkpoint(f"image{idx:02d}, processed")
572
+
573
+ except Exception as e:
574
+ print(f"Error processing image {image_path}: {str(e)}")
575
+ print(traceback.format_exc())
576
+ continue
577
+
578
+ # Create zip file
579
+ download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip")
580
+ with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
581
+ for info in txt_infos:
582
+ zipf.write(info['path'], arcname=info['name'])
583
+ # If using GPU, model will auto unload after zip file creation
584
+ cleanup_after_processing() # Comment here to turn off this behavior
585
+
586
+ progress(1.0, desc="Processing complete")
587
+ timer.report_all()
588
+ print('Processing is complete.')
589
+
590
+ # Return first image results as default if available even if we are tagging 1000+ images.
591
+ first_image_results = ("", "", "", {}, {}, "") # 6
592
+ if gallery and len(gallery) > 0:
593
+ first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0]
594
+ if first_image_path in tag_results:
595
+ result = tag_results[first_image_path]
596
+ first_image_results = (
597
+ result['character_tags'],
598
+ result['general_tags'],
599
+ result['combined_tags'],
600
+ result['json_data'],
601
+ result['rating'],
602
+ result['ips_detection']
603
+ )
604
+
605
+ return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path
606
+
607
+ except Exception as e:
608
+ print(f"Error in process_gallery_images: {str(e)}")
609
+ print(traceback.format_exc())
610
+ progress(1.0, desc="Processing failed")
611
+ return {}, "", "", "", {}, {}, "", None
612
+
613
+ def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData):
614
+ """Handle gallery image selection and update UI with stored results."""
615
+ if not selected_state or not tag_results:
616
+ return "", "", "", {}, {}, ""
617
+
618
+ # Get selected image path
619
+ selected_value = selected_state.value
620
+ if isinstance(selected_value, dict) and 'image' in selected_value:
621
+ image_path = selected_value['image']['path']
622
+ elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
623
+ image_path = selected_value[0]
624
+ else:
625
+ image_path = str(selected_value)
626
+
627
+ # Retrieve stored results
628
+ if image_path in tag_results:
629
+ result = tag_results[image_path]
630
+ return (
631
+ result['character_tags'],
632
+ result['general_tags'],
633
+ result['combined_tags'],
634
+ result['json_data'],
635
+ result['rating'],
636
+ result['ips_detection']
637
+ )
638
+
639
+ # Return empty if not found
640
+ return "", "", "", {}, {}, ""
641
+
642
+ def append_gallery(gallery, image):
643
+ """Add a single media file (image or video) to the gallery."""
644
+ return handle_single_media_upload(image, gallery)
645
+
646
+ def extend_gallery(gallery, images):
647
+ """Add multiple media files (images or videos) to the gallery."""
648
+ return handle_multiple_media_uploads(images, gallery)
649
+
650
+ def create_pixai_interface():
651
+ """Create the PixAI Gradio interface"""
652
+ with gr.Blocks(css=css, fill_width=True) as demo:
653
+ # gr.Markdown("Upload anime-style images to extract tags using PixAI")
654
+ # State to store results
655
+ tag_results = gr.State({})
656
+ selected_image = gr.Textbox(label='Selected Image', visible=False)
657
+
658
+ with gr.Row():
659
+ with gr.Column():
660
+ # Image upload section
661
+ with gr.Column(variant='panel'):
662
+ image_input = gr.Image(
663
+ label='Upload an Image (or paste from clipboard)',
664
+ type='filepath',
665
+ sources=['upload', 'clipboard'],
666
+ height=150
667
+ )
668
+ with gr.Row():
669
+ upload_button = gr.UploadButton(
670
+ 'Upload multiple images or videos',
671
+ file_types=['image', 'video'],
672
+ file_count='multiple',
673
+ size='sm'
674
+ )
675
+ gallery = gr.Gallery(
676
+ columns=2,
677
+ show_share_button=False,
678
+ interactive=True,
679
+ height='auto',
680
+ label='Grid of images',
681
+ preview=False,
682
+ elem_id='custom-gallery'
683
+ )
684
+ run_button = gr.Button("Analyze Images", variant="primary", size='lg')
685
+ model_dropdown = gr.Dropdown(
686
+ choices=["deepghs/pixai-tagger-v0.9-onnx"],
687
+ value="deepghs/pixai-tagger-v0.9-onnx",
688
+ label="Model"
689
+ )
690
+ # Threshold controls
691
+ with gr.Row():
692
+ general_threshold = gr.Slider(
693
+ minimum=0.0, maximum=1.0, value=0.30, step=0.05,
694
+ label="General Tags Threshold", scale=3
695
+ )
696
+ character_threshold = gr.Slider(
697
+ minimum=0.0, maximum=1.0, value=0.85, step=0.05,
698
+ label="Character Tags Threshold", scale=3
699
+ )
700
+
701
+ with gr.Row():
702
+ clear = gr.ClearButton(
703
+ components=[gallery, model_dropdown, general_threshold, character_threshold],
704
+ variant='secondary',
705
+ size='lg'
706
+ )
707
+ clear.add([tag_results])
708
+ detailed_json_output = gr.JSON(label="Detailed JSON")
709
+
710
+ with gr.Column(variant='panel'):
711
+
712
+ download_file = gr.File(label="Download")
713
+
714
+ # Output blocks
715
+ character_tags_output = gr.Textbox(
716
+ label="Character tags",
717
+ show_copy_button=True,
718
+ lines=3
719
+ )
720
+ general_tags_output = gr.Textbox(
721
+ label="General tags",
722
+ show_copy_button=True,
723
+ lines=3
724
+ )
725
+ ips_detection_output = gr.Textbox(
726
+ label="IPs Detection",
727
+ show_copy_button=True,
728
+ lines=5
729
+ )
730
+ combined_tags_output = gr.Textbox(
731
+ label="Combined tags",
732
+ show_copy_button=True,
733
+ lines=6
734
+ )
735
+ rating_output = gr.Label(label="Rating")
736
+
737
+ # Clear button targets
738
+ clear.add([
739
+ download_file,
740
+ character_tags_output,
741
+ general_tags_output,
742
+ ips_detection_output,
743
+ combined_tags_output,
744
+ rating_output,
745
+ detailed_json_output
746
+ ])
747
+
748
+ # Event handlers
749
+ image_input.change(
750
+ append_gallery,
751
+ inputs=[gallery, image_input],
752
+ outputs=[gallery, image_input]
753
+ )
754
+
755
+ upload_button.upload(
756
+ extend_gallery,
757
+ inputs=[gallery, upload_button],
758
+ outputs=gallery
759
+ )
760
+
761
+ gallery.select(
762
+ get_selection_from_gallery,
763
+ inputs=[gallery, tag_results],
764
+ outputs=[
765
+ character_tags_output,
766
+ general_tags_output,
767
+ combined_tags_output,
768
+ detailed_json_output,
769
+ rating_output,
770
+ ips_detection_output
771
+ ]
772
+ )
773
+
774
+ run_button.click(
775
+ process_gallery_images,
776
+ inputs=[gallery, model_dropdown, general_threshold, character_threshold],
777
+ outputs=[
778
+ tag_results,
779
+ character_tags_output,
780
+ general_tags_output,
781
+ combined_tags_output,
782
+ detailed_json_output,
783
+ rating_output,
784
+ ips_detection_output,
785
+ download_file
786
+ ]
787
+ )
788
+
789
+ gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)')
790
+
791
+ return demo
792
+
793
+ # Export public API
794
+ __all__ = [
795
+ 'get_pixai_tags',
796
+ 'process_single_image',
797
+ 'process_gallery_images',
798
+ 'create_pixai_interface',
799
+ 'unload_model',
800
+ 'cleanup_after_processing'
801
+ ]
 
 
 
 
 
 
 
 
 
modules/video_processor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import tempfile
4
+ from typing import List, Tuple, Optional
5
+ from PIL import Image
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Supported video formats
13
+ SUPPORTED_VIDEO_FORMATS = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv', '.m4v']
14
+
15
+ def is_video_file(file_path: str) -> bool:
16
+ """Check if the file is a supported video format."""
17
+ if not file_path:
18
+ return False
19
+ _, ext = os.path.splitext(file_path.lower())
20
+ return ext in SUPPORTED_VIDEO_FORMATS
21
+
22
+ def get_video_duration(video_path: str) -> float:
23
+ """Get the duration of a video in seconds."""
24
+ try:
25
+ cap = cv2.VideoCapture(video_path)
26
+ if not cap.isOpened():
27
+ logger.error(f"Could not open video: {video_path}")
28
+ return 0.0
29
+
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
32
+
33
+ if fps <= 0:
34
+ logger.warning(f"Invalid FPS for video {video_path}, using fallback method")
35
+ # Fallback method: get duration directly
36
+ duration = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
37
+ # Seek to end to get duration
38
+ cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1.0)
39
+ duration = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
40
+ else:
41
+ duration = frame_count / fps
42
+
43
+ cap.release()
44
+ return max(0.0, duration) # Ensure non-negative duration
45
+
46
+ except Exception as e:
47
+ logger.error(f"Error getting video duration for {video_path}: {str(e)}")
48
+ return 0.0
49
+
50
+ def extract_frames_from_video(
51
+ video_path: str,
52
+ max_duration: int = 30,
53
+ frame_interval: int = 1,
54
+ output_dir: Optional[str] = None
55
+ ) -> List[str]:
56
+ """
57
+ Extract frames from a video at specified intervals.
58
+
59
+ Args:
60
+ video_path: Path to the video file
61
+ max_duration: Maximum duration to process (seconds)
62
+ frame_interval: Interval between frames (seconds)
63
+ output_dir: Directory to save frames (creates temp if None)
64
+
65
+ Returns:
66
+ List of paths to extracted frame images
67
+ """
68
+ if not os.path.exists(video_path):
69
+ logger.error(f"Video file does not exist: {video_path}")
70
+ return []
71
+
72
+ if not is_video_file(video_path):
73
+ logger.error(f"Unsupported video format: {video_path}")
74
+ return []
75
+
76
+ # Create output directory if not provided
77
+ if output_dir is None:
78
+ output_dir = tempfile.mkdtemp(prefix="video_frames_")
79
+
80
+ try:
81
+ # Get video info
82
+ duration = get_video_duration(video_path)
83
+ logger.info(f"Video duration: {duration:.2f} seconds")
84
+
85
+ # Limit duration if necessary
86
+ process_duration = min(duration, max_duration)
87
+ logger.info(f"Processing {process_duration:.2f} seconds of video")
88
+
89
+ # Open video
90
+ cap = cv2.VideoCapture(video_path)
91
+ if not cap.isOpened():
92
+ logger.error(f"Could not open video: {video_path}")
93
+ return []
94
+
95
+ fps = cap.get(cv2.CAP_PROP_FPS)
96
+ if fps <= 0:
97
+ logger.error(f"Invalid FPS: {fps}")
98
+ cap.release()
99
+ return []
100
+
101
+ # Calculate frame positions
102
+ frame_positions = []
103
+ current_time = 0
104
+ while current_time < process_duration:
105
+ frame_number = int(current_time * fps)
106
+ frame_positions.append(frame_number)
107
+ current_time += frame_interval
108
+
109
+ logger.info(f"Extracting {len(frame_positions)} frames")
110
+
111
+ # Extract frames
112
+ frame_paths = []
113
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
114
+
115
+ for i, frame_number in enumerate(frame_positions):
116
+ # Set position to desired frame
117
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
118
+
119
+ ret, frame = cap.read()
120
+ if not ret:
121
+ logger.warning(f"Could not read frame {frame_number}")
122
+ continue
123
+
124
+ # Convert BGR to RGB
125
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
126
+
127
+ # Convert to PIL Image
128
+ pil_image = Image.fromarray(frame_rgb)
129
+
130
+ # Save frame
131
+ frame_filename = f"{video_name}_frame_{i+1:03d}.png"
132
+ frame_path = os.path.join(output_dir, frame_filename)
133
+ pil_image.save(frame_path, 'PNG')
134
+ frame_paths.append(frame_path)
135
+
136
+ logger.debug(f"Saved frame {i+1}/{len(frame_positions)}: {frame_filename}")
137
+
138
+ cap.release()
139
+ logger.info(f"Successfully extracted {len(frame_paths)} frames from {video_path}")
140
+ return frame_paths
141
+
142
+ except Exception as e:
143
+ logger.error(f"Error extracting frames from {video_path}: {str(e)}")
144
+ return []
145
+
146
+ def process_video_upload(video_path: str, max_duration: int = 30, frame_interval: int = 1) -> Tuple[List[str], str]:
147
+ """
148
+ Process a video upload and extract frames.
149
+
150
+ Args:
151
+ video_path: Path to the uploaded video
152
+ max_duration: Maximum duration to process (seconds)
153
+ frame_interval: Interval between frames (seconds)
154
+
155
+ Returns:
156
+ Tuple of (list of frame paths, output directory)
157
+ """
158
+ # Create temporary directory for frames
159
+ output_dir = tempfile.mkdtemp(prefix="video_frames_")
160
+
161
+ # Extract frames
162
+ frame_paths = extract_frames_from_video(
163
+ video_path,
164
+ max_duration,
165
+ frame_interval,
166
+ output_dir
167
+ )
168
+
169
+ return frame_paths, output_dir
170
+
171
+ def get_video_info(video_path: str) -> dict:
172
+ """Get detailed information about a video file."""
173
+ try:
174
+ cap = cv2.VideoCapture(video_path)
175
+ if not cap.isOpened():
176
+ return {"error": "Could not open video"}
177
+
178
+ fps = cap.get(cv2.CAP_PROP_FPS)
179
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
180
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
181
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
182
+ duration = frame_count / fps if fps > 0 else 0
183
+
184
+ cap.release()
185
+
186
+ return {
187
+ "duration": duration,
188
+ "fps": fps,
189
+ "frame_count": frame_count,
190
+ "width": width,
191
+ "height": height,
192
+ "resolution": f"{width}x{height}"
193
+ }
194
+
195
+ except Exception as e:
196
+ return {"error": str(e)}
197
+
198
+ # Export functions
199
+ __all__ = [
200
+ 'is_video_file',
201
+ 'get_video_duration',
202
+ 'extract_frames_from_video',
203
+ 'process_video_upload',
204
+ 'get_video_info',
205
+ 'SUPPORTED_VIDEO_FORMATS'
206
+ ]