AvtnshM commited on
Commit
3bec4e3
·
verified ·
1 Parent(s): 2570e31
Files changed (1) hide show
  1. app.py +113 -80
app.py CHANGED
@@ -2,30 +2,37 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  from transformers import (
5
- AutoModelForCTC,
6
  AutoModelForSpeechSeq2Seq,
7
  AutoProcessor,
 
8
  AutoModel,
9
  )
 
 
 
 
10
 
11
- # -------------------------------
12
  # Model configurations
13
- # -------------------------------
14
  MODEL_CONFIGS = {
15
- "Whisper Small (hi)": {
16
- "repo": "openai/whisper-small",
17
  "model_type": "seq2seq",
 
18
  },
19
- "IndicConformer 600M": {
20
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
21
- "model_type": "ctc", # but handled specially
 
22
  "trust_remote_code": True,
23
  },
 
 
 
 
 
24
  }
25
 
26
- # -------------------------------
27
  # Load model and processor
28
- # -------------------------------
29
  def load_model_and_processor(model_name):
30
  config = MODEL_CONFIGS[model_name]
31
  repo = config["repo"]
@@ -33,89 +40,115 @@ def load_model_and_processor(model_name):
33
  trust_remote_code = config.get("trust_remote_code", False)
34
 
35
  try:
36
- if "indic-conformer" in repo.lower():
37
  model = AutoModel.from_pretrained(repo, trust_remote_code=True)
38
- processor = None # Not required
39
- return model, processor, model_type
40
- else:
 
 
41
  processor = AutoProcessor.from_pretrained(repo, trust_remote_code=trust_remote_code)
42
  if model_type == "seq2seq":
43
  model = AutoModelForSpeechSeq2Seq.from_pretrained(repo, trust_remote_code=trust_remote_code)
44
  else:
45
  model = AutoModelForCTC.from_pretrained(repo, trust_remote_code=trust_remote_code)
46
- return model, processor, model_type
 
47
  except Exception as e:
48
  return None, None, f"Error loading model: {str(e)}"
49
 
50
- # -------------------------------
51
- # Transcription
52
- # -------------------------------
53
- def transcribe_audio(audio_file, model_name, reference_text):
54
- model, processor, model_type = load_model_and_processor(model_name)
55
- if model is None:
56
- return f"⚠️ Failed to load {model_name}: {processor}", ""
57
-
58
- # Load audio
59
- speech_array, sampling_rate = torchaudio.load(audio_file)
60
- if sampling_rate != 16000:
61
- speech_array = torchaudio.transforms.Resample(sampling_rate, 16000)(speech_array)
62
- speech_array = speech_array.squeeze().numpy()
63
-
64
- # Special handling for IndicConformer
65
- if "indic-conformer" in MODEL_CONFIGS[model_name]["repo"].lower():
66
- with torch.no_grad():
67
- transcription = model(torch.tensor(speech_array).unsqueeze(0), "hi", "ctc")
68
- transcription = transcription[0] if isinstance(transcription, list) else transcription
69
- else:
70
- inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt")
71
- with torch.no_grad():
72
- if model_type == "seq2seq":
73
- generated_ids = model.generate(inputs["input_features"])
74
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
75
- else:
76
- logits = model(**inputs).logits
77
- pred_ids = torch.argmax(logits, dim=-1)
78
- transcription = processor.batch_decode(pred_ids)[0]
79
-
80
- # Compute WER if reference given
81
- wer_score = None
82
- if reference_text.strip():
83
- from jiwer import wer
84
- wer_score = wer(reference_text, transcription)
85
-
86
- result = f"📝 Transcription: {transcription}"
87
- if wer_score is not None:
88
- result += f"\n📊 WER vs reference: {wer_score:.2%}"
89
-
90
- return result, transcription
91
-
92
- # -------------------------------
93
- # Gradio UI
94
- # -------------------------------
95
- with gr.Blocks() as demo:
96
- gr.Markdown("## 🎙️ Indic ASR Comparison App")
97
-
98
- with gr.Row():
99
- audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio")
100
- model_dropdown = gr.Dropdown(choices=list(MODEL_CONFIGS.keys()), value="Whisper Small (hi)", label="Select Model")
101
-
102
- # ✅ Paste enabled in textbox
103
- reference_text = gr.Textbox(
104
- label="Reference Text (optional, paste supported)",
105
- placeholder="Paste reference transcription here...",
106
- lines=4,
107
- interactive=True
108
- )
109
 
110
- transcribe_btn = gr.Button("Transcribe")
111
- output_result = gr.Textbox(label="Result", lines=6)
112
- raw_transcription = gr.Textbox(label="Raw Transcription", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- transcribe_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  fn=transcribe_audio,
116
- inputs=[audio_input, model_dropdown, reference_text],
117
- outputs=[output_result, raw_transcription]
 
 
 
 
 
 
 
118
  )
119
 
 
120
  if __name__ == "__main__":
121
- demo.launch()
 
 
2
  import torch
3
  import torchaudio
4
  from transformers import (
 
5
  AutoModelForSpeechSeq2Seq,
6
  AutoProcessor,
7
+ AutoModelForCTC,
8
  AutoModel,
9
  )
10
+ import librosa
11
+ import numpy as np
12
+ from jiwer import wer, cer
13
+ import time
14
 
 
15
  # Model configurations
 
16
  MODEL_CONFIGS = {
17
+ "AudioX-North (Jivi AI)": {
18
+ "repo": "jiviai/audioX-north-v1",
19
  "model_type": "seq2seq",
20
+ "description": "Supports Hindi, Gujarati, Marathi",
21
  },
22
+ "IndicConformer (AI4Bharat)": {
23
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
24
+ "model_type": "ctc_rnnt",
25
+ "description": "Supports 22 Indian languages",
26
  "trust_remote_code": True,
27
  },
28
+ "MMS (Facebook)": {
29
+ "repo": "facebook/mms-1b-all",
30
+ "model_type": "ctc",
31
+ "description": "Supports over 1,400 languages (fine-tuning recommended)",
32
+ },
33
  }
34
 
 
35
  # Load model and processor
 
36
  def load_model_and_processor(model_name):
37
  config = MODEL_CONFIGS[model_name]
38
  repo = config["repo"]
 
40
  trust_remote_code = config.get("trust_remote_code", False)
41
 
42
  try:
43
+ if model_name == "IndicConformer (AI4Bharat)":
44
  model = AutoModel.from_pretrained(repo, trust_remote_code=True)
45
+ processor = AutoProcessor.from_pretrained(repo, trust_remote_code=True)
46
+ elif model_name == "MMS (Facebook)":
47
+ model = AutoModelForCTC.from_pretrained(repo)
48
+ processor = AutoProcessor.from_pretrained(repo)
49
+ else: # AudioX-North
50
  processor = AutoProcessor.from_pretrained(repo, trust_remote_code=trust_remote_code)
51
  if model_type == "seq2seq":
52
  model = AutoModelForSpeechSeq2Seq.from_pretrained(repo, trust_remote_code=trust_remote_code)
53
  else:
54
  model = AutoModelForCTC.from_pretrained(repo, trust_remote_code=trust_remote_code)
55
+
56
+ return model, processor, model_type
57
  except Exception as e:
58
  return None, None, f"Error loading model: {str(e)}"
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Compute metrics (WER, CER, RTF)
62
+ def compute_metrics(reference, hypothesis, audio_duration, total_time):
63
+ if not reference or not hypothesis:
64
+ return None, None, None, None
65
+ try:
66
+ reference = reference.strip().lower()
67
+ hypothesis = hypothesis.strip().lower()
68
+ wer_score = wer(reference, hypothesis)
69
+ cer_score = cer(reference, hypothesis)
70
+ rtf = total_time / audio_duration if audio_duration > 0 else None
71
+ return wer_score, cer_score, rtf, total_time
72
+ except Exception:
73
+ return None, None, None, None
74
+
75
+
76
+ # Main transcription function
77
+ def transcribe_audio(audio_file, selected_models, reference_text=""):
78
+ if not audio_file:
79
+ return "Please upload an audio file."
80
+
81
+ results = []
82
+ try:
83
+ # Load and preprocess audio once
84
+ audio, sr = librosa.load(audio_file, sr=16000)
85
+ audio_duration = len(audio) / sr
86
+
87
+ for model_name in selected_models:
88
+ model, processor, model_type = load_model_and_processor(model_name)
89
+ if isinstance(model_type, str) and model_type.startswith("Error"):
90
+ results.append(f"{model_name}: {model_type}")
91
+ continue
92
+
93
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
94
+
95
+ start_time = time.time()
96
+ with torch.no_grad():
97
+ if model_type == "seq2seq":
98
+ input_features = inputs["input_features"]
99
+ outputs = model.generate(input_features)
100
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
101
+ else: # CTC or RNNT
102
+ input_values = inputs["input_values"]
103
+ logits = model(input_values).logits
104
+ predicted_ids = torch.argmax(logits, dim=-1)
105
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
106
+
107
+ total_time = time.time() - start_time
108
+
109
+ # Compute metrics
110
+ wer_score, cer_score, rtf, total_time_tracked = "", "", "", ""
111
+ if reference_text and transcription:
112
+ wer_score, cer_score, rtf, total_time_tracked = compute_metrics(
113
+ reference_text, transcription, audio_duration, total_time
114
+ )
115
+ wer_score = round(wer_score, 3) if wer_score is not None else ""
116
+ cer_score = round(cer_score, 3) if cer_score is not None else ""
117
+ rtf = round(rtf, 3) if rtf is not None else ""
118
+ total_time_tracked = round(total_time_tracked, 2) if total_time_tracked is not None else ""
119
 
120
+ result = (
121
+ f"### {model_name}\n"
122
+ f"- **Transcription:** {transcription}\n"
123
+ f"- **WER:** {wer_score}\n"
124
+ f"- **CER:** {cer_score}\n"
125
+ f"- **RTF:** {rtf}\n"
126
+ f"- **Time Taken (s):** {total_time_tracked}\n"
127
+ )
128
+ results.append(result)
129
+
130
+ return "\n\n".join(results)
131
+ except Exception as e:
132
+ return f"Error during transcription: {str(e)}"
133
+
134
+
135
+ # Gradio interface
136
+ def create_interface():
137
+ model_choices = list(MODEL_CONFIGS.keys())
138
+ return gr.Interface(
139
  fn=transcribe_audio,
140
+ inputs=[
141
+ gr.Audio(type="filepath", label="Upload Audio File (16kHz recommended)"),
142
+ gr.CheckboxGroup(choices=model_choices, label="Select Models", value=model_choices),
143
+ gr.Textbox(label="Reference Text (Optional for WER/CER)", placeholder="Enter or paste ground truth text here", lines=3),
144
+ ],
145
+ outputs=gr.Markdown(label="Results"),
146
+ title="Multilingual Speech-to-Text Benchmark",
147
+ description="Upload an audio file, select one or more models, and optionally provide reference text. The app benchmarks WER, CER, RTF, and Time Taken for each model.",
148
+ allow_flagging="never",
149
  )
150
 
151
+
152
  if __name__ == "__main__":
153
+ iface = create_interface()
154
+ iface.launch()