Spaces:
Running
on
Zero
Running
on
Zero
Add audio utilities and track sample audio with LFS
Browse files- .gitattributes +2 -0
- MuseControlLite_setup.py +874 -0
- README.md +35 -1
- app.py +424 -0
- config_inference.py +134 -0
- melody_condition_audio/322_piano.mp3 +3 -0
- melody_condition_audio/49_piano.mp3 +3 -0
- melody_condition_audio/57_jazz.mp3 +3 -0
- melody_condition_audio/610_bass.mp3 +3 -0
- melody_condition_audio/703_mideast.mp3 +3 -0
- melody_condition_audio/785_piano.mp3 +3 -0
- melody_condition_audio/933_string.mp3 +3 -0
- pipeline/stable_audio_multi_cfg_pipe.py +772 -0
- pipeline/stable_audio_multi_cfg_pipe_audio.py +783 -0
- requirements.txt +13 -0
- utils/extract_conditions.py +301 -0
- utils/feature_extractor.py +173 -0
- utils/stable_audio_dataset_utils.py +129 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
MuseControlLite_setup.py
ADDED
|
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch import nn
|
| 18 |
+
from diffusers.utils import deprecate, logging
|
| 19 |
+
from safetensors.torch import load_file
|
| 20 |
+
from diffusers.loaders import AttnProcsLayers
|
| 21 |
+
from utils.extract_conditions import compute_melody, compute_melody_v2, compute_dynamics, extract_melody_one_hot, evaluate_f1_rhythm
|
| 22 |
+
from madmom.features.downbeats import DBNDownBeatTrackingProcessor,RNNDownBeatProcessor
|
| 23 |
+
import numpy as np
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
import os
|
| 26 |
+
from utils.stable_audio_dataset_utils import load_audio_file
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
import soundfile as sf
|
| 29 |
+
|
| 30 |
+
# For zero initialized 1D CNN in the attention processor
|
| 31 |
+
def zero_module(module):
|
| 32 |
+
for p in module.parameters():
|
| 33 |
+
nn.init.zeros_(p)
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
# Original attention processor for
|
| 37 |
+
class StableAudioAttnProcessor2_0(torch.nn.Module):
|
| 38 |
+
r"""
|
| 39 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 40 |
+
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
super().__init__()
|
| 45 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 46 |
+
raise ImportError(
|
| 47 |
+
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 48 |
+
)
|
| 49 |
+
def apply_partial_rotary_emb(
|
| 50 |
+
self,
|
| 51 |
+
x: torch.Tensor,
|
| 52 |
+
freqs_cis: Tuple[torch.Tensor],
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 55 |
+
|
| 56 |
+
rot_dim = freqs_cis[0].shape[-1]
|
| 57 |
+
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
|
| 58 |
+
|
| 59 |
+
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
|
| 60 |
+
|
| 61 |
+
out = torch.cat((x_rotated, x_unrotated), dim=-1)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
def __call__(
|
| 65 |
+
self,
|
| 66 |
+
attn,
|
| 67 |
+
hidden_states: torch.Tensor,
|
| 68 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 69 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 70 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 73 |
+
|
| 74 |
+
residual = hidden_states
|
| 75 |
+
|
| 76 |
+
input_ndim = hidden_states.ndim
|
| 77 |
+
|
| 78 |
+
if input_ndim == 4:
|
| 79 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 80 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 81 |
+
|
| 82 |
+
batch_size, sequence_length, _ = (
|
| 83 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 84 |
+
)
|
| 85 |
+
if attention_mask is not None:
|
| 86 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 87 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 88 |
+
# (batch, heads, source_length, target_length)
|
| 89 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 90 |
+
|
| 91 |
+
query = attn.to_q(hidden_states)
|
| 92 |
+
|
| 93 |
+
if encoder_hidden_states is None:
|
| 94 |
+
encoder_hidden_states = hidden_states
|
| 95 |
+
elif attn.norm_cross:
|
| 96 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 97 |
+
|
| 98 |
+
key = attn.to_k(encoder_hidden_states)
|
| 99 |
+
value = attn.to_v(encoder_hidden_states)
|
| 100 |
+
head_dim = query.shape[-1] // attn.heads
|
| 101 |
+
kv_heads = key.shape[-1] // head_dim
|
| 102 |
+
|
| 103 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 104 |
+
|
| 105 |
+
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 106 |
+
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 107 |
+
|
| 108 |
+
if kv_heads != attn.heads:
|
| 109 |
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
| 110 |
+
heads_per_kv_head = attn.heads // kv_heads
|
| 111 |
+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
| 112 |
+
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
| 113 |
+
|
| 114 |
+
if attn.norm_q is not None:
|
| 115 |
+
query = attn.norm_q(query)
|
| 116 |
+
if attn.norm_k is not None:
|
| 117 |
+
key = attn.norm_k(key)
|
| 118 |
+
|
| 119 |
+
# Apply RoPE if needed
|
| 120 |
+
if rotary_emb is not None:
|
| 121 |
+
query_dtype = query.dtype
|
| 122 |
+
key_dtype = key.dtype
|
| 123 |
+
query = query.to(torch.float32)
|
| 124 |
+
key = key.to(torch.float32)
|
| 125 |
+
|
| 126 |
+
rot_dim = rotary_emb[0].shape[-1]
|
| 127 |
+
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
|
| 128 |
+
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
| 129 |
+
|
| 130 |
+
query = torch.cat((query_rotated, query_unrotated), dim=-1)
|
| 131 |
+
|
| 132 |
+
if not attn.is_cross_attention:
|
| 133 |
+
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
|
| 134 |
+
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
| 135 |
+
|
| 136 |
+
key = torch.cat((key_rotated, key_unrotated), dim=-1)
|
| 137 |
+
|
| 138 |
+
query = query.to(query_dtype)
|
| 139 |
+
key = key.to(key_dtype)
|
| 140 |
+
|
| 141 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 142 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 143 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 144 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 145 |
+
)
|
| 146 |
+
# print("hidden_states", hidden_states.shape)
|
| 147 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 148 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 149 |
+
|
| 150 |
+
# linear proj
|
| 151 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 152 |
+
# dropout
|
| 153 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 154 |
+
|
| 155 |
+
if input_ndim == 4:
|
| 156 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 157 |
+
|
| 158 |
+
if attn.residual_connection:
|
| 159 |
+
hidden_states = hidden_states + residual
|
| 160 |
+
|
| 161 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 162 |
+
|
| 163 |
+
return hidden_states
|
| 164 |
+
|
| 165 |
+
# The attention processor used in MuseControlLite, using 1 decoupled cross-attention layer
|
| 166 |
+
class StableAudioAttnProcessor2_0_rotary(torch.nn.Module):
|
| 167 |
+
r"""
|
| 168 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 169 |
+
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
|
| 170 |
+
"""
|
| 171 |
+
def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
| 172 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 173 |
+
raise ImportError(
|
| 174 |
+
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 175 |
+
)
|
| 176 |
+
super().__init__()
|
| 177 |
+
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
| 178 |
+
self.layer_id = layer_id
|
| 179 |
+
self.hidden_size = hidden_size
|
| 180 |
+
self.cross_attention_dim = cross_attention_dim
|
| 181 |
+
self.num_tokens = num_tokens
|
| 182 |
+
self.scale = scale
|
| 183 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 184 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 185 |
+
self.name = name
|
| 186 |
+
self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
|
| 187 |
+
self.rotary_emb = LlamaRotaryEmbedding(dim = 64)
|
| 188 |
+
self.to_k_ip.weight.requires_grad = True
|
| 189 |
+
self.to_v_ip.weight.requires_grad = True
|
| 190 |
+
self.conv_out.weight.requires_grad = True
|
| 191 |
+
def rotate_half(self, x):
|
| 192 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 193 |
+
x1, x2 = x.unbind(-1)
|
| 194 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def __call__(
|
| 198 |
+
self,
|
| 199 |
+
attn,
|
| 200 |
+
hidden_states: torch.Tensor,
|
| 201 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 202 |
+
encoder_hidden_states_con: Optional[torch.Tensor] = None,
|
| 203 |
+
encoder_hidden_states_audio: Optional[torch.Tensor] = None,
|
| 204 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 205 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
| 206 |
+
) -> torch.Tensor:
|
| 207 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 208 |
+
|
| 209 |
+
residual = hidden_states
|
| 210 |
+
|
| 211 |
+
input_ndim = hidden_states.ndim
|
| 212 |
+
|
| 213 |
+
if input_ndim == 4:
|
| 214 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 215 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 216 |
+
|
| 217 |
+
batch_size, sequence_length, _ = (
|
| 218 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 219 |
+
)
|
| 220 |
+
if attention_mask is not None:
|
| 221 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 222 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 223 |
+
# (batch, heads, source_length, target_length)
|
| 224 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 225 |
+
|
| 226 |
+
# The original cross attention in Stable-audio
|
| 227 |
+
###############################################################
|
| 228 |
+
query = attn.to_q(hidden_states)
|
| 229 |
+
ip_hidden_states = encoder_hidden_states_con
|
| 230 |
+
key = attn.to_k(encoder_hidden_states)
|
| 231 |
+
value = attn.to_v(encoder_hidden_states)
|
| 232 |
+
head_dim = query.shape[-1] // attn.heads
|
| 233 |
+
kv_heads = key.shape[-1] // head_dim
|
| 234 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 235 |
+
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 236 |
+
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 237 |
+
|
| 238 |
+
if kv_heads != attn.heads:
|
| 239 |
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
| 240 |
+
heads_per_kv_head = attn.heads // kv_heads
|
| 241 |
+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
| 242 |
+
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
| 243 |
+
if attn.norm_q is not None:
|
| 244 |
+
query = attn.norm_q(query)
|
| 245 |
+
if attn.norm_k is not None:
|
| 246 |
+
key = attn.norm_k(key)
|
| 247 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 248 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 249 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 250 |
+
)
|
| 251 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 252 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 253 |
+
###############################################################
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# The decupled cross attention in used in MuseControlLite, to deal with additional conditions
|
| 257 |
+
###############################################################
|
| 258 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 259 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 260 |
+
ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 261 |
+
ip_key_length = ip_key.shape[2]
|
| 262 |
+
ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 263 |
+
if kv_heads != attn.heads:
|
| 264 |
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
| 265 |
+
heads_per_kv_head = attn.heads // kv_heads
|
| 266 |
+
ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1)
|
| 267 |
+
ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1)
|
| 268 |
+
ip_value_length = ip_value.shape[2]
|
| 269 |
+
seq_len_query = query.shape[2]
|
| 270 |
+
|
| 271 |
+
# Generate position_ids for query, keys, values
|
| 272 |
+
position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query)
|
| 273 |
+
position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
|
| 274 |
+
position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device)
|
| 275 |
+
position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 276 |
+
position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device)
|
| 277 |
+
position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 278 |
+
|
| 279 |
+
# Rotate query, keys, values
|
| 280 |
+
cos, sin = self.rotary_emb(query, position_ids_query)
|
| 281 |
+
query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1))
|
| 282 |
+
cos, sin = self.rotary_emb(ip_key, position_ids_key)
|
| 283 |
+
ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1))
|
| 284 |
+
cos, sin = self.rotary_emb(ip_value, position_ids_value)
|
| 285 |
+
ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1))
|
| 286 |
+
|
| 287 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
| 288 |
+
query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 289 |
+
)
|
| 290 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 291 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
| 292 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2)
|
| 293 |
+
ip_hidden_states = self.conv_out(ip_hidden_states)
|
| 294 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2)
|
| 295 |
+
###############################################################
|
| 296 |
+
|
| 297 |
+
# Combine the output of the two cross-attention layers
|
| 298 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 299 |
+
# linear proj
|
| 300 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 301 |
+
# dropout
|
| 302 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 303 |
+
|
| 304 |
+
if input_ndim == 4:
|
| 305 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 306 |
+
|
| 307 |
+
if attn.residual_connection:
|
| 308 |
+
hidden_states = hidden_states + residual
|
| 309 |
+
|
| 310 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 311 |
+
|
| 312 |
+
return hidden_states
|
| 313 |
+
# The attention processor used in MuseControlLite, using 2 decoupled cross-attention layer. It needs further examination, don't use it now.
|
| 314 |
+
class StableAudioAttnProcessor2_0_rotary_double(torch.nn.Module):
|
| 315 |
+
r"""
|
| 316 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 317 |
+
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
|
| 318 |
+
"""
|
| 319 |
+
def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
| 320 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 321 |
+
raise ImportError(
|
| 322 |
+
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 323 |
+
)
|
| 324 |
+
super().__init__()
|
| 325 |
+
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
| 326 |
+
self.hidden_size = hidden_size
|
| 327 |
+
self.cross_attention_dim = cross_attention_dim
|
| 328 |
+
self.num_tokens = num_tokens
|
| 329 |
+
self.layer_id = layer_id
|
| 330 |
+
self.scale = scale
|
| 331 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 332 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 333 |
+
self.to_k_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 334 |
+
self.to_v_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 335 |
+
self.name = name
|
| 336 |
+
self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
|
| 337 |
+
self.conv_out_audio = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
|
| 338 |
+
self.rotary_emb = LlamaRotaryEmbedding(64)
|
| 339 |
+
self.to_k_ip.weight.requires_grad = True
|
| 340 |
+
self.to_v_ip.weight.requires_grad = True
|
| 341 |
+
self.conv_out.weight.requires_grad = True
|
| 342 |
+
# Below is for copying the weight of the original weight to the decoupled cross-attention
|
| 343 |
+
def rotate_half(self, x):
|
| 344 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 345 |
+
x1, x2 = x.unbind(-1)
|
| 346 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def __call__(
|
| 350 |
+
self,
|
| 351 |
+
attn,
|
| 352 |
+
hidden_states: torch.Tensor,
|
| 353 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 354 |
+
encoder_hidden_states_con: Optional[torch.Tensor] = None,
|
| 355 |
+
encoder_hidden_states_audio: Optional[torch.Tensor] = None,
|
| 356 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 359 |
+
|
| 360 |
+
residual = hidden_states
|
| 361 |
+
|
| 362 |
+
input_ndim = hidden_states.ndim
|
| 363 |
+
|
| 364 |
+
if input_ndim == 4:
|
| 365 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 366 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 367 |
+
|
| 368 |
+
batch_size, sequence_length, _ = (
|
| 369 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 370 |
+
)
|
| 371 |
+
if attention_mask is not None:
|
| 372 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 373 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 374 |
+
# (batch, heads, source_length, target_length)
|
| 375 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 376 |
+
|
| 377 |
+
# The original cross attention in Stable-audio
|
| 378 |
+
###############################################################
|
| 379 |
+
query = attn.to_q(hidden_states)
|
| 380 |
+
key = attn.to_k(encoder_hidden_states)
|
| 381 |
+
value = attn.to_v(encoder_hidden_states)
|
| 382 |
+
head_dim = query.shape[-1] // attn.heads
|
| 383 |
+
kv_heads = key.shape[-1] // head_dim
|
| 384 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 385 |
+
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 386 |
+
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 387 |
+
|
| 388 |
+
if kv_heads != attn.heads:
|
| 389 |
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
| 390 |
+
heads_per_kv_head = attn.heads // kv_heads
|
| 391 |
+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
| 392 |
+
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
| 393 |
+
|
| 394 |
+
if attn.norm_q is not None:
|
| 395 |
+
query = attn.norm_q(query)
|
| 396 |
+
if attn.norm_k is not None:
|
| 397 |
+
key = attn.norm_k(key)
|
| 398 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 399 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 400 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 401 |
+
)
|
| 402 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 403 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 404 |
+
# if self.layer_id == "0":
|
| 405 |
+
# hidden_states_sliced = hidden_states[:,1:,:]
|
| 406 |
+
# # Create a tensor of zeros with shape (bs, 1, 768)
|
| 407 |
+
# bs, _, dim2 = hidden_states_sliced.shape
|
| 408 |
+
# zeros = torch.zeros(bs, 1, dim2).cuda()
|
| 409 |
+
# # Concatenate the zero tensor along the second dimension (dim=1)
|
| 410 |
+
# hidden_states_sliced = torch.cat((hidden_states_sliced, zeros), dim=1)
|
| 411 |
+
# query_sliced = attn.to_q(hidden_states_sliced)
|
| 412 |
+
# query_sliced = query_sliced.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 413 |
+
# query = query_sliced
|
| 414 |
+
ip_hidden_states = encoder_hidden_states_con
|
| 415 |
+
ip_hidden_states_audio = encoder_hidden_states_audio
|
| 416 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 417 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 418 |
+
ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 419 |
+
ip_key_length = ip_key.shape[2]
|
| 420 |
+
ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 421 |
+
ip_key_audio = self.to_k_ip_audio(ip_hidden_states_audio)
|
| 422 |
+
ip_value_audio = self.to_v_ip_audio(ip_hidden_states_audio)
|
| 423 |
+
ip_key_audio = ip_key_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 424 |
+
ip_key_audio_length = ip_key_audio.shape[2]
|
| 425 |
+
ip_value_audio = ip_value_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
| 426 |
+
|
| 427 |
+
if kv_heads != attn.heads:
|
| 428 |
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
| 429 |
+
heads_per_kv_head = attn.heads // kv_heads
|
| 430 |
+
ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1)
|
| 431 |
+
ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1)
|
| 432 |
+
ip_key_audio = torch.repeat_interleave(ip_key_audio, heads_per_kv_head, dim=1)
|
| 433 |
+
ip_value_audio = torch.repeat_interleave(ip_value_audio, heads_per_kv_head, dim=1)
|
| 434 |
+
|
| 435 |
+
ip_value_length = ip_value.shape[2]
|
| 436 |
+
seq_len_query = query.shape[2]
|
| 437 |
+
ip_value_audio_length = ip_value_audio.shape[2]
|
| 438 |
+
|
| 439 |
+
position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query)
|
| 440 |
+
position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
|
| 441 |
+
|
| 442 |
+
# Generate position_ids for keys
|
| 443 |
+
position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device)
|
| 444 |
+
position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 445 |
+
position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device)
|
| 446 |
+
position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 447 |
+
# Generate position_ids for keys
|
| 448 |
+
position_ids_query_audio = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_audio_length / seq_len_query)
|
| 449 |
+
position_ids_query_audio = position_ids_query_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
|
| 450 |
+
position_ids_key_audio = torch.arange(ip_key_audio_length, dtype=torch.long, device=key.device)
|
| 451 |
+
position_ids_key_audio = position_ids_key_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 452 |
+
position_ids_value_audio = torch.arange(ip_value_audio_length, dtype=torch.long, device=value.device)
|
| 453 |
+
position_ids_value_audio = position_ids_value_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
|
| 454 |
+
cos, sin = self.rotary_emb(query, position_ids_query)
|
| 455 |
+
cos_audio, sin_audio = self.rotary_emb(query, position_ids_query_audio)
|
| 456 |
+
query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1))
|
| 457 |
+
query_pos_audio = (query * cos_audio.unsqueeze(1)) + (self.rotate_half(query) * sin_audio.unsqueeze(1))
|
| 458 |
+
|
| 459 |
+
cos, sin = self.rotary_emb(ip_key, position_ids_key)
|
| 460 |
+
cos_audio, sin_audio = self.rotary_emb(ip_key_audio, position_ids_key_audio)
|
| 461 |
+
ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1))
|
| 462 |
+
ip_key_audio = (ip_key_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_key_audio) * sin_audio.unsqueeze(1))
|
| 463 |
+
|
| 464 |
+
cos, sin = self.rotary_emb(ip_value, position_ids_value)
|
| 465 |
+
cos_audio, sin_audio = self.rotary_emb(ip_value_audio, position_ids_value_audio)
|
| 466 |
+
ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1))
|
| 467 |
+
ip_value_audio = (ip_value_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_value_audio) * sin_audio.unsqueeze(1))
|
| 468 |
+
|
| 469 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 470 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
| 471 |
+
query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 472 |
+
)
|
| 473 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 474 |
+
ip_hidden_states_audio = F.scaled_dot_product_attention(
|
| 475 |
+
query_pos_audio, ip_key_audio, ip_value_audio, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 476 |
+
)
|
| 477 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 478 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
| 479 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2)
|
| 480 |
+
|
| 481 |
+
ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 482 |
+
ip_hidden_states_audio = ip_hidden_states_audio.to(query.dtype)
|
| 483 |
+
ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2)
|
| 484 |
+
|
| 485 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 486 |
+
ip_hidden_states = self.conv_out(ip_hidden_states)
|
| 487 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2)
|
| 488 |
+
|
| 489 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 490 |
+
ip_hidden_states_audio = self.conv_out_audio(ip_hidden_states_audio)
|
| 491 |
+
ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2)
|
| 492 |
+
|
| 493 |
+
# Combine the tensors
|
| 494 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states + ip_hidden_states_audio
|
| 495 |
+
|
| 496 |
+
# linear proj
|
| 497 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 498 |
+
# dropout
|
| 499 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 500 |
+
|
| 501 |
+
if input_ndim == 4:
|
| 502 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 503 |
+
|
| 504 |
+
if attn.residual_connection:
|
| 505 |
+
hidden_states = hidden_states + residual
|
| 506 |
+
|
| 507 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 508 |
+
|
| 509 |
+
return hidden_states
|
| 510 |
+
def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
|
| 511 |
+
"""
|
| 512 |
+
Setup AP-adapter pipeline with attention processors and load checkpoints.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
config: Configuration dictionary
|
| 516 |
+
weight_dtype: Weight data type for the pipeline
|
| 517 |
+
transformer_ckpt: Path to transformer checkpoint
|
| 518 |
+
Returns:
|
| 519 |
+
tuple: (pipe, transformer) - Configured pipeline and transformer
|
| 520 |
+
"""
|
| 521 |
+
if 'audio' in config['condition_type'] and len(config['condition_type'])!=1:
|
| 522 |
+
from pipeline.stable_audio_multi_cfg_pipe_audio import StableAudioPipeline
|
| 523 |
+
attn_processor = StableAudioAttnProcessor2_0_rotary_double
|
| 524 |
+
audio_state_dict = load_file(config["audio_transformer_ckpt"], device="cpu")
|
| 525 |
+
else:
|
| 526 |
+
from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
|
| 527 |
+
attn_processor = StableAudioAttnProcessor2_0_rotary
|
| 528 |
+
pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=weight_dtype)
|
| 529 |
+
pipe.scheduler.config.sigma_max = config["sigma_max"]
|
| 530 |
+
pipe.scheduler.config.sigma_min = config["sigma_min"]
|
| 531 |
+
transformer = pipe.transformer
|
| 532 |
+
attn_procs = {}
|
| 533 |
+
for name in transformer.attn_processors.keys():
|
| 534 |
+
if name.endswith("attn1.processor"):
|
| 535 |
+
attn_procs[name] = StableAudioAttnProcessor2_0()
|
| 536 |
+
else:
|
| 537 |
+
attn_procs[name] = attn_processor(
|
| 538 |
+
layer_id = name.split(".")[1],
|
| 539 |
+
hidden_size=768,
|
| 540 |
+
name=name,
|
| 541 |
+
cross_attention_dim=768,
|
| 542 |
+
scale=config['ap_scale'],
|
| 543 |
+
).to("cuda", dtype=torch.float)
|
| 544 |
+
if transformer_ckpt is not None:
|
| 545 |
+
state_dict = load_file(transformer_ckpt, device="cuda")
|
| 546 |
+
for name, processor in attn_procs.items():
|
| 547 |
+
if isinstance(processor, attn_processor):
|
| 548 |
+
weight_name_v = name + ".to_v_ip.weight"
|
| 549 |
+
weight_name_k = name + ".to_k_ip.weight"
|
| 550 |
+
conv_out_weight = name + ".conv_out.weight"
|
| 551 |
+
processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].to(torch.float32))
|
| 552 |
+
processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].to(torch.float32))
|
| 553 |
+
processor.conv_out.weight = torch.nn.Parameter(state_dict[conv_out_weight].to(torch.float32))
|
| 554 |
+
if attn_processor == StableAudioAttnProcessor2_0_rotary_double:
|
| 555 |
+
audio_weight_name_v = name + ".to_v_ip.weight"
|
| 556 |
+
audio_weight_name_k = name + ".to_k_ip.weight"
|
| 557 |
+
audio_conv_out_weight = name + ".conv_out.weight"
|
| 558 |
+
processor.to_v_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_v].to(torch.float32))
|
| 559 |
+
processor.to_k_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_k].to(torch.float32))
|
| 560 |
+
processor.conv_out_audio.weight = torch.nn.Parameter(audio_state_dict[audio_conv_out_weight].to(torch.float32))
|
| 561 |
+
transformer.set_attn_processor(attn_procs)
|
| 562 |
+
class _Wrapper(AttnProcsLayers):
|
| 563 |
+
def forward(self, *args, **kwargs):
|
| 564 |
+
return pipe.transformer(*args, **kwargs)
|
| 565 |
+
transformer = _Wrapper(pipe.transformer.attn_processors)
|
| 566 |
+
|
| 567 |
+
return pipe
|
| 568 |
+
def initialize_condition_extractors(config):
|
| 569 |
+
"""
|
| 570 |
+
Initialize condition extractors based on configuration.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
config: Configuration dictionary containing condition types and checkpoint paths
|
| 574 |
+
|
| 575 |
+
Returns:
|
| 576 |
+
tuple: (condition_extractors, transformer_ckpt, extractor_ckpt)
|
| 577 |
+
"""
|
| 578 |
+
condition_extractors = {}
|
| 579 |
+
extractor_ckpt = {}
|
| 580 |
+
from utils.feature_extractor import dynamics_extractor, rhythm_extractor, melody_extractor_mono, melody_extractor_stereo, melody_extractor_full_mono, melody_extractor_full_stereo, dynamics_extractor_full_stereo
|
| 581 |
+
if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
|
| 582 |
+
if "melody_stereo" in config['condition_type']:
|
| 583 |
+
transformer_ckpt = config['transformer_ckpt_melody_stero']
|
| 584 |
+
extractor_ckpt = config['extractor_ckpt_melody_stero']
|
| 585 |
+
print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
|
| 586 |
+
melody_conditoner = melody_extractor_full_stereo().cuda().float()
|
| 587 |
+
condition_extractors["melody"] = melody_conditoner
|
| 588 |
+
elif "melody_mono" in config['condition_type']:
|
| 589 |
+
transformer_ckpt = config['transformer_ckpt_melody_mono']
|
| 590 |
+
extractor_ckpt = config['extractor_ckpt_melody_mono']
|
| 591 |
+
print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
|
| 592 |
+
melody_conditoner = melody_extractor_full_mono().cuda().float()
|
| 593 |
+
condition_extractors["melody"] = melody_conditoner
|
| 594 |
+
elif "audio" in config['condition_type']:
|
| 595 |
+
transformer_ckpt = config['audio_transformer_ckpt']
|
| 596 |
+
print(f"using model: {transformer_ckpt}")
|
| 597 |
+
else:
|
| 598 |
+
dynamics_conditoner = dynamics_extractor().cuda().float()
|
| 599 |
+
condition_extractors["dynamics"] = dynamics_conditoner
|
| 600 |
+
rhythm_conditoner = rhythm_extractor().cuda().float()
|
| 601 |
+
condition_extractors["rhythm"] = rhythm_conditoner
|
| 602 |
+
melody_conditoner = melody_extractor_mono().cuda().float()
|
| 603 |
+
condition_extractors["melody"] = melody_conditoner
|
| 604 |
+
transformer_ckpt = config['transformer_ckpt_musical']
|
| 605 |
+
extractor_ckpt = config['extractor_ckpt_musical']
|
| 606 |
+
print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
|
| 607 |
+
|
| 608 |
+
for conditioner_type, ckpt_path in extractor_ckpt.items():
|
| 609 |
+
state_dict = load_file(ckpt_path, device="cpu")
|
| 610 |
+
condition_extractors[conditioner_type].load_state_dict(state_dict)
|
| 611 |
+
condition_extractors[conditioner_type].eval()
|
| 612 |
+
|
| 613 |
+
return condition_extractors, transformer_ckpt
|
| 614 |
+
def evaluate_and_plot_results(audio_file, gen_file_path, output_dir, i):
|
| 615 |
+
"""
|
| 616 |
+
Evaluate and plot results comparing original and generated audio.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
audio_file (str): Path to the original audio file
|
| 620 |
+
gen_file_path (str): Path to the generated audio file
|
| 621 |
+
output_dir (str): Directory to save the plot
|
| 622 |
+
i (int): Index for naming the output file
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
tuple: (dynamics_score, rhythm_score, melody_score)
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
dynamics_condition = compute_dynamics(audio_file)
|
| 629 |
+
gen_dynamics = compute_dynamics(gen_file_path)
|
| 630 |
+
min_len_dynamics = min(gen_dynamics.shape[0], dynamics_condition.shape[0])
|
| 631 |
+
pearson_corr = np.corrcoef(gen_dynamics[:min_len_dynamics], dynamics_condition[:min_len_dynamics])[0, 1]
|
| 632 |
+
print("pearson_corr", pearson_corr)
|
| 633 |
+
|
| 634 |
+
melody_condition = extract_melody_one_hot(audio_file)
|
| 635 |
+
gen_melody = extract_melody_one_hot(gen_file_path)
|
| 636 |
+
min_len_melody = min(gen_melody.shape[1], melody_condition.shape[1])
|
| 637 |
+
matches = ((gen_melody[:, :min_len_melody] == melody_condition[:, :min_len_melody]) & (gen_melody[:, :min_len_melody] == 1)).sum()
|
| 638 |
+
accuracy = matches / min_len_melody
|
| 639 |
+
print("melody accuracy", accuracy)
|
| 640 |
+
|
| 641 |
+
# Adjust layout to avoid overlap
|
| 642 |
+
processor = RNNDownBeatProcessor()
|
| 643 |
+
original_path = os.path.join(output_dir, f"original_{i}.wav")
|
| 644 |
+
input_probabilities = processor(original_path)
|
| 645 |
+
generated_probabilities = processor(gen_file_path)
|
| 646 |
+
hmm_processor = DBNDownBeatTrackingProcessor(beats_per_bar=[3,4], fps=100)
|
| 647 |
+
input_timestamps = hmm_processor(input_probabilities)
|
| 648 |
+
generated_timestamps = hmm_processor(generated_probabilities)
|
| 649 |
+
precision, recall, f1 = evaluate_f1_rhythm(input_timestamps, generated_timestamps)
|
| 650 |
+
# Output results
|
| 651 |
+
print(f"F1 Score: {f1:.2f}")
|
| 652 |
+
|
| 653 |
+
# Plotting
|
| 654 |
+
frame_rate = 100 # Frames per second
|
| 655 |
+
input_time_axis = np.linspace(0, len(input_probabilities) / frame_rate, len(input_probabilities))
|
| 656 |
+
generate_time_axis = np.linspace(0, len(generated_probabilities) / frame_rate, len(generated_probabilities))
|
| 657 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Adjust figsize as needed
|
| 658 |
+
|
| 659 |
+
# ----------------------------
|
| 660 |
+
# Subplot (0,0): Dynamics Plot
|
| 661 |
+
ax = axes[0, 0]
|
| 662 |
+
ax.plot(dynamics_condition[:min_len_dynamics].squeeze(), linewidth=1, label='Dynamics condition')
|
| 663 |
+
ax.set_title('Dynamics')
|
| 664 |
+
ax.set_xlabel('Time Frame')
|
| 665 |
+
ax.set_ylabel('Dynamics (dB)')
|
| 666 |
+
ax.legend(fontsize=8)
|
| 667 |
+
ax.grid(True)
|
| 668 |
+
# ----------------------------
|
| 669 |
+
# Subplot (0,0): Dynamics Plot
|
| 670 |
+
ax = axes[1, 0]
|
| 671 |
+
ax.plot(gen_dynamics[:min_len_dynamics].squeeze(), linewidth=1, label='Generated Dynamics')
|
| 672 |
+
ax.set_title('Dynamics')
|
| 673 |
+
ax.set_xlabel('Time Frame')
|
| 674 |
+
ax.set_ylabel('Dynamics (dB)')
|
| 675 |
+
ax.legend(fontsize=8)
|
| 676 |
+
ax.grid(True)
|
| 677 |
+
|
| 678 |
+
# ----------------------------
|
| 679 |
+
# Subplot (0,2): Melody Condition (Chromagram)
|
| 680 |
+
ax = axes[0, 1]
|
| 681 |
+
im2 = ax.imshow(melody_condition[:, :min_len_melody], aspect='auto', origin='lower',
|
| 682 |
+
interpolation='nearest', cmap='plasma')
|
| 683 |
+
ax.set_title('Melody Condition')
|
| 684 |
+
ax.set_xlabel('Time')
|
| 685 |
+
ax.set_ylabel('Chroma Features')
|
| 686 |
+
|
| 687 |
+
# ----------------------------
|
| 688 |
+
# Subplot (0,1): Generated Melody (Chromagram)
|
| 689 |
+
ax = axes[1, 1]
|
| 690 |
+
im1 = ax.imshow(gen_melody[:, :min_len_melody], aspect='auto', origin='lower',
|
| 691 |
+
interpolation='nearest', cmap='viridis')
|
| 692 |
+
ax.set_title('Generated Melody')
|
| 693 |
+
ax.set_xlabel('Time')
|
| 694 |
+
ax.set_ylabel('Chroma Features')
|
| 695 |
+
|
| 696 |
+
# ----------------------------
|
| 697 |
+
# Subplot (1,0): Rhythm Input Probabilities
|
| 698 |
+
ax = axes[0, 2]
|
| 699 |
+
ax.plot(input_time_axis, input_probabilities,
|
| 700 |
+
label="Input Beat Probability")
|
| 701 |
+
ax.plot(input_time_axis, input_probabilities,
|
| 702 |
+
label="Input Downbeat Probability", alpha=0.8)
|
| 703 |
+
ax.set_title('Rhythm: Input')
|
| 704 |
+
ax.set_xlabel('Time (s)')
|
| 705 |
+
ax.set_ylabel('Probability')
|
| 706 |
+
ax.legend()
|
| 707 |
+
ax.grid(True)
|
| 708 |
+
|
| 709 |
+
# ----------------------------
|
| 710 |
+
# Subplot (1,1): Rhythm Generated Probabilities
|
| 711 |
+
ax = axes[1, 2]
|
| 712 |
+
ax.plot(generate_time_axis, generated_probabilities,
|
| 713 |
+
color='orange', label="Generated Beat Probability")
|
| 714 |
+
ax.plot(generate_time_axis, generated_probabilities,
|
| 715 |
+
alpha=0.8, color='red', label="Generated Downbeat Probability")
|
| 716 |
+
ax.set_title('Rhythm: Generated')
|
| 717 |
+
ax.set_xlabel('Time (s)')
|
| 718 |
+
ax.set_ylabel('Probability')
|
| 719 |
+
ax.legend()
|
| 720 |
+
ax.grid(True)
|
| 721 |
+
|
| 722 |
+
# Adjust layout and save the combined image
|
| 723 |
+
plt.tight_layout()
|
| 724 |
+
combined_path = os.path.join(output_dir, f"combined_{i}.png")
|
| 725 |
+
plt.savefig(combined_path)
|
| 726 |
+
plt.close()
|
| 727 |
+
|
| 728 |
+
print(f"Combined plot saved to {combined_path}")
|
| 729 |
+
|
| 730 |
+
return pearson_corr, f1, accuracy
|
| 731 |
+
|
| 732 |
+
def process_musical_conditions(config, audio_file, condition_extractors, output_dir, i, weight_dtype, MuseControlLite):
|
| 733 |
+
"""
|
| 734 |
+
Process and extract musical conditions (dynamics, rhythm, melody) from audio file.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
config: Configuration dictionary
|
| 738 |
+
audio_file: Path to the audio file
|
| 739 |
+
condition_extractors: Dictionary of condition extractors
|
| 740 |
+
output_dir: Output directory path
|
| 741 |
+
i: Index for file naming
|
| 742 |
+
weight_dtype: Weight data type for torch tensors
|
| 743 |
+
MuseControlLite: The MuseControlLite model instance
|
| 744 |
+
audio_mask_start: Start index for audio mask
|
| 745 |
+
audio_mask_end: End index for audio mask
|
| 746 |
+
musical_attribute_mask_start: Start index for musical attribute mask
|
| 747 |
+
musical_attribute_mask_end: End index for musical attribute mask
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
tuple: (final_condition, extracted_condition, final_condition_audio)
|
| 751 |
+
"""
|
| 752 |
+
total_seconds = 2097152/44100
|
| 753 |
+
use_audio_mask = False
|
| 754 |
+
use_musical_attribute_mask = False
|
| 755 |
+
if (config["audio_mask_start_seconds"] and config["audio_mask_end_seconds"]) != 0 and "audio" in config["condition_type"]:
|
| 756 |
+
use_audio_mask = True
|
| 757 |
+
audio_mask_start = int(config["audio_mask_start_seconds"] / total_seconds * 1024) # 1024 is the latent length for 2097152/44100 seconds
|
| 758 |
+
audio_mask_end = int(config["audio_mask_end_seconds"] / total_seconds * 1024)
|
| 759 |
+
print(
|
| 760 |
+
f"using mask for 'audio' from "
|
| 761 |
+
f"{config['audio_mask_start_seconds']}~{config['audio_mask_end_seconds']}"
|
| 762 |
+
)
|
| 763 |
+
if (config["musical_attribute_mask_start_seconds"] and config["musical_attribute_mask_end_seconds"]) != 0:
|
| 764 |
+
use_musical_attribute_mask = True
|
| 765 |
+
musical_attribute_mask_start = int(config["musical_attribute_mask_start_seconds"] / total_seconds * 1024)
|
| 766 |
+
musical_attribute_mask_end = int(config["musical_attribute_mask_end_seconds"] / total_seconds * 1024)
|
| 767 |
+
masked_types = [t for t in config['condition_type'] if t != 'audio']
|
| 768 |
+
print(
|
| 769 |
+
f"using mask for {', '.join(masked_types)} "
|
| 770 |
+
f"from {config['musical_attribute_mask_start_seconds']}~"
|
| 771 |
+
f"{config['musical_attribute_mask_end_seconds']}"
|
| 772 |
+
)
|
| 773 |
+
if "dynamics" in config["condition_type"]:
|
| 774 |
+
dynamics_condition = compute_dynamics(audio_file)
|
| 775 |
+
dynamics_condition = torch.from_numpy(dynamics_condition).cuda()
|
| 776 |
+
dynamics_condition = dynamics_condition.unsqueeze(0).unsqueeze(0)
|
| 777 |
+
print("dynamics_condition", dynamics_condition.shape)
|
| 778 |
+
extracted_dynamics_condition = condition_extractors["dynamics"](dynamics_condition.to(torch.float32))
|
| 779 |
+
masked_extracted_dynamics_condition = torch.zeros_like(extracted_dynamics_condition)
|
| 780 |
+
extracted_dynamics_condition = F.interpolate(extracted_dynamics_condition, size=1024, mode='linear', align_corners=False)
|
| 781 |
+
masked_extracted_dynamics_condition = F.interpolate(masked_extracted_dynamics_condition, size=1024, mode='linear', align_corners=False)
|
| 782 |
+
else:
|
| 783 |
+
extracted_dynamics_condition = torch.zeros((1, 192, 1024), device="cuda")
|
| 784 |
+
masked_extracted_dynamics_condition = extracted_dynamics_condition
|
| 785 |
+
if "rhythm" in config["condition_type"]:
|
| 786 |
+
rnn_processor = RNNDownBeatProcessor()
|
| 787 |
+
wave = load_audio_file(audio_file)
|
| 788 |
+
if wave is not None:
|
| 789 |
+
original_path = os.path.join(output_dir, f"original_{i}.wav")
|
| 790 |
+
sf.write(original_path, wave.T.float().cpu().numpy(), 44100)
|
| 791 |
+
rhythm_curve = rnn_processor(original_path)
|
| 792 |
+
rhythm_condition = torch.from_numpy(rhythm_curve).cuda()
|
| 793 |
+
rhythm_condition = rhythm_condition.transpose(0,1).unsqueeze(0)
|
| 794 |
+
print("rhythm_condition", rhythm_condition.shape)
|
| 795 |
+
extracted_rhythm_condition = condition_extractors["rhythm"](rhythm_condition.to(torch.float32))
|
| 796 |
+
masked_extracted_rhythm_condition = torch.zeros_like(extracted_rhythm_condition)
|
| 797 |
+
extracted_rhythm_condition = F.interpolate(extracted_rhythm_condition, size=1024, mode='linear', align_corners=False)
|
| 798 |
+
masked_extracted_rhythm_condition = F.interpolate(masked_extracted_rhythm_condition, size=1024, mode='linear', align_corners=False)
|
| 799 |
+
else:
|
| 800 |
+
extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda")
|
| 801 |
+
masked_extracted_rhythm_condition = extracted_rhythm_condition
|
| 802 |
+
else:
|
| 803 |
+
extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda")
|
| 804 |
+
masked_extracted_rhythm_condition = extracted_rhythm_condition
|
| 805 |
+
|
| 806 |
+
if "melody_mono" in config["condition_type"]:
|
| 807 |
+
melody_condition = compute_melody(audio_file)
|
| 808 |
+
melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0)
|
| 809 |
+
print("melody_condition", melody_condition.shape)
|
| 810 |
+
extracted_melody_condition = condition_extractors["melody"](melody_condition.to(torch.float32))
|
| 811 |
+
masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition)
|
| 812 |
+
extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False)
|
| 813 |
+
masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False)
|
| 814 |
+
elif "melody_stereo" in config["condition_type"]:
|
| 815 |
+
melody_condition = compute_melody_v2(audio_file)
|
| 816 |
+
melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0)
|
| 817 |
+
print("melody_condition", melody_condition.shape)
|
| 818 |
+
extracted_melody_condition = condition_extractors["melody"](melody_condition)
|
| 819 |
+
masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition)
|
| 820 |
+
extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False)
|
| 821 |
+
masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False)
|
| 822 |
+
else:
|
| 823 |
+
if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
|
| 824 |
+
extracted_melody_condition = torch.zeros((1, 768, 1024), device="cuda")
|
| 825 |
+
else:
|
| 826 |
+
extracted_melody_condition = torch.zeros((1, 192, 1024), device="cuda")
|
| 827 |
+
masked_extracted_melody_condition = extracted_melody_condition
|
| 828 |
+
|
| 829 |
+
# Use multiple cfg
|
| 830 |
+
if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
|
| 831 |
+
extracted_condition = extracted_melody_condition
|
| 832 |
+
final_condition = torch.concat((masked_extracted_melody_condition, masked_extracted_melody_condition, extracted_melody_condition), dim=0)
|
| 833 |
+
else:
|
| 834 |
+
extracted_blank_condition = torch.zeros((1, 192, 1024), device="cuda")
|
| 835 |
+
extracted_condition = torch.concat((extracted_rhythm_condition, extracted_dynamics_condition, extracted_melody_condition, extracted_blank_condition), dim=1)
|
| 836 |
+
masked_extracted_condition = torch.concat((masked_extracted_rhythm_condition, masked_extracted_dynamics_condition, masked_extracted_melody_condition, extracted_blank_condition), dim=1)
|
| 837 |
+
final_condition = torch.concat((masked_extracted_condition, masked_extracted_condition, extracted_condition), dim=0)
|
| 838 |
+
if "audio" in config["condition_type"]:
|
| 839 |
+
desired_repeats = 768 // 64 # Number of repeats needed
|
| 840 |
+
audio = load_audio_file(audio_file)
|
| 841 |
+
if audio is not None:
|
| 842 |
+
audio_condition = MuseControlLite.vae.encode(audio.unsqueeze(0).to(weight_dtype).cuda()).latent_dist.sample()
|
| 843 |
+
extracted_audio_condition = audio_condition.repeat_interleave(desired_repeats, dim=1).float()
|
| 844 |
+
pad_len = 1024 - extracted_audio_condition.shape[-1]
|
| 845 |
+
if pad_len > 0:
|
| 846 |
+
# Pad on the right side (last dimension)
|
| 847 |
+
extracted_audio_condition = F.pad(extracted_audio_condition, (0, pad_len))
|
| 848 |
+
masked_extracted_audio_condition = torch.zeros_like(extracted_audio_condition)
|
| 849 |
+
if len(config["condition_type"]) == 1:
|
| 850 |
+
final_condition = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0)
|
| 851 |
+
else:
|
| 852 |
+
final_condition_audio = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0)
|
| 853 |
+
final_condition = torch.concat((final_condition, extracted_condition), dim=0)
|
| 854 |
+
final_condition_audio = final_condition_audio.transpose(1, 2)
|
| 855 |
+
else:
|
| 856 |
+
final_condition_audio = None
|
| 857 |
+
final_condition = final_condition.transpose(1, 2)
|
| 858 |
+
if "audio" in config["condition_type"] and len(config["condition_type"])==1:
|
| 859 |
+
final_condition[:,audio_mask_start:audio_mask_end,:] = 0
|
| 860 |
+
if use_audio_mask:
|
| 861 |
+
config["guidance_scale_con"] = config["guidance_scale_audio"]
|
| 862 |
+
elif "audio" in config["condition_type"] and len(config["condition_type"])!=1 and use_audio_mask:
|
| 863 |
+
final_condition[:,:audio_mask_start,:] = 0
|
| 864 |
+
final_condition[:,audio_mask_end:,:] = 0
|
| 865 |
+
if 'final_condition_audio' in locals() and final_condition_audio is not None:
|
| 866 |
+
final_condition_audio[:,audio_mask_start:audio_mask_end,:] = 0
|
| 867 |
+
elif use_musical_attribute_mask:
|
| 868 |
+
final_condition[:,musical_attribute_mask_start:musical_attribute_mask_end,:] = 0
|
| 869 |
+
if 'final_condition_audio' in locals() and final_condition_audio is not None:
|
| 870 |
+
final_condition_audio[:,:musical_attribute_mask_start,:] = 0
|
| 871 |
+
final_condition_audio[:,musical_attribute_mask_end:,:] = 0
|
| 872 |
+
|
| 873 |
+
return final_condition, final_condition_audio if 'final_condition_audio' in locals() else None
|
| 874 |
+
|
README.md
CHANGED
|
@@ -11,4 +11,38 @@ license: mit
|
|
| 11 |
short_description: Inference for Stable-Audio-Open with more controls
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
short_description: Inference for Stable-Audio-Open with more controls
|
| 12 |
---
|
| 13 |
|
| 14 |
+
## MuseControlLite (Space)
|
| 15 |
+
|
| 16 |
+
Gradio UI for MuseControlLite adapters on top of `stabilityai/stable-audio-open-1.0`.
|
| 17 |
+
|
| 18 |
+
### Requirements
|
| 19 |
+
- **GPU Space** is required for generation (fp16 by default).
|
| 20 |
+
- A Hugging Face token with access to `stabilityai/stable-audio-open-1.0` (set as a Space secret, e.g., `HF_TOKEN`).
|
| 21 |
+
|
| 22 |
+
### What happens on startup
|
| 23 |
+
1) Installs Python deps from `requirements.txt` (includes `gradio`, `gdown`, `diffusers` fork).
|
| 24 |
+
2) Downloads MuseControlLite checkpoints with
|
| 25 |
+
`gdown 1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP --folder`
|
| 26 |
+
into `checkpoints/` if they are missing.
|
| 27 |
+
|
| 28 |
+
### Using the Space
|
| 29 |
+
1) Provide a text prompt.
|
| 30 |
+
2) Upload a 47.5s (or longer) audio file when using MuseControlLite conditions.
|
| 31 |
+
3) Select condition types (`melody_stereo`, `melody_mono`, `dynamics`, `rhythm`, `audio`) and adjust guidance/scales if needed.
|
| 32 |
+
4) Click **Generate**. Output is a single 47.5s WAV plus a short status summary.
|
| 33 |
+
|
| 34 |
+
### Tips
|
| 35 |
+
- `melody_stereo` cannot be combined with `dynamics`, `rhythm`, or `melody_mono`.
|
| 36 |
+
- For audio in/out-painting, use the audio condition with the masking sliders.
|
| 37 |
+
- Default examples are preloaded in the UI for quick tests.
|
| 38 |
+
|
| 39 |
+
### Local run (optional)
|
| 40 |
+
```bash
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
gdown 1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP --folder
|
| 43 |
+
huggingface-cli login
|
| 44 |
+
python app.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Acknowledgments
|
| 48 |
+
- Original repository: https://github.com/fundwotsai2001/MuseControlLite
|
app.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite
|
| 12 |
+
from config_inference import get_config
|
| 13 |
+
|
| 14 |
+
# Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100)
|
| 15 |
+
TOTAL_AUDIO_SECONDS = 2097152 / 44100
|
| 16 |
+
DEFAULT_CONFIG = get_config()
|
| 17 |
+
DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else ""
|
| 18 |
+
OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs")
|
| 19 |
+
CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"]
|
| 20 |
+
CHECKPOINT_EXPECTED = [
|
| 21 |
+
"./checkpoints/woSDD-all/model_3.safetensors",
|
| 22 |
+
"./checkpoints/woSDD-all/model_1.safetensors",
|
| 23 |
+
"./checkpoints/woSDD-all/model_2.safetensors",
|
| 24 |
+
"./checkpoints/woSDD-all/model.safetensors",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
os.makedirs(OUTPUT_ROOT, exist_ok=True)
|
| 28 |
+
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(DEFAULT_CONFIG.get("GPU_id", "0")))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def ensure_checkpoints() -> None:
|
| 32 |
+
"""Download checkpoints with gdown if they are missing."""
|
| 33 |
+
if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED):
|
| 34 |
+
return
|
| 35 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 36 |
+
try:
|
| 37 |
+
subprocess.run(
|
| 38 |
+
["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"],
|
| 39 |
+
check=True,
|
| 40 |
+
)
|
| 41 |
+
except Exception as exc: # pylint: disable=broad-except
|
| 42 |
+
# Do not crash the space on startup; inference will surface an error later if checkpoints are missing.
|
| 43 |
+
print(f"[warn] Checkpoint download failed: {exc}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
ensure_checkpoints()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ModelCache:
|
| 50 |
+
"""Lazy loader for heavy pipelines and condition extractors."""
|
| 51 |
+
|
| 52 |
+
def __init__(self) -> None:
|
| 53 |
+
self.cache: Dict[Tuple, Dict] = {}
|
| 54 |
+
|
| 55 |
+
def get(self, config: Dict) -> Dict:
|
| 56 |
+
key = (
|
| 57 |
+
tuple(sorted(config["condition_type"])),
|
| 58 |
+
config["weight_dtype"],
|
| 59 |
+
float(config["ap_scale"]),
|
| 60 |
+
config["apadapter"],
|
| 61 |
+
)
|
| 62 |
+
if key in self.cache:
|
| 63 |
+
return self.cache[key]
|
| 64 |
+
|
| 65 |
+
weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32
|
| 66 |
+
if config["apadapter"]:
|
| 67 |
+
condition_extractors, transformer_ckpt = initialize_condition_extractors(config)
|
| 68 |
+
pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda")
|
| 69 |
+
payload = {
|
| 70 |
+
"pipe": pipe,
|
| 71 |
+
"condition_extractors": condition_extractors,
|
| 72 |
+
"weight_dtype": weight_dtype,
|
| 73 |
+
"mode": "musecontrol",
|
| 74 |
+
}
|
| 75 |
+
else:
|
| 76 |
+
from diffusers import StableAudioPipeline
|
| 77 |
+
|
| 78 |
+
pipe = StableAudioPipeline.from_pretrained(
|
| 79 |
+
"stabilityai/stable-audio-open-1.0",
|
| 80 |
+
torch_dtype=weight_dtype,
|
| 81 |
+
).to("cuda")
|
| 82 |
+
payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"}
|
| 83 |
+
self.cache[key] = payload
|
| 84 |
+
return payload
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
model_cache = ModelCache()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _build_base_config() -> Dict:
|
| 91 |
+
return copy.deepcopy(DEFAULT_CONFIG)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _create_run_dir() -> str:
|
| 95 |
+
run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}")
|
| 96 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 97 |
+
return run_dir
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]:
|
| 101 |
+
if seed is None or seed == "":
|
| 102 |
+
return None
|
| 103 |
+
try:
|
| 104 |
+
seed_int = int(seed)
|
| 105 |
+
except (TypeError, ValueError):
|
| 106 |
+
return None
|
| 107 |
+
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
return generator.manual_seed(seed_int)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]:
|
| 112 |
+
condition_type = condition_type or []
|
| 113 |
+
if "melody_stereo" in condition_type and any(
|
| 114 |
+
choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono")
|
| 115 |
+
):
|
| 116 |
+
raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.")
|
| 117 |
+
return condition_type
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def run_inference(
|
| 121 |
+
prompt_text: str,
|
| 122 |
+
condition_audio: Optional[str],
|
| 123 |
+
condition_type: Optional[List[str]],
|
| 124 |
+
use_musecontrol: bool,
|
| 125 |
+
no_text: bool,
|
| 126 |
+
negative_text_prompt: str,
|
| 127 |
+
guidance_scale_text: float,
|
| 128 |
+
guidance_scale_con: float,
|
| 129 |
+
guidance_scale_audio: float,
|
| 130 |
+
denoise_step: int,
|
| 131 |
+
weight_dtype: str,
|
| 132 |
+
ap_scale: float,
|
| 133 |
+
sigma_min: float,
|
| 134 |
+
sigma_max: float,
|
| 135 |
+
audio_mask_start: float,
|
| 136 |
+
audio_mask_end: float,
|
| 137 |
+
musical_mask_start: float,
|
| 138 |
+
musical_mask_end: float,
|
| 139 |
+
seed: Optional[float],
|
| 140 |
+
):
|
| 141 |
+
if not torch.cuda.is_available():
|
| 142 |
+
raise gr.Error("This Space has no GPU attached. Please run locally with a GPU or duplicate to a GPU Space.")
|
| 143 |
+
|
| 144 |
+
condition_type = _validate_condition_choices(condition_type)
|
| 145 |
+
config = _build_base_config()
|
| 146 |
+
config.update(
|
| 147 |
+
{
|
| 148 |
+
"text": [prompt_text or ""],
|
| 149 |
+
"audio_files": [condition_audio or ""],
|
| 150 |
+
"apadapter": use_musecontrol,
|
| 151 |
+
"no_text": bool(no_text),
|
| 152 |
+
"negative_text_prompt": negative_text_prompt or "",
|
| 153 |
+
"guidance_scale_text": float(guidance_scale_text),
|
| 154 |
+
"guidance_scale_con": float(guidance_scale_con),
|
| 155 |
+
"guidance_scale_audio": float(guidance_scale_audio),
|
| 156 |
+
"denoise_step": int(denoise_step),
|
| 157 |
+
"weight_dtype": weight_dtype,
|
| 158 |
+
"ap_scale": float(ap_scale),
|
| 159 |
+
"sigma_min": float(sigma_min),
|
| 160 |
+
"sigma_max": float(sigma_max),
|
| 161 |
+
"audio_mask_start_seconds": float(audio_mask_start or 0),
|
| 162 |
+
"audio_mask_end_seconds": float(audio_mask_end or 0),
|
| 163 |
+
"musical_attribute_mask_start_seconds": float(musical_mask_start or 0),
|
| 164 |
+
"musical_attribute_mask_end_seconds": float(musical_mask_end or 0),
|
| 165 |
+
"show_result_and_plt": False,
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
config["condition_type"] = condition_type
|
| 169 |
+
if config["apadapter"]:
|
| 170 |
+
if not condition_type:
|
| 171 |
+
raise gr.Error("Select at least one condition type when using MuseControlLite.")
|
| 172 |
+
if not condition_audio:
|
| 173 |
+
raise gr.Error("Upload an audio file for conditioning.")
|
| 174 |
+
if not os.path.exists(condition_audio):
|
| 175 |
+
raise gr.Error("Condition audio file not found.")
|
| 176 |
+
|
| 177 |
+
run_dir = _create_run_dir()
|
| 178 |
+
config["output_dir"] = run_dir
|
| 179 |
+
generator = _seed_to_generator(seed)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
models = model_cache.get(config)
|
| 183 |
+
pipe = models["pipe"]
|
| 184 |
+
pipe.scheduler.config.sigma_min = config["sigma_min"]
|
| 185 |
+
pipe.scheduler.config.sigma_max = config["sigma_max"]
|
| 186 |
+
prompt_for_model = "" if config["no_text"] else (prompt_text or "")
|
| 187 |
+
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
if config["apadapter"]:
|
| 190 |
+
final_condition, final_condition_audio = process_musical_conditions(
|
| 191 |
+
config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe
|
| 192 |
+
)
|
| 193 |
+
waveform = pipe(
|
| 194 |
+
extracted_condition=final_condition,
|
| 195 |
+
extracted_condition_audio=final_condition_audio,
|
| 196 |
+
prompt=prompt_for_model,
|
| 197 |
+
negative_prompt=config["negative_text_prompt"],
|
| 198 |
+
num_inference_steps=config["denoise_step"],
|
| 199 |
+
guidance_scale_text=config["guidance_scale_text"],
|
| 200 |
+
guidance_scale_con=config["guidance_scale_con"],
|
| 201 |
+
guidance_scale_audio=config["guidance_scale_audio"],
|
| 202 |
+
num_waveforms_per_prompt=1,
|
| 203 |
+
audio_end_in_s=TOTAL_AUDIO_SECONDS,
|
| 204 |
+
generator=generator,
|
| 205 |
+
).audios
|
| 206 |
+
output = waveform[0].T.float().cpu().numpy()
|
| 207 |
+
sr = pipe.vae.sampling_rate
|
| 208 |
+
else:
|
| 209 |
+
audio = pipe(
|
| 210 |
+
prompt=prompt_for_model,
|
| 211 |
+
negative_prompt=config["negative_text_prompt"],
|
| 212 |
+
num_inference_steps=config["denoise_step"],
|
| 213 |
+
guidance_scale=config["guidance_scale_text"],
|
| 214 |
+
num_waveforms_per_prompt=1,
|
| 215 |
+
audio_end_in_s=TOTAL_AUDIO_SECONDS,
|
| 216 |
+
generator=generator,
|
| 217 |
+
).audios
|
| 218 |
+
output = audio[0].T.float().cpu().numpy()
|
| 219 |
+
sr = pipe.vae.sampling_rate
|
| 220 |
+
|
| 221 |
+
generated_path = os.path.join(run_dir, "generated.wav")
|
| 222 |
+
sf.write(generated_path, output, sr)
|
| 223 |
+
|
| 224 |
+
status_lines = [
|
| 225 |
+
f"Run directory: `{run_dir}`",
|
| 226 |
+
f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}",
|
| 227 |
+
f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}",
|
| 228 |
+
f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]",
|
| 229 |
+
]
|
| 230 |
+
if config["apadapter"]:
|
| 231 |
+
status_lines.append(
|
| 232 |
+
f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}"
|
| 233 |
+
)
|
| 234 |
+
if generator is not None:
|
| 235 |
+
status_lines.append(f"Seed: {int(seed)}")
|
| 236 |
+
|
| 237 |
+
status_md = "\n".join(f"- {line}" for line in status_lines)
|
| 238 |
+
return generated_path, status_md
|
| 239 |
+
except gr.Error:
|
| 240 |
+
raise
|
| 241 |
+
except Exception as err: # pylint: disable=broad-except
|
| 242 |
+
raise gr.Error(f"Generation failed: {err}") from err
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
EXAMPLES = [
|
| 246 |
+
[
|
| 247 |
+
"Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
|
| 248 |
+
"melody_condition_audio/49_piano.mp3",
|
| 249 |
+
["melody_stereo"],
|
| 250 |
+
True,
|
| 251 |
+
False,
|
| 252 |
+
"",
|
| 253 |
+
7.0,
|
| 254 |
+
1.5,
|
| 255 |
+
1.0,
|
| 256 |
+
50,
|
| 257 |
+
"fp16",
|
| 258 |
+
1.0,
|
| 259 |
+
0.3,
|
| 260 |
+
500,
|
| 261 |
+
0,
|
| 262 |
+
0,
|
| 263 |
+
0,
|
| 264 |
+
0,
|
| 265 |
+
42,
|
| 266 |
+
],
|
| 267 |
+
[
|
| 268 |
+
"fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
|
| 269 |
+
"melody_condition_audio/610_bass.mp3",
|
| 270 |
+
["melody_mono", "dynamics", "rhythm"],
|
| 271 |
+
True,
|
| 272 |
+
False,
|
| 273 |
+
"",
|
| 274 |
+
7.0,
|
| 275 |
+
1.5,
|
| 276 |
+
1.0,
|
| 277 |
+
50,
|
| 278 |
+
"fp16",
|
| 279 |
+
1.0,
|
| 280 |
+
0.3,
|
| 281 |
+
500,
|
| 282 |
+
0,
|
| 283 |
+
0,
|
| 284 |
+
0,
|
| 285 |
+
0,
|
| 286 |
+
7,
|
| 287 |
+
],
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def build_interface() -> gr.Blocks:
|
| 292 |
+
with gr.Blocks(title="MuseControlLite") as demo:
|
| 293 |
+
gr.Markdown(
|
| 294 |
+
"""
|
| 295 |
+
## MuseControlLite demo
|
| 296 |
+
UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation.
|
| 297 |
+
"""
|
| 298 |
+
)
|
| 299 |
+
with gr.Row():
|
| 300 |
+
prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT)
|
| 301 |
+
use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True)
|
| 302 |
+
no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False)
|
| 303 |
+
|
| 304 |
+
condition_audio = gr.Audio(
|
| 305 |
+
label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"]
|
| 306 |
+
)
|
| 307 |
+
condition_type = gr.CheckboxGroup(
|
| 308 |
+
CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", [])
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
with gr.Accordion("Advanced controls", open=False):
|
| 312 |
+
negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", ""))
|
| 313 |
+
with gr.Row():
|
| 314 |
+
guidance_scale_text = gr.Slider(
|
| 315 |
+
minimum=0.0,
|
| 316 |
+
maximum=12.0,
|
| 317 |
+
value=DEFAULT_CONFIG["guidance_scale_text"],
|
| 318 |
+
step=0.1,
|
| 319 |
+
label="Guidance scale (text)",
|
| 320 |
+
)
|
| 321 |
+
guidance_scale_con = gr.Slider(
|
| 322 |
+
minimum=0.0,
|
| 323 |
+
maximum=5.0,
|
| 324 |
+
value=DEFAULT_CONFIG["guidance_scale_con"],
|
| 325 |
+
step=0.1,
|
| 326 |
+
label="Guidance scale (conditions)",
|
| 327 |
+
)
|
| 328 |
+
guidance_scale_audio = gr.Slider(
|
| 329 |
+
minimum=0.0,
|
| 330 |
+
maximum=5.0,
|
| 331 |
+
value=DEFAULT_CONFIG["guidance_scale_audio"],
|
| 332 |
+
step=0.1,
|
| 333 |
+
label="Guidance scale (audio)",
|
| 334 |
+
)
|
| 335 |
+
with gr.Row():
|
| 336 |
+
denoise_step = gr.Slider(
|
| 337 |
+
minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps"
|
| 338 |
+
)
|
| 339 |
+
weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype")
|
| 340 |
+
ap_scale = gr.Slider(
|
| 341 |
+
minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale"
|
| 342 |
+
)
|
| 343 |
+
with gr.Row():
|
| 344 |
+
sigma_min = gr.Slider(
|
| 345 |
+
minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min"
|
| 346 |
+
)
|
| 347 |
+
sigma_max = gr.Slider(
|
| 348 |
+
minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max"
|
| 349 |
+
)
|
| 350 |
+
seed = gr.Number(label="Seed (optional)", precision=0)
|
| 351 |
+
with gr.Row():
|
| 352 |
+
audio_mask_start = gr.Number(
|
| 353 |
+
label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"]
|
| 354 |
+
)
|
| 355 |
+
audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"])
|
| 356 |
+
with gr.Row():
|
| 357 |
+
musical_mask_start = gr.Number(
|
| 358 |
+
label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"]
|
| 359 |
+
)
|
| 360 |
+
musical_mask_end = gr.Number(
|
| 361 |
+
label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"]
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 365 |
+
generated_audio = gr.Audio(label="Generated audio", type="filepath")
|
| 366 |
+
status = gr.Markdown(label="Run details")
|
| 367 |
+
|
| 368 |
+
generate_btn.click(
|
| 369 |
+
fn=run_inference,
|
| 370 |
+
inputs=[
|
| 371 |
+
prompt,
|
| 372 |
+
condition_audio,
|
| 373 |
+
condition_type,
|
| 374 |
+
use_musecontrol,
|
| 375 |
+
no_text,
|
| 376 |
+
negative_prompt,
|
| 377 |
+
guidance_scale_text,
|
| 378 |
+
guidance_scale_con,
|
| 379 |
+
guidance_scale_audio,
|
| 380 |
+
denoise_step,
|
| 381 |
+
weight_dtype,
|
| 382 |
+
ap_scale,
|
| 383 |
+
sigma_min,
|
| 384 |
+
sigma_max,
|
| 385 |
+
audio_mask_start,
|
| 386 |
+
audio_mask_end,
|
| 387 |
+
musical_mask_start,
|
| 388 |
+
musical_mask_end,
|
| 389 |
+
seed,
|
| 390 |
+
],
|
| 391 |
+
outputs=[generated_audio, status],
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
gr.Examples(
|
| 395 |
+
examples=EXAMPLES,
|
| 396 |
+
inputs=[
|
| 397 |
+
prompt,
|
| 398 |
+
condition_audio,
|
| 399 |
+
condition_type,
|
| 400 |
+
use_musecontrol,
|
| 401 |
+
no_text,
|
| 402 |
+
negative_prompt,
|
| 403 |
+
guidance_scale_text,
|
| 404 |
+
guidance_scale_con,
|
| 405 |
+
guidance_scale_audio,
|
| 406 |
+
denoise_step,
|
| 407 |
+
weight_dtype,
|
| 408 |
+
ap_scale,
|
| 409 |
+
sigma_min,
|
| 410 |
+
sigma_max,
|
| 411 |
+
audio_mask_start,
|
| 412 |
+
audio_mask_end,
|
| 413 |
+
musical_mask_start,
|
| 414 |
+
musical_mask_end,
|
| 415 |
+
seed,
|
| 416 |
+
],
|
| 417 |
+
label="Quick start examples (click to populate the form)",
|
| 418 |
+
)
|
| 419 |
+
return demo
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
if __name__ == "__main__":
|
| 423 |
+
demo = build_interface()
|
| 424 |
+
demo.launch()
|
config_inference.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def get_config():
|
| 2 |
+
return {
|
| 3 |
+
"condition_type": ["melody_stereo"], # you can choose any combinations in the two sets: ["dynamics", "rhythm", "melody_mono", "audio"], ["melody_stereo", "audio"]
|
| 4 |
+
# When using audio, is recommend to use empty string "" as prompt
|
| 5 |
+
"output_dir": "./generated_audio/output",
|
| 6 |
+
|
| 7 |
+
"GPU_id": "0",
|
| 8 |
+
|
| 9 |
+
"apadapter": True, # True for MuseControlLite, False for original Stable-audio
|
| 10 |
+
|
| 11 |
+
"ap_scale": 1.0, # recommend 1.0 for MuseControlLite, other values are not tested
|
| 12 |
+
|
| 13 |
+
"guidance_scale_text": 7.0,
|
| 14 |
+
|
| 15 |
+
"guidance_scale_con": 1.5, # The separated guidance for Musical attribute condition
|
| 16 |
+
|
| 17 |
+
"guidance_scale_audio": 1.0,
|
| 18 |
+
|
| 19 |
+
"denoise_step": 50,
|
| 20 |
+
|
| 21 |
+
"sigma_min": 0.3, # sigma_min and sigma_max are for the scheduler.
|
| 22 |
+
|
| 23 |
+
"sigma_max": 500, # Note that if sigma_max is too large or too small, the "audio condition generation" will be bad.
|
| 24 |
+
|
| 25 |
+
"weight_dtype": "fp16", # fp16 and fp32 sounds quiet the same.
|
| 26 |
+
|
| 27 |
+
"negative_text_prompt": "",
|
| 28 |
+
|
| 29 |
+
###############
|
| 30 |
+
|
| 31 |
+
"audio_mask_start_seconds": 14, # Apply mask to musical attributes choose only one mask to use, it automatically generates a complemetary mask to the other condition
|
| 32 |
+
|
| 33 |
+
"audio_mask_end_seconds": 47,
|
| 34 |
+
|
| 35 |
+
"musical_attribute_mask_start_seconds": 0, # 'Apply mask to audio condition, choose only one mask to use, it automatically generates a complemetary mask to the other condition'
|
| 36 |
+
|
| 37 |
+
"musical_attribute_mask_end_seconds": 0,
|
| 38 |
+
|
| 39 |
+
###############
|
| 40 |
+
|
| 41 |
+
"no_text": False, # Optional, set to true if no text prompt is needed (possible for audio inpainting or outpainting)
|
| 42 |
+
|
| 43 |
+
"show_result_and_plt": True,
|
| 44 |
+
|
| 45 |
+
"audio_files": [
|
| 46 |
+
"melody_condition_audio/49_piano.mp3",
|
| 47 |
+
"melody_condition_audio/49_piano.mp3",
|
| 48 |
+
"melody_condition_audio/49_piano.mp3",
|
| 49 |
+
"melody_condition_audio/322_piano.mp3",
|
| 50 |
+
"melody_condition_audio/322_piano.mp3",
|
| 51 |
+
"melody_condition_audio/322_piano.mp3",
|
| 52 |
+
"melody_condition_audio/610_bass.mp3",
|
| 53 |
+
"melody_condition_audio/610_bass.mp3",
|
| 54 |
+
"melody_condition_audio/785_piano.mp3",
|
| 55 |
+
"melody_condition_audio/785_piano.mp3",
|
| 56 |
+
"melody_condition_audio/933_string.mp3",
|
| 57 |
+
"melody_condition_audio/933_string.mp3",
|
| 58 |
+
"melody_condition_audio/6_uke_12.wav",
|
| 59 |
+
"melody_condition_audio/6_uke_12.wav",
|
| 60 |
+
"melody_condition_audio/57_jazz.mp3",
|
| 61 |
+
"melody_condition_audio/703_mideast.mp3",
|
| 62 |
+
|
| 63 |
+
],
|
| 64 |
+
# "audio_files": [
|
| 65 |
+
# "SDD_nosinging/SDD_audio/34/1004034.mp3",
|
| 66 |
+
# "original_15s/original_9.wav",
|
| 67 |
+
# "original_15s/original_10.wav",
|
| 68 |
+
# "original_15s/original_11.wav",
|
| 69 |
+
# "original_15s/original_15.wav",
|
| 70 |
+
# "original_15s/original_16.wav",
|
| 71 |
+
# "original_15s/original_21.wav",
|
| 72 |
+
# "original_15s/original_25.wav",
|
| 73 |
+
# ],
|
| 74 |
+
|
| 75 |
+
"text": [
|
| 76 |
+
"Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
|
| 77 |
+
"A heartfelt, warm acoustic guitar performance, evoking a sense of tenderness and deep emotion, with a melody that truly resonates and touches the heart.",
|
| 78 |
+
"A vibrant MIDI electronic composition with a hopeful and optimistic vibe.",
|
| 79 |
+
"This track composed of electronic instruments gives a sense of opening and clearness.",
|
| 80 |
+
"This track composed of electronic instruments gives a sense of opening and clearness.",
|
| 81 |
+
"Hopeful instrumental with guitar being the lead and tabla used for percussion in the middle giving a feeling of going somewhere with positive outlook.",
|
| 82 |
+
"A string ensemble opens the track with legato, melancholic melodies. The violins and violas play beautifully, while the cellos and bass provide harmonic support for the moving passages. The overall feel is deeply melancholic, with an emotionally stirring performance that remains harmonious and a sense of clearness.",
|
| 83 |
+
"An exceptionally harmonious string performance with a lively tempo in the first half, transitioning to a gentle and beautiful melody in the second half. It creates a warm and comforting atmosphere, featuring cellos and bass providing a solid foundation, while violins and violas showcase the main theme, all without any noise, resulting in a cohesive and serene sound.",
|
| 84 |
+
"Pop solo piano instrumental song. Simple harmony and emotional theme. Makes you feel nostalgic and wanting a cup of warm tea sitting on the couch while holding the person you love.",
|
| 85 |
+
"A whimsical string arrangement with rich layers, featuring violins as the main melody, accompanied by violas and cellos. The light, playful melody blends harmoniously, creating a sense of clarity.",
|
| 86 |
+
"An instrumental piece primarily featuring acoustic guitar, with a lively and nimble feel. The melody is bright, delivering an overall sense of joy.",
|
| 87 |
+
"A joyful saxophone performance that is smooth and cohesive, accompanied by cello. The first half features a relaxed tempo, while the second half picks up with an upbeat rhythm, creating a lively and energetic atmosphere. The overall sound is harmonious and clear, evoking feelings of happiness and vitality.",
|
| 88 |
+
"A cheerful piano performance with a smooth and flowing rhythm, evoking feelings of joy and vitality.",
|
| 89 |
+
"An instrumental piece primarily featuring piano, with a lively rhythm and cheerful melodies that evoke a sense of joyful childhood playfulness. The melodies are clear and bright.",
|
| 90 |
+
"fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
|
| 91 |
+
"A lively 70s style British pop song featuring drums, electric guitars, and synth violin. The instruments blend harmoniously, creating a dynamic, clean sound without any noise or clutter.",
|
| 92 |
+
"A soothing acoustic guitar song that evokes nostalgia, featuring intricate fingerpicking. The melody is both sacred and mysterious, with a rich texture."
|
| 93 |
+
],
|
| 94 |
+
|
| 95 |
+
########## adapters avilable ############
|
| 96 |
+
# We trained 4 set of adapters:
|
| 97 |
+
# 1. with conditions ["melody_mono", "dynamics", "rhythm"]
|
| 98 |
+
# 2. with conditions ["melody_mono"]
|
| 99 |
+
# 3. with conditions ["melody_stereo"]
|
| 100 |
+
# 3. with conditions ["audio"]
|
| 101 |
+
# MuseControlLite_inference_all.py will automaticaly choose the most suitable model according to the condition type:
|
| 102 |
+
###############
|
| 103 |
+
# Works for condition ["dynamics", "rhythm", "melody_mono"]
|
| 104 |
+
"transformer_ckpt_musical": "./checkpoints/woSDD-all/model_3.safetensors",
|
| 105 |
+
|
| 106 |
+
"extractor_ckpt_musical": {
|
| 107 |
+
"dynamics": "./checkpoints/woSDD-all/model_1.safetensors",
|
| 108 |
+
"melody": "./checkpoints/woSDD-all/model.safetensors",
|
| 109 |
+
"rhythm": "./checkpoints/woSDD-all/model_2.safetensors",
|
| 110 |
+
},
|
| 111 |
+
###############
|
| 112 |
+
|
| 113 |
+
# Works for ['audio], it works without a feature extractor, and could cooperate with other adapters
|
| 114 |
+
#################
|
| 115 |
+
"audio_transformer_ckpt": "./checkpoints/70000_Audio/model.safetensors",
|
| 116 |
+
|
| 117 |
+
# Specialized for ['melody_stereo']
|
| 118 |
+
###############
|
| 119 |
+
"transformer_ckpt_melody_stero": "./checkpoints/70000_Melody_stereo/model_1.safetensors",
|
| 120 |
+
|
| 121 |
+
"extractor_ckpt_melody_stero": {
|
| 122 |
+
"melody": "./checkpoints/70000_Melody_stereo/model.safetensors",
|
| 123 |
+
},
|
| 124 |
+
###############
|
| 125 |
+
|
| 126 |
+
# Specialized for ['melody_mono']
|
| 127 |
+
###############
|
| 128 |
+
"transformer_ckpt_melody_mono": "./checkpoints/40000_Melody_mono/model_1.safetensors",
|
| 129 |
+
|
| 130 |
+
"extractor_ckpt_melody_mono": {
|
| 131 |
+
"melody": "./checkpoints/40000_Melody_mono/model.safetensors",
|
| 132 |
+
},
|
| 133 |
+
###############
|
| 134 |
+
}
|
melody_condition_audio/322_piano.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:698e40e4067efa3b181ea367ec8b0bc76b651cc0ca9bee329a3833565f35a800
|
| 3 |
+
size 915798
|
melody_condition_audio/49_piano.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b00df95a62c91e33c71a4ee312fb84883b3ff58cadb66de8582055ce89d72636
|
| 3 |
+
size 1106827
|
melody_condition_audio/57_jazz.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9a4cf5f07b40270500ea05e3c756f1d02817c1c6cdd07724e7c102b33d71d2f
|
| 3 |
+
size 1101758
|
melody_condition_audio/610_bass.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5ead8df05aa7cd33f193c315691cfb6bc8f23bc651f6c4947e3685ab11503bb
|
| 3 |
+
size 1133359
|
melody_condition_audio/703_mideast.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8db8aaa68b2d749fcf6d4e0cfc1ccd7a41fe4e0e942f0c0dbab2033d5eca07b1
|
| 3 |
+
size 1143212
|
melody_condition_audio/785_piano.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e8cef236b1723ced1fb55c6a5f28205a36babbd8add95553a508a88519d74f7
|
| 3 |
+
size 1110813
|
melody_condition_audio/933_string.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f686104a4849b96dea470db2067df342d9e1265b30e254da5af85d83edcac45
|
| 3 |
+
size 1097973
|
pipeline/stable_audio_multi_cfg_pipe.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Callable, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
T5EncoderModel,
|
| 22 |
+
T5Tokenizer,
|
| 23 |
+
T5TokenizerFast,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from diffusers.models import AutoencoderOobleck, StableAudioDiTModel
|
| 27 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 28 |
+
from diffusers.schedulers import EDMDPMSolverMultistepScheduler
|
| 29 |
+
from diffusers.utils import (
|
| 30 |
+
logging,
|
| 31 |
+
replace_example_docstring,
|
| 32 |
+
)
|
| 33 |
+
import numpy as np
|
| 34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 35 |
+
from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
| 36 |
+
from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
|
| 37 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 38 |
+
|
| 39 |
+
def check_and_print_non_float32_parameters(model):
|
| 40 |
+
non_float32_params = []
|
| 41 |
+
for name, param in model.named_parameters():
|
| 42 |
+
if param.dtype != torch.float32:
|
| 43 |
+
non_float32_params.append((name, param.dtype))
|
| 44 |
+
|
| 45 |
+
if non_float32_params:
|
| 46 |
+
print("Not all parameters are in float32!")
|
| 47 |
+
print("The following parameters are not in float32:")
|
| 48 |
+
for name, dtype in non_float32_params:
|
| 49 |
+
print(f"Parameter: {name}, Data Type: {dtype}")
|
| 50 |
+
else:
|
| 51 |
+
print("All parameters are in float32.")
|
| 52 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 53 |
+
|
| 54 |
+
EXAMPLE_DOC_STRING = """
|
| 55 |
+
Examples:
|
| 56 |
+
```py
|
| 57 |
+
>>> import scipy
|
| 58 |
+
>>> import torch
|
| 59 |
+
>>> import soundfile as sf
|
| 60 |
+
>>> from diffusers import StableAudioPipeline
|
| 61 |
+
|
| 62 |
+
>>> repo_id = "stabilityai/stable-audio-open-1.0"
|
| 63 |
+
>>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 64 |
+
>>> pipe = pipe.to("cuda")
|
| 65 |
+
|
| 66 |
+
>>> # define the prompts
|
| 67 |
+
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
| 68 |
+
>>> negative_prompt = "Low quality."
|
| 69 |
+
|
| 70 |
+
>>> # set the seed for generator
|
| 71 |
+
>>> generator = torch.Generator("cuda").manual_seed(0)
|
| 72 |
+
|
| 73 |
+
>>> # run the generation
|
| 74 |
+
>>> audio = pipe(
|
| 75 |
+
... prompt,
|
| 76 |
+
... negative_prompt=negative_prompt,
|
| 77 |
+
... num_inference_steps=200,
|
| 78 |
+
... audio_end_in_s=10.0,
|
| 79 |
+
... num_waveforms_per_prompt=3,
|
| 80 |
+
... generator=generator,
|
| 81 |
+
... ).audios
|
| 82 |
+
|
| 83 |
+
>>> output = audio[0].T.float().cpu().numpy()
|
| 84 |
+
>>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
|
| 85 |
+
```
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class StableAudioPipeline(DiffusionPipeline):
|
| 90 |
+
r"""
|
| 91 |
+
Pipeline for text-to-audio generation using StableAudio.
|
| 92 |
+
|
| 93 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 94 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
vae ([`AutoencoderOobleck`]):
|
| 98 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 99 |
+
text_encoder ([`~transformers.T5EncoderModel`]):
|
| 100 |
+
Frozen text-encoder. StableAudio uses the encoder of
|
| 101 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 102 |
+
[google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
|
| 103 |
+
projection_model ([`StableAudioProjectionModel`]):
|
| 104 |
+
A trained model used to linearly project the hidden-states from the text encoder model and the start and
|
| 105 |
+
end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
|
| 106 |
+
give the input to the transformer model.
|
| 107 |
+
tokenizer ([`~transformers.T5Tokenizer`]):
|
| 108 |
+
Tokenizer to tokenize text for the frozen text-encoder.
|
| 109 |
+
transformer ([`StableAudioDiTModel`]):
|
| 110 |
+
A `StableAudioDiTModel` to denoise the encoded audio latents.
|
| 111 |
+
scheduler ([`EDMDPMSolverMultistepScheduler`]):
|
| 112 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
vae: AutoencoderOobleck,
|
| 120 |
+
text_encoder: T5EncoderModel,
|
| 121 |
+
projection_model: StableAudioProjectionModel,
|
| 122 |
+
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
| 123 |
+
transformer: StableAudioDiTModel,
|
| 124 |
+
scheduler: EDMDPMSolverMultistepScheduler,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.register_modules(
|
| 129 |
+
vae=vae,
|
| 130 |
+
text_encoder=text_encoder,
|
| 131 |
+
projection_model=projection_model,
|
| 132 |
+
tokenizer=tokenizer,
|
| 133 |
+
transformer=transformer,
|
| 134 |
+
scheduler=scheduler,
|
| 135 |
+
)
|
| 136 |
+
self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
|
| 137 |
+
|
| 138 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
|
| 139 |
+
def enable_vae_slicing(self):
|
| 140 |
+
r"""
|
| 141 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 142 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 143 |
+
"""
|
| 144 |
+
self.vae.enable_slicing()
|
| 145 |
+
|
| 146 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
|
| 147 |
+
def disable_vae_slicing(self):
|
| 148 |
+
r"""
|
| 149 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 150 |
+
computing decoding in one step.
|
| 151 |
+
"""
|
| 152 |
+
self.vae.disable_slicing()
|
| 153 |
+
|
| 154 |
+
def encode_prompt(
|
| 155 |
+
self,
|
| 156 |
+
prompt,
|
| 157 |
+
device,
|
| 158 |
+
do_classifier_free_guidance,
|
| 159 |
+
negative_prompt=None,
|
| 160 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 161 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 162 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 163 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 164 |
+
):
|
| 165 |
+
if prompt is not None and isinstance(prompt, str):
|
| 166 |
+
batch_size = 1
|
| 167 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 168 |
+
batch_size = len(prompt)
|
| 169 |
+
else:
|
| 170 |
+
batch_size = prompt_embeds.shape[0]
|
| 171 |
+
|
| 172 |
+
if prompt_embeds is None:
|
| 173 |
+
# 1. Tokenize text
|
| 174 |
+
text_inputs = self.tokenizer(
|
| 175 |
+
prompt,
|
| 176 |
+
padding="max_length",
|
| 177 |
+
max_length=self.tokenizer.model_max_length,
|
| 178 |
+
truncation=True,
|
| 179 |
+
return_tensors="pt",
|
| 180 |
+
)
|
| 181 |
+
text_input_ids = text_inputs.input_ids
|
| 182 |
+
attention_mask = text_inputs.attention_mask
|
| 183 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 184 |
+
|
| 185 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 186 |
+
text_input_ids, untruncated_ids
|
| 187 |
+
):
|
| 188 |
+
removed_text = self.tokenizer.batch_decode(
|
| 189 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 190 |
+
)
|
| 191 |
+
# logger.warning(
|
| 192 |
+
# f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
|
| 193 |
+
# f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 194 |
+
# )
|
| 195 |
+
|
| 196 |
+
text_input_ids = text_input_ids.to(device)
|
| 197 |
+
attention_mask = attention_mask.to(device)
|
| 198 |
+
|
| 199 |
+
# 2. Text encoder forward
|
| 200 |
+
self.text_encoder.eval()
|
| 201 |
+
prompt_embeds = self.text_encoder(
|
| 202 |
+
text_input_ids,
|
| 203 |
+
attention_mask=attention_mask,
|
| 204 |
+
)
|
| 205 |
+
prompt_embeds = prompt_embeds[0]
|
| 206 |
+
|
| 207 |
+
if do_classifier_free_guidance and negative_prompt is not None:
|
| 208 |
+
uncond_tokens: List[str]
|
| 209 |
+
if type(prompt) is not type(negative_prompt):
|
| 210 |
+
raise TypeError(
|
| 211 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 212 |
+
f" {type(prompt)}."
|
| 213 |
+
)
|
| 214 |
+
elif isinstance(negative_prompt, str):
|
| 215 |
+
uncond_tokens = [negative_prompt]
|
| 216 |
+
elif batch_size != len(negative_prompt):
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 219 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 220 |
+
" the batch size of `prompt`."
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
uncond_tokens = negative_prompt
|
| 224 |
+
|
| 225 |
+
# 1. Tokenize text
|
| 226 |
+
uncond_input = self.tokenizer(
|
| 227 |
+
uncond_tokens,
|
| 228 |
+
padding="max_length",
|
| 229 |
+
max_length=self.tokenizer.model_max_length,
|
| 230 |
+
truncation=True,
|
| 231 |
+
return_tensors="pt",
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
uncond_input_ids = uncond_input.input_ids.to(device)
|
| 235 |
+
negative_attention_mask = uncond_input.attention_mask.to(device)
|
| 236 |
+
|
| 237 |
+
# 2. Text encoder forward
|
| 238 |
+
self.text_encoder.eval()
|
| 239 |
+
negative_prompt_embeds = self.text_encoder(
|
| 240 |
+
uncond_input_ids,
|
| 241 |
+
attention_mask=negative_attention_mask,
|
| 242 |
+
)
|
| 243 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 244 |
+
|
| 245 |
+
if negative_attention_mask is not None:
|
| 246 |
+
# set the masked tokens to the null embed
|
| 247 |
+
negative_prompt_embeds = torch.where(
|
| 248 |
+
negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# 3. Project prompt_embeds and negative_prompt_embeds
|
| 252 |
+
if do_classifier_free_guidance and negative_prompt_embeds is not None:
|
| 253 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 254 |
+
# Here we concatenate the negative and text embeddings into a single batch
|
| 255 |
+
# to avoid doing two forward passes
|
| 256 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
|
| 257 |
+
if attention_mask is not None and negative_attention_mask is None:
|
| 258 |
+
negative_attention_mask = torch.ones_like(attention_mask)
|
| 259 |
+
elif attention_mask is None and negative_attention_mask is not None:
|
| 260 |
+
attention_mask = torch.ones_like(negative_attention_mask)
|
| 261 |
+
if attention_mask is not None:
|
| 262 |
+
attention_mask = torch.cat([negative_attention_mask, attention_mask, attention_mask])
|
| 263 |
+
|
| 264 |
+
prompt_embeds = self.projection_model(
|
| 265 |
+
text_hidden_states=prompt_embeds,
|
| 266 |
+
).text_hidden_states
|
| 267 |
+
if attention_mask is not None:
|
| 268 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 269 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 270 |
+
|
| 271 |
+
return prompt_embeds
|
| 272 |
+
|
| 273 |
+
def encode_duration(
|
| 274 |
+
self,
|
| 275 |
+
audio_start_in_s,
|
| 276 |
+
audio_end_in_s,
|
| 277 |
+
device,
|
| 278 |
+
do_classifier_free_guidance,
|
| 279 |
+
batch_size,
|
| 280 |
+
):
|
| 281 |
+
audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
|
| 282 |
+
audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
|
| 283 |
+
|
| 284 |
+
if len(audio_start_in_s) == 1:
|
| 285 |
+
audio_start_in_s = audio_start_in_s * batch_size
|
| 286 |
+
if len(audio_end_in_s) == 1:
|
| 287 |
+
audio_end_in_s = audio_end_in_s * batch_size
|
| 288 |
+
|
| 289 |
+
# Cast the inputs to floats
|
| 290 |
+
audio_start_in_s = [float(x) for x in audio_start_in_s]
|
| 291 |
+
audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
|
| 292 |
+
|
| 293 |
+
audio_end_in_s = [float(x) for x in audio_end_in_s]
|
| 294 |
+
audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
|
| 295 |
+
|
| 296 |
+
projection_output = self.projection_model(
|
| 297 |
+
start_seconds=audio_start_in_s,
|
| 298 |
+
end_seconds=audio_end_in_s,
|
| 299 |
+
)
|
| 300 |
+
seconds_start_hidden_states = projection_output.seconds_start_hidden_states
|
| 301 |
+
seconds_end_hidden_states = projection_output.seconds_end_hidden_states
|
| 302 |
+
|
| 303 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 304 |
+
# Here we repeat the audio hidden states to avoid doing two forward passes
|
| 305 |
+
if do_classifier_free_guidance:
|
| 306 |
+
seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
|
| 307 |
+
seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
|
| 308 |
+
|
| 309 |
+
return seconds_start_hidden_states, seconds_end_hidden_states
|
| 310 |
+
|
| 311 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 312 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 313 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 314 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 315 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 316 |
+
# and should be between [0, 1]
|
| 317 |
+
|
| 318 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 319 |
+
extra_step_kwargs = {}
|
| 320 |
+
if accepts_eta:
|
| 321 |
+
extra_step_kwargs["eta"] = eta
|
| 322 |
+
|
| 323 |
+
# check if the scheduler accepts generator
|
| 324 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 325 |
+
if accepts_generator:
|
| 326 |
+
extra_step_kwargs["generator"] = generator
|
| 327 |
+
return extra_step_kwargs
|
| 328 |
+
|
| 329 |
+
def check_inputs(
|
| 330 |
+
self,
|
| 331 |
+
prompt,
|
| 332 |
+
audio_start_in_s,
|
| 333 |
+
audio_end_in_s,
|
| 334 |
+
callback_steps,
|
| 335 |
+
negative_prompt=None,
|
| 336 |
+
prompt_embeds=None,
|
| 337 |
+
negative_prompt_embeds=None,
|
| 338 |
+
attention_mask=None,
|
| 339 |
+
negative_attention_mask=None,
|
| 340 |
+
initial_audio_waveforms=None,
|
| 341 |
+
initial_audio_sampling_rate=None,
|
| 342 |
+
):
|
| 343 |
+
if audio_end_in_s < audio_start_in_s:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if (
|
| 349 |
+
audio_start_in_s < self.projection_model.config.min_value
|
| 350 |
+
or audio_start_in_s > self.projection_model.config.max_value
|
| 351 |
+
):
|
| 352 |
+
raise ValueError(
|
| 353 |
+
f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
|
| 354 |
+
f"is {audio_start_in_s}."
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if (
|
| 358 |
+
audio_end_in_s < self.projection_model.config.min_value
|
| 359 |
+
or audio_end_in_s > self.projection_model.config.max_value
|
| 360 |
+
):
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
|
| 363 |
+
f"is {audio_end_in_s}."
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if (callback_steps is None) or (
|
| 367 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 368 |
+
):
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 371 |
+
f" {type(callback_steps)}."
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if prompt is not None and prompt_embeds is not None:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 377 |
+
" only forward one of the two."
|
| 378 |
+
)
|
| 379 |
+
elif prompt is None and (prompt_embeds is None):
|
| 380 |
+
raise ValueError(
|
| 381 |
+
"Provide either `prompt`, or `prompt_embeds`. Cannot leave"
|
| 382 |
+
"`prompt` undefined without specifying `prompt_embeds`."
|
| 383 |
+
)
|
| 384 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 385 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 386 |
+
|
| 387 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 390 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 394 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 395 |
+
raise ValueError(
|
| 396 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 397 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 398 |
+
f" {negative_prompt_embeds.shape}."
|
| 399 |
+
)
|
| 400 |
+
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 403 |
+
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
|
| 407 |
+
raise ValueError(
|
| 408 |
+
"`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
|
| 414 |
+
"Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def prepare_latents(
|
| 418 |
+
self,
|
| 419 |
+
batch_size,
|
| 420 |
+
num_channels_vae,
|
| 421 |
+
sample_size,
|
| 422 |
+
dtype,
|
| 423 |
+
device,
|
| 424 |
+
generator,
|
| 425 |
+
latents=None,
|
| 426 |
+
initial_audio_waveforms=None,
|
| 427 |
+
num_waveforms_per_prompt=None,
|
| 428 |
+
audio_channels=None,
|
| 429 |
+
):
|
| 430 |
+
shape = (batch_size, num_channels_vae, sample_size)
|
| 431 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 432 |
+
raise ValueError(
|
| 433 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 434 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if latents is None:
|
| 438 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 439 |
+
else:
|
| 440 |
+
latents = latents.to(device)
|
| 441 |
+
|
| 442 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 443 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 444 |
+
|
| 445 |
+
# encode the initial audio for use by the model
|
| 446 |
+
if initial_audio_waveforms is not None:
|
| 447 |
+
# check dimension
|
| 448 |
+
if initial_audio_waveforms.ndim == 2:
|
| 449 |
+
initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
|
| 450 |
+
elif initial_audio_waveforms.ndim != 3:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
|
| 456 |
+
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
|
| 457 |
+
|
| 458 |
+
# check num_channels
|
| 459 |
+
if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
|
| 460 |
+
initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
|
| 461 |
+
elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
|
| 462 |
+
initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
|
| 463 |
+
|
| 464 |
+
if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# crop or pad
|
| 470 |
+
audio_length = initial_audio_waveforms.shape[-1]
|
| 471 |
+
if audio_length < audio_vae_length:
|
| 472 |
+
logger.warning(
|
| 473 |
+
f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
|
| 474 |
+
)
|
| 475 |
+
elif audio_length > audio_vae_length:
|
| 476 |
+
logger.warning(
|
| 477 |
+
f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
audio = initial_audio_waveforms.new_zeros(audio_shape)
|
| 481 |
+
audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
|
| 482 |
+
|
| 483 |
+
encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
|
| 484 |
+
encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
|
| 485 |
+
latents = encoded_audio + latents
|
| 486 |
+
return latents
|
| 487 |
+
|
| 488 |
+
@torch.no_grad()
|
| 489 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 490 |
+
def __call__(
|
| 491 |
+
self,
|
| 492 |
+
guidance_scale_audio = None,
|
| 493 |
+
extracted_condition_audio = None,
|
| 494 |
+
extracted_condition = None,
|
| 495 |
+
prompt: Union[str, List[str]] = None,
|
| 496 |
+
audio_end_in_s: Optional[float] = None,
|
| 497 |
+
audio_start_in_s: Optional[float] = 0.0,
|
| 498 |
+
num_inference_steps: int = 100,
|
| 499 |
+
guidance_scale_text: float = 7.0,
|
| 500 |
+
guidance_scale_con: float = 2.0,
|
| 501 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 502 |
+
num_waveforms_per_prompt: Optional[int] = 1,
|
| 503 |
+
eta: float = 0.0,
|
| 504 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 505 |
+
latents: Optional[torch.Tensor] = None,
|
| 506 |
+
initial_audio_waveforms: Optional[torch.Tensor] = None,
|
| 507 |
+
initial_audio_sampling_rate: Optional[torch.Tensor] = None,
|
| 508 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 509 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 510 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 511 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 512 |
+
return_dict: bool = True,
|
| 513 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 514 |
+
callback_steps: Optional[int] = 1,
|
| 515 |
+
output_type: Optional[str] = "pt",
|
| 516 |
+
):
|
| 517 |
+
r"""
|
| 518 |
+
The call function to the pipeline for generation.
|
| 519 |
+
|
| 520 |
+
Args:
|
| 521 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 522 |
+
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
| 523 |
+
audio_end_in_s (`float`, *optional*, defaults to 47.55):
|
| 524 |
+
Audio end index in seconds.
|
| 525 |
+
audio_start_in_s (`float`, *optional*, defaults to 0):
|
| 526 |
+
Audio start index in seconds.
|
| 527 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 528 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 529 |
+
expense of slower inference.
|
| 530 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 531 |
+
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
| 532 |
+
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 533 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 534 |
+
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
| 535 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 536 |
+
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
| 537 |
+
The number of waveforms to generate per prompt.
|
| 538 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 539 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 540 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 541 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 542 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 543 |
+
generation deterministic.
|
| 544 |
+
latents (`torch.Tensor`, *optional*):
|
| 545 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
|
| 546 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 547 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 548 |
+
initial_audio_waveforms (`torch.Tensor`, *optional*):
|
| 549 |
+
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
|
| 550 |
+
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
|
| 551 |
+
corresponds to the number of prompts passed to the model.
|
| 552 |
+
initial_audio_sampling_rate (`int`, *optional*):
|
| 553 |
+
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
|
| 554 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 555 |
+
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
|
| 556 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
|
| 557 |
+
argument.
|
| 558 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 559 |
+
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
|
| 560 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 561 |
+
`negative_prompt` input argument.
|
| 562 |
+
attention_mask (`torch.LongTensor`, *optional*):
|
| 563 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 564 |
+
be computed from `prompt` input argument.
|
| 565 |
+
negative_attention_mask (`torch.LongTensor`, *optional*):
|
| 566 |
+
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
|
| 567 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 568 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 569 |
+
plain tuple.
|
| 570 |
+
callback (`Callable`, *optional*):
|
| 571 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 572 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 573 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 574 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 575 |
+
every step.
|
| 576 |
+
output_type (`str`, *optional*, defaults to `"pt"`):
|
| 577 |
+
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
| 578 |
+
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
| 579 |
+
model (LDM) output.
|
| 580 |
+
|
| 581 |
+
Examples:
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 585 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 586 |
+
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
| 587 |
+
"""
|
| 588 |
+
# 0. Convert audio input length from seconds to latent length
|
| 589 |
+
downsample_ratio = self.vae.hop_length
|
| 590 |
+
|
| 591 |
+
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
|
| 592 |
+
if audio_end_in_s is None:
|
| 593 |
+
audio_end_in_s = max_audio_length_in_s
|
| 594 |
+
|
| 595 |
+
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
|
| 596 |
+
raise ValueError(
|
| 597 |
+
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
|
| 601 |
+
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
|
| 602 |
+
waveform_length = int(self.transformer.config.sample_size) # * audio_end_in_s / 47.554
|
| 603 |
+
# waveform_length = 646
|
| 604 |
+
# 1. Check inputs. Raise error if not correct
|
| 605 |
+
self.check_inputs(
|
| 606 |
+
prompt,
|
| 607 |
+
audio_start_in_s,
|
| 608 |
+
audio_end_in_s,
|
| 609 |
+
callback_steps,
|
| 610 |
+
negative_prompt,
|
| 611 |
+
prompt_embeds,
|
| 612 |
+
negative_prompt_embeds,
|
| 613 |
+
attention_mask,
|
| 614 |
+
negative_attention_mask,
|
| 615 |
+
initial_audio_waveforms,
|
| 616 |
+
initial_audio_sampling_rate,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# 2. Define call parameters
|
| 620 |
+
if prompt is not None and isinstance(prompt, str):
|
| 621 |
+
batch_size = 1
|
| 622 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 623 |
+
batch_size = len(prompt)
|
| 624 |
+
else:
|
| 625 |
+
batch_size = prompt_embeds.shape[0]
|
| 626 |
+
|
| 627 |
+
device = self._execution_device
|
| 628 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 629 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 630 |
+
# corresponds to doing no classifier free guidance.
|
| 631 |
+
do_classifier_free_guidance = True
|
| 632 |
+
|
| 633 |
+
# 3. Encode input prompt
|
| 634 |
+
prompt_embeds = self.encode_prompt(
|
| 635 |
+
prompt,
|
| 636 |
+
device,
|
| 637 |
+
do_classifier_free_guidance,
|
| 638 |
+
negative_prompt,
|
| 639 |
+
prompt_embeds,
|
| 640 |
+
negative_prompt_embeds,
|
| 641 |
+
attention_mask,
|
| 642 |
+
negative_attention_mask,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# Encode duration
|
| 646 |
+
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
|
| 647 |
+
audio_start_in_s,
|
| 648 |
+
audio_end_in_s,
|
| 649 |
+
device,
|
| 650 |
+
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
|
| 651 |
+
batch_size,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Create text_audio_duration_embeds and audio_duration_embeds
|
| 655 |
+
text_audio_duration_embeds = torch.cat(
|
| 656 |
+
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
|
| 660 |
+
|
| 661 |
+
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
|
| 662 |
+
# to concatenate it to the embeddings
|
| 663 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
|
| 664 |
+
negative_text_audio_duration_embeds = torch.zeros_like(
|
| 665 |
+
text_audio_duration_embeds, device=text_audio_duration_embeds.device
|
| 666 |
+
)
|
| 667 |
+
text_audio_duration_embeds = torch.cat(
|
| 668 |
+
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
|
| 669 |
+
)
|
| 670 |
+
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
|
| 671 |
+
# if condition is not None:
|
| 672 |
+
# condition_conditioning = condition_model(condition)
|
| 673 |
+
# condition_no_conditioning = condition_model(torch.full_like(condition, fill_value=0))
|
| 674 |
+
# extracted_condition = torch.cat([condition_no_conditioning, condition_no_conditioning, condition_conditioning], dim=0)
|
| 675 |
+
|
| 676 |
+
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
|
| 677 |
+
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
|
| 678 |
+
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 679 |
+
text_audio_duration_embeds = text_audio_duration_embeds.view(
|
| 680 |
+
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 684 |
+
audio_duration_embeds = audio_duration_embeds.view(
|
| 685 |
+
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
# 4. Prepare timesteps
|
| 689 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 690 |
+
timesteps = self.scheduler.timesteps
|
| 691 |
+
|
| 692 |
+
# 5. Prepare latent variables
|
| 693 |
+
num_channels_vae = self.transformer.config.in_channels
|
| 694 |
+
latents = self.prepare_latents(
|
| 695 |
+
batch_size * num_waveforms_per_prompt,
|
| 696 |
+
num_channels_vae,
|
| 697 |
+
waveform_length,
|
| 698 |
+
text_audio_duration_embeds.dtype,
|
| 699 |
+
device,
|
| 700 |
+
generator,
|
| 701 |
+
latents,
|
| 702 |
+
initial_audio_waveforms,
|
| 703 |
+
num_waveforms_per_prompt,
|
| 704 |
+
audio_channels=self.vae.config.audio_channels,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# 6. Prepare extra step kwargs
|
| 708 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 709 |
+
|
| 710 |
+
# 7. Prepare rotary positional embedding
|
| 711 |
+
rotary_embedding = get_1d_rotary_pos_embed(
|
| 712 |
+
self.rotary_embed_dim,
|
| 713 |
+
latents.shape[2] + audio_duration_embeds.shape[1],
|
| 714 |
+
use_real=True,
|
| 715 |
+
repeat_interleave_real=False,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# 8. Denoising loop
|
| 719 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 720 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 721 |
+
for i, t in enumerate(timesteps):
|
| 722 |
+
# expand the latents if we are doing classifier free guidance
|
| 723 |
+
latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
| 724 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 725 |
+
with autocast():
|
| 726 |
+
noise_pred = self.transformer(
|
| 727 |
+
latent_model_input,
|
| 728 |
+
t.unsqueeze(0),
|
| 729 |
+
encoder_hidden_states=text_audio_duration_embeds,
|
| 730 |
+
encoder_hidden_states_con=extracted_condition,
|
| 731 |
+
global_hidden_states=audio_duration_embeds,
|
| 732 |
+
rotary_embedding=rotary_embedding,
|
| 733 |
+
return_dict=False,
|
| 734 |
+
)[0]
|
| 735 |
+
# transformer_weight_dtype = next(self.transformer.parameters()).dtype
|
| 736 |
+
# print("transformer_weight_dtype",transformer_weight_dtype)
|
| 737 |
+
# noise_pred = noise_pred.half()
|
| 738 |
+
|
| 739 |
+
# perform guidance
|
| 740 |
+
if do_classifier_free_guidance:
|
| 741 |
+
noise_pred_uncond, noise_pred_text, noise_pred_both= noise_pred.chunk(3)
|
| 742 |
+
noise_pred = noise_pred_uncond + guidance_scale_text * (noise_pred_text - noise_pred_uncond) + guidance_scale_con * (noise_pred_both - noise_pred_text)
|
| 743 |
+
|
| 744 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 745 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 746 |
+
|
| 747 |
+
# call the callback, if provided
|
| 748 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 749 |
+
progress_bar.update()
|
| 750 |
+
if callback is not None and i % callback_steps == 0:
|
| 751 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 752 |
+
callback(step_idx, t, latents)
|
| 753 |
+
|
| 754 |
+
# 9. Post-processing
|
| 755 |
+
if not output_type == "latent":
|
| 756 |
+
with autocast():
|
| 757 |
+
audio = self.vae.decode(latents).sample
|
| 758 |
+
|
| 759 |
+
else:
|
| 760 |
+
return AudioPipelineOutput(audios=latents)
|
| 761 |
+
|
| 762 |
+
audio = audio[:, :, waveform_start:waveform_end]
|
| 763 |
+
|
| 764 |
+
if output_type == "np":
|
| 765 |
+
audio = audio.cpu().float().numpy()
|
| 766 |
+
|
| 767 |
+
self.maybe_free_model_hooks()
|
| 768 |
+
|
| 769 |
+
if not return_dict:
|
| 770 |
+
return (audio,)
|
| 771 |
+
|
| 772 |
+
return AudioPipelineOutput(audios=audio)
|
pipeline/stable_audio_multi_cfg_pipe_audio.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Callable, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
T5EncoderModel,
|
| 22 |
+
T5Tokenizer,
|
| 23 |
+
T5TokenizerFast,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from diffusers.models import AutoencoderOobleck, StableAudioDiTModel
|
| 27 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 28 |
+
from diffusers.schedulers import EDMDPMSolverMultistepScheduler
|
| 29 |
+
from diffusers.utils import (
|
| 30 |
+
logging,
|
| 31 |
+
replace_example_docstring,
|
| 32 |
+
)
|
| 33 |
+
import numpy as np
|
| 34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 35 |
+
from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
| 36 |
+
from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
|
| 37 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 38 |
+
|
| 39 |
+
def check_and_print_non_float32_parameters(model):
|
| 40 |
+
non_float32_params = []
|
| 41 |
+
for name, param in model.named_parameters():
|
| 42 |
+
if param.dtype != torch.float32:
|
| 43 |
+
non_float32_params.append((name, param.dtype))
|
| 44 |
+
|
| 45 |
+
if non_float32_params:
|
| 46 |
+
print("Not all parameters are in float32!")
|
| 47 |
+
print("The following parameters are not in float32:")
|
| 48 |
+
for name, dtype in non_float32_params:
|
| 49 |
+
print(f"Parameter: {name}, Data Type: {dtype}")
|
| 50 |
+
else:
|
| 51 |
+
print("All parameters are in float32.")
|
| 52 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 53 |
+
|
| 54 |
+
EXAMPLE_DOC_STRING = """
|
| 55 |
+
Examples:
|
| 56 |
+
```py
|
| 57 |
+
>>> import scipy
|
| 58 |
+
>>> import torch
|
| 59 |
+
>>> import soundfile as sf
|
| 60 |
+
>>> from diffusers import StableAudioPipeline
|
| 61 |
+
|
| 62 |
+
>>> repo_id = "stabilityai/stable-audio-open-1.0"
|
| 63 |
+
>>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
| 64 |
+
>>> pipe = pipe.to("cuda")
|
| 65 |
+
|
| 66 |
+
>>> # define the prompts
|
| 67 |
+
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
| 68 |
+
>>> negative_prompt = "Low quality."
|
| 69 |
+
|
| 70 |
+
>>> # set the seed for generator
|
| 71 |
+
>>> generator = torch.Generator("cuda").manual_seed(0)
|
| 72 |
+
|
| 73 |
+
>>> # run the generation
|
| 74 |
+
>>> audio = pipe(
|
| 75 |
+
... prompt,
|
| 76 |
+
... negative_prompt=negative_prompt,
|
| 77 |
+
... num_inference_steps=200,
|
| 78 |
+
... audio_end_in_s=10.0,
|
| 79 |
+
... num_waveforms_per_prompt=3,
|
| 80 |
+
... generator=generator,
|
| 81 |
+
... ).audios
|
| 82 |
+
|
| 83 |
+
>>> output = audio[0].T.float().cpu().numpy()
|
| 84 |
+
>>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
|
| 85 |
+
```
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class StableAudioPipeline(DiffusionPipeline):
|
| 90 |
+
r"""
|
| 91 |
+
Pipeline for text-to-audio generation using StableAudio.
|
| 92 |
+
|
| 93 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 94 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
vae ([`AutoencoderOobleck`]):
|
| 98 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 99 |
+
text_encoder ([`~transformers.T5EncoderModel`]):
|
| 100 |
+
Frozen text-encoder. StableAudio uses the encoder of
|
| 101 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 102 |
+
[google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
|
| 103 |
+
projection_model ([`StableAudioProjectionModel`]):
|
| 104 |
+
A trained model used to linearly project the hidden-states from the text encoder model and the start and
|
| 105 |
+
end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
|
| 106 |
+
give the input to the transformer model.
|
| 107 |
+
tokenizer ([`~transformers.T5Tokenizer`]):
|
| 108 |
+
Tokenizer to tokenize text for the frozen text-encoder.
|
| 109 |
+
transformer ([`StableAudioDiTModel`]):
|
| 110 |
+
A `StableAudioDiTModel` to denoise the encoded audio latents.
|
| 111 |
+
scheduler ([`EDMDPMSolverMultistepScheduler`]):
|
| 112 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
vae: AutoencoderOobleck,
|
| 120 |
+
text_encoder: T5EncoderModel,
|
| 121 |
+
projection_model: StableAudioProjectionModel,
|
| 122 |
+
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
| 123 |
+
transformer: StableAudioDiTModel,
|
| 124 |
+
scheduler: EDMDPMSolverMultistepScheduler,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.register_modules(
|
| 129 |
+
vae=vae,
|
| 130 |
+
text_encoder=text_encoder,
|
| 131 |
+
projection_model=projection_model,
|
| 132 |
+
tokenizer=tokenizer,
|
| 133 |
+
transformer=transformer,
|
| 134 |
+
scheduler=scheduler,
|
| 135 |
+
)
|
| 136 |
+
self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
|
| 137 |
+
|
| 138 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
|
| 139 |
+
def enable_vae_slicing(self):
|
| 140 |
+
r"""
|
| 141 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 142 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 143 |
+
"""
|
| 144 |
+
self.vae.enable_slicing()
|
| 145 |
+
|
| 146 |
+
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
|
| 147 |
+
def disable_vae_slicing(self):
|
| 148 |
+
r"""
|
| 149 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 150 |
+
computing decoding in one step.
|
| 151 |
+
"""
|
| 152 |
+
self.vae.disable_slicing()
|
| 153 |
+
|
| 154 |
+
def encode_prompt(
|
| 155 |
+
self,
|
| 156 |
+
prompt,
|
| 157 |
+
device,
|
| 158 |
+
do_classifier_free_guidance,
|
| 159 |
+
negative_prompt=None,
|
| 160 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 161 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 162 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 163 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 164 |
+
):
|
| 165 |
+
if prompt is not None and isinstance(prompt, str):
|
| 166 |
+
batch_size = 1
|
| 167 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 168 |
+
batch_size = len(prompt)
|
| 169 |
+
else:
|
| 170 |
+
batch_size = prompt_embeds.shape[0]
|
| 171 |
+
|
| 172 |
+
if prompt_embeds is None:
|
| 173 |
+
# 1. Tokenize text
|
| 174 |
+
self.tokenizer.model_max_length = 512
|
| 175 |
+
text_inputs = self.tokenizer(
|
| 176 |
+
prompt,
|
| 177 |
+
padding="max_length",
|
| 178 |
+
max_length=self.tokenizer.model_max_length,
|
| 179 |
+
truncation=True,
|
| 180 |
+
return_tensors="pt",
|
| 181 |
+
)
|
| 182 |
+
text_input_ids = text_inputs.input_ids
|
| 183 |
+
attention_mask = text_inputs.attention_mask
|
| 184 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 185 |
+
|
| 186 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 187 |
+
text_input_ids, untruncated_ids
|
| 188 |
+
):
|
| 189 |
+
removed_text = self.tokenizer.batch_decode(
|
| 190 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 191 |
+
)
|
| 192 |
+
logger.warning(
|
| 193 |
+
f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
|
| 194 |
+
f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
text_input_ids = text_input_ids.to(device)
|
| 198 |
+
attention_mask = attention_mask.to(device)
|
| 199 |
+
|
| 200 |
+
# 2. Text encoder forward
|
| 201 |
+
self.text_encoder.eval()
|
| 202 |
+
prompt_embeds = self.text_encoder(
|
| 203 |
+
text_input_ids,
|
| 204 |
+
attention_mask=attention_mask,
|
| 205 |
+
)
|
| 206 |
+
prompt_embeds = prompt_embeds[0]
|
| 207 |
+
|
| 208 |
+
if do_classifier_free_guidance and negative_prompt is not None:
|
| 209 |
+
uncond_tokens: List[str]
|
| 210 |
+
if type(prompt) is not type(negative_prompt):
|
| 211 |
+
raise TypeError(
|
| 212 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 213 |
+
f" {type(prompt)}."
|
| 214 |
+
)
|
| 215 |
+
elif isinstance(negative_prompt, str):
|
| 216 |
+
uncond_tokens = [negative_prompt]
|
| 217 |
+
elif batch_size != len(negative_prompt):
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 220 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 221 |
+
" the batch size of `prompt`."
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
uncond_tokens = negative_prompt
|
| 225 |
+
|
| 226 |
+
# 1. Tokenize text
|
| 227 |
+
uncond_input = self.tokenizer(
|
| 228 |
+
uncond_tokens,
|
| 229 |
+
padding="max_length",
|
| 230 |
+
max_length=self.tokenizer.model_max_length,
|
| 231 |
+
truncation=True,
|
| 232 |
+
return_tensors="pt",
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
uncond_input_ids = uncond_input.input_ids.to(device)
|
| 236 |
+
negative_attention_mask = uncond_input.attention_mask.to(device)
|
| 237 |
+
|
| 238 |
+
# 2. Text encoder forward
|
| 239 |
+
self.text_encoder.eval()
|
| 240 |
+
negative_prompt_embeds = self.text_encoder(
|
| 241 |
+
uncond_input_ids,
|
| 242 |
+
attention_mask=negative_attention_mask,
|
| 243 |
+
)
|
| 244 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 245 |
+
|
| 246 |
+
if negative_attention_mask is not None:
|
| 247 |
+
# set the masked tokens to the null embed
|
| 248 |
+
negative_prompt_embeds = torch.where(
|
| 249 |
+
negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# 3. Project prompt_embeds and negative_prompt_embeds
|
| 253 |
+
if do_classifier_free_guidance and negative_prompt_embeds is not None:
|
| 254 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 255 |
+
# Here we concatenate the negative and text embeddings into a single batch
|
| 256 |
+
# to avoid doing two forward passes
|
| 257 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds, prompt_embeds])
|
| 258 |
+
if attention_mask is not None and negative_attention_mask is None:
|
| 259 |
+
negative_attention_mask = torch.ones_like(attention_mask)
|
| 260 |
+
elif attention_mask is None and negative_attention_mask is not None:
|
| 261 |
+
attention_mask = torch.ones_like(negative_attention_mask)
|
| 262 |
+
if attention_mask is not None:
|
| 263 |
+
attention_mask = torch.cat([negative_attention_mask, attention_mask, attention_mask, attention_mask])
|
| 264 |
+
|
| 265 |
+
prompt_embeds = self.projection_model(
|
| 266 |
+
text_hidden_states=prompt_embeds,
|
| 267 |
+
).text_hidden_states
|
| 268 |
+
if attention_mask is not None:
|
| 269 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 270 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
| 271 |
+
|
| 272 |
+
return prompt_embeds
|
| 273 |
+
|
| 274 |
+
def encode_duration(
|
| 275 |
+
self,
|
| 276 |
+
audio_start_in_s,
|
| 277 |
+
audio_end_in_s,
|
| 278 |
+
device,
|
| 279 |
+
do_classifier_free_guidance,
|
| 280 |
+
batch_size,
|
| 281 |
+
):
|
| 282 |
+
audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
|
| 283 |
+
audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
|
| 284 |
+
|
| 285 |
+
if len(audio_start_in_s) == 1:
|
| 286 |
+
audio_start_in_s = audio_start_in_s * batch_size
|
| 287 |
+
if len(audio_end_in_s) == 1:
|
| 288 |
+
audio_end_in_s = audio_end_in_s * batch_size
|
| 289 |
+
|
| 290 |
+
# Cast the inputs to floats
|
| 291 |
+
audio_start_in_s = [float(x) for x in audio_start_in_s]
|
| 292 |
+
audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
|
| 293 |
+
|
| 294 |
+
audio_end_in_s = [float(x) for x in audio_end_in_s]
|
| 295 |
+
audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
|
| 296 |
+
|
| 297 |
+
projection_output = self.projection_model(
|
| 298 |
+
start_seconds=audio_start_in_s,
|
| 299 |
+
end_seconds=audio_end_in_s,
|
| 300 |
+
)
|
| 301 |
+
seconds_start_hidden_states = projection_output.seconds_start_hidden_states
|
| 302 |
+
seconds_end_hidden_states = projection_output.seconds_end_hidden_states
|
| 303 |
+
|
| 304 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 305 |
+
# Here we repeat the audio hidden states to avoid doing two forward passes
|
| 306 |
+
if do_classifier_free_guidance:
|
| 307 |
+
seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
|
| 308 |
+
seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
|
| 309 |
+
|
| 310 |
+
return seconds_start_hidden_states, seconds_end_hidden_states
|
| 311 |
+
|
| 312 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 313 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 314 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 315 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 316 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 317 |
+
# and should be between [0, 1]
|
| 318 |
+
|
| 319 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 320 |
+
extra_step_kwargs = {}
|
| 321 |
+
if accepts_eta:
|
| 322 |
+
extra_step_kwargs["eta"] = eta
|
| 323 |
+
|
| 324 |
+
# check if the scheduler accepts generator
|
| 325 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 326 |
+
if accepts_generator:
|
| 327 |
+
extra_step_kwargs["generator"] = generator
|
| 328 |
+
return extra_step_kwargs
|
| 329 |
+
|
| 330 |
+
def check_inputs(
|
| 331 |
+
self,
|
| 332 |
+
prompt,
|
| 333 |
+
audio_start_in_s,
|
| 334 |
+
audio_end_in_s,
|
| 335 |
+
callback_steps,
|
| 336 |
+
negative_prompt=None,
|
| 337 |
+
prompt_embeds=None,
|
| 338 |
+
negative_prompt_embeds=None,
|
| 339 |
+
attention_mask=None,
|
| 340 |
+
negative_attention_mask=None,
|
| 341 |
+
initial_audio_waveforms=None,
|
| 342 |
+
initial_audio_sampling_rate=None,
|
| 343 |
+
):
|
| 344 |
+
if audio_end_in_s < audio_start_in_s:
|
| 345 |
+
raise ValueError(
|
| 346 |
+
f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if (
|
| 350 |
+
audio_start_in_s < self.projection_model.config.min_value
|
| 351 |
+
or audio_start_in_s > self.projection_model.config.max_value
|
| 352 |
+
):
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
|
| 355 |
+
f"is {audio_start_in_s}."
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if (
|
| 359 |
+
audio_end_in_s < self.projection_model.config.min_value
|
| 360 |
+
or audio_end_in_s > self.projection_model.config.max_value
|
| 361 |
+
):
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
|
| 364 |
+
f"is {audio_end_in_s}."
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if (callback_steps is None) or (
|
| 368 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 369 |
+
):
|
| 370 |
+
raise ValueError(
|
| 371 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 372 |
+
f" {type(callback_steps)}."
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
if prompt is not None and prompt_embeds is not None:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 378 |
+
" only forward one of the two."
|
| 379 |
+
)
|
| 380 |
+
elif prompt is None and (prompt_embeds is None):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
"Provide either `prompt`, or `prompt_embeds`. Cannot leave"
|
| 383 |
+
"`prompt` undefined without specifying `prompt_embeds`."
|
| 384 |
+
)
|
| 385 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 386 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 387 |
+
|
| 388 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 391 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 395 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 396 |
+
raise ValueError(
|
| 397 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 398 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 399 |
+
f" {negative_prompt_embeds.shape}."
|
| 400 |
+
)
|
| 401 |
+
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
| 402 |
+
raise ValueError(
|
| 403 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 404 |
+
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
|
| 408 |
+
raise ValueError(
|
| 409 |
+
"`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
|
| 415 |
+
"Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def prepare_latents(
|
| 419 |
+
self,
|
| 420 |
+
batch_size,
|
| 421 |
+
num_channels_vae,
|
| 422 |
+
sample_size,
|
| 423 |
+
dtype,
|
| 424 |
+
device,
|
| 425 |
+
generator,
|
| 426 |
+
latents=None,
|
| 427 |
+
initial_audio_waveforms=None,
|
| 428 |
+
num_waveforms_per_prompt=None,
|
| 429 |
+
audio_channels=None,
|
| 430 |
+
):
|
| 431 |
+
shape = (batch_size, num_channels_vae, sample_size)
|
| 432 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 435 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if latents is None:
|
| 439 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 440 |
+
else:
|
| 441 |
+
latents = latents.to(device)
|
| 442 |
+
|
| 443 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 444 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 445 |
+
|
| 446 |
+
# encode the initial audio for use by the model
|
| 447 |
+
if initial_audio_waveforms is not None:
|
| 448 |
+
# check dimension
|
| 449 |
+
if initial_audio_waveforms.ndim == 2:
|
| 450 |
+
initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
|
| 451 |
+
elif initial_audio_waveforms.ndim != 3:
|
| 452 |
+
raise ValueError(
|
| 453 |
+
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
|
| 457 |
+
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
|
| 458 |
+
|
| 459 |
+
# check num_channels
|
| 460 |
+
if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
|
| 461 |
+
initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
|
| 462 |
+
elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
|
| 463 |
+
initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
|
| 464 |
+
|
| 465 |
+
if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
|
| 466 |
+
raise ValueError(
|
| 467 |
+
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# crop or pad
|
| 471 |
+
audio_length = initial_audio_waveforms.shape[-1]
|
| 472 |
+
if audio_length < audio_vae_length:
|
| 473 |
+
logger.warning(
|
| 474 |
+
f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
|
| 475 |
+
)
|
| 476 |
+
elif audio_length > audio_vae_length:
|
| 477 |
+
logger.warning(
|
| 478 |
+
f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
audio = initial_audio_waveforms.new_zeros(audio_shape)
|
| 482 |
+
audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
|
| 483 |
+
|
| 484 |
+
encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
|
| 485 |
+
encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
|
| 486 |
+
latents = encoded_audio + latents
|
| 487 |
+
return latents
|
| 488 |
+
|
| 489 |
+
@torch.no_grad()
|
| 490 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 491 |
+
def __call__(
|
| 492 |
+
self,
|
| 493 |
+
extracted_condition_audio = None,
|
| 494 |
+
extracted_condition = None,
|
| 495 |
+
prompt: Union[str, List[str]] = None,
|
| 496 |
+
audio_end_in_s: Optional[float] = None,
|
| 497 |
+
audio_start_in_s: Optional[float] = 0.0,
|
| 498 |
+
num_inference_steps: int = 100,
|
| 499 |
+
guidance_scale_text: float = 7.0,
|
| 500 |
+
guidance_scale_con: float = 2.0,
|
| 501 |
+
guidance_scale_audio: float = 2.0,
|
| 502 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 503 |
+
num_waveforms_per_prompt: Optional[int] = 1,
|
| 504 |
+
eta: float = 0.0,
|
| 505 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 506 |
+
latents: Optional[torch.Tensor] = None,
|
| 507 |
+
initial_audio_waveforms: Optional[torch.Tensor] = None,
|
| 508 |
+
initial_audio_sampling_rate: Optional[torch.Tensor] = None,
|
| 509 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 510 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 511 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 512 |
+
negative_attention_mask: Optional[torch.LongTensor] = None,
|
| 513 |
+
return_dict: bool = True,
|
| 514 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 515 |
+
callback_steps: Optional[int] = 1,
|
| 516 |
+
output_type: Optional[str] = "pt",
|
| 517 |
+
):
|
| 518 |
+
r"""
|
| 519 |
+
The call function to the pipeline for generation.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 523 |
+
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
| 524 |
+
audio_end_in_s (`float`, *optional*, defaults to 47.55):
|
| 525 |
+
Audio end index in seconds.
|
| 526 |
+
audio_start_in_s (`float`, *optional*, defaults to 0):
|
| 527 |
+
Audio start index in seconds.
|
| 528 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 529 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 530 |
+
expense of slower inference.
|
| 531 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 532 |
+
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
| 533 |
+
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 534 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 535 |
+
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
| 536 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 537 |
+
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
| 538 |
+
The number of waveforms to generate per prompt.
|
| 539 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 540 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 541 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 542 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 543 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 544 |
+
generation deterministic.
|
| 545 |
+
latents (`torch.Tensor`, *optional*):
|
| 546 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
|
| 547 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 548 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 549 |
+
initial_audio_waveforms (`torch.Tensor`, *optional*):
|
| 550 |
+
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
|
| 551 |
+
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
|
| 552 |
+
corresponds to the number of prompts passed to the model.
|
| 553 |
+
initial_audio_sampling_rate (`int`, *optional*):
|
| 554 |
+
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
|
| 555 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 556 |
+
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
|
| 557 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
|
| 558 |
+
argument.
|
| 559 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 560 |
+
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
|
| 561 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 562 |
+
`negative_prompt` input argument.
|
| 563 |
+
attention_mask (`torch.LongTensor`, *optional*):
|
| 564 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 565 |
+
be computed from `prompt` input argument.
|
| 566 |
+
negative_attention_mask (`torch.LongTensor`, *optional*):
|
| 567 |
+
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
|
| 568 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 569 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 570 |
+
plain tuple.
|
| 571 |
+
callback (`Callable`, *optional*):
|
| 572 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 573 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 574 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 575 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 576 |
+
every step.
|
| 577 |
+
output_type (`str`, *optional*, defaults to `"pt"`):
|
| 578 |
+
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
| 579 |
+
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
| 580 |
+
model (LDM) output.
|
| 581 |
+
|
| 582 |
+
Examples:
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 586 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 587 |
+
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
| 588 |
+
"""
|
| 589 |
+
# 0. Convert audio input length from seconds to latent length
|
| 590 |
+
downsample_ratio = self.vae.hop_length
|
| 591 |
+
|
| 592 |
+
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
|
| 593 |
+
if audio_end_in_s is None:
|
| 594 |
+
audio_end_in_s = max_audio_length_in_s
|
| 595 |
+
|
| 596 |
+
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
|
| 602 |
+
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
|
| 603 |
+
waveform_length = int(self.transformer.config.sample_size) # * audio_end_in_s / 47.554
|
| 604 |
+
# waveform_length = 646
|
| 605 |
+
# 1. Check inputs. Raise error if not correct
|
| 606 |
+
self.check_inputs(
|
| 607 |
+
prompt,
|
| 608 |
+
audio_start_in_s,
|
| 609 |
+
audio_end_in_s,
|
| 610 |
+
callback_steps,
|
| 611 |
+
negative_prompt,
|
| 612 |
+
prompt_embeds,
|
| 613 |
+
negative_prompt_embeds,
|
| 614 |
+
attention_mask,
|
| 615 |
+
negative_attention_mask,
|
| 616 |
+
initial_audio_waveforms,
|
| 617 |
+
initial_audio_sampling_rate,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# 2. Define call parameters
|
| 621 |
+
if prompt is not None and isinstance(prompt, str):
|
| 622 |
+
batch_size = 1
|
| 623 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 624 |
+
batch_size = len(prompt)
|
| 625 |
+
else:
|
| 626 |
+
batch_size = prompt_embeds.shape[0]
|
| 627 |
+
|
| 628 |
+
device = self._execution_device
|
| 629 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 630 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 631 |
+
# corresponds to doing no classifier free guidance.
|
| 632 |
+
do_classifier_free_guidance = True
|
| 633 |
+
|
| 634 |
+
# 3. Encode input prompt
|
| 635 |
+
prompt_embeds = self.encode_prompt(
|
| 636 |
+
prompt,
|
| 637 |
+
device,
|
| 638 |
+
do_classifier_free_guidance,
|
| 639 |
+
negative_prompt,
|
| 640 |
+
prompt_embeds,
|
| 641 |
+
negative_prompt_embeds,
|
| 642 |
+
attention_mask,
|
| 643 |
+
negative_attention_mask,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Encode duration
|
| 647 |
+
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
|
| 648 |
+
audio_start_in_s,
|
| 649 |
+
audio_end_in_s,
|
| 650 |
+
device,
|
| 651 |
+
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
|
| 652 |
+
batch_size,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Create text_audio_duration_embeds and audio_duration_embeds
|
| 656 |
+
text_audio_duration_embeds = torch.cat(
|
| 657 |
+
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
|
| 661 |
+
|
| 662 |
+
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
|
| 663 |
+
# to concatenate it to the embeddings
|
| 664 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
|
| 665 |
+
negative_text_audio_duration_embeds = torch.zeros_like(
|
| 666 |
+
text_audio_duration_embeds, device=text_audio_duration_embeds.device
|
| 667 |
+
)
|
| 668 |
+
text_audio_duration_embeds = torch.cat(
|
| 669 |
+
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
|
| 670 |
+
)
|
| 671 |
+
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
|
| 672 |
+
# if condition is not None:
|
| 673 |
+
# condition_conditioning = condition_model(condition)
|
| 674 |
+
# condition_no_conditioning = condition_model(torch.full_like(condition, fill_value=0))
|
| 675 |
+
# extracted_condition = torch.cat([condition_no_conditioning, condition_no_conditioning, condition_conditioning], dim=0)
|
| 676 |
+
|
| 677 |
+
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
|
| 678 |
+
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
|
| 679 |
+
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 680 |
+
text_audio_duration_embeds = text_audio_duration_embeds.view(
|
| 681 |
+
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
| 685 |
+
audio_duration_embeds = audio_duration_embeds.view(
|
| 686 |
+
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# 4. Prepare timesteps
|
| 690 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 691 |
+
timesteps = self.scheduler.timesteps
|
| 692 |
+
|
| 693 |
+
# 5. Prepare latent variables
|
| 694 |
+
num_channels_vae = self.transformer.config.in_channels
|
| 695 |
+
latents = self.prepare_latents(
|
| 696 |
+
batch_size * num_waveforms_per_prompt,
|
| 697 |
+
num_channels_vae,
|
| 698 |
+
waveform_length,
|
| 699 |
+
text_audio_duration_embeds.dtype,
|
| 700 |
+
device,
|
| 701 |
+
generator,
|
| 702 |
+
latents,
|
| 703 |
+
initial_audio_waveforms,
|
| 704 |
+
num_waveforms_per_prompt,
|
| 705 |
+
audio_channels=self.vae.config.audio_channels,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
# 6. Prepare extra step kwargs
|
| 709 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 710 |
+
|
| 711 |
+
# 7. Prepare rotary positional embedding
|
| 712 |
+
rotary_embedding = get_1d_rotary_pos_embed(
|
| 713 |
+
self.rotary_embed_dim,
|
| 714 |
+
latents.shape[2] + audio_duration_embeds.shape[1],
|
| 715 |
+
use_real=True,
|
| 716 |
+
repeat_interleave_real=False,
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
# 8. Denoising loop
|
| 720 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 721 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 722 |
+
for i, t in enumerate(timesteps):
|
| 723 |
+
# expand the latents if we are doing classifier free guidance
|
| 724 |
+
latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
|
| 725 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 726 |
+
with autocast():
|
| 727 |
+
noise_pred = self.transformer(
|
| 728 |
+
latent_model_input,
|
| 729 |
+
t.unsqueeze(0),
|
| 730 |
+
encoder_hidden_states=text_audio_duration_embeds,
|
| 731 |
+
encoder_hidden_states_con=extracted_condition,
|
| 732 |
+
encoder_hidden_states_audio = extracted_condition_audio,
|
| 733 |
+
global_hidden_states=audio_duration_embeds,
|
| 734 |
+
rotary_embedding=rotary_embedding,
|
| 735 |
+
return_dict=False,
|
| 736 |
+
)[0]
|
| 737 |
+
# transformer_weight_dtype = next(self.transformer.parameters()).dtype
|
| 738 |
+
# print("transformer_weight_dtype",transformer_weight_dtype)
|
| 739 |
+
# noise_pred = noise_pred.half()
|
| 740 |
+
|
| 741 |
+
# perform guidance
|
| 742 |
+
if do_classifier_free_guidance:
|
| 743 |
+
noise_pred_uncond, noise_pred_text, noise_pred_both, noise_pred_both_audio = noise_pred.chunk(4)
|
| 744 |
+
noise_pred = noise_pred_uncond + guidance_scale_text * (noise_pred_text - noise_pred_uncond) + guidance_scale_con * (noise_pred_both - noise_pred_text) \
|
| 745 |
+
+ guidance_scale_audio * (noise_pred_both_audio - noise_pred_both)
|
| 746 |
+
# print("guidance_scale_audio", guidance_scale_audio)
|
| 747 |
+
# if do_classifier_free_guidance:
|
| 748 |
+
# noise_pred_uncond, noise_pred_text, noise_pred_both, noise_pred_both_audio = noise_pred.chunk(4)
|
| 749 |
+
# noise_pred_uncond_no_mask, noise_pred_text_no_mask, noise_pred_both_no_mask, noise_pred_both_audio_no_mask = noise_pred_uncond[:,:,:323], noise_pred_text[:,:,:323], noise_pred_both[:,:,:323], noise_pred_both_audio[:,:,:323]
|
| 750 |
+
# noise_pred_no_mask = noise_pred_uncond_no_mask + 7.0 * (noise_pred_text_no_mask - noise_pred_uncond_no_mask) + guidance_scale_con * (noise_pred_both_no_mask - noise_pred_text_no_mask) \
|
| 751 |
+
# + 1.5 * (noise_pred_both_audio_no_mask - noise_pred_both_no_mask)
|
| 752 |
+
# noise_pred_uncond_mask, noise_pred_text_mask, noise_pred_both_mask, noise_pred_both_audio_mask = noise_pred_uncond[:,:,323:], noise_pred_text[:,:,323:], noise_pred_both[:,:,323:], noise_pred_both_audio[:,:,323:]
|
| 753 |
+
# noise_pred_mask = noise_pred_uncond_mask + 7.0 * (noise_pred_text_mask - noise_pred_uncond_mask) + guidance_scale_con * (noise_pred_both_mask - noise_pred_text_mask) \
|
| 754 |
+
# + 4.5 * (noise_pred_both_audio_mask - noise_pred_both_mask)
|
| 755 |
+
# noise_pred = torch.concat((noise_pred_no_mask, noise_pred_mask), dim=2)
|
| 756 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 757 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 758 |
+
|
| 759 |
+
# call the callback, if provided
|
| 760 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 761 |
+
progress_bar.update()
|
| 762 |
+
if callback is not None and i % callback_steps == 0:
|
| 763 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 764 |
+
callback(step_idx, t, latents)
|
| 765 |
+
|
| 766 |
+
# 9. Post-processing
|
| 767 |
+
if not output_type == "latent":
|
| 768 |
+
with autocast():
|
| 769 |
+
audio = self.vae.decode(latents).sample
|
| 770 |
+
else:
|
| 771 |
+
return AudioPipelineOutput(audios=latents)
|
| 772 |
+
|
| 773 |
+
audio = audio[:, :, waveform_start:waveform_end]
|
| 774 |
+
|
| 775 |
+
if output_type == "np":
|
| 776 |
+
audio = audio.cpu().float().numpy()
|
| 777 |
+
|
| 778 |
+
self.maybe_free_model_hooks()
|
| 779 |
+
|
| 780 |
+
if not return_dict:
|
| 781 |
+
return (audio,)
|
| 782 |
+
|
| 783 |
+
return AudioPipelineOutput(audios=audio)
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/fundwotsai2001/[email protected]
|
| 2 |
+
git+https://github.com/YianLai0327/madmom.git
|
| 3 |
+
torch
|
| 4 |
+
torchaudio
|
| 5 |
+
soundfile
|
| 6 |
+
accelerate
|
| 7 |
+
transformers==4.46.1
|
| 8 |
+
matplotlib
|
| 9 |
+
librosa
|
| 10 |
+
torchsde
|
| 11 |
+
gdown
|
| 12 |
+
wandb
|
| 13 |
+
gradio
|
utils/extract_conditions.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.signal import savgol_filter
|
| 4 |
+
import librosa
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from torchaudio import transforms as T
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
import librosa
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compute_melody_v2(stereo_audio: torch.Tensor) -> np.ndarray:
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
stereo_audio: torch.Tensor of shape (2, N), 其中 stereo_audio[0] 是左聲道,
|
| 19 |
+
stereo_audio[1] 是右聲道。
|
| 20 |
+
sr: 取樣率 (sampling rate)。
|
| 21 |
+
Returns:
|
| 22 |
+
c: np.ndarray of shape (8, T_frames),
|
| 23 |
+
每一列代表: [L1, R1, L2, R2, L3, R3, L4, R4](按 frame 交錯),
|
| 24 |
+
且每個值都 ∈ {1, 2, …, 128},對應 CQT 的頻率 bin。
|
| 25 |
+
"""
|
| 26 |
+
audio, sr = torchaudio.load(stereo_audio)
|
| 27 |
+
# 1. 先針對左、右聲道分別計算 CQT (128 bins),回傳 cqt_db 形狀都是 (128, T_frames)
|
| 28 |
+
cqt_left = compute_music_represent(audio[0], sr) # shape: (128, T_frames)
|
| 29 |
+
cqt_right = compute_music_represent(audio[1], sr) # shape: (128, T_frames)
|
| 30 |
+
|
| 31 |
+
# 2. 取得時框 (frame) 數量
|
| 32 |
+
# 注意:librosa.cqt 的輸出 cqt_db 對應的「時框數」就是第二維度
|
| 33 |
+
T_frames = cqt_left.shape[1]
|
| 34 |
+
|
| 35 |
+
# 3. 預先配置輸出矩陣 c,dtype 用 int,shape = (8, T_frames)
|
| 36 |
+
c = np.zeros((8, T_frames), dtype=np.int32)
|
| 37 |
+
|
| 38 |
+
# 4. 逐一 frame 處理:對每個 frame 的 128 維度做 top-4
|
| 39 |
+
for j in range(T_frames):
|
| 40 |
+
# 4.1 取出當前時框的左、右聲道 CQT 能量(分貝值)
|
| 41 |
+
col_L = cqt_left[:, j] # shape: (128,)
|
| 42 |
+
col_R = cqt_right[:, j] # shape: (128,)
|
| 43 |
+
|
| 44 |
+
# 4.2 用 numpy.argsort 找到「前 4 大」的索引
|
| 45 |
+
# np.argsort 預設是從小到大排序,所以取最後 4 個,再反轉取大到小
|
| 46 |
+
idx4_L = np.argsort(col_L)[-4:][::-1] # 0-based, 長度=4
|
| 47 |
+
idx4_R = np.argsort(col_R)[-4:][::-1] # 0-based, 長度=4
|
| 48 |
+
|
| 49 |
+
# 4.3 轉成 1-based(因為題意寫 pixel ∈ {1,2,…,128})
|
| 50 |
+
idx4_L = idx4_L + 1 # 現在範圍是 1..128
|
| 51 |
+
idx4_R = idx4_R + 1
|
| 52 |
+
|
| 53 |
+
# 4.4 交錯填入 c 的第 j 欄
|
| 54 |
+
# 我們希望 c[:, j] = [L1, R1, L2, R2, L3, R3, L4, R4]
|
| 55 |
+
for k in range(4):
|
| 56 |
+
c[2 * k , j] = idx4_L[k]
|
| 57 |
+
c[2 * k + 1, j] = idx4_R[k]
|
| 58 |
+
|
| 59 |
+
return c[:,:4097]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_music_represent(audio, sr):
|
| 63 |
+
filter_y = torchaudio.functional.highpass_biquad(audio, sr, 261.6)
|
| 64 |
+
fmin = librosa.midi_to_hz(0)
|
| 65 |
+
cqt_spec = librosa.cqt(y=filter_y.numpy(), fmin=fmin, sr=sr, n_bins=128, bins_per_octave=12, hop_length=512)
|
| 66 |
+
cqt_db = librosa.amplitude_to_db(np.abs(cqt_spec), ref=np.max)
|
| 67 |
+
return cqt_db
|
| 68 |
+
|
| 69 |
+
def keep_top4_pitches_per_channel(cqt_db):
|
| 70 |
+
"""
|
| 71 |
+
cqt_db is assumed to have shape: (2, 128, time_frames).
|
| 72 |
+
We return a combined 2D array of shape (128, time_frames)
|
| 73 |
+
where only the top 4 pitch bins in each channel are kept
|
| 74 |
+
(for a total of up to 8 bins per time frame).
|
| 75 |
+
"""
|
| 76 |
+
# Parse shapes
|
| 77 |
+
num_channels, num_bins, num_frames = cqt_db.shape
|
| 78 |
+
|
| 79 |
+
# Initialize an output array that combines both channels
|
| 80 |
+
# and has zeros everywhere initially
|
| 81 |
+
combined = np.zeros((num_bins, num_frames), dtype=cqt_db.dtype)
|
| 82 |
+
|
| 83 |
+
for ch in range(num_channels):
|
| 84 |
+
for t in range(num_frames):
|
| 85 |
+
# Find the top 4 pitch bins for this channel at frame t
|
| 86 |
+
# argsort sorts ascending; we take the last 4 indices for top 4
|
| 87 |
+
top4_indices = np.argsort(cqt_db[ch, :, t])[-4:]
|
| 88 |
+
|
| 89 |
+
# Copy their values into the combined array
|
| 90 |
+
# We add to it in case there's overlap between channels
|
| 91 |
+
combined[top4_indices, t] = 1
|
| 92 |
+
return combined
|
| 93 |
+
def compute_melody(input_audio):
|
| 94 |
+
# Initialize parameters
|
| 95 |
+
sample_rate = 44100
|
| 96 |
+
|
| 97 |
+
# Load audio file
|
| 98 |
+
wav, sr = torchaudio.load(input_audio)
|
| 99 |
+
if sr != sample_rate:
|
| 100 |
+
resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
| 101 |
+
wav = resample(wav)
|
| 102 |
+
# Truncate or pad the audio to 2097152 samples
|
| 103 |
+
target_length = 2097152
|
| 104 |
+
if wav.size(1) > target_length:
|
| 105 |
+
# Truncate the audio if it is longer than the target length
|
| 106 |
+
wav = wav[:, :target_length]
|
| 107 |
+
elif wav.size(1) < target_length:
|
| 108 |
+
# Pad the audio with zeros if it is shorter than the target length
|
| 109 |
+
padding = target_length - wav.size(1)
|
| 110 |
+
wav = torch.cat([wav, torch.zeros(wav.size(0), padding)], dim=1)
|
| 111 |
+
melody = compute_music_represent(wav, 44100)
|
| 112 |
+
melody = keep_top4_pitches_per_channel(melody)
|
| 113 |
+
return melody
|
| 114 |
+
|
| 115 |
+
def compute_dynamics(audio_file, hop_length=160, target_sample_rate=44100, cut=True):
|
| 116 |
+
"""
|
| 117 |
+
Compute the dynamics curve for a given audio file.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
audio_file (str): Path to the audio file.
|
| 121 |
+
window_length (int): Length of FFT window for computing the spectrogram.
|
| 122 |
+
hop_length (int): Number of samples between successive frames.
|
| 123 |
+
smoothing_window (int): Length of the Savitzky-Golay filter window.
|
| 124 |
+
polyorder (int): Polynomial order of the Savitzky-Golay filter.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
dynamics_curve (numpy.ndarray): The computed dynamic values in dB.
|
| 128 |
+
"""
|
| 129 |
+
# Load audio file
|
| 130 |
+
waveform, original_sample_rate = torchaudio.load(audio_file)
|
| 131 |
+
if original_sample_rate != target_sample_rate:
|
| 132 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
|
| 133 |
+
waveform = resampler(waveform)
|
| 134 |
+
if cut:
|
| 135 |
+
waveform = waveform[:, :2097152]
|
| 136 |
+
# Ensure waveform has a single channel (e.g., select the first channel if multi-channel)
|
| 137 |
+
waveform = waveform.mean(dim=0, keepdim=True) # Mix all channels into one
|
| 138 |
+
waveform = waveform.clamp(-1, 1).numpy()
|
| 139 |
+
|
| 140 |
+
S = np.abs(librosa.stft(waveform, n_fft=1024, hop_length=hop_length))
|
| 141 |
+
mel_filter_bank = librosa.filters.mel(sr=target_sample_rate, n_fft=1024, n_mels=64, fmin=0, fmax=8000)
|
| 142 |
+
S = np.dot(mel_filter_bank, S)
|
| 143 |
+
energy = np.sum(S**2, axis=0)
|
| 144 |
+
dynamics_db = np.clip(energy, 1e-6, None)
|
| 145 |
+
dynamics_db = librosa.amplitude_to_db(energy, ref=np.max).squeeze(0)
|
| 146 |
+
smoothed_dynamics = savgol_filter(dynamics_db, window_length=279, polyorder=1)
|
| 147 |
+
# print(smoothed_dynamics.shape)
|
| 148 |
+
return smoothed_dynamics
|
| 149 |
+
def extract_melody_one_hot(audio_path,
|
| 150 |
+
sr=44100,
|
| 151 |
+
cutoff=261.2,
|
| 152 |
+
win_length=2048,
|
| 153 |
+
hop_length=256):
|
| 154 |
+
"""
|
| 155 |
+
Extract a one-hot chromagram-based melody from an audio file (mono).
|
| 156 |
+
|
| 157 |
+
Parameters:
|
| 158 |
+
-----------
|
| 159 |
+
audio_path : str
|
| 160 |
+
Path to the input audio file.
|
| 161 |
+
sr : int
|
| 162 |
+
Target sample rate to resample the audio (default: 44100).
|
| 163 |
+
cutoff : float
|
| 164 |
+
The high-pass filter cutoff frequency in Hz (default: Middle C ~ 261.2 Hz).
|
| 165 |
+
win_length : int
|
| 166 |
+
STFT window length for the chromagram (default: 2048).
|
| 167 |
+
hop_length : int
|
| 168 |
+
STFT hop length for the chromagram (default: 256).
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
--------
|
| 172 |
+
one_hot_chroma : np.ndarray, shape=(12, n_frames)
|
| 173 |
+
One-hot chromagram of the most prominent pitch class per frame.
|
| 174 |
+
"""
|
| 175 |
+
# ---------------------------------------------------------
|
| 176 |
+
# 1. Load audio (Torchaudio => shape: (channels, samples))
|
| 177 |
+
# ---------------------------------------------------------
|
| 178 |
+
audio, in_sr = torchaudio.load(audio_path)
|
| 179 |
+
|
| 180 |
+
# Convert to mono by averaging channels: shape => (samples,)
|
| 181 |
+
audio_mono = audio.mean(dim=0)
|
| 182 |
+
|
| 183 |
+
# Resample if necessary
|
| 184 |
+
if in_sr != sr:
|
| 185 |
+
resample_tf = T.Resample(orig_freq=in_sr, new_freq=sr)
|
| 186 |
+
audio_mono = resample_tf(audio_mono)
|
| 187 |
+
|
| 188 |
+
# Convert torch.Tensor => NumPy array: shape (samples,)
|
| 189 |
+
y = audio_mono.numpy()
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------
|
| 192 |
+
# 2. Design & apply a high-pass filter (Butterworth, order=2)
|
| 193 |
+
# ---------------------------------------------------------
|
| 194 |
+
nyquist = 0.5 * sr
|
| 195 |
+
norm_cutoff = cutoff / nyquist
|
| 196 |
+
b, a = signal.butter(N=2, Wn=norm_cutoff, btype='high', analog=False)
|
| 197 |
+
|
| 198 |
+
# filtfilt expects shape (n_samples,) for 1D
|
| 199 |
+
y_hp = signal.filtfilt(b, a, y)
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------
|
| 202 |
+
# 3. Compute the chromagram (librosa => shape: (12, n_frames))
|
| 203 |
+
# ---------------------------------------------------------
|
| 204 |
+
chroma = librosa.feature.chroma_stft(
|
| 205 |
+
y=y_hp,
|
| 206 |
+
sr=sr,
|
| 207 |
+
n_fft=win_length, # Usually >= win_length
|
| 208 |
+
win_length=win_length,
|
| 209 |
+
hop_length=hop_length
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# ---------------------------------------------------------
|
| 213 |
+
# 4. Convert chromagram to one-hot via argmax along pitch classes
|
| 214 |
+
# ---------------------------------------------------------
|
| 215 |
+
# pitch_class_idx => shape=(n_frames,)
|
| 216 |
+
pitch_class_idx = np.argmax(chroma, axis=0)
|
| 217 |
+
|
| 218 |
+
# Make a zero array of the same shape => (12, n_frames)
|
| 219 |
+
one_hot_chroma = np.zeros_like(chroma)
|
| 220 |
+
|
| 221 |
+
# For each frame (column in chroma), set the argmax row to 1
|
| 222 |
+
one_hot_chroma[pitch_class_idx, np.arange(chroma.shape[1])] = 1.0
|
| 223 |
+
|
| 224 |
+
return one_hot_chroma
|
| 225 |
+
def evaluate_f1_rhythm(input_timestamps, generated_timestamps, tolerance=0.07):
|
| 226 |
+
"""
|
| 227 |
+
Evaluates precision, recall, and F1-score for beat/downbeat timestamp alignment.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
input_timestamps (ndarray): 2D array of shape [n, 2], where column 0 contains timestamps.
|
| 231 |
+
generated_timestamps (ndarray): 2D array of shape [m, 2], where column 0 contains timestamps.
|
| 232 |
+
tolerance (float): Alignment tolerance in seconds (default: 70ms).
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
tuple: (precision, recall, f1)
|
| 236 |
+
"""
|
| 237 |
+
# Extract and sort timestamps
|
| 238 |
+
input_timestamps = np.asarray(input_timestamps)
|
| 239 |
+
generated_timestamps = np.asarray(generated_timestamps)
|
| 240 |
+
|
| 241 |
+
# If you only need the first column
|
| 242 |
+
if input_timestamps.size > 0:
|
| 243 |
+
input_timestamps = input_timestamps[:, 0]
|
| 244 |
+
input_timestamps.sort()
|
| 245 |
+
else:
|
| 246 |
+
input_timestamps = np.array([])
|
| 247 |
+
|
| 248 |
+
if generated_timestamps.size > 0:
|
| 249 |
+
generated_timestamps = generated_timestamps[:, 0]
|
| 250 |
+
generated_timestamps.sort()
|
| 251 |
+
else:
|
| 252 |
+
generated_timestamps = np.array([])
|
| 253 |
+
|
| 254 |
+
# Handle empty cases
|
| 255 |
+
# Case 1: Both are empty
|
| 256 |
+
if len(input_timestamps) == 0 and len(generated_timestamps) == 0:
|
| 257 |
+
# You could argue everything is correct since there's nothing to detect,
|
| 258 |
+
# but returning all zeros is a common convention.
|
| 259 |
+
return 0.0, 0.0, 0.0
|
| 260 |
+
|
| 261 |
+
# Case 2: No ground-truth timestamps, but predictions exist
|
| 262 |
+
if len(input_timestamps) == 0 and len(generated_timestamps) > 0:
|
| 263 |
+
# All predictions are false positives => tp=0, fp = len(generated_timestamps)
|
| 264 |
+
# => precision=0, recall is undefined (tp+fn=0), typically we treat recall=0
|
| 265 |
+
return 0.0, 0.0, 0.0
|
| 266 |
+
|
| 267 |
+
# Case 3: Ground-truth timestamps exist, but no predictions
|
| 268 |
+
if len(input_timestamps) > 0 and len(generated_timestamps) == 0:
|
| 269 |
+
# Everything in input_timestamps is a false negative => tp=0, fn = len(input_timestamps)
|
| 270 |
+
# => recall=0, precision is undefined (tp+fp=0), typically we treat precision=0
|
| 271 |
+
return 0.0, 0.0, 0.0
|
| 272 |
+
|
| 273 |
+
# If we get here, both arrays are non-empty
|
| 274 |
+
tp = 0
|
| 275 |
+
fp = 0
|
| 276 |
+
|
| 277 |
+
# Track matched ground-truth timestamps
|
| 278 |
+
matched_inputs = np.zeros(len(input_timestamps), dtype=bool)
|
| 279 |
+
|
| 280 |
+
for gen_ts in generated_timestamps:
|
| 281 |
+
# Calculate absolute differences to each reference timestamp
|
| 282 |
+
diffs = np.abs(input_timestamps - gen_ts)
|
| 283 |
+
# Find index of the closest input timestamp
|
| 284 |
+
min_diff_idx = np.argmin(diffs)
|
| 285 |
+
|
| 286 |
+
# Check if that difference is within tolerance and unmatched
|
| 287 |
+
if diffs[min_diff_idx] < tolerance and not matched_inputs[min_diff_idx]:
|
| 288 |
+
tp += 1
|
| 289 |
+
matched_inputs[min_diff_idx] = True
|
| 290 |
+
else:
|
| 291 |
+
fp += 1 # no suitable match found or closest was already matched
|
| 292 |
+
|
| 293 |
+
# Remaining unmatched input timestamps are false negatives
|
| 294 |
+
fn = np.sum(~matched_inputs)
|
| 295 |
+
|
| 296 |
+
# Compute precision, recall, f1
|
| 297 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 298 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 299 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 300 |
+
|
| 301 |
+
return precision, recall, f1
|
utils/feature_extractor.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
class dynamics_extractor_full_stereo(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.conv1d_1 = nn.Conv1d(2, 16, kernel_size=3, padding=1, stride=2)
|
| 8 |
+
self.conv1d_2 = nn.Conv1d(16, 16, kernel_size=3, padding=1)
|
| 9 |
+
self.conv1d_3 = nn.Conv1d(16, 128, kernel_size=3, padding=1, stride=2)
|
| 10 |
+
self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
|
| 11 |
+
self.conv1d_5 = nn.Conv1d(128, 256, kernel_size=3, padding=1, stride=2)
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
# original shape: (batchsize, 1, 8280)
|
| 14 |
+
# x = x.unsqueeze(1) # shape: (batchsize, 1, 8280)
|
| 15 |
+
x = self.conv1d_1(x) # shape: (batchsize, 16, 4140)
|
| 16 |
+
x = F.silu(x)
|
| 17 |
+
x = self.conv1d_2(x) # shape: (batchsize, 16, 4140)
|
| 18 |
+
x = F.silu(x)
|
| 19 |
+
x = self.conv1d_3(x) # shape: (batchsize, 128, 2070)
|
| 20 |
+
x = F.silu(x)
|
| 21 |
+
x = self.conv1d_4(x) # shape: (batchsize, 128, 2070)
|
| 22 |
+
x = F.silu(x)
|
| 23 |
+
x = self.conv1d_5(x) # shape: (batchsize, 192, 1035)
|
| 24 |
+
return x
|
| 25 |
+
class melody_extractor_full_mono(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.conv1d_1 = nn.Conv1d(128, 256, kernel_size=3, padding=0, stride=2)
|
| 29 |
+
self.conv1d_2 = nn.Conv1d(256, 256, kernel_size=3, padding=1)
|
| 30 |
+
self.conv1d_3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, stride=2)
|
| 31 |
+
self.conv1d_4 = nn.Conv1d(512, 512, kernel_size=3, padding=1)
|
| 32 |
+
self.conv1d_5 = nn.Conv1d(512, 768, kernel_size=3, padding=1)
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
# original shape: (batchsize, 12, 1296)
|
| 35 |
+
x = self.conv1d_1(x)# shape: (batchsize, 64, 2048)
|
| 36 |
+
x = F.silu(x)
|
| 37 |
+
x = self.conv1d_2(x) # shape: (batchsize, 64, 2048)
|
| 38 |
+
x = F.silu(x)
|
| 39 |
+
x = self.conv1d_3(x) # shape: (batchsize, 128, 1024)
|
| 40 |
+
x = F.silu(x)
|
| 41 |
+
x = self.conv1d_4(x) # shape: (batchsize, 128, 1024)
|
| 42 |
+
x = F.silu(x)
|
| 43 |
+
x = self.conv1d_5(x) # shape: (batchsize, 768, 1024)
|
| 44 |
+
return x
|
| 45 |
+
class melody_extractor_mono(nn.Module):
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.conv1d_1 = nn.Conv1d(128, 128, kernel_size=3, padding=0, stride=2)
|
| 49 |
+
self.conv1d_2 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
|
| 50 |
+
self.conv1d_3 = nn.Conv1d(192, 192, kernel_size=3, padding=1)
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
# original shape: (batchsize, 12, 1296)
|
| 53 |
+
x = self.conv1d_1(x)# shape: (batchsize, 64, 2048)
|
| 54 |
+
x = F.silu(x)
|
| 55 |
+
x = self.conv1d_2(x) # shape: (batchsize, 64, 2048)
|
| 56 |
+
x = F.silu(x)
|
| 57 |
+
x = self.conv1d_3(x) # shape: (batchsize, 128, 1024)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
class melody_extractor_full_stereo(nn.Module):
|
| 61 |
+
def __init__(self):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.embed = nn.Embedding(num_embeddings=129, embedding_dim=48)
|
| 64 |
+
|
| 65 |
+
# Four Conv1d layers, each with kernel_size=3, padding=1:
|
| 66 |
+
self.conv1 = nn.Conv1d(384, 384, kernel_size=3, padding=1)
|
| 67 |
+
self.conv2 = nn.Conv1d(384, 768, kernel_size=3, padding=1)
|
| 68 |
+
self.conv3 = nn.Conv1d(768, 768, kernel_size=3, padding=1)
|
| 69 |
+
|
| 70 |
+
def forward(self, melody_idxs):
|
| 71 |
+
# melody_idxs: LongTensor of shape (B, 8, 4096)
|
| 72 |
+
B, eight, L = melody_idxs.shape # L == 4096
|
| 73 |
+
|
| 74 |
+
# 1) Embed:
|
| 75 |
+
# (B, 8, 4096) → (B, 8, 4096, 48)
|
| 76 |
+
embedded = self.embed(melody_idxs)
|
| 77 |
+
|
| 78 |
+
# 2) Permute & reshape → (B, 8*48, 4096) = (B, 384, 4096)
|
| 79 |
+
x = embedded.permute(0, 1, 3, 2) # (B, 8, 48, 4096)
|
| 80 |
+
x = x.reshape(B, eight * 48, L) # (B, 384, 4096)
|
| 81 |
+
|
| 82 |
+
# 3) Conv1 → (B, 384, 4096)
|
| 83 |
+
x = F.silu(self.conv1(x))
|
| 84 |
+
|
| 85 |
+
# 4) Conv2 → (B, 768, 4096)
|
| 86 |
+
x = F.silu(self.conv2(x))
|
| 87 |
+
|
| 88 |
+
# 5) Conv3 → (B, 768, 4096)
|
| 89 |
+
x = F.silu(self.conv3(x))
|
| 90 |
+
|
| 91 |
+
# Now x is (B, 1536, 4096) and can be sent on to whatever comes next
|
| 92 |
+
return x
|
| 93 |
+
class melody_extractor_stereo(nn.Module):
|
| 94 |
+
def __init__(self):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.embed = nn.Embedding(num_embeddings=129, embedding_dim=4)
|
| 97 |
+
|
| 98 |
+
# Four Conv1d layers, each with kernel_size=3, padding=1:
|
| 99 |
+
self.conv1 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
|
| 100 |
+
self.conv2 = nn.Conv1d(64, 64, kernel_size=3, padding=0, stride=2)
|
| 101 |
+
self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
| 102 |
+
self.conv4 = nn.Conv1d(128, 128, kernel_size=3, padding=1, stride=2)
|
| 103 |
+
self.conv5 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
| 104 |
+
|
| 105 |
+
def forward(self, melody_idxs):
|
| 106 |
+
# melody_idxs: LongTensor of shape (B, 8, 4096)
|
| 107 |
+
B, eight, L = melody_idxs.shape # L == 4096
|
| 108 |
+
|
| 109 |
+
# 1) Embed:
|
| 110 |
+
# (B, 8, 4096) → (B, 8, 4096, 4)
|
| 111 |
+
embedded = self.embed(melody_idxs)
|
| 112 |
+
|
| 113 |
+
# 2) Permute & reshape → (B, 8*4, 4096) = (B, 32, 4096)
|
| 114 |
+
x = embedded.permute(0, 1, 3, 2) # (B, 8, 4, 4096)
|
| 115 |
+
x = x.reshape(B, eight * 4, L) # (B, 32, 4096)
|
| 116 |
+
|
| 117 |
+
# 3) Conv1 → (B, 384, 4096)
|
| 118 |
+
x = F.silu(self.conv1(x))
|
| 119 |
+
|
| 120 |
+
# 4) Conv2 → (B, 768, 4096)
|
| 121 |
+
x = F.silu(self.conv2(x))
|
| 122 |
+
|
| 123 |
+
# 5) Conv3 → (B, 768, 4096)
|
| 124 |
+
x = F.silu(self.conv3(x))
|
| 125 |
+
|
| 126 |
+
x = F.silu(self.conv4(x))
|
| 127 |
+
|
| 128 |
+
x = F.silu(self.conv5(x))
|
| 129 |
+
|
| 130 |
+
# Now x is (B, 1536, 4096) and can be sent on to whatever comes next
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
class dynamics_extractor(nn.Module):
|
| 134 |
+
def __init__(self):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.conv1d_1 = nn.Conv1d(1, 16, kernel_size=3, padding=1, stride=2)
|
| 137 |
+
self.conv1d_2 = nn.Conv1d(16, 16, kernel_size=3, padding=1)
|
| 138 |
+
self.conv1d_3 = nn.Conv1d(16, 128, kernel_size=3, padding=1, stride=2)
|
| 139 |
+
self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
|
| 140 |
+
self.conv1d_5 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
# original shape: (batchsize, 1, 8280)
|
| 143 |
+
# x = x.unsqueeze(1) # shape: (batchsize, 1, 8280)
|
| 144 |
+
x = self.conv1d_1(x) # shape: (batchsize, 16, 4140)
|
| 145 |
+
x = F.silu(x)
|
| 146 |
+
x = self.conv1d_2(x) # shape: (batchsize, 16, 4140)
|
| 147 |
+
x = F.silu(x)
|
| 148 |
+
x = self.conv1d_3(x) # shape: (batchsize, 128, 2070)
|
| 149 |
+
x = F.silu(x)
|
| 150 |
+
x = self.conv1d_4(x) # shape: (batchsize, 128, 2070)
|
| 151 |
+
x = F.silu(x)
|
| 152 |
+
x = self.conv1d_5(x) # shape: (batchsize, 192, 1035)
|
| 153 |
+
return x
|
| 154 |
+
class rhythm_extractor(nn.Module):
|
| 155 |
+
def __init__(self):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.conv1d_1 = nn.Conv1d(2, 16, kernel_size=3, padding=1)
|
| 158 |
+
self.conv1d_2 = nn.Conv1d(16, 64, kernel_size=3, padding=1)
|
| 159 |
+
self.conv1d_3 = nn.Conv1d(64, 128, kernel_size=3, padding=1, stride=2)
|
| 160 |
+
self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
|
| 161 |
+
self.conv1d_5 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
# original shape: (batchsize, 2, 3000)
|
| 164 |
+
x = self.conv1d_1(x)# shape: (batchsize, 64, 3000)
|
| 165 |
+
x = F.silu(x)
|
| 166 |
+
x = self.conv1d_2(x) # shape: (batchsize, 64, 3000)
|
| 167 |
+
x = F.silu(x)
|
| 168 |
+
x = self.conv1d_3(x) # shape: (batchsize, 128, 1500)
|
| 169 |
+
x = F.silu(x)
|
| 170 |
+
x = self.conv1d_4(x) # shape: (batchsize, 128, 1500)
|
| 171 |
+
x = F.silu(x)
|
| 172 |
+
x = self.conv1d_5(x) # shape:
|
| 173 |
+
return x
|
utils/stable_audio_dataset_utils.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
import torchaudio
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchaudio import transforms as T
|
| 9 |
+
|
| 10 |
+
def load_audio_file(filename, target_sr=44100, target_samples=2097152):
|
| 11 |
+
try:
|
| 12 |
+
audio, in_sr = torchaudio.load(filename)
|
| 13 |
+
# Resample if necessary
|
| 14 |
+
if in_sr != target_sr:
|
| 15 |
+
resampler = T.Resample(in_sr, target_sr)
|
| 16 |
+
audio = resampler(audio)
|
| 17 |
+
augs = torch.nn.Sequential(
|
| 18 |
+
PhaseFlipper(),
|
| 19 |
+
)
|
| 20 |
+
audio = augs(audio)
|
| 21 |
+
audio = audio.clamp(-1, 1)
|
| 22 |
+
encoding = torch.nn.Sequential(
|
| 23 |
+
Stereo(),
|
| 24 |
+
)
|
| 25 |
+
audio = encoding(audio)
|
| 26 |
+
# audio.shape is [channels, samples]
|
| 27 |
+
num_samples = audio.shape[-1]
|
| 28 |
+
|
| 29 |
+
# if num_samples < target_samples:
|
| 30 |
+
# # Pad if it's too short
|
| 31 |
+
# pad_amount = target_samples - num_samples
|
| 32 |
+
# # Zero-pad at the end (or randomly if you prefer)
|
| 33 |
+
# audio = F.pad(audio, (0, pad_amount))
|
| 34 |
+
# print(f"pad {pad_amount}")
|
| 35 |
+
# else:
|
| 36 |
+
audio = audio[:, :target_samples]
|
| 37 |
+
return audio
|
| 38 |
+
except RuntimeError:
|
| 39 |
+
print(f"Failed to decode audio file: {filename}")
|
| 40 |
+
return None
|
| 41 |
+
class PadCrop(nn.Module):
|
| 42 |
+
def __init__(self, n_samples, randomize=True):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.n_samples = n_samples
|
| 45 |
+
self.randomize = randomize
|
| 46 |
+
|
| 47 |
+
def __call__(self, signal):
|
| 48 |
+
n, s = signal.shape
|
| 49 |
+
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
|
| 50 |
+
end = start + self.n_samples
|
| 51 |
+
output = signal.new_zeros([n, self.n_samples])
|
| 52 |
+
output[:, :min(s, self.n_samples)] = signal[:, start:end]
|
| 53 |
+
return output
|
| 54 |
+
|
| 55 |
+
class PadCrop_Normalized_T(nn.Module):
|
| 56 |
+
|
| 57 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 58 |
+
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.n_samples = n_samples
|
| 62 |
+
self.sample_rate = sample_rate
|
| 63 |
+
self.randomize = randomize
|
| 64 |
+
|
| 65 |
+
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 66 |
+
|
| 67 |
+
n_channels, n_samples = source.shape
|
| 68 |
+
|
| 69 |
+
# If the audio is shorter than the desired length, pad it
|
| 70 |
+
upper_bound = max(0, n_samples - self.n_samples)
|
| 71 |
+
|
| 72 |
+
# If randomize is False, always start at the beginning of the audio
|
| 73 |
+
offset = 0
|
| 74 |
+
if(self.randomize and n_samples > self.n_samples):
|
| 75 |
+
offset = random.randint(0, upper_bound)
|
| 76 |
+
|
| 77 |
+
# Calculate the start and end times of the chunk
|
| 78 |
+
t_start = offset / (upper_bound + self.n_samples)
|
| 79 |
+
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
|
| 80 |
+
|
| 81 |
+
# Create the chunk
|
| 82 |
+
chunk = source.new_zeros([n_channels, self.n_samples])
|
| 83 |
+
|
| 84 |
+
# Copy the audio into the chunk
|
| 85 |
+
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
|
| 86 |
+
|
| 87 |
+
# Calculate the start and end times of the chunk in seconds
|
| 88 |
+
seconds_start = math.floor(offset / self.sample_rate)
|
| 89 |
+
seconds_total = math.ceil(n_samples / self.sample_rate)
|
| 90 |
+
|
| 91 |
+
# Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
|
| 92 |
+
padding_mask = torch.zeros([self.n_samples])
|
| 93 |
+
padding_mask[:min(n_samples, self.n_samples)] = 1
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
return (
|
| 97 |
+
chunk,
|
| 98 |
+
offset,
|
| 99 |
+
offset + self.n_samples,
|
| 100 |
+
seconds_start,
|
| 101 |
+
seconds_total,
|
| 102 |
+
padding_mask
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
class PhaseFlipper(nn.Module):
|
| 106 |
+
"Randomly invert the phase of a signal"
|
| 107 |
+
def __init__(self, p=0.5):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.p = p
|
| 110 |
+
def __call__(self, signal):
|
| 111 |
+
return -signal if (random.random() < self.p) else signal
|
| 112 |
+
|
| 113 |
+
class Mono(nn.Module):
|
| 114 |
+
def __call__(self, signal):
|
| 115 |
+
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
|
| 116 |
+
|
| 117 |
+
class Stereo(nn.Module):
|
| 118 |
+
def __call__(self, signal):
|
| 119 |
+
signal_shape = signal.shape
|
| 120 |
+
# Check if it's mono
|
| 121 |
+
if len(signal_shape) == 1: # s -> 2, s
|
| 122 |
+
signal = signal.unsqueeze(0).repeat(2, 1)
|
| 123 |
+
elif len(signal_shape) == 2:
|
| 124 |
+
if signal_shape[0] == 1: #1, s -> 2, s
|
| 125 |
+
signal = signal.repeat(2, 1)
|
| 126 |
+
elif signal_shape[0] > 2: #?, s -> 2,s
|
| 127 |
+
signal = signal[:2, :]
|
| 128 |
+
|
| 129 |
+
return signal
|