Spaces:
Running on Zero
Running on Zero
| import streamlit as st | |
| import os | |
| import random | |
| import librosa | |
| import librosa.display | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import re | |
| from collections import defaultdict | |
| # 设置页面配置 | |
| st.set_page_config(page_title="多音频对比与频谱可视化工具 (v5.2 - 最终稳定版)", layout="wide") | |
| # --- 辅助函数 --- | |
| def get_safe_prefix(prefix_input): | |
| """从逗号分隔的字符串中安全地获取第一个有效前缀,否则返回空字符串。""" | |
| prefix_list = [p.strip() for p in prefix_input.split(',') if p.strip()] | |
| return prefix_list[0] if prefix_list else "" | |
| def get_prefix_list(prefix_input): | |
| """从逗号分隔的字符串中获取所有有效前缀的列表。""" | |
| return [p.strip() for p in prefix_input.split(',') if p.strip()] | |
| def get_universal_ids_regex(folder_path, prefix_list, regex_pattern, extensions=['.wav', '.mp3', '.flac']): | |
| # ... (此函数体保持不变) ... | |
| if not os.path.isdir(folder_path): | |
| return set() | |
| ids = set() | |
| try: | |
| compiled_regex = re.compile(regex_pattern) | |
| for filename in os.listdir(folder_path): | |
| base, ext = os.path.splitext(filename) | |
| if ext.lower() in extensions: | |
| current_base = base | |
| for prefix in prefix_list: | |
| if current_base.startswith(prefix): | |
| current_base = current_base[len(prefix):] | |
| break | |
| match = compiled_regex.match(current_base) | |
| if match and 'x' in match.groupdict(): | |
| file_id_x = match.group('x') | |
| if file_id_x: | |
| ids.add(file_id_x) | |
| except re.error as e: | |
| return set() | |
| except Exception: | |
| pass | |
| return ids | |
| def find_matched_ids(output_path, mix_path, tar_path, out_pfx_str, mix_pfx_str, tar_pfx_str, out_pat, mix_pat, tar_pat): | |
| """获取三个文件夹中所有通用 ID (x) 的交集,并缓存结果。""" | |
| out_prefixes_list = get_prefix_list(out_pfx_str) | |
| mix_prefixes_list = get_prefix_list(mix_pfx_str) | |
| tar_prefixes_list = get_prefix_list(tar_pfx_str) | |
| # 获取三个集合的通用 ID (x) | |
| out_ids = get_universal_ids_regex(output_path, out_prefixes_list, out_pat) | |
| mix_ids = get_universal_ids_regex(mix_path, mix_prefixes_list, mix_pat) | |
| tar_ids = get_universal_ids_regex(tar_path, tar_prefixes_list, tar_pat) | |
| # 找到三个集合的交集 | |
| matched_x_ids = sorted(list(out_ids & mix_ids & tar_ids)) | |
| return matched_x_ids | |
| def generate_spectrogram(audio_path, title): | |
| # --- 修复点 2: 移除中文标题 (中文乱码问题) --- | |
| try: | |
| y, sr = librosa.load(audio_path, sr=None) | |
| S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=512) | |
| S_dB = librosa.power_to_db(S, ref=np.max) | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| img = librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel', ax=ax, cmap='viridis') | |
| ax.set(title=f'Mel Spectrogram: {title}', xlabel='Time', ylabel='Mel Freq') # <-- 标题改为英文 | |
| ax.tick_params(labelsize=8) | |
| return fig | |
| except FileNotFoundError: | |
| return None | |
| except Exception as e: | |
| st.error(f"处理音频文件失败: {str(e)}") | |
| return None | |
| # --- Streamlit 状态初始化 --- | |
| if 'output_path' not in st.session_state: st.session_state.output_path = "" | |
| if 'mixture_path' not in st.session_state: st.session_state.mixture_path = "" | |
| if 'target_path' not in st.session_state: st.session_state.target_path = "" | |
| if 'output_pattern' not in st.session_state: st.session_state.output_pattern = r"(?P<x>\d+)_DT(?P<y>\d+)" | |
| if 'mix_pattern' not in st.session_state: st.session_state.mix_pattern = r"(?P<x>\d+)_DT(?P<y>\d+)" | |
| if 'tar_pattern' not in st.session_state: st.session_state.tar_pattern = r"(?P<x>\d+)" | |
| if 'output_prefixes' not in st.session_state: st.session_state.output_prefixes = "" | |
| if 'mix_prefixes' not in st.session_state: st.session_state.mix_prefixes = "" | |
| if 'tar_prefixes' not in st.session_state: st.session_state.tar_prefixes = "" | |
| if 'separator' not in st.session_state: st.session_state.separator = "_DT" # <--- 新增状态 | |
| if 'selected_y' not in st.session_state: st.session_state.selected_y = "0" | |
| if 'available_x_ids' not in st.session_state: st.session_state.available_x_ids = [] | |
| if 'selected_x_id' not in st.session_state: st.session_state.selected_x_id = None | |
| # --- 主体 UI --- | |
| st.title("🎼 多音频对比与频谱可视化工具 (v5.2 - 最终稳定版)") | |
| # 1. 文件夹输入 | |
| st.header("1. 输入文件夹路径") | |
| col_out, col_mix, col_tar = st.columns(3) | |
| with col_out: st.session_state.output_path = st.text_input("Output 文件夹路径", st.session_state.output_path) | |
| with col_mix: st.session_state.mixture_path = st.text_input("Mixture 文件夹路径", st.session_state.mixture_path) | |
| with col_tar: st.session_state.target_path = st.text_input("Target 文件夹路径", st.session_state.target_path) | |
| # 2. 模式和前缀配置 | |
| st.header("2. 配置文件名匹配模式") | |
| st.markdown("请为每种文件类型配置**独立**的模式和前缀。模式必须包含 `(?P<x>...)`。") | |
| # --- Output 配置 --- | |
| col_out_p, col_out_r = st.columns(2) | |
| with col_out_p: st.session_state.output_prefixes = st.text_input("Output 前缀", st.session_state.output_prefixes) | |
| with col_out_r: st.session_state.output_pattern = st.text_input("Output 模式 (提取 x)", st.session_state.output_pattern, key='out_pat', help="例如: `(?P<x>\d+)_DT\d+`") | |
| # --- Mixture 配置 --- | |
| col_mix_p, col_mix_r = st.columns(2) | |
| with col_mix_p: st.session_state.mix_prefixes = st.text_input("Mixture 前缀", st.session_state.mix_prefixes) | |
| with col_mix_r: st.session_state.mix_pattern = st.text_input("Mixture 模式 (提取 x)", st.session_state.mix_pattern, key='mix_pat', help="例如: `(?P<x>\d+)_DT\d+`") | |
| # --- Target 配置 (修复点 1: 确保 Target 模式在这里配置) --- | |
| col_tar_p, col_tar_r = st.columns(2) | |
| with col_tar_p: st.session_state.tar_prefixes = st.text_input("Target 前缀", st.session_state.tar_prefixes) | |
| with col_tar_r: st.session_state.tar_pattern = st.text_input("Target 模式 (提取 x)", st.session_state.tar_pattern, key='tar_pat', help="例如: `(?P<x>\d+)`") | |
| # 3. 核心版本号配置 (UI 结构调整,避免重复标题) | |
| st.header("3. 核心版本号与 ID 加载") | |
| col_sep, col_y, col_btn_y = st.columns([1, 1, 2]) | |
| with col_sep: | |
| st.session_state.separator = st.text_input("X和Y之间的分隔符", st.session_state.separator, help="例如:`_DT`,`_V_` 等。**注意:修改此项后必须相应修改第2部分的模式!**") | |
| with col_y: | |
| st.session_state.selected_y = st.text_input("目标小版本号 (y)", st.session_state.selected_y) | |
| with col_btn_y: | |
| st.write(" ") | |
| if st.button("加载/刷新通用ID列表 (x)", help="清除缓存,根据新的模式和前缀重新匹配通用大序号 (x)"): | |
| st.cache_data.clear() | |
| st.session_state.selected_x_id = None | |
| st.rerun() | |
| # 4. 文件列表加载逻辑 (通用 ID x) | |
| if st.session_state.output_path and st.session_state.mixture_path and st.session_state.target_path: | |
| # 核心:调用缓存函数,如果参数不变,它会立即返回结果 | |
| matched_x_ids = find_matched_ids( | |
| st.session_state.output_path, st.session_state.mixture_path, st.session_state.target_path, | |
| st.session_state.output_prefixes, st.session_state.mix_prefixes, st.session_state.tar_prefixes, | |
| st.session_state.output_pattern, st.session_state.mix_pattern, st.session_state.tar_pattern | |
| ) | |
| st.session_state.available_x_ids = matched_x_ids | |
| if not st.session_state.available_x_ids: | |
| st.warning("在三个文件夹中未找到匹配的通用音频 ID (x)。请检查路径、文件格式或正则表达式模式。") | |
| else: | |
| st.success(f"成功找到 {len(st.session_state.available_x_ids)} 个匹配的通用 ID (x)。") | |
| if st.session_state.selected_x_id not in st.session_state.available_x_ids: | |
| st.session_state.selected_x_id = st.session_state.available_x_ids[0] if st.session_state.available_x_ids else None | |
| # 5. 选择音频对 (通用 ID x) | |
| if st.session_state.available_x_ids: | |
| st.header("4. 选择音频 ID (x)") | |
| col_select, col_random = st.columns([3, 1]) | |
| with col_select: | |
| new_selected_x_id = st.selectbox( | |
| "手动选择一个通用音频 ID (x)", | |
| st.session_state.available_x_ids, | |
| index=st.session_state.available_x_ids.index(st.session_state.selected_x_id) if st.session_state.selected_x_id in st.session_state.available_x_ids else 0 | |
| ) | |
| if new_selected_x_id and new_selected_x_id != st.session_state.selected_x_id: | |
| st.session_state.selected_x_id = new_selected_x_id | |
| st.rerun() | |
| with col_random: | |
| st.write(" ") | |
| if st.button("随机抽取 ID (x)"): | |
| st.session_state.selected_x_id = random.choice(st.session_state.available_x_ids) | |
| st.rerun() | |
| # 6. 展示结果 | |
| if st.session_state.selected_x_id: | |
| selected_x_id = st.session_state.selected_x_id | |
| selected_y = st.session_state.selected_y | |
| separator = st.session_state.separator | |
| st.header(f"5. 展示结果:ID(x) - {selected_x_id}, 版本(y) - {selected_y}") | |
| # --- 文件路径构造逻辑 --- | |
| def find_full_path(folder, base_name): | |
| """尝试查找常用扩展名,返回完整路径""" | |
| for ext in ['.wav', '.flac', '.mp3']: | |
| full_path = os.path.join(folder, base_name + ext) | |
| if os.path.exists(full_path): | |
| return full_path | |
| return None | |
| # ✅ 安全获取前缀 | |
| out_prefix = get_safe_prefix(st.session_state.output_prefixes) | |
| mix_prefix = get_safe_prefix(st.session_state.mix_prefixes) | |
| tar_prefix = get_safe_prefix(st.session_state.tar_prefixes) | |
| # 1. Output 文件:[Output Prefix]x[SEPARATOR]y | |
| output_base_name = f"{out_prefix}{selected_x_id}{separator}{selected_y}" | |
| output_file_path = find_full_path(st.session_state.output_path, output_base_name) | |
| # 2. Mixture 文件:[Mixture Prefix]x[SEPARATOR]y | |
| mixture_base_name = f"{mix_prefix}{selected_x_id}{separator}{selected_y}" | |
| mixture_file_path = find_full_path(st.session_state.mixture_path, mixture_base_name) | |
| # 3. Target 文件:[Target Prefix]x (简化的命名) | |
| target_base_name = f"{tar_prefix}{selected_x_id}" | |
| target_file_path = find_full_path(st.session_state.target_path, target_base_name) | |
| # --- 三列展示 --- | |
| col_out, col_mix, col_tar = st.columns(3) | |
| def display_audio_col(col, title, file_path, base_name): | |
| with col: | |
| st.subheader(title) | |
| if file_path: | |
| st.markdown(f"**File:** `{os.path.basename(file_path)}`") | |
| try: | |
| # 尝试用 flac 格式播放,如果文件是 flac | |
| ext = os.path.splitext(file_path)[1].lower() | |
| st.audio(file_path, format=f'audio/{ext.strip(".")}' if ext else 'audio/wav') | |
| except Exception as e: | |
| st.error(f"Playback failed for {title}: {str(e)}") | |
| fig = generate_spectrogram(file_path, title) | |
| if fig: | |
| st.pyplot(fig) | |
| plt.close(fig) | |
| else: | |
| st.warning(f"File not found. Attempted base name: `{base_name}`") | |
| # 1. Output 列 | |
| display_audio_col(col_out, "Output (Result)", output_file_path, output_base_name) | |
| # 2. Mixture 列 | |
| display_audio_col(col_mix, "Mixture (Original)", mixture_file_path, mixture_base_name) | |
| # 3. Target 列 | |
| display_audio_col(col_tar, "Target (Ground Truth)", target_file_path, target_base_name) |