import numpy as np import torch import torchaudio import librosa import scipy.signal as signal from scipy.signal import savgol_filter from torchaudio import transforms as T def compute_melody_v2(stereo_audio: torch.Tensor) -> np.ndarray: """ Args: stereo_audio: torch.Tensor of shape (2, N), 其中 stereo_audio[0] 是左聲道, stereo_audio[1] 是右聲道。 sr: 取樣率 (sampling rate)。 Returns: c: np.ndarray of shape (8, T_frames), 每一列代表: [L1, R1, L2, R2, L3, R3, L4, R4](按 frame 交錯), 且每個值都 ∈ {1, 2, …, 128},對應 CQT 的頻率 bin。 """ audio_np, sr = librosa.load(stereo_audio, sr=None, mono=False) if audio_np.ndim == 1: audio_np = np.expand_dims(audio_np, 0) audio = torch.as_tensor(audio_np, dtype=torch.float32) # 1. 先針對左、右聲道分別計算 CQT (128 bins),回傳 cqt_db 形狀都是 (128, T_frames) cqt_left = compute_music_represent(audio[0], sr) # shape: (128, T_frames) cqt_right = compute_music_represent(audio[1], sr) # shape: (128, T_frames) # 2. 取得時框 (frame) 數量 # 注意:librosa.cqt 的輸出 cqt_db 對應的「時框數」就是第二維度 T_frames = cqt_left.shape[1] # 3. 預先配置輸出矩陣 c,dtype 用 int,shape = (8, T_frames) c = np.zeros((8, T_frames), dtype=np.int32) # 4. 逐一 frame 處理:對每個 frame 的 128 維度做 top-4 for j in range(T_frames): # 4.1 取出當前時框的左、右聲道 CQT 能量(分貝值) col_L = cqt_left[:, j] # shape: (128,) col_R = cqt_right[:, j] # shape: (128,) # 4.2 用 numpy.argsort 找到「前 4 大」的索引 # np.argsort 預設是從小到大排序,所以取最後 4 個,再反轉取大到小 idx4_L = np.argsort(col_L)[-4:][::-1] # 0-based, 長度=4 idx4_R = np.argsort(col_R)[-4:][::-1] # 0-based, 長度=4 # 4.3 轉成 1-based(因為題意寫 pixel ∈ {1,2,…,128}) idx4_L = idx4_L + 1 # 現在範圍是 1..128 idx4_R = idx4_R + 1 # 4.4 交錯填入 c 的第 j 欄 # 我們希望 c[:, j] = [L1, R1, L2, R2, L3, R3, L4, R4] for k in range(4): c[2 * k , j] = idx4_L[k] c[2 * k + 1, j] = idx4_R[k] return c[:,:4097] def compute_music_represent(audio, sr): filter_y = torchaudio.functional.highpass_biquad(audio, sr, 261.6) fmin = librosa.midi_to_hz(0) cqt_spec = librosa.cqt(y=filter_y.numpy(), fmin=fmin, sr=sr, n_bins=128, bins_per_octave=12, hop_length=512) cqt_db = librosa.amplitude_to_db(np.abs(cqt_spec), ref=np.max) return cqt_db def keep_top4_pitches_per_channel(cqt_db): """ cqt_db is assumed to have shape: (2, 128, time_frames). We return a combined 2D array of shape (128, time_frames) where only the top 4 pitch bins in each channel are kept (for a total of up to 8 bins per time frame). """ # Parse shapes num_channels, num_bins, num_frames = cqt_db.shape # Initialize an output array that combines both channels # and has zeros everywhere initially combined = np.zeros((num_bins, num_frames), dtype=cqt_db.dtype) for ch in range(num_channels): for t in range(num_frames): # Find the top 4 pitch bins for this channel at frame t # argsort sorts ascending; we take the last 4 indices for top 4 top4_indices = np.argsort(cqt_db[ch, :, t])[-4:] # Copy their values into the combined array # We add to it in case there's overlap between channels combined[top4_indices, t] = 1 return combined def compute_melody(input_audio): # Initialize parameters sample_rate = 44100 # Load audio file wav_np, sr = librosa.load(input_audio, sr=None, mono=False) if wav_np.ndim == 1: wav_np = np.expand_dims(wav_np, 0) wav = torch.as_tensor(wav_np, dtype=torch.float32) if sr != sample_rate: resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) wav = resample(wav) # Truncate or pad the audio to 2097152 samples target_length = 2097152 if wav.size(1) > target_length: # Truncate the audio if it is longer than the target length wav = wav[:, :target_length] elif wav.size(1) < target_length: # Pad the audio with zeros if it is shorter than the target length padding = target_length - wav.size(1) wav = torch.cat([wav, torch.zeros(wav.size(0), padding)], dim=1) melody = compute_music_represent(wav, 44100) melody = keep_top4_pitches_per_channel(melody) return melody def compute_dynamics(audio_file, hop_length=160, target_sample_rate=44100, cut=True): """ Compute the dynamics curve for a given audio file. Args: audio_file (str): Path to the audio file. window_length (int): Length of FFT window for computing the spectrogram. hop_length (int): Number of samples between successive frames. smoothing_window (int): Length of the Savitzky-Golay filter window. polyorder (int): Polynomial order of the Savitzky-Golay filter. Returns: dynamics_curve (numpy.ndarray): The computed dynamic values in dB. """ # Load audio file waveform_np, original_sample_rate = librosa.load(audio_file, sr=None, mono=False) if waveform_np.ndim == 1: waveform_np = np.expand_dims(waveform_np, 0) waveform = torch.as_tensor(waveform_np, dtype=torch.float32) if original_sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate) waveform = resampler(waveform) if cut: waveform = waveform[:, :2097152] # Ensure waveform has a single channel (e.g., select the first channel if multi-channel) waveform = waveform.mean(dim=0, keepdim=True) # Mix all channels into one waveform = waveform.clamp(-1, 1).numpy() S = np.abs(librosa.stft(waveform, n_fft=1024, hop_length=hop_length)) mel_filter_bank = librosa.filters.mel(sr=target_sample_rate, n_fft=1024, n_mels=64, fmin=0, fmax=8000) S = np.dot(mel_filter_bank, S) energy = np.sum(S**2, axis=0) dynamics_db = np.clip(energy, 1e-6, None) dynamics_db = librosa.amplitude_to_db(energy, ref=np.max).squeeze(0) smoothed_dynamics = savgol_filter(dynamics_db, window_length=279, polyorder=1) # print(smoothed_dynamics.shape) return smoothed_dynamics def extract_melody_one_hot(audio_path, sr=44100, cutoff=261.2, win_length=2048, hop_length=256): """ Extract a one-hot chromagram-based melody from an audio file (mono). Parameters: ----------- audio_path : str Path to the input audio file. sr : int Target sample rate to resample the audio (default: 44100). cutoff : float The high-pass filter cutoff frequency in Hz (default: Middle C ~ 261.2 Hz). win_length : int STFT window length for the chromagram (default: 2048). hop_length : int STFT hop length for the chromagram (default: 256). Returns: -------- one_hot_chroma : np.ndarray, shape=(12, n_frames) One-hot chromagram of the most prominent pitch class per frame. """ # --------------------------------------------------------- # 1. Load audio (librosa => shape: (channels, samples)) # --------------------------------------------------------- audio_np, in_sr = librosa.load(audio_path, sr=None, mono=False) if audio_np.ndim == 1: audio_np = np.expand_dims(audio_np, 0) audio = torch.as_tensor(audio_np, dtype=torch.float32) # Convert to mono by averaging channels: shape => (samples,) audio_mono = audio.mean(dim=0) # Resample if necessary if in_sr != sr: resample_tf = T.Resample(orig_freq=in_sr, new_freq=sr) audio_mono = resample_tf(audio_mono) # Convert torch.Tensor => NumPy array: shape (samples,) y = audio_mono.numpy() # --------------------------------------------------------- # 2. Design & apply a high-pass filter (Butterworth, order=2) # --------------------------------------------------------- nyquist = 0.5 * sr norm_cutoff = cutoff / nyquist b, a = signal.butter(N=2, Wn=norm_cutoff, btype='high', analog=False) # filtfilt expects shape (n_samples,) for 1D y_hp = signal.filtfilt(b, a, y) # --------------------------------------------------------- # 3. Compute the chromagram (librosa => shape: (12, n_frames)) # --------------------------------------------------------- chroma = librosa.feature.chroma_stft( y=y_hp, sr=sr, n_fft=win_length, # Usually >= win_length win_length=win_length, hop_length=hop_length ) # --------------------------------------------------------- # 4. Convert chromagram to one-hot via argmax along pitch classes # --------------------------------------------------------- # pitch_class_idx => shape=(n_frames,) pitch_class_idx = np.argmax(chroma, axis=0) # Make a zero array of the same shape => (12, n_frames) one_hot_chroma = np.zeros_like(chroma) # For each frame (column in chroma), set the argmax row to 1 one_hot_chroma[pitch_class_idx, np.arange(chroma.shape[1])] = 1.0 return one_hot_chroma def evaluate_f1_rhythm(input_timestamps, generated_timestamps, tolerance=0.07): """ Evaluates precision, recall, and F1-score for beat/downbeat timestamp alignment. Args: input_timestamps (ndarray): 2D array of shape [n, 2], where column 0 contains timestamps. generated_timestamps (ndarray): 2D array of shape [m, 2], where column 0 contains timestamps. tolerance (float): Alignment tolerance in seconds (default: 70ms). Returns: tuple: (precision, recall, f1) """ # Extract and sort timestamps input_timestamps = np.asarray(input_timestamps) generated_timestamps = np.asarray(generated_timestamps) # If you only need the first column if input_timestamps.size > 0: input_timestamps = input_timestamps[:, 0] input_timestamps.sort() else: input_timestamps = np.array([]) if generated_timestamps.size > 0: generated_timestamps = generated_timestamps[:, 0] generated_timestamps.sort() else: generated_timestamps = np.array([]) # Handle empty cases # Case 1: Both are empty if len(input_timestamps) == 0 and len(generated_timestamps) == 0: # You could argue everything is correct since there's nothing to detect, # but returning all zeros is a common convention. return 0.0, 0.0, 0.0 # Case 2: No ground-truth timestamps, but predictions exist if len(input_timestamps) == 0 and len(generated_timestamps) > 0: # All predictions are false positives => tp=0, fp = len(generated_timestamps) # => precision=0, recall is undefined (tp+fn=0), typically we treat recall=0 return 0.0, 0.0, 0.0 # Case 3: Ground-truth timestamps exist, but no predictions if len(input_timestamps) > 0 and len(generated_timestamps) == 0: # Everything in input_timestamps is a false negative => tp=0, fn = len(input_timestamps) # => recall=0, precision is undefined (tp+fp=0), typically we treat precision=0 return 0.0, 0.0, 0.0 # If we get here, both arrays are non-empty tp = 0 fp = 0 # Track matched ground-truth timestamps matched_inputs = np.zeros(len(input_timestamps), dtype=bool) for gen_ts in generated_timestamps: # Calculate absolute differences to each reference timestamp diffs = np.abs(input_timestamps - gen_ts) # Find index of the closest input timestamp min_diff_idx = np.argmin(diffs) # Check if that difference is within tolerance and unmatched if diffs[min_diff_idx] < tolerance and not matched_inputs[min_diff_idx]: tp += 1 matched_inputs[min_diff_idx] = True else: fp += 1 # no suitable match found or closest was already matched # Remaining unmatched input timestamps are false negatives fn = np.sum(~matched_inputs) # Compute precision, recall, f1 precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 return precision, recall, f1