# app.py import gradio as gr import torch import torchaudio from transformers import ( pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoImageProcessor, AutoModelForObjectDetection, BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, VitsModel, AutoTokenizer ) from PIL import Image, ImageDraw import requests import numpy as np import soundfile as sf from gtts import gTTS import tempfile import os from sentence_transformers import SentenceTransformer # Инициализация моделей (ленивая загрузка) models = {} def load_audio_model(model_name): if model_name not in models: if model_name == "whisper": models[model_name] = pipeline( "automatic-speech-recognition", model="openai/whisper-small" ) elif model_name == "wav2vec2": models[model_name] = pipeline( "automatic-speech-recognition", model="bond005/wav2vec2-large-ru-golos" ) elif model_name == "audio_classifier": models[model_name] = pipeline( "audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593" ) elif model_name == "emotion_classifier": models[model_name] = pipeline( "audio-classification", model="superb/hubert-large-superb-er" ) return models[model_name] def load_image_model(model_name): if model_name not in models: if model_name == "object_detection": models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") elif model_name == "segmentation": models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") elif model_name == "captioning": models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") elif model_name == "vqa": models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") elif model_name == "clip": models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return models[model_name] # Функции для обработки аудио def audio_classification(audio_file, model_type): classifier = load_audio_model(model_type) results = classifier(audio_file) output = "Топ-5 предсказаний:\n" for i, result in enumerate(results[:5]): output += f"{i+1}. {result['label']}: {result['score']:.4f}\n" return output def speech_recognition(audio_file, model_type): asr_pipeline = load_audio_model(model_type) if model_type == "whisper": result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"}) else: result = asr_pipeline(audio_file) return result['text'] def text_to_speech(text, model_type): if model_type == "silero": # Silero TTS model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language='ru', speaker='ru_v3') with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name) return f.name elif model_type == "gtts": # Google TTS with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: tts = gTTS(text=text, lang='ru') tts.save(f.name) return f.name elif model_type == "mms": # Facebook MMS TTS model = VitsModel.from_pretrained("facebook/mms-tts-rus") tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): output = model(**inputs).waveform with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) return f.name # Функции для обработки изображений def object_detection(image): detector = load_image_model("object_detection") results = detector(image) # Рисуем bounding boxes draw = ImageDraw.Draw(image) for result in results: box = result['box'] label = result['label'] score = result['score'] draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline='red', width=3) draw.text((box['xmin'], box['ymin']), f"{label}: {score:.2f}", fill='red') return image def image_segmentation(image): segmenter = load_image_model("segmentation") results = segmenter(image) # Возвращаем первую маску сегментации return results[0]['mask'] def image_captioning(image): captioner = load_image_model("captioning") result = captioner(image) return result[0]['generated_text'] def visual_question_answering(image, question): vqa_pipeline = load_image_model("vqa") cleaned_question = (question or "").strip() result = vqa_pipeline( image=image, question=cleaned_question, truncation=True, # keep text within ViLT max sequence length (40) max_length=40, ) return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" def zero_shot_classification(image, classes): model = load_image_model("clip") processor = models["clip_processor"] class_list = [cls.strip() for cls in classes.split(",")] inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) result = "Zero-Shot Classification Results:\n" for i, cls in enumerate(class_list): result += f"{cls}: {probs[0][i].item():.4f}\n" return result def image_retrieval(images, query): if not images or not query: return "Пожалуйста, загрузите изображения и введите запрос" # Используем CLIP для поиска model = load_image_model("clip") processor = models["clip_processor"] # Обрабатываем все изображения if isinstance(images, tuple): images = list(images) normalized_images = [] for item in images: # Gallery может вернуть (image, caption); берем только картинку if isinstance(item, (list, tuple)) and item: normalized_images.append(item[0]) else: normalized_images.append(item) image_inputs = processor(images=normalized_images, return_tensors="pt", padding=True) with torch.no_grad(): image_embeddings = model.get_image_features(**image_inputs) image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) # Обрабатываем текстовый запрос text_inputs = processor(text=[query], return_tensors="pt", padding=True) with torch.no_grad(): text_embeddings = model.get_text_features(**text_inputs) text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) # Вычисляем схожести similarities = (image_embeddings @ text_embeddings.T) # Находим лучшее изображение best_idx = similarities.argmax().item() best_score = similarities[best_idx].item() return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", normalized_images[best_idx] # Создаем интерфейс Gradio with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎯 Мультимодальные AI модели") gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers") with gr.Tab("🎵 Классификация аудио"): gr.Markdown("## Zero-Shot Audio Classification") with gr.Row(): with gr.Column(): audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath") audio_model_dropdown = gr.Dropdown( choices=["audio_classifier", "emotion_classifier"], label="Выберите модель", value="audio_classifier", info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи" ) classify_btn = gr.Button("Классифицировать") with gr.Column(): audio_output = gr.Textbox(label="Результаты классификации", lines=10) classify_btn.click( fn=audio_classification, inputs=[audio_input, audio_model_dropdown], outputs=audio_output ) with gr.Tab("🗣️ Распознавание речи"): gr.Markdown("## Automatic Speech Recognition (ASR)") with gr.Row(): with gr.Column(): asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath") asr_model_dropdown = gr.Dropdown( choices=["whisper", "wav2vec2"], label="Выберите модель", value="whisper", info="whisper - многоязычная, wav2vec2 - специализированная для русского" ) transcribe_btn = gr.Button("Транскрибировать") with gr.Column(): asr_output = gr.Textbox(label="Транскрипция", lines=5) transcribe_btn.click( fn=speech_recognition, inputs=[asr_audio_input, asr_model_dropdown], outputs=asr_output ) with gr.Tab("🔊 Синтез речи"): gr.Markdown("## Text-to-Speech (TTS)") with gr.Row(): with gr.Column(): tts_text_input = gr.Textbox( label="Введите текст для синтеза", placeholder="Введите текст на русском языке...", lines=3 ) tts_model_dropdown = gr.Dropdown( choices=["silero", "gtts", "mms"], label="Выберите модель", value="silero", info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS" ) synthesize_btn = gr.Button("Синтезировать речь") with gr.Column(): tts_output = gr.Audio(label="Синтезированная речь") synthesize_btn.click( fn=text_to_speech, inputs=[tts_text_input, tts_model_dropdown], outputs=tts_output ) with gr.Tab("📦 Детекция объектов"): gr.Markdown("## Object Detection") with gr.Row(): with gr.Column(): obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") detect_btn = gr.Button("Обнаружить объекты") with gr.Column(): obj_detection_output = gr.Image(label="Результат детекции") detect_btn.click( fn=object_detection, inputs=obj_detection_input, outputs=obj_detection_output ) with gr.Tab("🎨 Сегментация"): gr.Markdown("## Image Segmentation") with gr.Row(): with gr.Column(): seg_input = gr.Image(label="Загрузите изображение", type="pil") segment_btn = gr.Button("Сегментировать") with gr.Column(): seg_output = gr.Image(label="Маска сегментации") segment_btn.click( fn=image_segmentation, inputs=seg_input, outputs=seg_output ) with gr.Tab("📝 Описание изображений"): gr.Markdown("## Image Captioning") with gr.Row(): with gr.Column(): caption_input = gr.Image(label="Загрузите изображение", type="pil") caption_btn = gr.Button("Сгенерировать описание") with gr.Column(): caption_output = gr.Textbox(label="Описание изображения", lines=3) caption_btn.click( fn=image_captioning, inputs=caption_input, outputs=caption_output ) with gr.Tab("❓ Визуальные вопросы"): gr.Markdown("## Visual Question Answering") with gr.Row(): with gr.Column(): vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") vqa_question_input = gr.Textbox( label="Вопрос об изображении", placeholder="Что происходит на этом изображении?", lines=2 ) vqa_btn = gr.Button("Ответить на вопрос") with gr.Column(): vqa_output = gr.Textbox(label="Ответ", lines=3) vqa_btn.click( fn=visual_question_answering, inputs=[vqa_image_input, vqa_question_input], outputs=vqa_output ) with gr.Tab("🎯 Zero-Shot классификация"): gr.Markdown("## Zero-Shot Image Classification") with gr.Row(): with gr.Column(): zs_image_input = gr.Image(label="Загрузите изображение", type="pil") zs_classes_input = gr.Textbox( label="Классы для классификации (через запятую)", placeholder="человек, машина, дерево, здание, животное", lines=2 ) zs_classify_btn = gr.Button("Классифицировать") with gr.Column(): zs_output = gr.Textbox(label="Результаты классификации", lines=10) zs_classify_btn.click( fn=zero_shot_classification, inputs=[zs_image_input, zs_classes_input], outputs=zs_output ) with gr.Tab("🔍 Поиск изображений"): gr.Markdown("## Image Retrieval") with gr.Row(): with gr.Column(): retrieval_images_input = gr.Gallery( label="Загрузите изображения для поиска", type="pil" ) retrieval_query_input = gr.Textbox( label="Текстовый запрос", placeholder="описание того, что вы ищете...", lines=2 ) retrieval_btn = gr.Button("Найти изображение") with gr.Column(): retrieval_output_text = gr.Textbox(label="Результат поиска") retrieval_output_image = gr.Image(label="Найденное изображение") retrieval_btn.click( fn=image_retrieval, inputs=[retrieval_images_input, retrieval_query_input], outputs=[retrieval_output_text, retrieval_output_image] ) gr.Markdown("---") gr.Markdown("### 📊 Поддерживаемые задачи:") gr.Markdown(""" - **🎵 Аудио**: Классификация, распознавание речи, синтез речи - **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений - **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям """) if __name__ == "__main__": demo.launch(share=True)