EEYD / networks.py
alibabasglab's picture
Update networks.py
27b0df9 verified
"""
Authors: Zexu Pan, Shengkui Zhao
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import soundfile as sf
import librosa
import tempfile
import os
import subprocess
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
import numpy as np
import ffmpeg
class SpeechModel:
def __init__(self, args):
if torch.cuda.is_available():
print('GPU is found and used!')
self.device = torch.device('cuda')
else:
# If no GPU is detected, use the CPU
args.use_cuda = 0
self.device = torch.device('cpu')
self.args = args
self.model = None
self.name = None
self.data = {}
def get_free_gpu(self):
try:
# Run nvidia-smi to query GPU memory usage and free memory
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE)
gpu_info = result.stdout.decode('utf-8').strip().split('\n')
free_gpu = None
max_free_memory = 0
for i, info in enumerate(gpu_info):
used, free = map(int, info.split(','))
if free > max_free_memory:
max_free_memory = free
free_gpu = i
return free_gpu
except Exception as e:
print(f"Error finding free GPU: {e}")
return None
def load_model(self):
checkpoint_path = hf_hub_download(repo_id=f"alibabasglab/{self.args.model_name}", filename="last_checkpoint.pt")
# Load the checkpoint file into memory (map_location ensures compatibility with different devices)
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# Load the model's state dictionary (weights and biases) into the current model
if 'model' in checkpoint:
pretrained_model = checkpoint['model']
else:
pretrained_model = checkpoint
state = self.model.state_dict()
for key in state.keys():
if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
state[key] = pretrained_model[key]
elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
state[key] = pretrained_model[key.replace('module.', '')]
elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
state[key] = pretrained_model['module.'+key]
else: raise NameError(f'{key} not loaded')
self.model.load_state_dict(state)
print(f'Successfully loaded model weights for decoding')
def process(self, file_path, text):
orig_audio = self.load_data(file_path)
text = [text]
with torch.no_grad():
chunk_size = 160000 # 240000
print(orig_audio.shape)
if orig_audio.shape[0] > chunk_size:
output_audio = torch.zeros(1,orig_audio.shape[0])
for itr in range(0, orig_audio.shape[0]//chunk_size):
output_audio[:,chunk_size*itr:chunk_size*(itr+1)] = self.model(orig_audio[chunk_size*itr:chunk_size*(itr+1)], text, self.device)
output_audio[:,-chunk_size:] = self.model(orig_audio[-chunk_size:], text, self.device)
else:
output_audio = self.model(orig_audio, text, self.device)
output_audio = output_audio.detach().squeeze().cpu().numpy()
# residual_audio = residual_audio.detach().squeeze().cpu().numpy()
residual_audio = orig_audio - output_audio
return orig_audio, output_audio, residual_audio
def _audioread(self, path, sampling_rate):
data, fs = sf.read(path, dtype='float32')
if len(data.shape) >1:
if data.shape[0] > data.shape[1]:
data = data[:, 0]
else:
data = data[0, :]
if fs != sampling_rate:
data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate)
max_val = np.max(np.abs(data))
if max_val > 1:
data /= max_val
return data
def _videoread(self, path, sampling_rate):
try:
# Use ffmpeg to extract audio and output raw PCM data
process = (
ffmpeg
.input(path)
.output('pipe:', format='wav', ar=sampling_rate, ac=1)
.run(capture_stdout=True, capture_stderr=True)
)
# Read the audio data from the raw output
audio_data = np.frombuffer(process[0], dtype=np.int16)
# Normalize to [-1, 1] if needed (optional, depending on your use case)
audio_data = audio_data.astype(np.float32) / 32768.0
max_val = np.max(np.abs(audio_data))
if max_val > 1:
audio_data /= max_val
return audio_data
except ffmpeg.Error as e:
print(f"Error loading audio from video: {e.stderr.decode()}")
return None, None
def load_data(self, file_path):
"""
Detect whether the file is audio or video, then process it.
- Audio: Load using `soundfile`.
- Video: Extract audio and resample to 16 kHz.
"""
# Check if the file exists
if not os.path.isfile(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Supported audio and video extensions
audio_extensions = ['.wav', '.flac', '.mp3', '.ogg', '.mat']
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.webm']
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in audio_extensions:
# Handle audio files
print(f"Processing audio file: {file_path}")
data = self._audioread(file_path, self.args.sampling_rate)
return data
elif ext in video_extensions:
# Handle video files
print(f"Processing video file: {file_path}")
data = self._videoread(file_path, self.args.sampling_rate)
return data
else:
raise ValueError(f"Unsupported file type: {file_path}")
class select_network(nn.Module):
def __init__(self, args):
super(select_network, self).__init__()
self.args = args
from models.tflocoformer.tflocoformer_separator import TFLocoformer
self.sep_network = TFLocoformer(args)
print(f'{args.model_name} running.')
import os
from transformers import AutoTokenizer, T5EncoderModel
model_path = snapshot_download(repo_id="alibabasglab/t5-base")
model_path = os.path.join(model_path, "t5-base")
# model_path = hf_hub_download(repo_id="alibabasglab/t5-base", filename="t5-base")
self.tokenizer =AutoTokenizer.from_pretrained(model_path, model_max_length=512)
self.text_encoder = T5EncoderModel.from_pretrained(model_path)
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
for param in self.text_encoder.parameters():
param.requires_grad = False
from models.beats.BEATs import BEATs, BEATsConfig
model_path = snapshot_download(repo_id="alibabasglab/beats")
model_path = os.path.join(model_path, "BEATs_iter3_plus_AS2M.pt")
checkpoint = torch.load(model_path)
cfg = BEATsConfig(checkpoint['cfg'])
self.BEATs_model = BEATs(cfg)
self.BEATs_model.load_state_dict(checkpoint['model'])
self.BEATs_model.eval()
for param in self.BEATs_model.parameters():
param.requires_grad = False
def forward(self, mixture, t_ref, device):
mixture = torch.tensor(mixture).to(device)
mixture = mixture.unsqueeze(0)
text_input = self.tokenizer(t_ref, return_tensors="pt", truncation=True, padding="longest")
text_input_ids = text_input["input_ids"].to(device)
text_attention_mask = text_input["attention_mask"].to(device)
text_len = torch.sum(text_attention_mask, dim=1)
text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state
t_ref = (text_embedding.clone().detach(), text_attention_mask.clone().detach(), text_len.clone().detach())
with torch.no_grad():
padding_mask = torch.zeros_like(mixture).bool()
a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0]
a_ref = a_ref.transpose(1,2)
return self.forword_step(mixture, t_ref, a_ref.clone().detach())
def forword_step(self, mixture, t_ref, a_ref):
return self.sep_network(mixture, t_ref, a_ref)
class network_wrapper(SpeechModel):
def __init__(self, args):
# Initialize the parent SpeechModel class
super(network_wrapper, self).__init__(args)
# Import the AV-MossFormer2 model for 16 kHz target speech enhancement
# Initialize the model
self.model = select_network(args)
# Load pre-trained model checkpoint
self.load_model()
# Move model to the appropriate device (GPU/CPU)
self.model.to(self.device)
# Set the model to evaluation mode (no gradient calculation)
self.model.eval()