Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Authors: Zexu Pan, Shengkui Zhao | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.cuda.amp import autocast, GradScaler | |
| import soundfile as sf | |
| import librosa | |
| import tempfile | |
| import os | |
| import subprocess | |
| from tqdm import tqdm | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| import ffmpeg | |
| class SpeechModel: | |
| def __init__(self, args): | |
| if torch.cuda.is_available(): | |
| print('GPU is found and used!') | |
| self.device = torch.device('cuda') | |
| else: | |
| # If no GPU is detected, use the CPU | |
| args.use_cuda = 0 | |
| self.device = torch.device('cpu') | |
| self.args = args | |
| self.model = None | |
| self.name = None | |
| self.data = {} | |
| def get_free_gpu(self): | |
| try: | |
| # Run nvidia-smi to query GPU memory usage and free memory | |
| result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE) | |
| gpu_info = result.stdout.decode('utf-8').strip().split('\n') | |
| free_gpu = None | |
| max_free_memory = 0 | |
| for i, info in enumerate(gpu_info): | |
| used, free = map(int, info.split(',')) | |
| if free > max_free_memory: | |
| max_free_memory = free | |
| free_gpu = i | |
| return free_gpu | |
| except Exception as e: | |
| print(f"Error finding free GPU: {e}") | |
| return None | |
| def load_model(self): | |
| checkpoint_path = hf_hub_download(repo_id=f"alibabasglab/{self.args.model_name}", filename="last_checkpoint.pt") | |
| # Load the checkpoint file into memory (map_location ensures compatibility with different devices) | |
| checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
| # Load the model's state dictionary (weights and biases) into the current model | |
| if 'model' in checkpoint: | |
| pretrained_model = checkpoint['model'] | |
| else: | |
| pretrained_model = checkpoint | |
| state = self.model.state_dict() | |
| for key in state.keys(): | |
| if key in pretrained_model and state[key].shape == pretrained_model[key].shape: | |
| state[key] = pretrained_model[key] | |
| elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape: | |
| state[key] = pretrained_model[key.replace('module.', '')] | |
| elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape: | |
| state[key] = pretrained_model['module.'+key] | |
| else: raise NameError(f'{key} not loaded') | |
| self.model.load_state_dict(state) | |
| print(f'Successfully loaded model weights for decoding') | |
| def process(self, file_path, text): | |
| orig_audio = self.load_data(file_path) | |
| text = [text] | |
| with torch.no_grad(): | |
| chunk_size = 160000 # 240000 | |
| print(orig_audio.shape) | |
| if orig_audio.shape[0] > chunk_size: | |
| output_audio = torch.zeros(1,orig_audio.shape[0]) | |
| for itr in range(0, orig_audio.shape[0]//chunk_size): | |
| output_audio[:,chunk_size*itr:chunk_size*(itr+1)] = self.model(orig_audio[chunk_size*itr:chunk_size*(itr+1)], text, self.device) | |
| output_audio[:,-chunk_size:] = self.model(orig_audio[-chunk_size:], text, self.device) | |
| else: | |
| output_audio = self.model(orig_audio, text, self.device) | |
| output_audio = output_audio.detach().squeeze().cpu().numpy() | |
| # residual_audio = residual_audio.detach().squeeze().cpu().numpy() | |
| residual_audio = orig_audio - output_audio | |
| return orig_audio, output_audio, residual_audio | |
| def _audioread(self, path, sampling_rate): | |
| data, fs = sf.read(path, dtype='float32') | |
| if len(data.shape) >1: | |
| if data.shape[0] > data.shape[1]: | |
| data = data[:, 0] | |
| else: | |
| data = data[0, :] | |
| if fs != sampling_rate: | |
| data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate) | |
| max_val = np.max(np.abs(data)) | |
| if max_val > 1: | |
| data /= max_val | |
| return data | |
| def _videoread(self, path, sampling_rate): | |
| try: | |
| # Use ffmpeg to extract audio and output raw PCM data | |
| process = ( | |
| ffmpeg | |
| .input(path) | |
| .output('pipe:', format='wav', ar=sampling_rate, ac=1) | |
| .run(capture_stdout=True, capture_stderr=True) | |
| ) | |
| # Read the audio data from the raw output | |
| audio_data = np.frombuffer(process[0], dtype=np.int16) | |
| # Normalize to [-1, 1] if needed (optional, depending on your use case) | |
| audio_data = audio_data.astype(np.float32) / 32768.0 | |
| max_val = np.max(np.abs(audio_data)) | |
| if max_val > 1: | |
| audio_data /= max_val | |
| return audio_data | |
| except ffmpeg.Error as e: | |
| print(f"Error loading audio from video: {e.stderr.decode()}") | |
| return None, None | |
| def load_data(self, file_path): | |
| """ | |
| Detect whether the file is audio or video, then process it. | |
| - Audio: Load using `soundfile`. | |
| - Video: Extract audio and resample to 16 kHz. | |
| """ | |
| # Check if the file exists | |
| if not os.path.isfile(file_path): | |
| raise FileNotFoundError(f"File not found: {file_path}") | |
| # Supported audio and video extensions | |
| audio_extensions = ['.wav', '.flac', '.mp3', '.ogg', '.mat'] | |
| video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.webm'] | |
| _, ext = os.path.splitext(file_path) | |
| ext = ext.lower() | |
| if ext in audio_extensions: | |
| # Handle audio files | |
| print(f"Processing audio file: {file_path}") | |
| data = self._audioread(file_path, self.args.sampling_rate) | |
| return data | |
| elif ext in video_extensions: | |
| # Handle video files | |
| print(f"Processing video file: {file_path}") | |
| data = self._videoread(file_path, self.args.sampling_rate) | |
| return data | |
| else: | |
| raise ValueError(f"Unsupported file type: {file_path}") | |
| class select_network(nn.Module): | |
| def __init__(self, args): | |
| super(select_network, self).__init__() | |
| self.args = args | |
| from models.tflocoformer.tflocoformer_separator import TFLocoformer | |
| self.sep_network = TFLocoformer(args) | |
| print(f'{args.model_name} running.') | |
| import os | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| model_path = snapshot_download(repo_id="alibabasglab/t5-base") | |
| model_path = os.path.join(model_path, "t5-base") | |
| # model_path = hf_hub_download(repo_id="alibabasglab/t5-base", filename="t5-base") | |
| self.tokenizer =AutoTokenizer.from_pretrained(model_path, model_max_length=512) | |
| self.text_encoder = T5EncoderModel.from_pretrained(model_path) | |
| # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = False | |
| from models.beats.BEATs import BEATs, BEATsConfig | |
| model_path = snapshot_download(repo_id="alibabasglab/beats") | |
| model_path = os.path.join(model_path, "BEATs_iter3_plus_AS2M.pt") | |
| checkpoint = torch.load(model_path) | |
| cfg = BEATsConfig(checkpoint['cfg']) | |
| self.BEATs_model = BEATs(cfg) | |
| self.BEATs_model.load_state_dict(checkpoint['model']) | |
| self.BEATs_model.eval() | |
| for param in self.BEATs_model.parameters(): | |
| param.requires_grad = False | |
| def forward(self, mixture, t_ref, device): | |
| mixture = torch.tensor(mixture).to(device) | |
| mixture = mixture.unsqueeze(0) | |
| text_input = self.tokenizer(t_ref, return_tensors="pt", truncation=True, padding="longest") | |
| text_input_ids = text_input["input_ids"].to(device) | |
| text_attention_mask = text_input["attention_mask"].to(device) | |
| text_len = torch.sum(text_attention_mask, dim=1) | |
| text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state | |
| t_ref = (text_embedding.clone().detach(), text_attention_mask.clone().detach(), text_len.clone().detach()) | |
| with torch.no_grad(): | |
| padding_mask = torch.zeros_like(mixture).bool() | |
| a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0] | |
| a_ref = a_ref.transpose(1,2) | |
| return self.forword_step(mixture, t_ref, a_ref.clone().detach()) | |
| def forword_step(self, mixture, t_ref, a_ref): | |
| return self.sep_network(mixture, t_ref, a_ref) | |
| class network_wrapper(SpeechModel): | |
| def __init__(self, args): | |
| # Initialize the parent SpeechModel class | |
| super(network_wrapper, self).__init__(args) | |
| # Import the AV-MossFormer2 model for 16 kHz target speech enhancement | |
| # Initialize the model | |
| self.model = select_network(args) | |
| # Load pre-trained model checkpoint | |
| self.load_model() | |
| # Move model to the appropriate device (GPU/CPU) | |
| self.model.to(self.device) | |
| # Set the model to evaluation mode (no gradient calculation) | |
| self.model.eval() | |