Spaces:
Running
Running
| # app.py | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import ( | |
| pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, | |
| AutoImageProcessor, AutoModelForObjectDetection, | |
| BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, | |
| VitsModel, AutoTokenizer | |
| ) | |
| from PIL import Image, ImageDraw | |
| import requests | |
| import numpy as np | |
| import soundfile as sf | |
| from gtts import gTTS | |
| import tempfile | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| # Инициализация моделей (ленивая загрузка) | |
| models = {} | |
| def load_audio_model(model_name): | |
| if model_name not in models: | |
| if model_name == "whisper": | |
| models[model_name] = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small" | |
| ) | |
| elif model_name == "wav2vec2": | |
| models[model_name] = pipeline( | |
| "automatic-speech-recognition", | |
| model="bond005/wav2vec2-large-ru-golos" | |
| ) | |
| elif model_name == "audio_classifier": | |
| models[model_name] = pipeline( | |
| "audio-classification", | |
| model="MIT/ast-finetuned-audioset-10-10-0.4593" | |
| ) | |
| elif model_name == "emotion_classifier": | |
| models[model_name] = pipeline( | |
| "audio-classification", | |
| model="superb/hubert-large-superb-er" | |
| ) | |
| return models[model_name] | |
| def load_image_model(model_name): | |
| if model_name not in models: | |
| if model_name == "object_detection": | |
| models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") | |
| elif model_name == "segmentation": | |
| models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") | |
| elif model_name == "captioning": | |
| models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
| elif model_name == "vqa": | |
| models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") | |
| elif model_name == "clip": | |
| models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| return models[model_name] | |
| # Функции для обработки аудио | |
| def audio_classification(audio_file, model_type): | |
| classifier = load_audio_model(model_type) | |
| results = classifier(audio_file) | |
| output = "Топ-5 предсказаний:\n" | |
| for i, result in enumerate(results[:5]): | |
| output += f"{i+1}. {result['label']}: {result['score']:.4f}\n" | |
| return output | |
| def speech_recognition(audio_file, model_type): | |
| asr_pipeline = load_audio_model(model_type) | |
| if model_type == "whisper": | |
| result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"}) | |
| else: | |
| result = asr_pipeline(audio_file) | |
| return result['text'] | |
| def text_to_speech(text, model_type): | |
| if model_type == "silero": | |
| # Silero TTS | |
| model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', | |
| model='silero_tts', | |
| language='ru', | |
| speaker='ru_v3') | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name) | |
| return f.name | |
| elif model_type == "gtts": | |
| # Google TTS | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| tts = gTTS(text=text, lang='ru') | |
| tts.save(f.name) | |
| return f.name | |
| elif model_type == "mms": | |
| # Facebook MMS TTS | |
| model = VitsModel.from_pretrained("facebook/mms-tts-rus") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model(**inputs).waveform | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) | |
| return f.name | |
| # Функции для обработки изображений | |
| def object_detection(image): | |
| detector = load_image_model("object_detection") | |
| results = detector(image) | |
| # Рисуем bounding boxes | |
| draw = ImageDraw.Draw(image) | |
| for result in results: | |
| box = result['box'] | |
| label = result['label'] | |
| score = result['score'] | |
| draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], | |
| outline='red', width=3) | |
| draw.text((box['xmin'], box['ymin']), | |
| f"{label}: {score:.2f}", fill='red') | |
| return image | |
| def image_segmentation(image): | |
| segmenter = load_image_model("segmentation") | |
| results = segmenter(image) | |
| # Возвращаем первую маску сегментации | |
| return results[0]['mask'] | |
| def image_captioning(image): | |
| captioner = load_image_model("captioning") | |
| result = captioner(image) | |
| return result[0]['generated_text'] | |
| def visual_question_answering(image, question): | |
| vqa_pipeline = load_image_model("vqa") | |
| cleaned_question = (question or "").strip() | |
| result = vqa_pipeline( | |
| image=image, | |
| question=cleaned_question, | |
| truncation=True, # keep text within ViLT max sequence length (40) | |
| max_length=40, | |
| ) | |
| return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" | |
| def zero_shot_classification(image, classes): | |
| model = load_image_model("clip") | |
| processor = models["clip_processor"] | |
| class_list = [cls.strip() for cls in classes.split(",")] | |
| inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1) | |
| result = "Zero-Shot Classification Results:\n" | |
| for i, cls in enumerate(class_list): | |
| result += f"{cls}: {probs[0][i].item():.4f}\n" | |
| return result | |
| def image_retrieval(images, query): | |
| if not images or not query: | |
| return "Пожалуйста, загрузите изображения и введите запрос" | |
| # Используем CLIP для поиска | |
| model = load_image_model("clip") | |
| processor = models["clip_processor"] | |
| # Обрабатываем все изображения | |
| if isinstance(images, tuple): | |
| images = list(images) | |
| normalized_images = [] | |
| for item in images: | |
| # Gallery может вернуть (image, caption); берем только картинку | |
| if isinstance(item, (list, tuple)) and item: | |
| normalized_images.append(item[0]) | |
| else: | |
| normalized_images.append(item) | |
| image_inputs = processor(images=normalized_images, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| image_embeddings = model.get_image_features(**image_inputs) | |
| image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) | |
| # Обрабатываем текстовый запрос | |
| text_inputs = processor(text=[query], return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| text_embeddings = model.get_text_features(**text_inputs) | |
| text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) | |
| # Вычисляем схожести | |
| similarities = (image_embeddings @ text_embeddings.T) | |
| # Находим лучшее изображение | |
| best_idx = similarities.argmax().item() | |
| best_score = similarities[best_idx].item() | |
| return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", normalized_images[best_idx] | |
| # Создаем интерфейс Gradio | |
| with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎯 Мультимодальные AI модели") | |
| gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers") | |
| with gr.Tab("🎵 Классификация аудио"): | |
| gr.Markdown("## Zero-Shot Audio Classification") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath") | |
| audio_model_dropdown = gr.Dropdown( | |
| choices=["audio_classifier", "emotion_classifier"], | |
| label="Выберите модель", | |
| value="audio_classifier", | |
| info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи" | |
| ) | |
| classify_btn = gr.Button("Классифицировать") | |
| with gr.Column(): | |
| audio_output = gr.Textbox(label="Результаты классификации", lines=10) | |
| classify_btn.click( | |
| fn=audio_classification, | |
| inputs=[audio_input, audio_model_dropdown], | |
| outputs=audio_output | |
| ) | |
| with gr.Tab("🗣️ Распознавание речи"): | |
| gr.Markdown("## Automatic Speech Recognition (ASR)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath") | |
| asr_model_dropdown = gr.Dropdown( | |
| choices=["whisper", "wav2vec2"], | |
| label="Выберите модель", | |
| value="whisper", | |
| info="whisper - многоязычная, wav2vec2 - специализированная для русского" | |
| ) | |
| transcribe_btn = gr.Button("Транскрибировать") | |
| with gr.Column(): | |
| asr_output = gr.Textbox(label="Транскрипция", lines=5) | |
| transcribe_btn.click( | |
| fn=speech_recognition, | |
| inputs=[asr_audio_input, asr_model_dropdown], | |
| outputs=asr_output | |
| ) | |
| with gr.Tab("🔊 Синтез речи"): | |
| gr.Markdown("## Text-to-Speech (TTS)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| tts_text_input = gr.Textbox( | |
| label="Введите текст для синтеза", | |
| placeholder="Введите текст на русском языке...", | |
| lines=3 | |
| ) | |
| tts_model_dropdown = gr.Dropdown( | |
| choices=["silero", "gtts", "mms"], | |
| label="Выберите модель", | |
| value="silero", | |
| info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS" | |
| ) | |
| synthesize_btn = gr.Button("Синтезировать речь") | |
| with gr.Column(): | |
| tts_output = gr.Audio(label="Синтезированная речь") | |
| synthesize_btn.click( | |
| fn=text_to_speech, | |
| inputs=[tts_text_input, tts_model_dropdown], | |
| outputs=tts_output | |
| ) | |
| with gr.Tab("📦 Детекция объектов"): | |
| gr.Markdown("## Object Detection") | |
| with gr.Row(): | |
| with gr.Column(): | |
| obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") | |
| detect_btn = gr.Button("Обнаружить объекты") | |
| with gr.Column(): | |
| obj_detection_output = gr.Image(label="Результат детекции") | |
| detect_btn.click( | |
| fn=object_detection, | |
| inputs=obj_detection_input, | |
| outputs=obj_detection_output | |
| ) | |
| with gr.Tab("🎨 Сегментация"): | |
| gr.Markdown("## Image Segmentation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| seg_input = gr.Image(label="Загрузите изображение", type="pil") | |
| segment_btn = gr.Button("Сегментировать") | |
| with gr.Column(): | |
| seg_output = gr.Image(label="Маска сегментации") | |
| segment_btn.click( | |
| fn=image_segmentation, | |
| inputs=seg_input, | |
| outputs=seg_output | |
| ) | |
| with gr.Tab("📝 Описание изображений"): | |
| gr.Markdown("## Image Captioning") | |
| with gr.Row(): | |
| with gr.Column(): | |
| caption_input = gr.Image(label="Загрузите изображение", type="pil") | |
| caption_btn = gr.Button("Сгенерировать описание") | |
| with gr.Column(): | |
| caption_output = gr.Textbox(label="Описание изображения", lines=3) | |
| caption_btn.click( | |
| fn=image_captioning, | |
| inputs=caption_input, | |
| outputs=caption_output | |
| ) | |
| with gr.Tab("❓ Визуальные вопросы"): | |
| gr.Markdown("## Visual Question Answering") | |
| with gr.Row(): | |
| with gr.Column(): | |
| vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") | |
| vqa_question_input = gr.Textbox( | |
| label="Вопрос об изображении", | |
| placeholder="Что происходит на этом изображении?", | |
| lines=2 | |
| ) | |
| vqa_btn = gr.Button("Ответить на вопрос") | |
| with gr.Column(): | |
| vqa_output = gr.Textbox(label="Ответ", lines=3) | |
| vqa_btn.click( | |
| fn=visual_question_answering, | |
| inputs=[vqa_image_input, vqa_question_input], | |
| outputs=vqa_output | |
| ) | |
| with gr.Tab("🎯 Zero-Shot классификация"): | |
| gr.Markdown("## Zero-Shot Image Classification") | |
| with gr.Row(): | |
| with gr.Column(): | |
| zs_image_input = gr.Image(label="Загрузите изображение", type="pil") | |
| zs_classes_input = gr.Textbox( | |
| label="Классы для классификации (через запятую)", | |
| placeholder="человек, машина, дерево, здание, животное", | |
| lines=2 | |
| ) | |
| zs_classify_btn = gr.Button("Классифицировать") | |
| with gr.Column(): | |
| zs_output = gr.Textbox(label="Результаты классификации", lines=10) | |
| zs_classify_btn.click( | |
| fn=zero_shot_classification, | |
| inputs=[zs_image_input, zs_classes_input], | |
| outputs=zs_output | |
| ) | |
| with gr.Tab("🔍 Поиск изображений"): | |
| gr.Markdown("## Image Retrieval") | |
| with gr.Row(): | |
| with gr.Column(): | |
| retrieval_images_input = gr.Gallery( | |
| label="Загрузите изображения для поиска", | |
| type="pil" | |
| ) | |
| retrieval_query_input = gr.Textbox( | |
| label="Текстовый запрос", | |
| placeholder="описание того, что вы ищете...", | |
| lines=2 | |
| ) | |
| retrieval_btn = gr.Button("Найти изображение") | |
| with gr.Column(): | |
| retrieval_output_text = gr.Textbox(label="Результат поиска") | |
| retrieval_output_image = gr.Image(label="Найденное изображение") | |
| retrieval_btn.click( | |
| fn=image_retrieval, | |
| inputs=[retrieval_images_input, retrieval_query_input], | |
| outputs=[retrieval_output_text, retrieval_output_image] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 📊 Поддерживаемые задачи:") | |
| gr.Markdown(""" | |
| - **🎵 Аудио**: Классификация, распознавание речи, синтез речи | |
| - **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений | |
| - **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |