# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from diffusers.utils import deprecate, logging from safetensors.torch import load_file from diffusers.loaders import AttnProcsLayers from utils.extract_conditions import compute_melody, compute_melody_v2, compute_dynamics, extract_melody_one_hot, evaluate_f1_rhythm from madmom.features.downbeats import DBNDownBeatTrackingProcessor,RNNDownBeatProcessor import numpy as np import matplotlib.pyplot as plt import os from utils.stable_audio_dataset_utils import load_audio_file logger = logging.get_logger(__name__) # pylint: disable=invalid-name import soundfile as sf TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_TOKEN") # For zero initialized 1D CNN in the attention processor def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module # Original attention processor for class StableAudioAttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. """ def __init__(self): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def apply_partial_rotary_emb( self, x: torch.Tensor, freqs_cis: Tuple[torch.Tensor], ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb rot_dim = freqs_cis[0].shape[-1] x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) out = torch.cat((x_rotated, x_unrotated), dim=-1) return out def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if rotary_emb is not None: query_dtype = query.dtype key_dtype = key.dtype query = query.to(torch.float32) key = key.to(torch.float32) rot_dim = rotary_emb[0].shape[-1] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) query = torch.cat((query_rotated, query_unrotated), dim=-1) if not attn.is_cross_attention: key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) key = torch.cat((key_rotated, key_unrotated), dim=-1) query = query.to(query_dtype) key = key.to(key_dtype) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # print("hidden_states", hidden_states.shape) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # The attention processor used in MuseControlLite, using 1 decoupled cross-attention layer class StableAudioAttnProcessor2_0_rotary(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. """ def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) super().__init__() from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding self.layer_id = layer_id self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.num_tokens = num_tokens self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.name = name self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False)) self.rotary_emb = LlamaRotaryEmbedding(dim = 64) self.to_k_ip.weight.requires_grad = True self.to_v_ip.weight.requires_grad = True self.conv_out.weight.requires_grad = True def rotate_half(self, x): x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(-1) return torch.cat((-x2, x1), dim=-1) def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states_con: Optional[torch.Tensor] = None, encoder_hidden_states_audio: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # The original cross attention in Stable-audio ############################################################### query = attn.to_q(hidden_states) ip_hidden_states = encoder_hidden_states_con key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) ############################################################### # The decupled cross attention in used in MuseControlLite, to deal with additional conditions ############################################################### ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) ip_key_length = ip_key.shape[2] ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1) ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1) ip_value_length = ip_value.shape[2] seq_len_query = query.shape[2] # Generate position_ids for query, keys, values position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query) position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query] position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device) position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device) position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] # Rotate query, keys, values cos, sin = self.rotary_emb(query, position_ids_query) query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1)) cos, sin = self.rotary_emb(ip_key, position_ids_key) ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1)) cos, sin = self.rotary_emb(ip_value, position_ids_value) ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1)) ip_hidden_states = F.scaled_dot_product_attention( query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.transpose(1, 2) ip_hidden_states = self.conv_out(ip_hidden_states) ip_hidden_states = ip_hidden_states.transpose(1, 2) ############################################################### # Combine the output of the two cross-attention layers hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # The attention processor used in MuseControlLite, using 2 decoupled cross-attention layer. It needs further examination, don't use it now. class StableAudioAttnProcessor2_0_rotary_double(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. """ def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) super().__init__() from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.num_tokens = num_tokens self.layer_id = layer_id self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_k_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.name = name self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False)) self.conv_out_audio = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False)) self.rotary_emb = LlamaRotaryEmbedding(64) self.to_k_ip.weight.requires_grad = True self.to_v_ip.weight.requires_grad = True self.conv_out.weight.requires_grad = True # Below is for copying the weight of the original weight to the decoupled cross-attention def rotate_half(self, x): x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(-1) return torch.cat((-x2, x1), dim=-1) def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states_con: Optional[torch.Tensor] = None, encoder_hidden_states_audio: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # The original cross attention in Stable-audio ############################################################### query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # if self.layer_id == "0": # hidden_states_sliced = hidden_states[:,1:,:] # # Create a tensor of zeros with shape (bs, 1, 768) # bs, _, dim2 = hidden_states_sliced.shape # zeros = torch.zeros(bs, 1, dim2).cuda() # # Concatenate the zero tensor along the second dimension (dim=1) # hidden_states_sliced = torch.cat((hidden_states_sliced, zeros), dim=1) # query_sliced = attn.to_q(hidden_states_sliced) # query_sliced = query_sliced.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # query = query_sliced ip_hidden_states = encoder_hidden_states_con ip_hidden_states_audio = encoder_hidden_states_audio ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) ip_key_length = ip_key.shape[2] ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) ip_key_audio = self.to_k_ip_audio(ip_hidden_states_audio) ip_value_audio = self.to_v_ip_audio(ip_hidden_states_audio) ip_key_audio = ip_key_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) ip_key_audio_length = ip_key_audio.shape[2] ip_value_audio = ip_value_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1) ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1) ip_key_audio = torch.repeat_interleave(ip_key_audio, heads_per_kv_head, dim=1) ip_value_audio = torch.repeat_interleave(ip_value_audio, heads_per_kv_head, dim=1) ip_value_length = ip_value.shape[2] seq_len_query = query.shape[2] ip_value_audio_length = ip_value_audio.shape[2] position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query) position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query] # Generate position_ids for keys position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device) position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device) position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] # Generate position_ids for keys position_ids_query_audio = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_audio_length / seq_len_query) position_ids_query_audio = position_ids_query_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query] position_ids_key_audio = torch.arange(ip_key_audio_length, dtype=torch.long, device=key.device) position_ids_key_audio = position_ids_key_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] position_ids_value_audio = torch.arange(ip_value_audio_length, dtype=torch.long, device=value.device) position_ids_value_audio = position_ids_value_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key] cos, sin = self.rotary_emb(query, position_ids_query) cos_audio, sin_audio = self.rotary_emb(query, position_ids_query_audio) query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1)) query_pos_audio = (query * cos_audio.unsqueeze(1)) + (self.rotate_half(query) * sin_audio.unsqueeze(1)) cos, sin = self.rotary_emb(ip_key, position_ids_key) cos_audio, sin_audio = self.rotary_emb(ip_key_audio, position_ids_key_audio) ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1)) ip_key_audio = (ip_key_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_key_audio) * sin_audio.unsqueeze(1)) cos, sin = self.rotary_emb(ip_value, position_ids_value) cos_audio, sin_audio = self.rotary_emb(ip_value_audio, position_ids_value_audio) ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1)) ip_value_audio = (ip_value_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_value_audio) * sin_audio.unsqueeze(1)) with torch.amp.autocast(device_type='cuda'): ip_hidden_states = F.scaled_dot_product_attention( query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) with torch.amp.autocast(device_type='cuda'): ip_hidden_states_audio = F.scaled_dot_product_attention( query_pos_audio, ip_key_audio, ip_value_audio, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.transpose(1, 2) ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states_audio = ip_hidden_states_audio.to(query.dtype) ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2) with torch.amp.autocast(device_type='cuda'): ip_hidden_states = self.conv_out(ip_hidden_states) ip_hidden_states = ip_hidden_states.transpose(1, 2) with torch.amp.autocast(device_type='cuda'): ip_hidden_states_audio = self.conv_out_audio(ip_hidden_states_audio) ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2) # Combine the tensors hidden_states = hidden_states + self.scale * ip_hidden_states + ip_hidden_states_audio # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def setup_MuseControlLite(config, weight_dtype, transformer_ckpt): """ Setup AP-adapter pipeline with attention processors and load checkpoints. Args: config: Configuration dictionary weight_dtype: Weight data type for the pipeline transformer_ckpt: Path to transformer checkpoint Returns: tuple: (pipe, transformer) - Configured pipeline and transformer """ if 'audio' in config['condition_type'] and len(config['condition_type'])!=1: from pipeline.stable_audio_multi_cfg_pipe_audio import StableAudioPipeline attn_processor = StableAudioAttnProcessor2_0_rotary_double audio_state_dict = load_file(config["audio_transformer_ckpt"], device="cpu") else: from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline attn_processor = StableAudioAttnProcessor2_0_rotary pipe = StableAudioPipeline.from_pretrained( "stabilityai/stable-audio-open-1.0", torch_dtype=weight_dtype, token=TOKEN ) pipe.scheduler.config.sigma_max = config["sigma_max"] pipe.scheduler.config.sigma_min = config["sigma_min"] transformer = pipe.transformer attn_procs = {} for name in transformer.attn_processors.keys(): if name.endswith("attn1.processor"): attn_procs[name] = StableAudioAttnProcessor2_0() else: attn_procs[name] = attn_processor( layer_id = name.split(".")[1], hidden_size=768, name=name, cross_attention_dim=768, scale=config['ap_scale'], ).to("cuda", dtype=torch.float) if transformer_ckpt is not None: state_dict = load_file(transformer_ckpt, device="cuda") for name, processor in attn_procs.items(): if isinstance(processor, attn_processor): weight_name_v = name + ".to_v_ip.weight" weight_name_k = name + ".to_k_ip.weight" conv_out_weight = name + ".conv_out.weight" processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].to(torch.float32)) processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].to(torch.float32)) processor.conv_out.weight = torch.nn.Parameter(state_dict[conv_out_weight].to(torch.float32)) if attn_processor == StableAudioAttnProcessor2_0_rotary_double: audio_weight_name_v = name + ".to_v_ip.weight" audio_weight_name_k = name + ".to_k_ip.weight" audio_conv_out_weight = name + ".conv_out.weight" processor.to_v_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_v].to(torch.float32)) processor.to_k_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_k].to(torch.float32)) processor.conv_out_audio.weight = torch.nn.Parameter(audio_state_dict[audio_conv_out_weight].to(torch.float32)) transformer.set_attn_processor(attn_procs) class _Wrapper(AttnProcsLayers): def forward(self, *args, **kwargs): return pipe.transformer(*args, **kwargs) transformer = _Wrapper(pipe.transformer.attn_processors) return pipe def initialize_condition_extractors(config): """ Initialize condition extractors based on configuration. Args: config: Configuration dictionary containing condition types and checkpoint paths Returns: tuple: (condition_extractors, transformer_ckpt, extractor_ckpt) """ condition_extractors = {} extractor_ckpt = {} from utils.feature_extractor import dynamics_extractor, rhythm_extractor, melody_extractor_mono, melody_extractor_stereo, melody_extractor_full_mono, melody_extractor_full_stereo, dynamics_extractor_full_stereo if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']): if "melody_stereo" in config['condition_type']: transformer_ckpt = config['transformer_ckpt_melody_stero'] extractor_ckpt = config['extractor_ckpt_melody_stero'] print(f"using model: {transformer_ckpt}, {extractor_ckpt}") melody_conditoner = melody_extractor_full_stereo().cuda().float() condition_extractors["melody"] = melody_conditoner elif "melody_mono" in config['condition_type']: transformer_ckpt = config['transformer_ckpt_melody_mono'] extractor_ckpt = config['extractor_ckpt_melody_mono'] print(f"using model: {transformer_ckpt}, {extractor_ckpt}") melody_conditoner = melody_extractor_full_mono().cuda().float() condition_extractors["melody"] = melody_conditoner elif "audio" in config['condition_type']: transformer_ckpt = config['audio_transformer_ckpt'] print(f"using model: {transformer_ckpt}") else: dynamics_conditoner = dynamics_extractor().cuda().float() condition_extractors["dynamics"] = dynamics_conditoner rhythm_conditoner = rhythm_extractor().cuda().float() condition_extractors["rhythm"] = rhythm_conditoner melody_conditoner = melody_extractor_mono().cuda().float() condition_extractors["melody"] = melody_conditoner transformer_ckpt = config['transformer_ckpt_musical'] extractor_ckpt = config['extractor_ckpt_musical'] print(f"using model: {transformer_ckpt}, {extractor_ckpt}") for conditioner_type, ckpt_path in extractor_ckpt.items(): state_dict = load_file(ckpt_path, device="cpu") condition_extractors[conditioner_type].load_state_dict(state_dict) condition_extractors[conditioner_type].eval() return condition_extractors, transformer_ckpt def evaluate_and_plot_results(audio_file, gen_file_path, output_dir, i): """ Evaluate and plot results comparing original and generated audio. Args: audio_file (str): Path to the original audio file gen_file_path (str): Path to the generated audio file output_dir (str): Directory to save the plot i (int): Index for naming the output file Returns: tuple: (dynamics_score, rhythm_score, melody_score) """ dynamics_condition = compute_dynamics(audio_file) gen_dynamics = compute_dynamics(gen_file_path) min_len_dynamics = min(gen_dynamics.shape[0], dynamics_condition.shape[0]) pearson_corr = np.corrcoef(gen_dynamics[:min_len_dynamics], dynamics_condition[:min_len_dynamics])[0, 1] print("pearson_corr", pearson_corr) melody_condition = extract_melody_one_hot(audio_file) gen_melody = extract_melody_one_hot(gen_file_path) min_len_melody = min(gen_melody.shape[1], melody_condition.shape[1]) matches = ((gen_melody[:, :min_len_melody] == melody_condition[:, :min_len_melody]) & (gen_melody[:, :min_len_melody] == 1)).sum() accuracy = matches / min_len_melody print("melody accuracy", accuracy) # Adjust layout to avoid overlap processor = RNNDownBeatProcessor() original_path = os.path.join(output_dir, f"original_{i}.wav") input_probabilities = processor(original_path) generated_probabilities = processor(gen_file_path) hmm_processor = DBNDownBeatTrackingProcessor(beats_per_bar=[3,4], fps=100) input_timestamps = hmm_processor(input_probabilities) generated_timestamps = hmm_processor(generated_probabilities) precision, recall, f1 = evaluate_f1_rhythm(input_timestamps, generated_timestamps) # Output results print(f"F1 Score: {f1:.2f}") # Plotting frame_rate = 100 # Frames per second input_time_axis = np.linspace(0, len(input_probabilities) / frame_rate, len(input_probabilities)) generate_time_axis = np.linspace(0, len(generated_probabilities) / frame_rate, len(generated_probabilities)) fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Adjust figsize as needed # ---------------------------- # Subplot (0,0): Dynamics Plot ax = axes[0, 0] ax.plot(dynamics_condition[:min_len_dynamics].squeeze(), linewidth=1, label='Dynamics condition') ax.set_title('Dynamics') ax.set_xlabel('Time Frame') ax.set_ylabel('Dynamics (dB)') ax.legend(fontsize=8) ax.grid(True) # ---------------------------- # Subplot (0,0): Dynamics Plot ax = axes[1, 0] ax.plot(gen_dynamics[:min_len_dynamics].squeeze(), linewidth=1, label='Generated Dynamics') ax.set_title('Dynamics') ax.set_xlabel('Time Frame') ax.set_ylabel('Dynamics (dB)') ax.legend(fontsize=8) ax.grid(True) # ---------------------------- # Subplot (0,2): Melody Condition (Chromagram) ax = axes[0, 1] im2 = ax.imshow(melody_condition[:, :min_len_melody], aspect='auto', origin='lower', interpolation='nearest', cmap='plasma') ax.set_title('Melody Condition') ax.set_xlabel('Time') ax.set_ylabel('Chroma Features') # ---------------------------- # Subplot (0,1): Generated Melody (Chromagram) ax = axes[1, 1] im1 = ax.imshow(gen_melody[:, :min_len_melody], aspect='auto', origin='lower', interpolation='nearest', cmap='viridis') ax.set_title('Generated Melody') ax.set_xlabel('Time') ax.set_ylabel('Chroma Features') # ---------------------------- # Subplot (1,0): Rhythm Input Probabilities ax = axes[0, 2] ax.plot(input_time_axis, input_probabilities, label="Input Beat Probability") ax.plot(input_time_axis, input_probabilities, label="Input Downbeat Probability", alpha=0.8) ax.set_title('Rhythm: Input') ax.set_xlabel('Time (s)') ax.set_ylabel('Probability') ax.legend() ax.grid(True) # ---------------------------- # Subplot (1,1): Rhythm Generated Probabilities ax = axes[1, 2] ax.plot(generate_time_axis, generated_probabilities, color='orange', label="Generated Beat Probability") ax.plot(generate_time_axis, generated_probabilities, alpha=0.8, color='red', label="Generated Downbeat Probability") ax.set_title('Rhythm: Generated') ax.set_xlabel('Time (s)') ax.set_ylabel('Probability') ax.legend() ax.grid(True) # Adjust layout and save the combined image plt.tight_layout() combined_path = os.path.join(output_dir, f"combined_{i}.png") plt.savefig(combined_path) plt.close() print(f"Combined plot saved to {combined_path}") return pearson_corr, f1, accuracy def process_musical_conditions(config, audio_file, condition_extractors, output_dir, i, weight_dtype, MuseControlLite): """ Process and extract musical conditions (dynamics, rhythm, melody) from audio file. Args: config: Configuration dictionary audio_file: Path to the audio file condition_extractors: Dictionary of condition extractors output_dir: Output directory path i: Index for file naming weight_dtype: Weight data type for torch tensors MuseControlLite: The MuseControlLite model instance audio_mask_start: Start index for audio mask audio_mask_end: End index for audio mask musical_attribute_mask_start: Start index for musical attribute mask musical_attribute_mask_end: End index for musical attribute mask Returns: tuple: (final_condition, extracted_condition, final_condition_audio) """ total_seconds = 2097152/44100 use_audio_mask = False use_musical_attribute_mask = False if (config["audio_mask_start_seconds"] and config["audio_mask_end_seconds"]) != 0 and "audio" in config["condition_type"]: use_audio_mask = True audio_mask_start = int(config["audio_mask_start_seconds"] / total_seconds * 1024) # 1024 is the latent length for 2097152/44100 seconds audio_mask_end = int(config["audio_mask_end_seconds"] / total_seconds * 1024) print( f"using mask for 'audio' from " f"{config['audio_mask_start_seconds']}~{config['audio_mask_end_seconds']}" ) if (config["musical_attribute_mask_start_seconds"] and config["musical_attribute_mask_end_seconds"]) != 0: use_musical_attribute_mask = True musical_attribute_mask_start = int(config["musical_attribute_mask_start_seconds"] / total_seconds * 1024) musical_attribute_mask_end = int(config["musical_attribute_mask_end_seconds"] / total_seconds * 1024) masked_types = [t for t in config['condition_type'] if t != 'audio'] print( f"using mask for {', '.join(masked_types)} " f"from {config['musical_attribute_mask_start_seconds']}~" f"{config['musical_attribute_mask_end_seconds']}" ) if "dynamics" in config["condition_type"]: dynamics_condition = compute_dynamics(audio_file) dynamics_condition = torch.from_numpy(dynamics_condition).cuda() dynamics_condition = dynamics_condition.unsqueeze(0).unsqueeze(0) print("dynamics_condition", dynamics_condition.shape) extracted_dynamics_condition = condition_extractors["dynamics"](dynamics_condition.to(torch.float32)) masked_extracted_dynamics_condition = torch.zeros_like(extracted_dynamics_condition) extracted_dynamics_condition = F.interpolate(extracted_dynamics_condition, size=1024, mode='linear', align_corners=False) masked_extracted_dynamics_condition = F.interpolate(masked_extracted_dynamics_condition, size=1024, mode='linear', align_corners=False) else: extracted_dynamics_condition = torch.zeros((1, 192, 1024), device="cuda") masked_extracted_dynamics_condition = extracted_dynamics_condition if "rhythm" in config["condition_type"]: rnn_processor = RNNDownBeatProcessor() wave = load_audio_file(audio_file) if wave is not None: original_path = os.path.join(output_dir, f"original_{i}.wav") sf.write(original_path, wave.T.float().cpu().numpy(), 44100) rhythm_curve = rnn_processor(original_path) rhythm_condition = torch.from_numpy(rhythm_curve).cuda() rhythm_condition = rhythm_condition.transpose(0,1).unsqueeze(0) print("rhythm_condition", rhythm_condition.shape) extracted_rhythm_condition = condition_extractors["rhythm"](rhythm_condition.to(torch.float32)) masked_extracted_rhythm_condition = torch.zeros_like(extracted_rhythm_condition) extracted_rhythm_condition = F.interpolate(extracted_rhythm_condition, size=1024, mode='linear', align_corners=False) masked_extracted_rhythm_condition = F.interpolate(masked_extracted_rhythm_condition, size=1024, mode='linear', align_corners=False) else: extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda") masked_extracted_rhythm_condition = extracted_rhythm_condition else: extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda") masked_extracted_rhythm_condition = extracted_rhythm_condition if "melody_mono" in config["condition_type"]: melody_condition = compute_melody(audio_file) melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0) print("melody_condition", melody_condition.shape) extracted_melody_condition = condition_extractors["melody"](melody_condition.to(torch.float32)) masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition) extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False) masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False) elif "melody_stereo" in config["condition_type"]: melody_condition = compute_melody_v2(audio_file) melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0) print("melody_condition", melody_condition.shape) extracted_melody_condition = condition_extractors["melody"](melody_condition) masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition) extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False) masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False) else: if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']): extracted_melody_condition = torch.zeros((1, 768, 1024), device="cuda") else: extracted_melody_condition = torch.zeros((1, 192, 1024), device="cuda") masked_extracted_melody_condition = extracted_melody_condition # Use multiple cfg if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']): extracted_condition = extracted_melody_condition final_condition = torch.concat((masked_extracted_melody_condition, masked_extracted_melody_condition, extracted_melody_condition), dim=0) else: extracted_blank_condition = torch.zeros((1, 192, 1024), device="cuda") extracted_condition = torch.concat((extracted_rhythm_condition, extracted_dynamics_condition, extracted_melody_condition, extracted_blank_condition), dim=1) masked_extracted_condition = torch.concat((masked_extracted_rhythm_condition, masked_extracted_dynamics_condition, masked_extracted_melody_condition, extracted_blank_condition), dim=1) final_condition = torch.concat((masked_extracted_condition, masked_extracted_condition, extracted_condition), dim=0) if "audio" in config["condition_type"]: desired_repeats = 768 // 64 # Number of repeats needed audio = load_audio_file(audio_file) if audio is not None: audio_condition = MuseControlLite.vae.encode(audio.unsqueeze(0).to(weight_dtype).cuda()).latent_dist.sample() extracted_audio_condition = audio_condition.repeat_interleave(desired_repeats, dim=1).float() pad_len = 1024 - extracted_audio_condition.shape[-1] if pad_len > 0: # Pad on the right side (last dimension) extracted_audio_condition = F.pad(extracted_audio_condition, (0, pad_len)) masked_extracted_audio_condition = torch.zeros_like(extracted_audio_condition) if len(config["condition_type"]) == 1: final_condition = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0) else: final_condition_audio = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0) final_condition = torch.concat((final_condition, extracted_condition), dim=0) final_condition_audio = final_condition_audio.transpose(1, 2) else: final_condition_audio = None final_condition = final_condition.transpose(1, 2) if "audio" in config["condition_type"] and len(config["condition_type"])==1: final_condition[:,audio_mask_start:audio_mask_end,:] = 0 if use_audio_mask: config["guidance_scale_con"] = config["guidance_scale_audio"] elif "audio" in config["condition_type"] and len(config["condition_type"])!=1 and use_audio_mask: final_condition[:,:audio_mask_start,:] = 0 final_condition[:,audio_mask_end:,:] = 0 if 'final_condition_audio' in locals() and final_condition_audio is not None: final_condition_audio[:,audio_mask_start:audio_mask_end,:] = 0 elif use_musical_attribute_mask: final_condition[:,musical_attribute_mask_start:musical_attribute_mask_end,:] = 0 if 'final_condition_audio' in locals() and final_condition_audio is not None: final_condition_audio[:,:musical_attribute_mask_start,:] = 0 final_condition_audio[:,musical_attribute_mask_end:,:] = 0 return final_condition, final_condition_audio if 'final_condition_audio' in locals() else None