Spaces:
Running
Running
| import os | |
| import json | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as rt | |
| import pandas as pd | |
| from PIL import Image | |
| from huggingface_hub import login | |
| from translator import translate_texts | |
| # ------------------------------------------------------------------ | |
| # 模型配置 | |
| # ------------------------------------------------------------------ | |
| MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| else: | |
| print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") | |
| # ------------------------------------------------------------------ | |
| # Tagger 类 (全局实例化) | |
| # ------------------------------------------------------------------ | |
| class Tagger: | |
| def __init__(self): | |
| self.hf_token = HF_TOKEN | |
| self.tag_names = [] | |
| self.categories = {} | |
| self.model = None | |
| self.input_size = 0 | |
| self._load_model_and_labels() | |
| def _load_model_and_labels(self): | |
| try: | |
| label_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, LABEL_FILENAME, token=self.hf_token, resume_download=True | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, MODEL_FILENAME, token=self.hf_token, resume_download=True | |
| ) | |
| tags_df = pd.read_csv(label_path) | |
| self.tag_names = tags_df["name"].tolist() | |
| self.categories = { | |
| "rating": np.where(tags_df["category"] == 9)[0], | |
| "general": np.where(tags_df["category"] == 0)[0], | |
| "character": np.where(tags_df["category"] == 4)[0], | |
| } | |
| self.model = rt.InferenceSession(model_path) | |
| self.input_size = self.model.get_inputs()[0].shape[1] | |
| print("✅ 模型和标签加载成功") | |
| except Exception as e: | |
| print(f"❌ 模型或标签加载失败: {e}") | |
| raise RuntimeError(f"模型初始化失败: {e}") | |
| def _preprocess(self, img: Image.Image) -> np.ndarray: | |
| if img is None: | |
| raise ValueError("输入图像不能为空") | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| size = max(img.size) | |
| canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
| canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2)) | |
| if size != self.input_size: | |
| canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
| return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR | |
| def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): | |
| if self.model is None: | |
| raise RuntimeError("模型未成功加载,无法进行预测。") | |
| inp_name = self.model.get_inputs()[0].name | |
| outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
| res = {"ratings": {}, "general": {}, "characters": {}} | |
| tag_categories_for_translation = {"ratings": [], "general": [], "characters": []} | |
| for idx in self.categories["rating"]: | |
| tag_name = self.tag_names[idx].replace("_", " ") | |
| res["ratings"][tag_name] = float(outputs[idx]) | |
| tag_categories_for_translation["ratings"].append(tag_name) | |
| for idx in self.categories["general"]: | |
| if outputs[idx] > gen_th: | |
| tag_name = self.tag_names[idx].replace("_", " ") | |
| res["general"][tag_name] = float(outputs[idx]) | |
| tag_categories_for_translation["general"].append(tag_name) | |
| for idx in self.categories["character"]: | |
| if outputs[idx] > char_th: | |
| tag_name = self.tag_names[idx].replace("_", " ") | |
| res["characters"][tag_name] = float(outputs[idx]) | |
| tag_categories_for_translation["characters"].append(tag_name) | |
| res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True)) | |
| res["characters"] = dict(sorted(res["characters"].items(), key=lambda kv: kv[1], reverse=True)) | |
| res["ratings"] = dict(sorted(res["ratings"].items(), key=lambda kv: kv[1], reverse=True)) | |
| tag_categories_for_translation["general"] = list(res["general"].keys()) | |
| tag_categories_for_translation["characters"] = list(res["characters"].keys()) | |
| tag_categories_for_translation["ratings"] = list(res["ratings"].keys()) | |
| return res, tag_categories_for_translation | |
| # 全局 Tagger 实例 | |
| try: | |
| tagger_instance = Tagger() | |
| except RuntimeError as e: | |
| print(f"应用启动时Tagger初始化失败: {e}") | |
| tagger_instance = None # 允许应用启动,但在处理时会失败 | |
| # ------------------------------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------------------------------ | |
| custom_css = """ | |
| .label-container { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| border: 1px solid #ddd; | |
| padding: 10px; | |
| border-radius: 5px; | |
| background-color: #f9f9f9; | |
| } | |
| .tag-item { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin: 2px 0; | |
| padding: 2px 5px; | |
| border-radius: 3px; | |
| background-color: #fff; | |
| transition: background-color 0.2s; | |
| } | |
| .tag-item:hover { | |
| background-color: #f0f0f0; | |
| } | |
| .tag-en { | |
| font-weight: bold; | |
| color: #333; | |
| cursor: pointer; /* Indicates clickable */ | |
| } | |
| .tag-zh { | |
| color: #666; | |
| margin-left: 10px; | |
| } | |
| .tag-score { | |
| color: #999; | |
| font-size: 0.9em; | |
| } | |
| .btn-analyze-container { /* Custom class for analyze button container */ | |
| margin-top: 15px; | |
| margin-bottom: 15px; | |
| } | |
| """ | |
| _js_functions = """ | |
| function copyToClipboard(text) { | |
| console.log('copyToClipboard function was called.'); | |
| console.log('Received text:', text); | |
| // 如果 text 未定义或为 null | |
| if (typeof text === 'undefined' || text === null) { | |
| console.warn('copyToClipboard was called with undefined or null text. Aborting this specific copy operation.'); | |
| return; | |
| } | |
| navigator.clipboard.writeText(text).then(() => { | |
| // console.log('Tag copied to clipboard: ' + text); | |
| const feedback = document.createElement('div'); | |
| // 确保 text 是字符串类型,再进行 substring 操作 | |
| let displayText = String(text); | |
| displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : ''); | |
| feedback.textContent = '已复制: ' + displayText; | |
| feedback.style.position = 'fixed'; | |
| feedback.style.bottom = '20px'; | |
| feedback.style.left = '50%'; | |
| feedback.style.transform = 'translateX(-50%)'; | |
| feedback.style.backgroundColor = '#4CAF50'; | |
| feedback.style.color = 'white'; | |
| feedback.style.padding = '10px 20px'; | |
| feedback.style.borderRadius = '5px'; | |
| feedback.style.zIndex = '10000'; | |
| feedback.style.transition = 'opacity 0.5s ease-out'; | |
| document.body.appendChild(feedback); | |
| setTimeout(() => { | |
| feedback.style.opacity = '0'; | |
| setTimeout(() => { | |
| if (document.body.contains(feedback)) { // 确保元素还在DOM中 | |
| document.body.removeChild(feedback); | |
| } | |
| }, 500); | |
| }, 1500); | |
| }).catch(err => { | |
| console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text); | |
| const errorFeedback = document.createElement('div'); | |
| errorFeedback.textContent = '复制操作失败!'; | |
| errorFeedback.style.position = 'fixed'; | |
| errorFeedback.style.bottom = '20px'; | |
| errorFeedback.style.left = '50%'; | |
| errorFeedback.style.transform = 'translateX(-50%)'; | |
| errorFeedback.style.backgroundColor = '#D32F2F'; | |
| errorFeedback.style.color = 'white'; | |
| errorFeedback.style.padding = '10px 20px'; | |
| errorFeedback.style.borderRadius = '5px'; | |
| errorFeedback.style.zIndex = '10000'; | |
| errorFeedback.style.transition = 'opacity 0.5s ease-out'; | |
| document.body.appendChild(errorFeedback); | |
| setTimeout(() => { | |
| errorFeedback.style.opacity = '0'; | |
| setTimeout(() => { | |
| if (document.body.contains(errorFeedback)) { | |
| document.body.removeChild(errorFeedback); | |
| } | |
| }, 500); | |
| }, 2500); | |
| }); | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo: | |
| gr.Markdown("# 🖼️ AI 图像标签分析器") | |
| gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)") | |
| state_res = gr.State({}) | |
| state_translations_dict = gr.State({}) | |
| state_tag_categories_for_translation = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="上传图片", height=300) | |
| btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"]) | |
| with gr.Accordion("⚙️ 高级设置", open=False): | |
| gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准") | |
| char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值") | |
| show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度") | |
| with gr.Accordion("📊 标签汇总设置", open=True): | |
| gr.Markdown("选择要包含在下方汇总文本框中的标签类别:") | |
| with gr.Row(): | |
| sum_general = gr.Checkbox(True, label="通用标签", min_width=50) | |
| sum_char = gr.Checkbox(True, label="角色标签", min_width=50) | |
| sum_rating = gr.Checkbox(False, label="评分标签", min_width=50) | |
| sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符") | |
| sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译") | |
| processing_info = gr.Markdown("", visible=False) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("🏷️ 通用标签"): | |
| out_general = gr.HTML(label="General Tags") | |
| with gr.TabItem("👤 角色标签"): | |
| gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签推测基于截至2024年2月的数据。</p>") | |
| out_char = gr.HTML(label="Character Tags") | |
| with gr.TabItem("⭐ 评分标签"): | |
| out_rating = gr.HTML(label="Rating Tags") | |
| gr.Markdown("### 标签汇总结果") | |
| out_summary = gr.Textbox( | |
| label="标签汇总", | |
| placeholder="分析完成后,此处将显示汇总的英文标签...", | |
| lines=5, | |
| show_copy_button=True | |
| ) | |
| # ----------------- 辅助函数 ----------------- | |
| def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True): | |
| if not tags_dict: | |
| return "<p>暂无标签</p>" | |
| html = '<div class="label-container">' | |
| if not isinstance(translations_list, list): | |
| translations_list = [] | |
| tag_keys = list(tags_dict.keys()) | |
| for i, tag in enumerate(tag_keys): | |
| score = tags_dict[tag] | |
| escaped_tag = tag.replace("'", "\\'") # Escape for JS | |
| html += '<div class="tag-item">' | |
| tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>' | |
| if show_translation_in_list and i < len(translations_list) and translations_list[i]: | |
| tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>' | |
| html += f'<div>{tag_display_html}</div>' | |
| if show_scores: | |
| html += f'<span class="tag-score">{score:.3f}</span>' | |
| html += '</div>' | |
| html += '</div>' | |
| return html | |
| def generate_summary_text_content( | |
| current_res, current_translations_dict, | |
| s_gen, s_char, s_rat, s_sep_type, s_show_zh | |
| ): | |
| if not current_res: | |
| return "请先分析图像或选择要汇总的标签类别。" | |
| summary_parts = [] | |
| separators = {"逗号": ", ", "换行": "\n", "空格": " "} | |
| separator = separators.get(s_sep_type, ", ") | |
| categories_to_summarize = [] | |
| if s_gen: categories_to_summarize.append("general") | |
| if s_char: categories_to_summarize.append("characters") | |
| if s_rat: categories_to_summarize.append("ratings") | |
| if not categories_to_summarize: | |
| return "请至少选择一个标签类别进行汇总。" | |
| for cat_key in categories_to_summarize: | |
| if current_res.get(cat_key): | |
| tags_to_join = [] | |
| cat_tags_en = list(current_res[cat_key].keys()) | |
| cat_translations = current_translations_dict.get(cat_key, []) | |
| for i, en_tag in enumerate(cat_tags_en): | |
| if s_show_zh and i < len(cat_translations) and cat_translations[i]: | |
| tags_to_join.append(f"{en_tag}/*{cat_translations[i]}*/") | |
| else: | |
| tags_to_join.append(en_tag) | |
| if tags_to_join: # only add if there are tags for this category | |
| summary_parts.append(separator.join(tags_to_join)) | |
| joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator if separator == "\n" else " " | |
| final_summary = joiner.join(summary_parts) | |
| return final_summary if final_summary else "选定的类别中没有找到标签。" | |
| def process_image_and_generate_outputs( | |
| img, g_th, c_th, s_scores, # Main inputs | |
| s_gen, s_char, s_rat, s_sep, s_zh_in_sum | |
| ): | |
| if img is None: | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="❌ 请先上传图片。"), | |
| "", "", "", "", | |
| gr.update(placeholder="请先上传图片并开始分析..."), | |
| {}, {}, {} | |
| ) | |
| return | |
| if tagger_instance is None: | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), | |
| "", "", "", "", | |
| gr.update(placeholder="分析器初始化失败..."), | |
| {}, {}, {} | |
| ) | |
| return | |
| yield ( | |
| gr.update(interactive=False, value="🔄 处理中..."), | |
| gr.update(visible=True, value="🔄 正在分析图像,请稍候..."), | |
| gr.HTML(value="<p>分析中...</p>"), # General | |
| gr.HTML(value="<p>分析中...</p>"), # Character | |
| gr.HTML(value="<p>分析中...</p>"), # Rating | |
| gr.update(value="分析中,请稍候..."), # Summary | |
| {}, {}, {} # Clear states initially | |
| ) | |
| try: | |
| res, tag_categories_original_order = tagger_instance.predict(img, g_th, c_th) | |
| all_tags_to_translate = [] | |
| for cat_key in ["general", "characters", "ratings"]: | |
| all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, [])) | |
| all_translations_flat = [] | |
| if all_tags_to_translate: | |
| all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh") | |
| current_translations_dict = {} | |
| offset = 0 | |
| for cat_key in ["general", "characters", "ratings"]: | |
| cat_original_tags = tag_categories_original_order.get(cat_key, []) | |
| num_tags_in_cat = len(cat_original_tags) | |
| if num_tags_in_cat > 0: | |
| current_translations_dict[cat_key] = all_translations_flat[offset : offset + num_tags_in_cat] | |
| offset += num_tags_in_cat | |
| else: | |
| current_translations_dict[cat_key] = [] | |
| general_html = format_tags_html(res.get("general", {}), current_translations_dict.get("general", []), "general", s_scores, True) | |
| char_html = format_tags_html(res.get("characters", {}), current_translations_dict.get("characters", []), "characters", s_scores, True) | |
| rating_html = format_tags_html(res.get("ratings", {}), current_translations_dict.get("ratings", []), "ratings", s_scores, True) | |
| summary_text = generate_summary_text_content( | |
| res, current_translations_dict, | |
| s_gen, s_char, s_rat, s_sep, s_zh_in_sum | |
| ) | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="✅ 分析完成!"), | |
| general_html, | |
| char_html, | |
| rating_html, | |
| gr.update(value=summary_text), | |
| res, | |
| current_translations_dict, | |
| tag_categories_original_order | |
| ) | |
| except Exception as e: | |
| import traceback | |
| tb_str = traceback.format_exc() | |
| print(f"处理时发生错误: {e}\n{tb_str}") | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), | |
| "<p>处理出错</p>", "<p>处理出错</p>", "<p>处理出错</p>", | |
| gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."), | |
| {}, {}, {} | |
| ) | |
| def update_summary_display( | |
| s_gen, s_char, s_rat, s_sep, s_zh_in_sum, | |
| current_res_from_state, current_translations_from_state | |
| ): | |
| if not current_res_from_state: | |
| return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="") | |
| new_summary_text = generate_summary_text_content( | |
| current_res_from_state, current_translations_from_state, | |
| s_gen, s_char, s_rat, s_sep, s_zh_in_sum | |
| ) | |
| return gr.update(value=new_summary_text) | |
| btn.click( | |
| process_image_and_generate_outputs, | |
| inputs=[ | |
| img_in, gen_slider, char_slider, show_tag_scores, | |
| sum_general, sum_char, sum_rating, sum_sep, sum_show_zh | |
| ], | |
| outputs=[ | |
| btn, processing_info, | |
| out_general, out_char, out_rating, | |
| out_summary, | |
| state_res, state_translations_dict, state_tag_categories_for_translation | |
| ], | |
| ) | |
| summary_controls = [sum_general, sum_char, sum_rating, sum_sep, sum_show_zh] | |
| for ctrl in summary_controls: | |
| ctrl.change( | |
| fn=update_summary_display, | |
| inputs=summary_controls + [state_res, state_translations_dict], | |
| outputs=[out_summary], | |
| ) | |
| if __name__ == "__main__": | |
| if tagger_instance is None: | |
| print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |