AvtnshM commited on
Commit
eba4a14
·
verified ·
1 Parent(s): 20a9846
Files changed (1) hide show
  1. app.py +114 -30
app.py CHANGED
@@ -41,8 +41,22 @@ def load_model_and_processor(model_name):
41
 
42
  try:
43
  if model_name == "IndicConformer (AI4Bharat)":
44
- model = AutoModel.from_pretrained(repo, trust_remote_code=True)
45
- processor = AutoProcessor.from_pretrained(repo, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  elif model_name == "MMS (Facebook)":
47
  model = AutoModelForCTC.from_pretrained(repo)
48
  processor = AutoProcessor.from_pretrained(repo)
@@ -74,10 +88,10 @@ def compute_metrics(reference, hypothesis, audio_duration, total_time):
74
  # Main transcription function
75
  def transcribe_audio(audio_file, selected_models, reference_text=""):
76
  if not audio_file:
77
- return "Please upload an audio file.", []
78
 
79
  if not selected_models:
80
- return "Please select at least one model.", []
81
 
82
  table_data = []
83
  try:
@@ -98,25 +112,45 @@ def transcribe_audio(audio_file, selected_models, reference_text=""):
98
  ])
99
  continue
100
 
101
- inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
102
-
103
  start_time = time.time()
104
- with torch.no_grad():
105
- if model_type == "seq2seq":
106
- input_features = inputs["input_features"]
107
- outputs = model.generate(input_features)
108
- transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
109
- else: # CTC or RNNT
110
- input_values = inputs["input_values"]
111
- logits = model(input_values).logits
112
- predicted_ids = torch.argmax(logits, dim=-1)
113
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  total_time = time.time() - start_time
116
 
117
  # Compute metrics
118
  wer_score, cer_score, rtf = "-", "-", "-"
119
- if reference_text and transcription:
120
  wer_val, cer_val, rtf_val, _ = compute_metrics(
121
  reference_text, transcription, audio_duration, total_time
122
  )
@@ -140,15 +174,36 @@ def transcribe_audio(audio_file, selected_models, reference_text=""):
140
  if reference_text:
141
  summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n"
142
 
143
- return summary, table_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
- return f"Error during transcription: {str(e)}", []
 
146
 
147
  # Create Gradio interface with blocks for better control
148
  def create_interface():
149
  model_choices = list(MODEL_CONFIGS.keys())
150
 
151
- with gr.Blocks(title="Multilingual Speech-to-Text Benchmark") as iface:
 
 
 
152
  gr.Markdown("""
153
  # Multilingual Speech-to-Text Benchmark
154
  Upload an audio file, select one or more models, and optionally provide reference text.
@@ -167,14 +222,20 @@ def create_interface():
167
  value=[model_choices[0]], # Default to first model
168
  interactive=True
169
  )
170
- reference_input = gr.Textbox(
171
- label="Reference Text (Optional for WER/CER)",
172
- placeholder="Enter or paste ground truth text here",
173
- lines=8,
174
- interactive=True,
175
- max_lines=20
176
- )
177
- submit_btn = gr.Button("Transcribe", variant="primary", size="lg")
 
 
 
 
 
 
178
 
179
  with gr.Column(scale=2):
180
  summary_output = gr.Markdown(label="Summary", value="Upload an audio file and select models to begin...")
@@ -187,20 +248,43 @@ def create_interface():
187
  wrap=True,
188
  column_widths=[150, 400, 80, 80, 80, 100]
189
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # Connect the function
192
  submit_btn.click(
193
  fn=transcribe_audio,
194
  inputs=[audio_input, model_selection, reference_input],
195
- outputs=[summary_output, results_table]
196
  )
197
 
198
  # Also allow triggering on Enter in reference text
199
  reference_input.submit(
200
  fn=transcribe_audio,
201
  inputs=[audio_input, model_selection, reference_input],
202
- outputs=[summary_output, results_table]
203
  )
 
 
 
 
 
 
 
 
 
 
204
 
205
  return iface
206
 
 
41
 
42
  try:
43
  if model_name == "IndicConformer (AI4Bharat)":
44
+ # Use the working method for AI4Bharat model
45
+ print(f"Loading {model_name}...")
46
+ try:
47
+ model = AutoModel.from_pretrained(
48
+ repo,
49
+ trust_remote_code=True,
50
+ torch_dtype=torch.float32,
51
+ low_cpu_mem_usage=True
52
+ )
53
+ except Exception as e1:
54
+ print(f"Primary loading failed, trying fallback: {e1}")
55
+ model = AutoModel.from_pretrained(repo, trust_remote_code=True)
56
+
57
+ # AI4Bharat doesn't use a traditional processor
58
+ processor = None
59
+ return model, processor, model_type
60
  elif model_name == "MMS (Facebook)":
61
  model = AutoModelForCTC.from_pretrained(repo)
62
  processor = AutoProcessor.from_pretrained(repo)
 
88
  # Main transcription function
89
  def transcribe_audio(audio_file, selected_models, reference_text=""):
90
  if not audio_file:
91
+ return "Please upload an audio file.", [], ""
92
 
93
  if not selected_models:
94
+ return "Please select at least one model.", [], ""
95
 
96
  table_data = []
97
  try:
 
112
  ])
113
  continue
114
 
 
 
115
  start_time = time.time()
116
+
117
+ # Handle different model types
118
+ try:
119
+ if model_name == "IndicConformer (AI4Bharat)":
120
+ # Use AI4Bharat specific processing
121
+ wav = torch.from_numpy(audio).unsqueeze(0) # Add batch dimension
122
+ if torch.max(torch.abs(wav)) > 0:
123
+ wav = wav / torch.max(torch.abs(wav)) # Normalize
124
+
125
+ with torch.no_grad():
126
+ # Default to Hindi and RNNT for AI4Bharat
127
+ transcription = model(wav, "hi", "rnnt")
128
+ if isinstance(transcription, list):
129
+ transcription = transcription[0] if transcription else ""
130
+ transcription = str(transcription).strip()
131
+ else:
132
+ # Standard processing for other models
133
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
134
+
135
+ with torch.no_grad():
136
+ if model_type == "seq2seq":
137
+ input_features = inputs["input_features"]
138
+ outputs = model.generate(input_features)
139
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
140
+ else: # CTC or RNNT
141
+ input_values = inputs["input_values"]
142
+ logits = model(input_values).logits
143
+ predicted_ids = torch.argmax(logits, dim=-1)
144
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
145
+
146
+ except Exception as e:
147
+ transcription = f"Processing error: {str(e)}"
148
 
149
  total_time = time.time() - start_time
150
 
151
  # Compute metrics
152
  wer_score, cer_score, rtf = "-", "-", "-"
153
+ if reference_text and transcription and not transcription.startswith("Processing error"):
154
  wer_val, cer_val, rtf_val, _ = compute_metrics(
155
  reference_text, transcription, audio_duration, total_time
156
  )
 
174
  if reference_text:
175
  summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n"
176
 
177
+ # Create copyable text output
178
+ copyable_text = "SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*50 + "\n\n"
179
+ copyable_text += f"Audio Duration: {audio_duration:.2f}s\n"
180
+ copyable_text += f"Models Tested: {len(selected_models)}\n"
181
+ if reference_text:
182
+ copyable_text += f"Reference Text: {reference_text}\n"
183
+ copyable_text += "\n" + "-"*50 + "\n\n"
184
+
185
+ for i, row in enumerate(table_data):
186
+ copyable_text += f"MODEL {i+1}: {row[0]}\n"
187
+ copyable_text += f"Transcription: {row[1]}\n"
188
+ copyable_text += f"WER: {row[2]}\n"
189
+ copyable_text += f"CER: {row[3]}\n"
190
+ copyable_text += f"RTF: {row[4]}\n"
191
+ copyable_text += f"Time Taken: {row[5]}\n"
192
+ copyable_text += "\n" + "-"*30 + "\n\n"
193
+
194
+ return summary, table_data, copyable_text
195
  except Exception as e:
196
+ error_msg = f"Error during transcription: {str(e)}"
197
+ return error_msg, [], error_msg
198
 
199
  # Create Gradio interface with blocks for better control
200
  def create_interface():
201
  model_choices = list(MODEL_CONFIGS.keys())
202
 
203
+ with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css="""
204
+ .paste-button { margin: 5px 0; }
205
+ .copy-area { font-family: monospace; font-size: 12px; }
206
+ """) as iface:
207
  gr.Markdown("""
208
  # Multilingual Speech-to-Text Benchmark
209
  Upload an audio file, select one or more models, and optionally provide reference text.
 
222
  value=[model_choices[0]], # Default to first model
223
  interactive=True
224
  )
225
+
226
+ # Enhanced reference text input with paste functionality
227
+ with gr.Group():
228
+ gr.Markdown("### Reference Text (Optional for WER/CER)")
229
+ reference_input = gr.Textbox(
230
+ placeholder="Enter or paste ground truth text here...",
231
+ lines=8,
232
+ max_lines=20,
233
+ show_copy_button=True,
234
+ interactive=True,
235
+ elem_classes="paste-area"
236
+ )
237
+
238
+ submit_btn = gr.Button("🚀 Transcribe", variant="primary", size="lg")
239
 
240
  with gr.Column(scale=2):
241
  summary_output = gr.Markdown(label="Summary", value="Upload an audio file and select models to begin...")
 
248
  wrap=True,
249
  column_widths=[150, 400, 80, 80, 80, 100]
250
  )
251
+
252
+ # Copyable results section
253
+ with gr.Group():
254
+ gr.Markdown("### 📋 Copy Results")
255
+ copyable_output = gr.Textbox(
256
+ label="Copy-Paste Friendly Results",
257
+ lines=15,
258
+ max_lines=30,
259
+ show_copy_button=True,
260
+ interactive=False,
261
+ elem_classes="copy-area",
262
+ placeholder="Results will appear here in copy-paste friendly format..."
263
+ )
264
 
265
  # Connect the function
266
  submit_btn.click(
267
  fn=transcribe_audio,
268
  inputs=[audio_input, model_selection, reference_input],
269
+ outputs=[summary_output, results_table, copyable_output]
270
  )
271
 
272
  # Also allow triggering on Enter in reference text
273
  reference_input.submit(
274
  fn=transcribe_audio,
275
  inputs=[audio_input, model_selection, reference_input],
276
+ outputs=[summary_output, results_table, copyable_output]
277
  )
278
+
279
+ # Add example and instructions
280
+ gr.Markdown("""
281
+ ---
282
+ ### 💡 Tips:
283
+ - **Reference Text**: Paste your ground truth text to calculate WER/CER metrics
284
+ - **Copy Results**: Use the copy button in the results section to copy formatted results
285
+ - **AI4Bharat Model**: Automatically uses Hindi language with RNNT decoding
286
+ - **Supported Formats**: WAV, MP3, FLAC, M4A (16kHz recommended for best results)
287
+ """)
288
 
289
  return iface
290