video_image / app.py
PatrickRedStar's picture
123
e4b559c
# app.py
import gradio as gr
import torch
import torchaudio
from transformers import (
pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq,
AutoImageProcessor, AutoModelForObjectDetection,
BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor,
VitsModel, AutoTokenizer
)
from PIL import Image, ImageDraw
import requests
import numpy as np
import soundfile as sf
from gtts import gTTS
import tempfile
import os
from sentence_transformers import SentenceTransformer
# Инициализация моделей (ленивая загрузка)
models = {}
def load_audio_model(model_name):
if model_name not in models:
if model_name == "whisper":
models[model_name] = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small"
)
elif model_name == "wav2vec2":
models[model_name] = pipeline(
"automatic-speech-recognition",
model="bond005/wav2vec2-large-ru-golos"
)
elif model_name == "audio_classifier":
models[model_name] = pipeline(
"audio-classification",
model="MIT/ast-finetuned-audioset-10-10-0.4593"
)
elif model_name == "emotion_classifier":
models[model_name] = pipeline(
"audio-classification",
model="superb/hubert-large-superb-er"
)
return models[model_name]
def load_image_model(model_name):
if model_name not in models:
if model_name == "object_detection":
models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50")
elif model_name == "segmentation":
models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512")
elif model_name == "captioning":
models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
elif model_name == "vqa":
models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
elif model_name == "clip":
models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return models[model_name]
# Функции для обработки аудио
def audio_classification(audio_file, model_type):
classifier = load_audio_model(model_type)
results = classifier(audio_file)
output = "Топ-5 предсказаний:\n"
for i, result in enumerate(results[:5]):
output += f"{i+1}. {result['label']}: {result['score']:.4f}\n"
return output
def speech_recognition(audio_file, model_type):
asr_pipeline = load_audio_model(model_type)
if model_type == "whisper":
result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"})
else:
result = asr_pipeline(audio_file)
return result['text']
def text_to_speech(text, model_type):
if model_type == "silero":
# Silero TTS
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_tts',
language='ru',
speaker='ru_v3')
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name)
return f.name
elif model_type == "gtts":
# Google TTS
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
tts = gTTS(text=text, lang='ru')
tts.save(f.name)
return f.name
elif model_type == "mms":
# Facebook MMS TTS
model = VitsModel.from_pretrained("facebook/mms-tts-rus")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs).waveform
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
return f.name
# Функции для обработки изображений
def object_detection(image):
detector = load_image_model("object_detection")
results = detector(image)
# Рисуем bounding boxes
draw = ImageDraw.Draw(image)
for result in results:
box = result['box']
label = result['label']
score = result['score']
draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']],
outline='red', width=3)
draw.text((box['xmin'], box['ymin']),
f"{label}: {score:.2f}", fill='red')
return image
def image_segmentation(image):
segmenter = load_image_model("segmentation")
results = segmenter(image)
# Возвращаем первую маску сегментации
return results[0]['mask']
def image_captioning(image):
captioner = load_image_model("captioning")
result = captioner(image)
return result[0]['generated_text']
def visual_question_answering(image, question):
vqa_pipeline = load_image_model("vqa")
cleaned_question = (question or "").strip()
result = vqa_pipeline(
image=image,
question=cleaned_question,
truncation=True, # keep text within ViLT max sequence length (40)
max_length=40,
)
return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})"
def zero_shot_classification(image, classes):
model = load_image_model("clip")
processor = models["clip_processor"]
class_list = [cls.strip() for cls in classes.split(",")]
inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
result = "Zero-Shot Classification Results:\n"
for i, cls in enumerate(class_list):
result += f"{cls}: {probs[0][i].item():.4f}\n"
return result
def image_retrieval(images, query):
if not images or not query:
return "Пожалуйста, загрузите изображения и введите запрос"
# Используем CLIP для поиска
model = load_image_model("clip")
processor = models["clip_processor"]
# Обрабатываем все изображения
if isinstance(images, tuple):
images = list(images)
normalized_images = []
for item in images:
# Gallery может вернуть (image, caption); берем только картинку
if isinstance(item, (list, tuple)) and item:
normalized_images.append(item[0])
else:
normalized_images.append(item)
image_inputs = processor(images=normalized_images, return_tensors="pt", padding=True)
with torch.no_grad():
image_embeddings = model.get_image_features(**image_inputs)
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
# Обрабатываем текстовый запрос
text_inputs = processor(text=[query], return_tensors="pt", padding=True)
with torch.no_grad():
text_embeddings = model.get_text_features(**text_inputs)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# Вычисляем схожести
similarities = (image_embeddings @ text_embeddings.T)
# Находим лучшее изображение
best_idx = similarities.argmax().item()
best_score = similarities[best_idx].item()
return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", normalized_images[best_idx]
# Создаем интерфейс Gradio
with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 Мультимодальные AI модели")
gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers")
with gr.Tab("🎵 Классификация аудио"):
gr.Markdown("## Zero-Shot Audio Classification")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath")
audio_model_dropdown = gr.Dropdown(
choices=["audio_classifier", "emotion_classifier"],
label="Выберите модель",
value="audio_classifier",
info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи"
)
classify_btn = gr.Button("Классифицировать")
with gr.Column():
audio_output = gr.Textbox(label="Результаты классификации", lines=10)
classify_btn.click(
fn=audio_classification,
inputs=[audio_input, audio_model_dropdown],
outputs=audio_output
)
with gr.Tab("🗣️ Распознавание речи"):
gr.Markdown("## Automatic Speech Recognition (ASR)")
with gr.Row():
with gr.Column():
asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath")
asr_model_dropdown = gr.Dropdown(
choices=["whisper", "wav2vec2"],
label="Выберите модель",
value="whisper",
info="whisper - многоязычная, wav2vec2 - специализированная для русского"
)
transcribe_btn = gr.Button("Транскрибировать")
with gr.Column():
asr_output = gr.Textbox(label="Транскрипция", lines=5)
transcribe_btn.click(
fn=speech_recognition,
inputs=[asr_audio_input, asr_model_dropdown],
outputs=asr_output
)
with gr.Tab("🔊 Синтез речи"):
gr.Markdown("## Text-to-Speech (TTS)")
with gr.Row():
with gr.Column():
tts_text_input = gr.Textbox(
label="Введите текст для синтеза",
placeholder="Введите текст на русском языке...",
lines=3
)
tts_model_dropdown = gr.Dropdown(
choices=["silero", "gtts", "mms"],
label="Выберите модель",
value="silero",
info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS"
)
synthesize_btn = gr.Button("Синтезировать речь")
with gr.Column():
tts_output = gr.Audio(label="Синтезированная речь")
synthesize_btn.click(
fn=text_to_speech,
inputs=[tts_text_input, tts_model_dropdown],
outputs=tts_output
)
with gr.Tab("📦 Детекция объектов"):
gr.Markdown("## Object Detection")
with gr.Row():
with gr.Column():
obj_detection_input = gr.Image(label="Загрузите изображение", type="pil")
detect_btn = gr.Button("Обнаружить объекты")
with gr.Column():
obj_detection_output = gr.Image(label="Результат детекции")
detect_btn.click(
fn=object_detection,
inputs=obj_detection_input,
outputs=obj_detection_output
)
with gr.Tab("🎨 Сегментация"):
gr.Markdown("## Image Segmentation")
with gr.Row():
with gr.Column():
seg_input = gr.Image(label="Загрузите изображение", type="pil")
segment_btn = gr.Button("Сегментировать")
with gr.Column():
seg_output = gr.Image(label="Маска сегментации")
segment_btn.click(
fn=image_segmentation,
inputs=seg_input,
outputs=seg_output
)
with gr.Tab("📝 Описание изображений"):
gr.Markdown("## Image Captioning")
with gr.Row():
with gr.Column():
caption_input = gr.Image(label="Загрузите изображение", type="pil")
caption_btn = gr.Button("Сгенерировать описание")
with gr.Column():
caption_output = gr.Textbox(label="Описание изображения", lines=3)
caption_btn.click(
fn=image_captioning,
inputs=caption_input,
outputs=caption_output
)
with gr.Tab("❓ Визуальные вопросы"):
gr.Markdown("## Visual Question Answering")
with gr.Row():
with gr.Column():
vqa_image_input = gr.Image(label="Загрузите изображение", type="pil")
vqa_question_input = gr.Textbox(
label="Вопрос об изображении",
placeholder="Что происходит на этом изображении?",
lines=2
)
vqa_btn = gr.Button("Ответить на вопрос")
with gr.Column():
vqa_output = gr.Textbox(label="Ответ", lines=3)
vqa_btn.click(
fn=visual_question_answering,
inputs=[vqa_image_input, vqa_question_input],
outputs=vqa_output
)
with gr.Tab("🎯 Zero-Shot классификация"):
gr.Markdown("## Zero-Shot Image Classification")
with gr.Row():
with gr.Column():
zs_image_input = gr.Image(label="Загрузите изображение", type="pil")
zs_classes_input = gr.Textbox(
label="Классы для классификации (через запятую)",
placeholder="человек, машина, дерево, здание, животное",
lines=2
)
zs_classify_btn = gr.Button("Классифицировать")
with gr.Column():
zs_output = gr.Textbox(label="Результаты классификации", lines=10)
zs_classify_btn.click(
fn=zero_shot_classification,
inputs=[zs_image_input, zs_classes_input],
outputs=zs_output
)
with gr.Tab("🔍 Поиск изображений"):
gr.Markdown("## Image Retrieval")
with gr.Row():
with gr.Column():
retrieval_images_input = gr.Gallery(
label="Загрузите изображения для поиска",
type="pil"
)
retrieval_query_input = gr.Textbox(
label="Текстовый запрос",
placeholder="описание того, что вы ищете...",
lines=2
)
retrieval_btn = gr.Button("Найти изображение")
with gr.Column():
retrieval_output_text = gr.Textbox(label="Результат поиска")
retrieval_output_image = gr.Image(label="Найденное изображение")
retrieval_btn.click(
fn=image_retrieval,
inputs=[retrieval_images_input, retrieval_query_input],
outputs=[retrieval_output_text, retrieval_output_image]
)
gr.Markdown("---")
gr.Markdown("### 📊 Поддерживаемые задачи:")
gr.Markdown("""
- **🎵 Аудио**: Классификация, распознавание речи, синтез речи
- **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений
- **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям
""")
if __name__ == "__main__":
demo.launch(share=True)