manoskary commited on
Commit
7a421a5
·
1 Parent(s): a52e073

Add audio utilities and track sample audio with LFS

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
MuseControlLite_setup.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, List, Optional, Tuple, Union
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from diffusers.utils import deprecate, logging
19
+ from safetensors.torch import load_file
20
+ from diffusers.loaders import AttnProcsLayers
21
+ from utils.extract_conditions import compute_melody, compute_melody_v2, compute_dynamics, extract_melody_one_hot, evaluate_f1_rhythm
22
+ from madmom.features.downbeats import DBNDownBeatTrackingProcessor,RNNDownBeatProcessor
23
+ import numpy as np
24
+ import matplotlib.pyplot as plt
25
+ import os
26
+ from utils.stable_audio_dataset_utils import load_audio_file
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+ import soundfile as sf
29
+
30
+ # For zero initialized 1D CNN in the attention processor
31
+ def zero_module(module):
32
+ for p in module.parameters():
33
+ nn.init.zeros_(p)
34
+ return module
35
+
36
+ # Original attention processor for
37
+ class StableAudioAttnProcessor2_0(torch.nn.Module):
38
+ r"""
39
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
40
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
41
+ """
42
+
43
+ def __init__(self):
44
+ super().__init__()
45
+ if not hasattr(F, "scaled_dot_product_attention"):
46
+ raise ImportError(
47
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
48
+ )
49
+ def apply_partial_rotary_emb(
50
+ self,
51
+ x: torch.Tensor,
52
+ freqs_cis: Tuple[torch.Tensor],
53
+ ) -> torch.Tensor:
54
+ from diffusers.models.embeddings import apply_rotary_emb
55
+
56
+ rot_dim = freqs_cis[0].shape[-1]
57
+ x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
58
+
59
+ x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
60
+
61
+ out = torch.cat((x_rotated, x_unrotated), dim=-1)
62
+ return out
63
+
64
+ def __call__(
65
+ self,
66
+ attn,
67
+ hidden_states: torch.Tensor,
68
+ encoder_hidden_states: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ rotary_emb: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ from diffusers.models.embeddings import apply_rotary_emb
73
+
74
+ residual = hidden_states
75
+
76
+ input_ndim = hidden_states.ndim
77
+
78
+ if input_ndim == 4:
79
+ batch_size, channel, height, width = hidden_states.shape
80
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
81
+
82
+ batch_size, sequence_length, _ = (
83
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
84
+ )
85
+ if attention_mask is not None:
86
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
87
+ # scaled_dot_product_attention expects attention_mask shape to be
88
+ # (batch, heads, source_length, target_length)
89
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
90
+
91
+ query = attn.to_q(hidden_states)
92
+
93
+ if encoder_hidden_states is None:
94
+ encoder_hidden_states = hidden_states
95
+ elif attn.norm_cross:
96
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
97
+
98
+ key = attn.to_k(encoder_hidden_states)
99
+ value = attn.to_v(encoder_hidden_states)
100
+ head_dim = query.shape[-1] // attn.heads
101
+ kv_heads = key.shape[-1] // head_dim
102
+
103
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
+
105
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
106
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
107
+
108
+ if kv_heads != attn.heads:
109
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
110
+ heads_per_kv_head = attn.heads // kv_heads
111
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
112
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
113
+
114
+ if attn.norm_q is not None:
115
+ query = attn.norm_q(query)
116
+ if attn.norm_k is not None:
117
+ key = attn.norm_k(key)
118
+
119
+ # Apply RoPE if needed
120
+ if rotary_emb is not None:
121
+ query_dtype = query.dtype
122
+ key_dtype = key.dtype
123
+ query = query.to(torch.float32)
124
+ key = key.to(torch.float32)
125
+
126
+ rot_dim = rotary_emb[0].shape[-1]
127
+ query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
128
+ query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
129
+
130
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
131
+
132
+ if not attn.is_cross_attention:
133
+ key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
134
+ key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
135
+
136
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
137
+
138
+ query = query.to(query_dtype)
139
+ key = key.to(key_dtype)
140
+
141
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
142
+ # TODO: add support for attn.scale when we move to Torch 2.1
143
+ hidden_states = F.scaled_dot_product_attention(
144
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
145
+ )
146
+ # print("hidden_states", hidden_states.shape)
147
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
148
+ hidden_states = hidden_states.to(query.dtype)
149
+
150
+ # linear proj
151
+ hidden_states = attn.to_out[0](hidden_states)
152
+ # dropout
153
+ hidden_states = attn.to_out[1](hidden_states)
154
+
155
+ if input_ndim == 4:
156
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
157
+
158
+ if attn.residual_connection:
159
+ hidden_states = hidden_states + residual
160
+
161
+ hidden_states = hidden_states / attn.rescale_output_factor
162
+
163
+ return hidden_states
164
+
165
+ # The attention processor used in MuseControlLite, using 1 decoupled cross-attention layer
166
+ class StableAudioAttnProcessor2_0_rotary(torch.nn.Module):
167
+ r"""
168
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
169
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
170
+ """
171
+ def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0):
172
+ if not hasattr(F, "scaled_dot_product_attention"):
173
+ raise ImportError(
174
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
175
+ )
176
+ super().__init__()
177
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
178
+ self.layer_id = layer_id
179
+ self.hidden_size = hidden_size
180
+ self.cross_attention_dim = cross_attention_dim
181
+ self.num_tokens = num_tokens
182
+ self.scale = scale
183
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
184
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
185
+ self.name = name
186
+ self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
187
+ self.rotary_emb = LlamaRotaryEmbedding(dim = 64)
188
+ self.to_k_ip.weight.requires_grad = True
189
+ self.to_v_ip.weight.requires_grad = True
190
+ self.conv_out.weight.requires_grad = True
191
+ def rotate_half(self, x):
192
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
193
+ x1, x2 = x.unbind(-1)
194
+ return torch.cat((-x2, x1), dim=-1)
195
+
196
+
197
+ def __call__(
198
+ self,
199
+ attn,
200
+ hidden_states: torch.Tensor,
201
+ encoder_hidden_states: Optional[torch.Tensor] = None,
202
+ encoder_hidden_states_con: Optional[torch.Tensor] = None,
203
+ encoder_hidden_states_audio: Optional[torch.Tensor] = None,
204
+ attention_mask: Optional[torch.Tensor] = None,
205
+ rotary_emb: Optional[torch.Tensor] = None,
206
+ ) -> torch.Tensor:
207
+ from diffusers.models.embeddings import apply_rotary_emb
208
+
209
+ residual = hidden_states
210
+
211
+ input_ndim = hidden_states.ndim
212
+
213
+ if input_ndim == 4:
214
+ batch_size, channel, height, width = hidden_states.shape
215
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
216
+
217
+ batch_size, sequence_length, _ = (
218
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
219
+ )
220
+ if attention_mask is not None:
221
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
222
+ # scaled_dot_product_attention expects attention_mask shape to be
223
+ # (batch, heads, source_length, target_length)
224
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
225
+
226
+ # The original cross attention in Stable-audio
227
+ ###############################################################
228
+ query = attn.to_q(hidden_states)
229
+ ip_hidden_states = encoder_hidden_states_con
230
+ key = attn.to_k(encoder_hidden_states)
231
+ value = attn.to_v(encoder_hidden_states)
232
+ head_dim = query.shape[-1] // attn.heads
233
+ kv_heads = key.shape[-1] // head_dim
234
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
235
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
236
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
237
+
238
+ if kv_heads != attn.heads:
239
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
240
+ heads_per_kv_head = attn.heads // kv_heads
241
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
242
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
243
+ if attn.norm_q is not None:
244
+ query = attn.norm_q(query)
245
+ if attn.norm_k is not None:
246
+ key = attn.norm_k(key)
247
+ # TODO: add support for attn.scale when we move to Torch 2.1
248
+ hidden_states = F.scaled_dot_product_attention(
249
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
250
+ )
251
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
252
+ hidden_states = hidden_states.to(query.dtype)
253
+ ###############################################################
254
+
255
+
256
+ # The decupled cross attention in used in MuseControlLite, to deal with additional conditions
257
+ ###############################################################
258
+ ip_key = self.to_k_ip(ip_hidden_states)
259
+ ip_value = self.to_v_ip(ip_hidden_states)
260
+ ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
261
+ ip_key_length = ip_key.shape[2]
262
+ ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
263
+ if kv_heads != attn.heads:
264
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
265
+ heads_per_kv_head = attn.heads // kv_heads
266
+ ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1)
267
+ ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1)
268
+ ip_value_length = ip_value.shape[2]
269
+ seq_len_query = query.shape[2]
270
+
271
+ # Generate position_ids for query, keys, values
272
+ position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query)
273
+ position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
274
+ position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device)
275
+ position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
276
+ position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device)
277
+ position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
278
+
279
+ # Rotate query, keys, values
280
+ cos, sin = self.rotary_emb(query, position_ids_query)
281
+ query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1))
282
+ cos, sin = self.rotary_emb(ip_key, position_ids_key)
283
+ ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1))
284
+ cos, sin = self.rotary_emb(ip_value, position_ids_value)
285
+ ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1))
286
+
287
+ ip_hidden_states = F.scaled_dot_product_attention(
288
+ query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
289
+ )
290
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
291
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
292
+ ip_hidden_states = ip_hidden_states.transpose(1, 2)
293
+ ip_hidden_states = self.conv_out(ip_hidden_states)
294
+ ip_hidden_states = ip_hidden_states.transpose(1, 2)
295
+ ###############################################################
296
+
297
+ # Combine the output of the two cross-attention layers
298
+ hidden_states = hidden_states + self.scale * ip_hidden_states
299
+ # linear proj
300
+ hidden_states = attn.to_out[0](hidden_states)
301
+ # dropout
302
+ hidden_states = attn.to_out[1](hidden_states)
303
+
304
+ if input_ndim == 4:
305
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
306
+
307
+ if attn.residual_connection:
308
+ hidden_states = hidden_states + residual
309
+
310
+ hidden_states = hidden_states / attn.rescale_output_factor
311
+
312
+ return hidden_states
313
+ # The attention processor used in MuseControlLite, using 2 decoupled cross-attention layer. It needs further examination, don't use it now.
314
+ class StableAudioAttnProcessor2_0_rotary_double(torch.nn.Module):
315
+ r"""
316
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
317
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
318
+ """
319
+ def __init__(self, layer_id, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0):
320
+ if not hasattr(F, "scaled_dot_product_attention"):
321
+ raise ImportError(
322
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
323
+ )
324
+ super().__init__()
325
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
326
+ self.hidden_size = hidden_size
327
+ self.cross_attention_dim = cross_attention_dim
328
+ self.num_tokens = num_tokens
329
+ self.layer_id = layer_id
330
+ self.scale = scale
331
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
332
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
+ self.to_k_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
+ self.to_v_ip_audio = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
335
+ self.name = name
336
+ self.conv_out = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
337
+ self.conv_out_audio = zero_module(nn.Conv1d(1536,1536,kernel_size=1, padding=0, bias=False))
338
+ self.rotary_emb = LlamaRotaryEmbedding(64)
339
+ self.to_k_ip.weight.requires_grad = True
340
+ self.to_v_ip.weight.requires_grad = True
341
+ self.conv_out.weight.requires_grad = True
342
+ # Below is for copying the weight of the original weight to the decoupled cross-attention
343
+ def rotate_half(self, x):
344
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
345
+ x1, x2 = x.unbind(-1)
346
+ return torch.cat((-x2, x1), dim=-1)
347
+
348
+
349
+ def __call__(
350
+ self,
351
+ attn,
352
+ hidden_states: torch.Tensor,
353
+ encoder_hidden_states: Optional[torch.Tensor] = None,
354
+ encoder_hidden_states_con: Optional[torch.Tensor] = None,
355
+ encoder_hidden_states_audio: Optional[torch.Tensor] = None,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ ) -> torch.Tensor:
358
+ from diffusers.models.embeddings import apply_rotary_emb
359
+
360
+ residual = hidden_states
361
+
362
+ input_ndim = hidden_states.ndim
363
+
364
+ if input_ndim == 4:
365
+ batch_size, channel, height, width = hidden_states.shape
366
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
367
+
368
+ batch_size, sequence_length, _ = (
369
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
370
+ )
371
+ if attention_mask is not None:
372
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
373
+ # scaled_dot_product_attention expects attention_mask shape to be
374
+ # (batch, heads, source_length, target_length)
375
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
376
+
377
+ # The original cross attention in Stable-audio
378
+ ###############################################################
379
+ query = attn.to_q(hidden_states)
380
+ key = attn.to_k(encoder_hidden_states)
381
+ value = attn.to_v(encoder_hidden_states)
382
+ head_dim = query.shape[-1] // attn.heads
383
+ kv_heads = key.shape[-1] // head_dim
384
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
386
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
387
+
388
+ if kv_heads != attn.heads:
389
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
390
+ heads_per_kv_head = attn.heads // kv_heads
391
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
392
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
393
+
394
+ if attn.norm_q is not None:
395
+ query = attn.norm_q(query)
396
+ if attn.norm_k is not None:
397
+ key = attn.norm_k(key)
398
+ # TODO: add support for attn.scale when we move to Torch 2.1
399
+ hidden_states = F.scaled_dot_product_attention(
400
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
401
+ )
402
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
403
+ hidden_states = hidden_states.to(query.dtype)
404
+ # if self.layer_id == "0":
405
+ # hidden_states_sliced = hidden_states[:,1:,:]
406
+ # # Create a tensor of zeros with shape (bs, 1, 768)
407
+ # bs, _, dim2 = hidden_states_sliced.shape
408
+ # zeros = torch.zeros(bs, 1, dim2).cuda()
409
+ # # Concatenate the zero tensor along the second dimension (dim=1)
410
+ # hidden_states_sliced = torch.cat((hidden_states_sliced, zeros), dim=1)
411
+ # query_sliced = attn.to_q(hidden_states_sliced)
412
+ # query_sliced = query_sliced.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
413
+ # query = query_sliced
414
+ ip_hidden_states = encoder_hidden_states_con
415
+ ip_hidden_states_audio = encoder_hidden_states_audio
416
+ ip_key = self.to_k_ip(ip_hidden_states)
417
+ ip_value = self.to_v_ip(ip_hidden_states)
418
+ ip_key = ip_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
419
+ ip_key_length = ip_key.shape[2]
420
+ ip_value = ip_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
421
+ ip_key_audio = self.to_k_ip_audio(ip_hidden_states_audio)
422
+ ip_value_audio = self.to_v_ip_audio(ip_hidden_states_audio)
423
+ ip_key_audio = ip_key_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
424
+ ip_key_audio_length = ip_key_audio.shape[2]
425
+ ip_value_audio = ip_value_audio.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
426
+
427
+ if kv_heads != attn.heads:
428
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
429
+ heads_per_kv_head = attn.heads // kv_heads
430
+ ip_key = torch.repeat_interleave(ip_key, heads_per_kv_head, dim=1)
431
+ ip_value = torch.repeat_interleave(ip_value, heads_per_kv_head, dim=1)
432
+ ip_key_audio = torch.repeat_interleave(ip_key_audio, heads_per_kv_head, dim=1)
433
+ ip_value_audio = torch.repeat_interleave(ip_value_audio, heads_per_kv_head, dim=1)
434
+
435
+ ip_value_length = ip_value.shape[2]
436
+ seq_len_query = query.shape[2]
437
+ ip_value_audio_length = ip_value_audio.shape[2]
438
+
439
+ position_ids_query = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_length / seq_len_query)
440
+ position_ids_query = position_ids_query.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
441
+
442
+ # Generate position_ids for keys
443
+ position_ids_key = torch.arange(ip_key_length, dtype=torch.long, device=key.device)
444
+ position_ids_key = position_ids_key.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
445
+ position_ids_value = torch.arange(ip_value_length, dtype=torch.long, device=value.device)
446
+ position_ids_value = position_ids_value.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
447
+ # Generate position_ids for keys
448
+ position_ids_query_audio = torch.arange(seq_len_query, dtype=torch.long, device=query.device) * (ip_key_audio_length / seq_len_query)
449
+ position_ids_query_audio = position_ids_query_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_query]
450
+ position_ids_key_audio = torch.arange(ip_key_audio_length, dtype=torch.long, device=key.device)
451
+ position_ids_key_audio = position_ids_key_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
452
+ position_ids_value_audio = torch.arange(ip_value_audio_length, dtype=torch.long, device=value.device)
453
+ position_ids_value_audio = position_ids_value_audio.unsqueeze(0).expand(batch_size, -1) # Shape: [batch_size, seq_len_key]
454
+ cos, sin = self.rotary_emb(query, position_ids_query)
455
+ cos_audio, sin_audio = self.rotary_emb(query, position_ids_query_audio)
456
+ query_pos = (query * cos.unsqueeze(1)) + (self.rotate_half(query) * sin.unsqueeze(1))
457
+ query_pos_audio = (query * cos_audio.unsqueeze(1)) + (self.rotate_half(query) * sin_audio.unsqueeze(1))
458
+
459
+ cos, sin = self.rotary_emb(ip_key, position_ids_key)
460
+ cos_audio, sin_audio = self.rotary_emb(ip_key_audio, position_ids_key_audio)
461
+ ip_key = (ip_key * cos.unsqueeze(1)) + (self.rotate_half(ip_key) * sin.unsqueeze(1))
462
+ ip_key_audio = (ip_key_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_key_audio) * sin_audio.unsqueeze(1))
463
+
464
+ cos, sin = self.rotary_emb(ip_value, position_ids_value)
465
+ cos_audio, sin_audio = self.rotary_emb(ip_value_audio, position_ids_value_audio)
466
+ ip_value = (ip_value * cos.unsqueeze(1)) + (self.rotate_half(ip_value) * sin.unsqueeze(1))
467
+ ip_value_audio = (ip_value_audio * cos_audio.unsqueeze(1)) + (self.rotate_half(ip_value_audio) * sin_audio.unsqueeze(1))
468
+
469
+ with torch.amp.autocast(device_type='cuda'):
470
+ ip_hidden_states = F.scaled_dot_product_attention(
471
+ query_pos, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
472
+ )
473
+ with torch.amp.autocast(device_type='cuda'):
474
+ ip_hidden_states_audio = F.scaled_dot_product_attention(
475
+ query_pos_audio, ip_key_audio, ip_value_audio, attn_mask=None, dropout_p=0.0, is_causal=False
476
+ )
477
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
478
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
479
+ ip_hidden_states = ip_hidden_states.transpose(1, 2)
480
+
481
+ ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
482
+ ip_hidden_states_audio = ip_hidden_states_audio.to(query.dtype)
483
+ ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2)
484
+
485
+ with torch.amp.autocast(device_type='cuda'):
486
+ ip_hidden_states = self.conv_out(ip_hidden_states)
487
+ ip_hidden_states = ip_hidden_states.transpose(1, 2)
488
+
489
+ with torch.amp.autocast(device_type='cuda'):
490
+ ip_hidden_states_audio = self.conv_out_audio(ip_hidden_states_audio)
491
+ ip_hidden_states_audio = ip_hidden_states_audio.transpose(1, 2)
492
+
493
+ # Combine the tensors
494
+ hidden_states = hidden_states + self.scale * ip_hidden_states + ip_hidden_states_audio
495
+
496
+ # linear proj
497
+ hidden_states = attn.to_out[0](hidden_states)
498
+ # dropout
499
+ hidden_states = attn.to_out[1](hidden_states)
500
+
501
+ if input_ndim == 4:
502
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
503
+
504
+ if attn.residual_connection:
505
+ hidden_states = hidden_states + residual
506
+
507
+ hidden_states = hidden_states / attn.rescale_output_factor
508
+
509
+ return hidden_states
510
+ def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
511
+ """
512
+ Setup AP-adapter pipeline with attention processors and load checkpoints.
513
+
514
+ Args:
515
+ config: Configuration dictionary
516
+ weight_dtype: Weight data type for the pipeline
517
+ transformer_ckpt: Path to transformer checkpoint
518
+ Returns:
519
+ tuple: (pipe, transformer) - Configured pipeline and transformer
520
+ """
521
+ if 'audio' in config['condition_type'] and len(config['condition_type'])!=1:
522
+ from pipeline.stable_audio_multi_cfg_pipe_audio import StableAudioPipeline
523
+ attn_processor = StableAudioAttnProcessor2_0_rotary_double
524
+ audio_state_dict = load_file(config["audio_transformer_ckpt"], device="cpu")
525
+ else:
526
+ from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
527
+ attn_processor = StableAudioAttnProcessor2_0_rotary
528
+ pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=weight_dtype)
529
+ pipe.scheduler.config.sigma_max = config["sigma_max"]
530
+ pipe.scheduler.config.sigma_min = config["sigma_min"]
531
+ transformer = pipe.transformer
532
+ attn_procs = {}
533
+ for name in transformer.attn_processors.keys():
534
+ if name.endswith("attn1.processor"):
535
+ attn_procs[name] = StableAudioAttnProcessor2_0()
536
+ else:
537
+ attn_procs[name] = attn_processor(
538
+ layer_id = name.split(".")[1],
539
+ hidden_size=768,
540
+ name=name,
541
+ cross_attention_dim=768,
542
+ scale=config['ap_scale'],
543
+ ).to("cuda", dtype=torch.float)
544
+ if transformer_ckpt is not None:
545
+ state_dict = load_file(transformer_ckpt, device="cuda")
546
+ for name, processor in attn_procs.items():
547
+ if isinstance(processor, attn_processor):
548
+ weight_name_v = name + ".to_v_ip.weight"
549
+ weight_name_k = name + ".to_k_ip.weight"
550
+ conv_out_weight = name + ".conv_out.weight"
551
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].to(torch.float32))
552
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].to(torch.float32))
553
+ processor.conv_out.weight = torch.nn.Parameter(state_dict[conv_out_weight].to(torch.float32))
554
+ if attn_processor == StableAudioAttnProcessor2_0_rotary_double:
555
+ audio_weight_name_v = name + ".to_v_ip.weight"
556
+ audio_weight_name_k = name + ".to_k_ip.weight"
557
+ audio_conv_out_weight = name + ".conv_out.weight"
558
+ processor.to_v_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_v].to(torch.float32))
559
+ processor.to_k_ip_audio.weight = torch.nn.Parameter(audio_state_dict[audio_weight_name_k].to(torch.float32))
560
+ processor.conv_out_audio.weight = torch.nn.Parameter(audio_state_dict[audio_conv_out_weight].to(torch.float32))
561
+ transformer.set_attn_processor(attn_procs)
562
+ class _Wrapper(AttnProcsLayers):
563
+ def forward(self, *args, **kwargs):
564
+ return pipe.transformer(*args, **kwargs)
565
+ transformer = _Wrapper(pipe.transformer.attn_processors)
566
+
567
+ return pipe
568
+ def initialize_condition_extractors(config):
569
+ """
570
+ Initialize condition extractors based on configuration.
571
+
572
+ Args:
573
+ config: Configuration dictionary containing condition types and checkpoint paths
574
+
575
+ Returns:
576
+ tuple: (condition_extractors, transformer_ckpt, extractor_ckpt)
577
+ """
578
+ condition_extractors = {}
579
+ extractor_ckpt = {}
580
+ from utils.feature_extractor import dynamics_extractor, rhythm_extractor, melody_extractor_mono, melody_extractor_stereo, melody_extractor_full_mono, melody_extractor_full_stereo, dynamics_extractor_full_stereo
581
+ if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
582
+ if "melody_stereo" in config['condition_type']:
583
+ transformer_ckpt = config['transformer_ckpt_melody_stero']
584
+ extractor_ckpt = config['extractor_ckpt_melody_stero']
585
+ print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
586
+ melody_conditoner = melody_extractor_full_stereo().cuda().float()
587
+ condition_extractors["melody"] = melody_conditoner
588
+ elif "melody_mono" in config['condition_type']:
589
+ transformer_ckpt = config['transformer_ckpt_melody_mono']
590
+ extractor_ckpt = config['extractor_ckpt_melody_mono']
591
+ print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
592
+ melody_conditoner = melody_extractor_full_mono().cuda().float()
593
+ condition_extractors["melody"] = melody_conditoner
594
+ elif "audio" in config['condition_type']:
595
+ transformer_ckpt = config['audio_transformer_ckpt']
596
+ print(f"using model: {transformer_ckpt}")
597
+ else:
598
+ dynamics_conditoner = dynamics_extractor().cuda().float()
599
+ condition_extractors["dynamics"] = dynamics_conditoner
600
+ rhythm_conditoner = rhythm_extractor().cuda().float()
601
+ condition_extractors["rhythm"] = rhythm_conditoner
602
+ melody_conditoner = melody_extractor_mono().cuda().float()
603
+ condition_extractors["melody"] = melody_conditoner
604
+ transformer_ckpt = config['transformer_ckpt_musical']
605
+ extractor_ckpt = config['extractor_ckpt_musical']
606
+ print(f"using model: {transformer_ckpt}, {extractor_ckpt}")
607
+
608
+ for conditioner_type, ckpt_path in extractor_ckpt.items():
609
+ state_dict = load_file(ckpt_path, device="cpu")
610
+ condition_extractors[conditioner_type].load_state_dict(state_dict)
611
+ condition_extractors[conditioner_type].eval()
612
+
613
+ return condition_extractors, transformer_ckpt
614
+ def evaluate_and_plot_results(audio_file, gen_file_path, output_dir, i):
615
+ """
616
+ Evaluate and plot results comparing original and generated audio.
617
+
618
+ Args:
619
+ audio_file (str): Path to the original audio file
620
+ gen_file_path (str): Path to the generated audio file
621
+ output_dir (str): Directory to save the plot
622
+ i (int): Index for naming the output file
623
+
624
+ Returns:
625
+ tuple: (dynamics_score, rhythm_score, melody_score)
626
+ """
627
+
628
+ dynamics_condition = compute_dynamics(audio_file)
629
+ gen_dynamics = compute_dynamics(gen_file_path)
630
+ min_len_dynamics = min(gen_dynamics.shape[0], dynamics_condition.shape[0])
631
+ pearson_corr = np.corrcoef(gen_dynamics[:min_len_dynamics], dynamics_condition[:min_len_dynamics])[0, 1]
632
+ print("pearson_corr", pearson_corr)
633
+
634
+ melody_condition = extract_melody_one_hot(audio_file)
635
+ gen_melody = extract_melody_one_hot(gen_file_path)
636
+ min_len_melody = min(gen_melody.shape[1], melody_condition.shape[1])
637
+ matches = ((gen_melody[:, :min_len_melody] == melody_condition[:, :min_len_melody]) & (gen_melody[:, :min_len_melody] == 1)).sum()
638
+ accuracy = matches / min_len_melody
639
+ print("melody accuracy", accuracy)
640
+
641
+ # Adjust layout to avoid overlap
642
+ processor = RNNDownBeatProcessor()
643
+ original_path = os.path.join(output_dir, f"original_{i}.wav")
644
+ input_probabilities = processor(original_path)
645
+ generated_probabilities = processor(gen_file_path)
646
+ hmm_processor = DBNDownBeatTrackingProcessor(beats_per_bar=[3,4], fps=100)
647
+ input_timestamps = hmm_processor(input_probabilities)
648
+ generated_timestamps = hmm_processor(generated_probabilities)
649
+ precision, recall, f1 = evaluate_f1_rhythm(input_timestamps, generated_timestamps)
650
+ # Output results
651
+ print(f"F1 Score: {f1:.2f}")
652
+
653
+ # Plotting
654
+ frame_rate = 100 # Frames per second
655
+ input_time_axis = np.linspace(0, len(input_probabilities) / frame_rate, len(input_probabilities))
656
+ generate_time_axis = np.linspace(0, len(generated_probabilities) / frame_rate, len(generated_probabilities))
657
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # Adjust figsize as needed
658
+
659
+ # ----------------------------
660
+ # Subplot (0,0): Dynamics Plot
661
+ ax = axes[0, 0]
662
+ ax.plot(dynamics_condition[:min_len_dynamics].squeeze(), linewidth=1, label='Dynamics condition')
663
+ ax.set_title('Dynamics')
664
+ ax.set_xlabel('Time Frame')
665
+ ax.set_ylabel('Dynamics (dB)')
666
+ ax.legend(fontsize=8)
667
+ ax.grid(True)
668
+ # ----------------------------
669
+ # Subplot (0,0): Dynamics Plot
670
+ ax = axes[1, 0]
671
+ ax.plot(gen_dynamics[:min_len_dynamics].squeeze(), linewidth=1, label='Generated Dynamics')
672
+ ax.set_title('Dynamics')
673
+ ax.set_xlabel('Time Frame')
674
+ ax.set_ylabel('Dynamics (dB)')
675
+ ax.legend(fontsize=8)
676
+ ax.grid(True)
677
+
678
+ # ----------------------------
679
+ # Subplot (0,2): Melody Condition (Chromagram)
680
+ ax = axes[0, 1]
681
+ im2 = ax.imshow(melody_condition[:, :min_len_melody], aspect='auto', origin='lower',
682
+ interpolation='nearest', cmap='plasma')
683
+ ax.set_title('Melody Condition')
684
+ ax.set_xlabel('Time')
685
+ ax.set_ylabel('Chroma Features')
686
+
687
+ # ----------------------------
688
+ # Subplot (0,1): Generated Melody (Chromagram)
689
+ ax = axes[1, 1]
690
+ im1 = ax.imshow(gen_melody[:, :min_len_melody], aspect='auto', origin='lower',
691
+ interpolation='nearest', cmap='viridis')
692
+ ax.set_title('Generated Melody')
693
+ ax.set_xlabel('Time')
694
+ ax.set_ylabel('Chroma Features')
695
+
696
+ # ----------------------------
697
+ # Subplot (1,0): Rhythm Input Probabilities
698
+ ax = axes[0, 2]
699
+ ax.plot(input_time_axis, input_probabilities,
700
+ label="Input Beat Probability")
701
+ ax.plot(input_time_axis, input_probabilities,
702
+ label="Input Downbeat Probability", alpha=0.8)
703
+ ax.set_title('Rhythm: Input')
704
+ ax.set_xlabel('Time (s)')
705
+ ax.set_ylabel('Probability')
706
+ ax.legend()
707
+ ax.grid(True)
708
+
709
+ # ----------------------------
710
+ # Subplot (1,1): Rhythm Generated Probabilities
711
+ ax = axes[1, 2]
712
+ ax.plot(generate_time_axis, generated_probabilities,
713
+ color='orange', label="Generated Beat Probability")
714
+ ax.plot(generate_time_axis, generated_probabilities,
715
+ alpha=0.8, color='red', label="Generated Downbeat Probability")
716
+ ax.set_title('Rhythm: Generated')
717
+ ax.set_xlabel('Time (s)')
718
+ ax.set_ylabel('Probability')
719
+ ax.legend()
720
+ ax.grid(True)
721
+
722
+ # Adjust layout and save the combined image
723
+ plt.tight_layout()
724
+ combined_path = os.path.join(output_dir, f"combined_{i}.png")
725
+ plt.savefig(combined_path)
726
+ plt.close()
727
+
728
+ print(f"Combined plot saved to {combined_path}")
729
+
730
+ return pearson_corr, f1, accuracy
731
+
732
+ def process_musical_conditions(config, audio_file, condition_extractors, output_dir, i, weight_dtype, MuseControlLite):
733
+ """
734
+ Process and extract musical conditions (dynamics, rhythm, melody) from audio file.
735
+
736
+ Args:
737
+ config: Configuration dictionary
738
+ audio_file: Path to the audio file
739
+ condition_extractors: Dictionary of condition extractors
740
+ output_dir: Output directory path
741
+ i: Index for file naming
742
+ weight_dtype: Weight data type for torch tensors
743
+ MuseControlLite: The MuseControlLite model instance
744
+ audio_mask_start: Start index for audio mask
745
+ audio_mask_end: End index for audio mask
746
+ musical_attribute_mask_start: Start index for musical attribute mask
747
+ musical_attribute_mask_end: End index for musical attribute mask
748
+
749
+ Returns:
750
+ tuple: (final_condition, extracted_condition, final_condition_audio)
751
+ """
752
+ total_seconds = 2097152/44100
753
+ use_audio_mask = False
754
+ use_musical_attribute_mask = False
755
+ if (config["audio_mask_start_seconds"] and config["audio_mask_end_seconds"]) != 0 and "audio" in config["condition_type"]:
756
+ use_audio_mask = True
757
+ audio_mask_start = int(config["audio_mask_start_seconds"] / total_seconds * 1024) # 1024 is the latent length for 2097152/44100 seconds
758
+ audio_mask_end = int(config["audio_mask_end_seconds"] / total_seconds * 1024)
759
+ print(
760
+ f"using mask for 'audio' from "
761
+ f"{config['audio_mask_start_seconds']}~{config['audio_mask_end_seconds']}"
762
+ )
763
+ if (config["musical_attribute_mask_start_seconds"] and config["musical_attribute_mask_end_seconds"]) != 0:
764
+ use_musical_attribute_mask = True
765
+ musical_attribute_mask_start = int(config["musical_attribute_mask_start_seconds"] / total_seconds * 1024)
766
+ musical_attribute_mask_end = int(config["musical_attribute_mask_end_seconds"] / total_seconds * 1024)
767
+ masked_types = [t for t in config['condition_type'] if t != 'audio']
768
+ print(
769
+ f"using mask for {', '.join(masked_types)} "
770
+ f"from {config['musical_attribute_mask_start_seconds']}~"
771
+ f"{config['musical_attribute_mask_end_seconds']}"
772
+ )
773
+ if "dynamics" in config["condition_type"]:
774
+ dynamics_condition = compute_dynamics(audio_file)
775
+ dynamics_condition = torch.from_numpy(dynamics_condition).cuda()
776
+ dynamics_condition = dynamics_condition.unsqueeze(0).unsqueeze(0)
777
+ print("dynamics_condition", dynamics_condition.shape)
778
+ extracted_dynamics_condition = condition_extractors["dynamics"](dynamics_condition.to(torch.float32))
779
+ masked_extracted_dynamics_condition = torch.zeros_like(extracted_dynamics_condition)
780
+ extracted_dynamics_condition = F.interpolate(extracted_dynamics_condition, size=1024, mode='linear', align_corners=False)
781
+ masked_extracted_dynamics_condition = F.interpolate(masked_extracted_dynamics_condition, size=1024, mode='linear', align_corners=False)
782
+ else:
783
+ extracted_dynamics_condition = torch.zeros((1, 192, 1024), device="cuda")
784
+ masked_extracted_dynamics_condition = extracted_dynamics_condition
785
+ if "rhythm" in config["condition_type"]:
786
+ rnn_processor = RNNDownBeatProcessor()
787
+ wave = load_audio_file(audio_file)
788
+ if wave is not None:
789
+ original_path = os.path.join(output_dir, f"original_{i}.wav")
790
+ sf.write(original_path, wave.T.float().cpu().numpy(), 44100)
791
+ rhythm_curve = rnn_processor(original_path)
792
+ rhythm_condition = torch.from_numpy(rhythm_curve).cuda()
793
+ rhythm_condition = rhythm_condition.transpose(0,1).unsqueeze(0)
794
+ print("rhythm_condition", rhythm_condition.shape)
795
+ extracted_rhythm_condition = condition_extractors["rhythm"](rhythm_condition.to(torch.float32))
796
+ masked_extracted_rhythm_condition = torch.zeros_like(extracted_rhythm_condition)
797
+ extracted_rhythm_condition = F.interpolate(extracted_rhythm_condition, size=1024, mode='linear', align_corners=False)
798
+ masked_extracted_rhythm_condition = F.interpolate(masked_extracted_rhythm_condition, size=1024, mode='linear', align_corners=False)
799
+ else:
800
+ extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda")
801
+ masked_extracted_rhythm_condition = extracted_rhythm_condition
802
+ else:
803
+ extracted_rhythm_condition = torch.zeros((1, 192, 1024), device="cuda")
804
+ masked_extracted_rhythm_condition = extracted_rhythm_condition
805
+
806
+ if "melody_mono" in config["condition_type"]:
807
+ melody_condition = compute_melody(audio_file)
808
+ melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0)
809
+ print("melody_condition", melody_condition.shape)
810
+ extracted_melody_condition = condition_extractors["melody"](melody_condition.to(torch.float32))
811
+ masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition)
812
+ extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False)
813
+ masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False)
814
+ elif "melody_stereo" in config["condition_type"]:
815
+ melody_condition = compute_melody_v2(audio_file)
816
+ melody_condition = torch.from_numpy(melody_condition).cuda().unsqueeze(0)
817
+ print("melody_condition", melody_condition.shape)
818
+ extracted_melody_condition = condition_extractors["melody"](melody_condition)
819
+ masked_extracted_melody_condition = torch.zeros_like(extracted_melody_condition)
820
+ extracted_melody_condition = F.interpolate(extracted_melody_condition, size=1024, mode='linear', align_corners=False)
821
+ masked_extracted_melody_condition = F.interpolate(masked_extracted_melody_condition, size=1024, mode='linear', align_corners=False)
822
+ else:
823
+ if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
824
+ extracted_melody_condition = torch.zeros((1, 768, 1024), device="cuda")
825
+ else:
826
+ extracted_melody_condition = torch.zeros((1, 192, 1024), device="cuda")
827
+ masked_extracted_melody_condition = extracted_melody_condition
828
+
829
+ # Use multiple cfg
830
+ if not ("rhythm" in config['condition_type'] or "dynamics" in config['condition_type']):
831
+ extracted_condition = extracted_melody_condition
832
+ final_condition = torch.concat((masked_extracted_melody_condition, masked_extracted_melody_condition, extracted_melody_condition), dim=0)
833
+ else:
834
+ extracted_blank_condition = torch.zeros((1, 192, 1024), device="cuda")
835
+ extracted_condition = torch.concat((extracted_rhythm_condition, extracted_dynamics_condition, extracted_melody_condition, extracted_blank_condition), dim=1)
836
+ masked_extracted_condition = torch.concat((masked_extracted_rhythm_condition, masked_extracted_dynamics_condition, masked_extracted_melody_condition, extracted_blank_condition), dim=1)
837
+ final_condition = torch.concat((masked_extracted_condition, masked_extracted_condition, extracted_condition), dim=0)
838
+ if "audio" in config["condition_type"]:
839
+ desired_repeats = 768 // 64 # Number of repeats needed
840
+ audio = load_audio_file(audio_file)
841
+ if audio is not None:
842
+ audio_condition = MuseControlLite.vae.encode(audio.unsqueeze(0).to(weight_dtype).cuda()).latent_dist.sample()
843
+ extracted_audio_condition = audio_condition.repeat_interleave(desired_repeats, dim=1).float()
844
+ pad_len = 1024 - extracted_audio_condition.shape[-1]
845
+ if pad_len > 0:
846
+ # Pad on the right side (last dimension)
847
+ extracted_audio_condition = F.pad(extracted_audio_condition, (0, pad_len))
848
+ masked_extracted_audio_condition = torch.zeros_like(extracted_audio_condition)
849
+ if len(config["condition_type"]) == 1:
850
+ final_condition = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0)
851
+ else:
852
+ final_condition_audio = torch.concat((masked_extracted_audio_condition, masked_extracted_audio_condition, masked_extracted_audio_condition, extracted_audio_condition), dim=0)
853
+ final_condition = torch.concat((final_condition, extracted_condition), dim=0)
854
+ final_condition_audio = final_condition_audio.transpose(1, 2)
855
+ else:
856
+ final_condition_audio = None
857
+ final_condition = final_condition.transpose(1, 2)
858
+ if "audio" in config["condition_type"] and len(config["condition_type"])==1:
859
+ final_condition[:,audio_mask_start:audio_mask_end,:] = 0
860
+ if use_audio_mask:
861
+ config["guidance_scale_con"] = config["guidance_scale_audio"]
862
+ elif "audio" in config["condition_type"] and len(config["condition_type"])!=1 and use_audio_mask:
863
+ final_condition[:,:audio_mask_start,:] = 0
864
+ final_condition[:,audio_mask_end:,:] = 0
865
+ if 'final_condition_audio' in locals() and final_condition_audio is not None:
866
+ final_condition_audio[:,audio_mask_start:audio_mask_end,:] = 0
867
+ elif use_musical_attribute_mask:
868
+ final_condition[:,musical_attribute_mask_start:musical_attribute_mask_end,:] = 0
869
+ if 'final_condition_audio' in locals() and final_condition_audio is not None:
870
+ final_condition_audio[:,:musical_attribute_mask_start,:] = 0
871
+ final_condition_audio[:,musical_attribute_mask_end:,:] = 0
872
+
873
+ return final_condition, final_condition_audio if 'final_condition_audio' in locals() else None
874
+
README.md CHANGED
@@ -11,4 +11,38 @@ license: mit
11
  short_description: Inference for Stable-Audio-Open with more controls
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Inference for Stable-Audio-Open with more controls
12
  ---
13
 
14
+ ## MuseControlLite (Space)
15
+
16
+ Gradio UI for MuseControlLite adapters on top of `stabilityai/stable-audio-open-1.0`.
17
+
18
+ ### Requirements
19
+ - **GPU Space** is required for generation (fp16 by default).
20
+ - A Hugging Face token with access to `stabilityai/stable-audio-open-1.0` (set as a Space secret, e.g., `HF_TOKEN`).
21
+
22
+ ### What happens on startup
23
+ 1) Installs Python deps from `requirements.txt` (includes `gradio`, `gdown`, `diffusers` fork).
24
+ 2) Downloads MuseControlLite checkpoints with
25
+ `gdown 1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP --folder`
26
+ into `checkpoints/` if they are missing.
27
+
28
+ ### Using the Space
29
+ 1) Provide a text prompt.
30
+ 2) Upload a 47.5s (or longer) audio file when using MuseControlLite conditions.
31
+ 3) Select condition types (`melody_stereo`, `melody_mono`, `dynamics`, `rhythm`, `audio`) and adjust guidance/scales if needed.
32
+ 4) Click **Generate**. Output is a single 47.5s WAV plus a short status summary.
33
+
34
+ ### Tips
35
+ - `melody_stereo` cannot be combined with `dynamics`, `rhythm`, or `melody_mono`.
36
+ - For audio in/out-painting, use the audio condition with the masking sliders.
37
+ - Default examples are preloaded in the UI for quick tests.
38
+
39
+ ### Local run (optional)
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ gdown 1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP --folder
43
+ huggingface-cli login
44
+ python app.py
45
+ ```
46
+
47
+ ### Acknowledgments
48
+ - Original repository: https://github.com/fundwotsai2001/MuseControlLite
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import subprocess
4
+ import time
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import soundfile as sf
9
+ import torch
10
+
11
+ from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite
12
+ from config_inference import get_config
13
+
14
+ # Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100)
15
+ TOTAL_AUDIO_SECONDS = 2097152 / 44100
16
+ DEFAULT_CONFIG = get_config()
17
+ DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else ""
18
+ OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs")
19
+ CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"]
20
+ CHECKPOINT_EXPECTED = [
21
+ "./checkpoints/woSDD-all/model_3.safetensors",
22
+ "./checkpoints/woSDD-all/model_1.safetensors",
23
+ "./checkpoints/woSDD-all/model_2.safetensors",
24
+ "./checkpoints/woSDD-all/model.safetensors",
25
+ ]
26
+
27
+ os.makedirs(OUTPUT_ROOT, exist_ok=True)
28
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(DEFAULT_CONFIG.get("GPU_id", "0")))
29
+
30
+
31
+ def ensure_checkpoints() -> None:
32
+ """Download checkpoints with gdown if they are missing."""
33
+ if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED):
34
+ return
35
+ os.makedirs("checkpoints", exist_ok=True)
36
+ try:
37
+ subprocess.run(
38
+ ["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"],
39
+ check=True,
40
+ )
41
+ except Exception as exc: # pylint: disable=broad-except
42
+ # Do not crash the space on startup; inference will surface an error later if checkpoints are missing.
43
+ print(f"[warn] Checkpoint download failed: {exc}")
44
+
45
+
46
+ ensure_checkpoints()
47
+
48
+
49
+ class ModelCache:
50
+ """Lazy loader for heavy pipelines and condition extractors."""
51
+
52
+ def __init__(self) -> None:
53
+ self.cache: Dict[Tuple, Dict] = {}
54
+
55
+ def get(self, config: Dict) -> Dict:
56
+ key = (
57
+ tuple(sorted(config["condition_type"])),
58
+ config["weight_dtype"],
59
+ float(config["ap_scale"]),
60
+ config["apadapter"],
61
+ )
62
+ if key in self.cache:
63
+ return self.cache[key]
64
+
65
+ weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32
66
+ if config["apadapter"]:
67
+ condition_extractors, transformer_ckpt = initialize_condition_extractors(config)
68
+ pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda")
69
+ payload = {
70
+ "pipe": pipe,
71
+ "condition_extractors": condition_extractors,
72
+ "weight_dtype": weight_dtype,
73
+ "mode": "musecontrol",
74
+ }
75
+ else:
76
+ from diffusers import StableAudioPipeline
77
+
78
+ pipe = StableAudioPipeline.from_pretrained(
79
+ "stabilityai/stable-audio-open-1.0",
80
+ torch_dtype=weight_dtype,
81
+ ).to("cuda")
82
+ payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"}
83
+ self.cache[key] = payload
84
+ return payload
85
+
86
+
87
+ model_cache = ModelCache()
88
+
89
+
90
+ def _build_base_config() -> Dict:
91
+ return copy.deepcopy(DEFAULT_CONFIG)
92
+
93
+
94
+ def _create_run_dir() -> str:
95
+ run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}")
96
+ os.makedirs(run_dir, exist_ok=True)
97
+ return run_dir
98
+
99
+
100
+ def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]:
101
+ if seed is None or seed == "":
102
+ return None
103
+ try:
104
+ seed_int = int(seed)
105
+ except (TypeError, ValueError):
106
+ return None
107
+ generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
108
+ return generator.manual_seed(seed_int)
109
+
110
+
111
+ def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]:
112
+ condition_type = condition_type or []
113
+ if "melody_stereo" in condition_type and any(
114
+ choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono")
115
+ ):
116
+ raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.")
117
+ return condition_type
118
+
119
+
120
+ def run_inference(
121
+ prompt_text: str,
122
+ condition_audio: Optional[str],
123
+ condition_type: Optional[List[str]],
124
+ use_musecontrol: bool,
125
+ no_text: bool,
126
+ negative_text_prompt: str,
127
+ guidance_scale_text: float,
128
+ guidance_scale_con: float,
129
+ guidance_scale_audio: float,
130
+ denoise_step: int,
131
+ weight_dtype: str,
132
+ ap_scale: float,
133
+ sigma_min: float,
134
+ sigma_max: float,
135
+ audio_mask_start: float,
136
+ audio_mask_end: float,
137
+ musical_mask_start: float,
138
+ musical_mask_end: float,
139
+ seed: Optional[float],
140
+ ):
141
+ if not torch.cuda.is_available():
142
+ raise gr.Error("This Space has no GPU attached. Please run locally with a GPU or duplicate to a GPU Space.")
143
+
144
+ condition_type = _validate_condition_choices(condition_type)
145
+ config = _build_base_config()
146
+ config.update(
147
+ {
148
+ "text": [prompt_text or ""],
149
+ "audio_files": [condition_audio or ""],
150
+ "apadapter": use_musecontrol,
151
+ "no_text": bool(no_text),
152
+ "negative_text_prompt": negative_text_prompt or "",
153
+ "guidance_scale_text": float(guidance_scale_text),
154
+ "guidance_scale_con": float(guidance_scale_con),
155
+ "guidance_scale_audio": float(guidance_scale_audio),
156
+ "denoise_step": int(denoise_step),
157
+ "weight_dtype": weight_dtype,
158
+ "ap_scale": float(ap_scale),
159
+ "sigma_min": float(sigma_min),
160
+ "sigma_max": float(sigma_max),
161
+ "audio_mask_start_seconds": float(audio_mask_start or 0),
162
+ "audio_mask_end_seconds": float(audio_mask_end or 0),
163
+ "musical_attribute_mask_start_seconds": float(musical_mask_start or 0),
164
+ "musical_attribute_mask_end_seconds": float(musical_mask_end or 0),
165
+ "show_result_and_plt": False,
166
+ }
167
+ )
168
+ config["condition_type"] = condition_type
169
+ if config["apadapter"]:
170
+ if not condition_type:
171
+ raise gr.Error("Select at least one condition type when using MuseControlLite.")
172
+ if not condition_audio:
173
+ raise gr.Error("Upload an audio file for conditioning.")
174
+ if not os.path.exists(condition_audio):
175
+ raise gr.Error("Condition audio file not found.")
176
+
177
+ run_dir = _create_run_dir()
178
+ config["output_dir"] = run_dir
179
+ generator = _seed_to_generator(seed)
180
+
181
+ try:
182
+ models = model_cache.get(config)
183
+ pipe = models["pipe"]
184
+ pipe.scheduler.config.sigma_min = config["sigma_min"]
185
+ pipe.scheduler.config.sigma_max = config["sigma_max"]
186
+ prompt_for_model = "" if config["no_text"] else (prompt_text or "")
187
+
188
+ with torch.no_grad():
189
+ if config["apadapter"]:
190
+ final_condition, final_condition_audio = process_musical_conditions(
191
+ config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe
192
+ )
193
+ waveform = pipe(
194
+ extracted_condition=final_condition,
195
+ extracted_condition_audio=final_condition_audio,
196
+ prompt=prompt_for_model,
197
+ negative_prompt=config["negative_text_prompt"],
198
+ num_inference_steps=config["denoise_step"],
199
+ guidance_scale_text=config["guidance_scale_text"],
200
+ guidance_scale_con=config["guidance_scale_con"],
201
+ guidance_scale_audio=config["guidance_scale_audio"],
202
+ num_waveforms_per_prompt=1,
203
+ audio_end_in_s=TOTAL_AUDIO_SECONDS,
204
+ generator=generator,
205
+ ).audios
206
+ output = waveform[0].T.float().cpu().numpy()
207
+ sr = pipe.vae.sampling_rate
208
+ else:
209
+ audio = pipe(
210
+ prompt=prompt_for_model,
211
+ negative_prompt=config["negative_text_prompt"],
212
+ num_inference_steps=config["denoise_step"],
213
+ guidance_scale=config["guidance_scale_text"],
214
+ num_waveforms_per_prompt=1,
215
+ audio_end_in_s=TOTAL_AUDIO_SECONDS,
216
+ generator=generator,
217
+ ).audios
218
+ output = audio[0].T.float().cpu().numpy()
219
+ sr = pipe.vae.sampling_rate
220
+
221
+ generated_path = os.path.join(run_dir, "generated.wav")
222
+ sf.write(generated_path, output, sr)
223
+
224
+ status_lines = [
225
+ f"Run directory: `{run_dir}`",
226
+ f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}",
227
+ f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}",
228
+ f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]",
229
+ ]
230
+ if config["apadapter"]:
231
+ status_lines.append(
232
+ f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}"
233
+ )
234
+ if generator is not None:
235
+ status_lines.append(f"Seed: {int(seed)}")
236
+
237
+ status_md = "\n".join(f"- {line}" for line in status_lines)
238
+ return generated_path, status_md
239
+ except gr.Error:
240
+ raise
241
+ except Exception as err: # pylint: disable=broad-except
242
+ raise gr.Error(f"Generation failed: {err}") from err
243
+
244
+
245
+ EXAMPLES = [
246
+ [
247
+ "Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
248
+ "melody_condition_audio/49_piano.mp3",
249
+ ["melody_stereo"],
250
+ True,
251
+ False,
252
+ "",
253
+ 7.0,
254
+ 1.5,
255
+ 1.0,
256
+ 50,
257
+ "fp16",
258
+ 1.0,
259
+ 0.3,
260
+ 500,
261
+ 0,
262
+ 0,
263
+ 0,
264
+ 0,
265
+ 42,
266
+ ],
267
+ [
268
+ "fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
269
+ "melody_condition_audio/610_bass.mp3",
270
+ ["melody_mono", "dynamics", "rhythm"],
271
+ True,
272
+ False,
273
+ "",
274
+ 7.0,
275
+ 1.5,
276
+ 1.0,
277
+ 50,
278
+ "fp16",
279
+ 1.0,
280
+ 0.3,
281
+ 500,
282
+ 0,
283
+ 0,
284
+ 0,
285
+ 0,
286
+ 7,
287
+ ],
288
+ ]
289
+
290
+
291
+ def build_interface() -> gr.Blocks:
292
+ with gr.Blocks(title="MuseControlLite") as demo:
293
+ gr.Markdown(
294
+ """
295
+ ## MuseControlLite demo
296
+ UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation.
297
+ """
298
+ )
299
+ with gr.Row():
300
+ prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT)
301
+ use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True)
302
+ no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False)
303
+
304
+ condition_audio = gr.Audio(
305
+ label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"]
306
+ )
307
+ condition_type = gr.CheckboxGroup(
308
+ CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", [])
309
+ )
310
+
311
+ with gr.Accordion("Advanced controls", open=False):
312
+ negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", ""))
313
+ with gr.Row():
314
+ guidance_scale_text = gr.Slider(
315
+ minimum=0.0,
316
+ maximum=12.0,
317
+ value=DEFAULT_CONFIG["guidance_scale_text"],
318
+ step=0.1,
319
+ label="Guidance scale (text)",
320
+ )
321
+ guidance_scale_con = gr.Slider(
322
+ minimum=0.0,
323
+ maximum=5.0,
324
+ value=DEFAULT_CONFIG["guidance_scale_con"],
325
+ step=0.1,
326
+ label="Guidance scale (conditions)",
327
+ )
328
+ guidance_scale_audio = gr.Slider(
329
+ minimum=0.0,
330
+ maximum=5.0,
331
+ value=DEFAULT_CONFIG["guidance_scale_audio"],
332
+ step=0.1,
333
+ label="Guidance scale (audio)",
334
+ )
335
+ with gr.Row():
336
+ denoise_step = gr.Slider(
337
+ minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps"
338
+ )
339
+ weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype")
340
+ ap_scale = gr.Slider(
341
+ minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale"
342
+ )
343
+ with gr.Row():
344
+ sigma_min = gr.Slider(
345
+ minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min"
346
+ )
347
+ sigma_max = gr.Slider(
348
+ minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max"
349
+ )
350
+ seed = gr.Number(label="Seed (optional)", precision=0)
351
+ with gr.Row():
352
+ audio_mask_start = gr.Number(
353
+ label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"]
354
+ )
355
+ audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"])
356
+ with gr.Row():
357
+ musical_mask_start = gr.Number(
358
+ label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"]
359
+ )
360
+ musical_mask_end = gr.Number(
361
+ label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"]
362
+ )
363
+
364
+ generate_btn = gr.Button("Generate", variant="primary")
365
+ generated_audio = gr.Audio(label="Generated audio", type="filepath")
366
+ status = gr.Markdown(label="Run details")
367
+
368
+ generate_btn.click(
369
+ fn=run_inference,
370
+ inputs=[
371
+ prompt,
372
+ condition_audio,
373
+ condition_type,
374
+ use_musecontrol,
375
+ no_text,
376
+ negative_prompt,
377
+ guidance_scale_text,
378
+ guidance_scale_con,
379
+ guidance_scale_audio,
380
+ denoise_step,
381
+ weight_dtype,
382
+ ap_scale,
383
+ sigma_min,
384
+ sigma_max,
385
+ audio_mask_start,
386
+ audio_mask_end,
387
+ musical_mask_start,
388
+ musical_mask_end,
389
+ seed,
390
+ ],
391
+ outputs=[generated_audio, status],
392
+ )
393
+
394
+ gr.Examples(
395
+ examples=EXAMPLES,
396
+ inputs=[
397
+ prompt,
398
+ condition_audio,
399
+ condition_type,
400
+ use_musecontrol,
401
+ no_text,
402
+ negative_prompt,
403
+ guidance_scale_text,
404
+ guidance_scale_con,
405
+ guidance_scale_audio,
406
+ denoise_step,
407
+ weight_dtype,
408
+ ap_scale,
409
+ sigma_min,
410
+ sigma_max,
411
+ audio_mask_start,
412
+ audio_mask_end,
413
+ musical_mask_start,
414
+ musical_mask_end,
415
+ seed,
416
+ ],
417
+ label="Quick start examples (click to populate the form)",
418
+ )
419
+ return demo
420
+
421
+
422
+ if __name__ == "__main__":
423
+ demo = build_interface()
424
+ demo.launch()
config_inference.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_config():
2
+ return {
3
+ "condition_type": ["melody_stereo"], # you can choose any combinations in the two sets: ["dynamics", "rhythm", "melody_mono", "audio"], ["melody_stereo", "audio"]
4
+ # When using audio, is recommend to use empty string "" as prompt
5
+ "output_dir": "./generated_audio/output",
6
+
7
+ "GPU_id": "0",
8
+
9
+ "apadapter": True, # True for MuseControlLite, False for original Stable-audio
10
+
11
+ "ap_scale": 1.0, # recommend 1.0 for MuseControlLite, other values are not tested
12
+
13
+ "guidance_scale_text": 7.0,
14
+
15
+ "guidance_scale_con": 1.5, # The separated guidance for Musical attribute condition
16
+
17
+ "guidance_scale_audio": 1.0,
18
+
19
+ "denoise_step": 50,
20
+
21
+ "sigma_min": 0.3, # sigma_min and sigma_max are for the scheduler.
22
+
23
+ "sigma_max": 500, # Note that if sigma_max is too large or too small, the "audio condition generation" will be bad.
24
+
25
+ "weight_dtype": "fp16", # fp16 and fp32 sounds quiet the same.
26
+
27
+ "negative_text_prompt": "",
28
+
29
+ ###############
30
+
31
+ "audio_mask_start_seconds": 14, # Apply mask to musical attributes choose only one mask to use, it automatically generates a complemetary mask to the other condition
32
+
33
+ "audio_mask_end_seconds": 47,
34
+
35
+ "musical_attribute_mask_start_seconds": 0, # 'Apply mask to audio condition, choose only one mask to use, it automatically generates a complemetary mask to the other condition'
36
+
37
+ "musical_attribute_mask_end_seconds": 0,
38
+
39
+ ###############
40
+
41
+ "no_text": False, # Optional, set to true if no text prompt is needed (possible for audio inpainting or outpainting)
42
+
43
+ "show_result_and_plt": True,
44
+
45
+ "audio_files": [
46
+ "melody_condition_audio/49_piano.mp3",
47
+ "melody_condition_audio/49_piano.mp3",
48
+ "melody_condition_audio/49_piano.mp3",
49
+ "melody_condition_audio/322_piano.mp3",
50
+ "melody_condition_audio/322_piano.mp3",
51
+ "melody_condition_audio/322_piano.mp3",
52
+ "melody_condition_audio/610_bass.mp3",
53
+ "melody_condition_audio/610_bass.mp3",
54
+ "melody_condition_audio/785_piano.mp3",
55
+ "melody_condition_audio/785_piano.mp3",
56
+ "melody_condition_audio/933_string.mp3",
57
+ "melody_condition_audio/933_string.mp3",
58
+ "melody_condition_audio/6_uke_12.wav",
59
+ "melody_condition_audio/6_uke_12.wav",
60
+ "melody_condition_audio/57_jazz.mp3",
61
+ "melody_condition_audio/703_mideast.mp3",
62
+
63
+ ],
64
+ # "audio_files": [
65
+ # "SDD_nosinging/SDD_audio/34/1004034.mp3",
66
+ # "original_15s/original_9.wav",
67
+ # "original_15s/original_10.wav",
68
+ # "original_15s/original_11.wav",
69
+ # "original_15s/original_15.wav",
70
+ # "original_15s/original_16.wav",
71
+ # "original_15s/original_21.wav",
72
+ # "original_15s/original_25.wav",
73
+ # ],
74
+
75
+ "text": [
76
+ "Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
77
+ "A heartfelt, warm acoustic guitar performance, evoking a sense of tenderness and deep emotion, with a melody that truly resonates and touches the heart.",
78
+ "A vibrant MIDI electronic composition with a hopeful and optimistic vibe.",
79
+ "This track composed of electronic instruments gives a sense of opening and clearness.",
80
+ "This track composed of electronic instruments gives a sense of opening and clearness.",
81
+ "Hopeful instrumental with guitar being the lead and tabla used for percussion in the middle giving a feeling of going somewhere with positive outlook.",
82
+ "A string ensemble opens the track with legato, melancholic melodies. The violins and violas play beautifully, while the cellos and bass provide harmonic support for the moving passages. The overall feel is deeply melancholic, with an emotionally stirring performance that remains harmonious and a sense of clearness.",
83
+ "An exceptionally harmonious string performance with a lively tempo in the first half, transitioning to a gentle and beautiful melody in the second half. It creates a warm and comforting atmosphere, featuring cellos and bass providing a solid foundation, while violins and violas showcase the main theme, all without any noise, resulting in a cohesive and serene sound.",
84
+ "Pop solo piano instrumental song. Simple harmony and emotional theme. Makes you feel nostalgic and wanting a cup of warm tea sitting on the couch while holding the person you love.",
85
+ "A whimsical string arrangement with rich layers, featuring violins as the main melody, accompanied by violas and cellos. The light, playful melody blends harmoniously, creating a sense of clarity.",
86
+ "An instrumental piece primarily featuring acoustic guitar, with a lively and nimble feel. The melody is bright, delivering an overall sense of joy.",
87
+ "A joyful saxophone performance that is smooth and cohesive, accompanied by cello. The first half features a relaxed tempo, while the second half picks up with an upbeat rhythm, creating a lively and energetic atmosphere. The overall sound is harmonious and clear, evoking feelings of happiness and vitality.",
88
+ "A cheerful piano performance with a smooth and flowing rhythm, evoking feelings of joy and vitality.",
89
+ "An instrumental piece primarily featuring piano, with a lively rhythm and cheerful melodies that evoke a sense of joyful childhood playfulness. The melodies are clear and bright.",
90
+ "fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
91
+ "A lively 70s style British pop song featuring drums, electric guitars, and synth violin. The instruments blend harmoniously, creating a dynamic, clean sound without any noise or clutter.",
92
+ "A soothing acoustic guitar song that evokes nostalgia, featuring intricate fingerpicking. The melody is both sacred and mysterious, with a rich texture."
93
+ ],
94
+
95
+ ########## adapters avilable ############
96
+ # We trained 4 set of adapters:
97
+ # 1. with conditions ["melody_mono", "dynamics", "rhythm"]
98
+ # 2. with conditions ["melody_mono"]
99
+ # 3. with conditions ["melody_stereo"]
100
+ # 3. with conditions ["audio"]
101
+ # MuseControlLite_inference_all.py will automaticaly choose the most suitable model according to the condition type:
102
+ ###############
103
+ # Works for condition ["dynamics", "rhythm", "melody_mono"]
104
+ "transformer_ckpt_musical": "./checkpoints/woSDD-all/model_3.safetensors",
105
+
106
+ "extractor_ckpt_musical": {
107
+ "dynamics": "./checkpoints/woSDD-all/model_1.safetensors",
108
+ "melody": "./checkpoints/woSDD-all/model.safetensors",
109
+ "rhythm": "./checkpoints/woSDD-all/model_2.safetensors",
110
+ },
111
+ ###############
112
+
113
+ # Works for ['audio], it works without a feature extractor, and could cooperate with other adapters
114
+ #################
115
+ "audio_transformer_ckpt": "./checkpoints/70000_Audio/model.safetensors",
116
+
117
+ # Specialized for ['melody_stereo']
118
+ ###############
119
+ "transformer_ckpt_melody_stero": "./checkpoints/70000_Melody_stereo/model_1.safetensors",
120
+
121
+ "extractor_ckpt_melody_stero": {
122
+ "melody": "./checkpoints/70000_Melody_stereo/model.safetensors",
123
+ },
124
+ ###############
125
+
126
+ # Specialized for ['melody_mono']
127
+ ###############
128
+ "transformer_ckpt_melody_mono": "./checkpoints/40000_Melody_mono/model_1.safetensors",
129
+
130
+ "extractor_ckpt_melody_mono": {
131
+ "melody": "./checkpoints/40000_Melody_mono/model.safetensors",
132
+ },
133
+ ###############
134
+ }
melody_condition_audio/322_piano.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:698e40e4067efa3b181ea367ec8b0bc76b651cc0ca9bee329a3833565f35a800
3
+ size 915798
melody_condition_audio/49_piano.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b00df95a62c91e33c71a4ee312fb84883b3ff58cadb66de8582055ce89d72636
3
+ size 1106827
melody_condition_audio/57_jazz.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9a4cf5f07b40270500ea05e3c756f1d02817c1c6cdd07724e7c102b33d71d2f
3
+ size 1101758
melody_condition_audio/610_bass.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5ead8df05aa7cd33f193c315691cfb6bc8f23bc651f6c4947e3685ab11503bb
3
+ size 1133359
melody_condition_audio/703_mideast.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8db8aaa68b2d749fcf6d4e0cfc1ccd7a41fe4e0e942f0c0dbab2033d5eca07b1
3
+ size 1143212
melody_condition_audio/785_piano.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e8cef236b1723ced1fb55c6a5f28205a36babbd8add95553a508a88519d74f7
3
+ size 1110813
melody_condition_audio/933_string.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f686104a4849b96dea470db2067df342d9e1265b30e254da5af85d83edcac45
3
+ size 1097973
pipeline/stable_audio_multi_cfg_pipe.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import matplotlib.pyplot as plt
15
+
16
+ import inspect
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ T5EncoderModel,
22
+ T5Tokenizer,
23
+ T5TokenizerFast,
24
+ )
25
+
26
+ from diffusers.models import AutoencoderOobleck, StableAudioDiTModel
27
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
28
+ from diffusers.schedulers import EDMDPMSolverMultistepScheduler
29
+ from diffusers.utils import (
30
+ logging,
31
+ replace_example_docstring,
32
+ )
33
+ import numpy as np
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
36
+ from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
37
+ from torch.cuda.amp import autocast, GradScaler
38
+
39
+ def check_and_print_non_float32_parameters(model):
40
+ non_float32_params = []
41
+ for name, param in model.named_parameters():
42
+ if param.dtype != torch.float32:
43
+ non_float32_params.append((name, param.dtype))
44
+
45
+ if non_float32_params:
46
+ print("Not all parameters are in float32!")
47
+ print("The following parameters are not in float32:")
48
+ for name, dtype in non_float32_params:
49
+ print(f"Parameter: {name}, Data Type: {dtype}")
50
+ else:
51
+ print("All parameters are in float32.")
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import scipy
58
+ >>> import torch
59
+ >>> import soundfile as sf
60
+ >>> from diffusers import StableAudioPipeline
61
+
62
+ >>> repo_id = "stabilityai/stable-audio-open-1.0"
63
+ >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
64
+ >>> pipe = pipe.to("cuda")
65
+
66
+ >>> # define the prompts
67
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
68
+ >>> negative_prompt = "Low quality."
69
+
70
+ >>> # set the seed for generator
71
+ >>> generator = torch.Generator("cuda").manual_seed(0)
72
+
73
+ >>> # run the generation
74
+ >>> audio = pipe(
75
+ ... prompt,
76
+ ... negative_prompt=negative_prompt,
77
+ ... num_inference_steps=200,
78
+ ... audio_end_in_s=10.0,
79
+ ... num_waveforms_per_prompt=3,
80
+ ... generator=generator,
81
+ ... ).audios
82
+
83
+ >>> output = audio[0].T.float().cpu().numpy()
84
+ >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
85
+ ```
86
+ """
87
+
88
+
89
+ class StableAudioPipeline(DiffusionPipeline):
90
+ r"""
91
+ Pipeline for text-to-audio generation using StableAudio.
92
+
93
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
94
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
95
+
96
+ Args:
97
+ vae ([`AutoencoderOobleck`]):
98
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
99
+ text_encoder ([`~transformers.T5EncoderModel`]):
100
+ Frozen text-encoder. StableAudio uses the encoder of
101
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
102
+ [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
103
+ projection_model ([`StableAudioProjectionModel`]):
104
+ A trained model used to linearly project the hidden-states from the text encoder model and the start and
105
+ end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
106
+ give the input to the transformer model.
107
+ tokenizer ([`~transformers.T5Tokenizer`]):
108
+ Tokenizer to tokenize text for the frozen text-encoder.
109
+ transformer ([`StableAudioDiTModel`]):
110
+ A `StableAudioDiTModel` to denoise the encoded audio latents.
111
+ scheduler ([`EDMDPMSolverMultistepScheduler`]):
112
+ A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
113
+ """
114
+
115
+ model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
116
+
117
+ def __init__(
118
+ self,
119
+ vae: AutoencoderOobleck,
120
+ text_encoder: T5EncoderModel,
121
+ projection_model: StableAudioProjectionModel,
122
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
123
+ transformer: StableAudioDiTModel,
124
+ scheduler: EDMDPMSolverMultistepScheduler,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ projection_model=projection_model,
132
+ tokenizer=tokenizer,
133
+ transformer=transformer,
134
+ scheduler=scheduler,
135
+ )
136
+ self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
137
+
138
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
139
+ def enable_vae_slicing(self):
140
+ r"""
141
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143
+ """
144
+ self.vae.enable_slicing()
145
+
146
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
147
+ def disable_vae_slicing(self):
148
+ r"""
149
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
150
+ computing decoding in one step.
151
+ """
152
+ self.vae.disable_slicing()
153
+
154
+ def encode_prompt(
155
+ self,
156
+ prompt,
157
+ device,
158
+ do_classifier_free_guidance,
159
+ negative_prompt=None,
160
+ prompt_embeds: Optional[torch.Tensor] = None,
161
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.LongTensor] = None,
163
+ negative_attention_mask: Optional[torch.LongTensor] = None,
164
+ ):
165
+ if prompt is not None and isinstance(prompt, str):
166
+ batch_size = 1
167
+ elif prompt is not None and isinstance(prompt, list):
168
+ batch_size = len(prompt)
169
+ else:
170
+ batch_size = prompt_embeds.shape[0]
171
+
172
+ if prompt_embeds is None:
173
+ # 1. Tokenize text
174
+ text_inputs = self.tokenizer(
175
+ prompt,
176
+ padding="max_length",
177
+ max_length=self.tokenizer.model_max_length,
178
+ truncation=True,
179
+ return_tensors="pt",
180
+ )
181
+ text_input_ids = text_inputs.input_ids
182
+ attention_mask = text_inputs.attention_mask
183
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
184
+
185
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
186
+ text_input_ids, untruncated_ids
187
+ ):
188
+ removed_text = self.tokenizer.batch_decode(
189
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
190
+ )
191
+ # logger.warning(
192
+ # f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
193
+ # f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
194
+ # )
195
+
196
+ text_input_ids = text_input_ids.to(device)
197
+ attention_mask = attention_mask.to(device)
198
+
199
+ # 2. Text encoder forward
200
+ self.text_encoder.eval()
201
+ prompt_embeds = self.text_encoder(
202
+ text_input_ids,
203
+ attention_mask=attention_mask,
204
+ )
205
+ prompt_embeds = prompt_embeds[0]
206
+
207
+ if do_classifier_free_guidance and negative_prompt is not None:
208
+ uncond_tokens: List[str]
209
+ if type(prompt) is not type(negative_prompt):
210
+ raise TypeError(
211
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
212
+ f" {type(prompt)}."
213
+ )
214
+ elif isinstance(negative_prompt, str):
215
+ uncond_tokens = [negative_prompt]
216
+ elif batch_size != len(negative_prompt):
217
+ raise ValueError(
218
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
219
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
220
+ " the batch size of `prompt`."
221
+ )
222
+ else:
223
+ uncond_tokens = negative_prompt
224
+
225
+ # 1. Tokenize text
226
+ uncond_input = self.tokenizer(
227
+ uncond_tokens,
228
+ padding="max_length",
229
+ max_length=self.tokenizer.model_max_length,
230
+ truncation=True,
231
+ return_tensors="pt",
232
+ )
233
+
234
+ uncond_input_ids = uncond_input.input_ids.to(device)
235
+ negative_attention_mask = uncond_input.attention_mask.to(device)
236
+
237
+ # 2. Text encoder forward
238
+ self.text_encoder.eval()
239
+ negative_prompt_embeds = self.text_encoder(
240
+ uncond_input_ids,
241
+ attention_mask=negative_attention_mask,
242
+ )
243
+ negative_prompt_embeds = negative_prompt_embeds[0]
244
+
245
+ if negative_attention_mask is not None:
246
+ # set the masked tokens to the null embed
247
+ negative_prompt_embeds = torch.where(
248
+ negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
249
+ )
250
+
251
+ # 3. Project prompt_embeds and negative_prompt_embeds
252
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
253
+ # For classifier free guidance, we need to do two forward passes.
254
+ # Here we concatenate the negative and text embeddings into a single batch
255
+ # to avoid doing two forward passes
256
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
257
+ if attention_mask is not None and negative_attention_mask is None:
258
+ negative_attention_mask = torch.ones_like(attention_mask)
259
+ elif attention_mask is None and negative_attention_mask is not None:
260
+ attention_mask = torch.ones_like(negative_attention_mask)
261
+ if attention_mask is not None:
262
+ attention_mask = torch.cat([negative_attention_mask, attention_mask, attention_mask])
263
+
264
+ prompt_embeds = self.projection_model(
265
+ text_hidden_states=prompt_embeds,
266
+ ).text_hidden_states
267
+ if attention_mask is not None:
268
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
269
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
270
+
271
+ return prompt_embeds
272
+
273
+ def encode_duration(
274
+ self,
275
+ audio_start_in_s,
276
+ audio_end_in_s,
277
+ device,
278
+ do_classifier_free_guidance,
279
+ batch_size,
280
+ ):
281
+ audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
282
+ audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
283
+
284
+ if len(audio_start_in_s) == 1:
285
+ audio_start_in_s = audio_start_in_s * batch_size
286
+ if len(audio_end_in_s) == 1:
287
+ audio_end_in_s = audio_end_in_s * batch_size
288
+
289
+ # Cast the inputs to floats
290
+ audio_start_in_s = [float(x) for x in audio_start_in_s]
291
+ audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
292
+
293
+ audio_end_in_s = [float(x) for x in audio_end_in_s]
294
+ audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
295
+
296
+ projection_output = self.projection_model(
297
+ start_seconds=audio_start_in_s,
298
+ end_seconds=audio_end_in_s,
299
+ )
300
+ seconds_start_hidden_states = projection_output.seconds_start_hidden_states
301
+ seconds_end_hidden_states = projection_output.seconds_end_hidden_states
302
+
303
+ # For classifier free guidance, we need to do two forward passes.
304
+ # Here we repeat the audio hidden states to avoid doing two forward passes
305
+ if do_classifier_free_guidance:
306
+ seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
307
+ seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
308
+
309
+ return seconds_start_hidden_states, seconds_end_hidden_states
310
+
311
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
312
+ def prepare_extra_step_kwargs(self, generator, eta):
313
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
314
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
315
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
316
+ # and should be between [0, 1]
317
+
318
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
319
+ extra_step_kwargs = {}
320
+ if accepts_eta:
321
+ extra_step_kwargs["eta"] = eta
322
+
323
+ # check if the scheduler accepts generator
324
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
325
+ if accepts_generator:
326
+ extra_step_kwargs["generator"] = generator
327
+ return extra_step_kwargs
328
+
329
+ def check_inputs(
330
+ self,
331
+ prompt,
332
+ audio_start_in_s,
333
+ audio_end_in_s,
334
+ callback_steps,
335
+ negative_prompt=None,
336
+ prompt_embeds=None,
337
+ negative_prompt_embeds=None,
338
+ attention_mask=None,
339
+ negative_attention_mask=None,
340
+ initial_audio_waveforms=None,
341
+ initial_audio_sampling_rate=None,
342
+ ):
343
+ if audio_end_in_s < audio_start_in_s:
344
+ raise ValueError(
345
+ f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
346
+ )
347
+
348
+ if (
349
+ audio_start_in_s < self.projection_model.config.min_value
350
+ or audio_start_in_s > self.projection_model.config.max_value
351
+ ):
352
+ raise ValueError(
353
+ f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
354
+ f"is {audio_start_in_s}."
355
+ )
356
+
357
+ if (
358
+ audio_end_in_s < self.projection_model.config.min_value
359
+ or audio_end_in_s > self.projection_model.config.max_value
360
+ ):
361
+ raise ValueError(
362
+ f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
363
+ f"is {audio_end_in_s}."
364
+ )
365
+
366
+ if (callback_steps is None) or (
367
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
368
+ ):
369
+ raise ValueError(
370
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
371
+ f" {type(callback_steps)}."
372
+ )
373
+
374
+ if prompt is not None and prompt_embeds is not None:
375
+ raise ValueError(
376
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
377
+ " only forward one of the two."
378
+ )
379
+ elif prompt is None and (prompt_embeds is None):
380
+ raise ValueError(
381
+ "Provide either `prompt`, or `prompt_embeds`. Cannot leave"
382
+ "`prompt` undefined without specifying `prompt_embeds`."
383
+ )
384
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
385
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
386
+
387
+ if negative_prompt is not None and negative_prompt_embeds is not None:
388
+ raise ValueError(
389
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
390
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
391
+ )
392
+
393
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
394
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
395
+ raise ValueError(
396
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
397
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
398
+ f" {negative_prompt_embeds.shape}."
399
+ )
400
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
401
+ raise ValueError(
402
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
403
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
404
+ )
405
+
406
+ if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
407
+ raise ValueError(
408
+ "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
409
+ )
410
+
411
+ if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
412
+ raise ValueError(
413
+ f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
414
+ "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
415
+ )
416
+
417
+ def prepare_latents(
418
+ self,
419
+ batch_size,
420
+ num_channels_vae,
421
+ sample_size,
422
+ dtype,
423
+ device,
424
+ generator,
425
+ latents=None,
426
+ initial_audio_waveforms=None,
427
+ num_waveforms_per_prompt=None,
428
+ audio_channels=None,
429
+ ):
430
+ shape = (batch_size, num_channels_vae, sample_size)
431
+ if isinstance(generator, list) and len(generator) != batch_size:
432
+ raise ValueError(
433
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
434
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
435
+ )
436
+
437
+ if latents is None:
438
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
439
+ else:
440
+ latents = latents.to(device)
441
+
442
+ # scale the initial noise by the standard deviation required by the scheduler
443
+ latents = latents * self.scheduler.init_noise_sigma
444
+
445
+ # encode the initial audio for use by the model
446
+ if initial_audio_waveforms is not None:
447
+ # check dimension
448
+ if initial_audio_waveforms.ndim == 2:
449
+ initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
450
+ elif initial_audio_waveforms.ndim != 3:
451
+ raise ValueError(
452
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
453
+ )
454
+
455
+ audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
456
+ audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
457
+
458
+ # check num_channels
459
+ if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
460
+ initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
461
+ elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
462
+ initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
463
+
464
+ if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
465
+ raise ValueError(
466
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
467
+ )
468
+
469
+ # crop or pad
470
+ audio_length = initial_audio_waveforms.shape[-1]
471
+ if audio_length < audio_vae_length:
472
+ logger.warning(
473
+ f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
474
+ )
475
+ elif audio_length > audio_vae_length:
476
+ logger.warning(
477
+ f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
478
+ )
479
+
480
+ audio = initial_audio_waveforms.new_zeros(audio_shape)
481
+ audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
482
+
483
+ encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
484
+ encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
485
+ latents = encoded_audio + latents
486
+ return latents
487
+
488
+ @torch.no_grad()
489
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
490
+ def __call__(
491
+ self,
492
+ guidance_scale_audio = None,
493
+ extracted_condition_audio = None,
494
+ extracted_condition = None,
495
+ prompt: Union[str, List[str]] = None,
496
+ audio_end_in_s: Optional[float] = None,
497
+ audio_start_in_s: Optional[float] = 0.0,
498
+ num_inference_steps: int = 100,
499
+ guidance_scale_text: float = 7.0,
500
+ guidance_scale_con: float = 2.0,
501
+ negative_prompt: Optional[Union[str, List[str]]] = None,
502
+ num_waveforms_per_prompt: Optional[int] = 1,
503
+ eta: float = 0.0,
504
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
505
+ latents: Optional[torch.Tensor] = None,
506
+ initial_audio_waveforms: Optional[torch.Tensor] = None,
507
+ initial_audio_sampling_rate: Optional[torch.Tensor] = None,
508
+ prompt_embeds: Optional[torch.Tensor] = None,
509
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
510
+ attention_mask: Optional[torch.LongTensor] = None,
511
+ negative_attention_mask: Optional[torch.LongTensor] = None,
512
+ return_dict: bool = True,
513
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
514
+ callback_steps: Optional[int] = 1,
515
+ output_type: Optional[str] = "pt",
516
+ ):
517
+ r"""
518
+ The call function to the pipeline for generation.
519
+
520
+ Args:
521
+ prompt (`str` or `List[str]`, *optional*):
522
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
523
+ audio_end_in_s (`float`, *optional*, defaults to 47.55):
524
+ Audio end index in seconds.
525
+ audio_start_in_s (`float`, *optional*, defaults to 0):
526
+ Audio start index in seconds.
527
+ num_inference_steps (`int`, *optional*, defaults to 100):
528
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
529
+ expense of slower inference.
530
+ guidance_scale (`float`, *optional*, defaults to 7.0):
531
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
532
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
533
+ negative_prompt (`str` or `List[str]`, *optional*):
534
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
535
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
536
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
537
+ The number of waveforms to generate per prompt.
538
+ eta (`float`, *optional*, defaults to 0.0):
539
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
540
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
541
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
542
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
543
+ generation deterministic.
544
+ latents (`torch.Tensor`, *optional*):
545
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
546
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
547
+ tensor is generated by sampling using the supplied random `generator`.
548
+ initial_audio_waveforms (`torch.Tensor`, *optional*):
549
+ Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
550
+ `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
551
+ corresponds to the number of prompts passed to the model.
552
+ initial_audio_sampling_rate (`int`, *optional*):
553
+ Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
554
+ prompt_embeds (`torch.Tensor`, *optional*):
555
+ Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
556
+ *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
557
+ argument.
558
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
559
+ Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
560
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
561
+ `negative_prompt` input argument.
562
+ attention_mask (`torch.LongTensor`, *optional*):
563
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
564
+ be computed from `prompt` input argument.
565
+ negative_attention_mask (`torch.LongTensor`, *optional*):
566
+ Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
567
+ return_dict (`bool`, *optional*, defaults to `True`):
568
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
569
+ plain tuple.
570
+ callback (`Callable`, *optional*):
571
+ A function that calls every `callback_steps` steps during inference. The function is called with the
572
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
573
+ callback_steps (`int`, *optional*, defaults to 1):
574
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
575
+ every step.
576
+ output_type (`str`, *optional*, defaults to `"pt"`):
577
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
578
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
579
+ model (LDM) output.
580
+
581
+ Examples:
582
+
583
+ Returns:
584
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
585
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
586
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
587
+ """
588
+ # 0. Convert audio input length from seconds to latent length
589
+ downsample_ratio = self.vae.hop_length
590
+
591
+ max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
592
+ if audio_end_in_s is None:
593
+ audio_end_in_s = max_audio_length_in_s
594
+
595
+ if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
596
+ raise ValueError(
597
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
598
+ )
599
+
600
+ waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
601
+ waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
602
+ waveform_length = int(self.transformer.config.sample_size) # * audio_end_in_s / 47.554
603
+ # waveform_length = 646
604
+ # 1. Check inputs. Raise error if not correct
605
+ self.check_inputs(
606
+ prompt,
607
+ audio_start_in_s,
608
+ audio_end_in_s,
609
+ callback_steps,
610
+ negative_prompt,
611
+ prompt_embeds,
612
+ negative_prompt_embeds,
613
+ attention_mask,
614
+ negative_attention_mask,
615
+ initial_audio_waveforms,
616
+ initial_audio_sampling_rate,
617
+ )
618
+
619
+ # 2. Define call parameters
620
+ if prompt is not None and isinstance(prompt, str):
621
+ batch_size = 1
622
+ elif prompt is not None and isinstance(prompt, list):
623
+ batch_size = len(prompt)
624
+ else:
625
+ batch_size = prompt_embeds.shape[0]
626
+
627
+ device = self._execution_device
628
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
629
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
630
+ # corresponds to doing no classifier free guidance.
631
+ do_classifier_free_guidance = True
632
+
633
+ # 3. Encode input prompt
634
+ prompt_embeds = self.encode_prompt(
635
+ prompt,
636
+ device,
637
+ do_classifier_free_guidance,
638
+ negative_prompt,
639
+ prompt_embeds,
640
+ negative_prompt_embeds,
641
+ attention_mask,
642
+ negative_attention_mask,
643
+ )
644
+
645
+ # Encode duration
646
+ seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
647
+ audio_start_in_s,
648
+ audio_end_in_s,
649
+ device,
650
+ do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
651
+ batch_size,
652
+ )
653
+
654
+ # Create text_audio_duration_embeds and audio_duration_embeds
655
+ text_audio_duration_embeds = torch.cat(
656
+ [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
657
+ )
658
+
659
+ audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
660
+
661
+ # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
662
+ # to concatenate it to the embeddings
663
+ if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
664
+ negative_text_audio_duration_embeds = torch.zeros_like(
665
+ text_audio_duration_embeds, device=text_audio_duration_embeds.device
666
+ )
667
+ text_audio_duration_embeds = torch.cat(
668
+ [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
669
+ )
670
+ audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
671
+ # if condition is not None:
672
+ # condition_conditioning = condition_model(condition)
673
+ # condition_no_conditioning = condition_model(torch.full_like(condition, fill_value=0))
674
+ # extracted_condition = torch.cat([condition_no_conditioning, condition_no_conditioning, condition_conditioning], dim=0)
675
+
676
+ bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
677
+ # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
678
+ text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
679
+ text_audio_duration_embeds = text_audio_duration_embeds.view(
680
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
681
+ )
682
+
683
+ audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
684
+ audio_duration_embeds = audio_duration_embeds.view(
685
+ bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
686
+ )
687
+
688
+ # 4. Prepare timesteps
689
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
690
+ timesteps = self.scheduler.timesteps
691
+
692
+ # 5. Prepare latent variables
693
+ num_channels_vae = self.transformer.config.in_channels
694
+ latents = self.prepare_latents(
695
+ batch_size * num_waveforms_per_prompt,
696
+ num_channels_vae,
697
+ waveform_length,
698
+ text_audio_duration_embeds.dtype,
699
+ device,
700
+ generator,
701
+ latents,
702
+ initial_audio_waveforms,
703
+ num_waveforms_per_prompt,
704
+ audio_channels=self.vae.config.audio_channels,
705
+ )
706
+
707
+ # 6. Prepare extra step kwargs
708
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
709
+
710
+ # 7. Prepare rotary positional embedding
711
+ rotary_embedding = get_1d_rotary_pos_embed(
712
+ self.rotary_embed_dim,
713
+ latents.shape[2] + audio_duration_embeds.shape[1],
714
+ use_real=True,
715
+ repeat_interleave_real=False,
716
+ )
717
+
718
+ # 8. Denoising loop
719
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
720
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
721
+ for i, t in enumerate(timesteps):
722
+ # expand the latents if we are doing classifier free guidance
723
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
724
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
725
+ with autocast():
726
+ noise_pred = self.transformer(
727
+ latent_model_input,
728
+ t.unsqueeze(0),
729
+ encoder_hidden_states=text_audio_duration_embeds,
730
+ encoder_hidden_states_con=extracted_condition,
731
+ global_hidden_states=audio_duration_embeds,
732
+ rotary_embedding=rotary_embedding,
733
+ return_dict=False,
734
+ )[0]
735
+ # transformer_weight_dtype = next(self.transformer.parameters()).dtype
736
+ # print("transformer_weight_dtype",transformer_weight_dtype)
737
+ # noise_pred = noise_pred.half()
738
+
739
+ # perform guidance
740
+ if do_classifier_free_guidance:
741
+ noise_pred_uncond, noise_pred_text, noise_pred_both= noise_pred.chunk(3)
742
+ noise_pred = noise_pred_uncond + guidance_scale_text * (noise_pred_text - noise_pred_uncond) + guidance_scale_con * (noise_pred_both - noise_pred_text)
743
+
744
+ # compute the previous noisy sample x_t -> x_t-1
745
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
746
+
747
+ # call the callback, if provided
748
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
749
+ progress_bar.update()
750
+ if callback is not None and i % callback_steps == 0:
751
+ step_idx = i // getattr(self.scheduler, "order", 1)
752
+ callback(step_idx, t, latents)
753
+
754
+ # 9. Post-processing
755
+ if not output_type == "latent":
756
+ with autocast():
757
+ audio = self.vae.decode(latents).sample
758
+
759
+ else:
760
+ return AudioPipelineOutput(audios=latents)
761
+
762
+ audio = audio[:, :, waveform_start:waveform_end]
763
+
764
+ if output_type == "np":
765
+ audio = audio.cpu().float().numpy()
766
+
767
+ self.maybe_free_model_hooks()
768
+
769
+ if not return_dict:
770
+ return (audio,)
771
+
772
+ return AudioPipelineOutput(audios=audio)
pipeline/stable_audio_multi_cfg_pipe_audio.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import matplotlib.pyplot as plt
15
+
16
+ import inspect
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ T5EncoderModel,
22
+ T5Tokenizer,
23
+ T5TokenizerFast,
24
+ )
25
+
26
+ from diffusers.models import AutoencoderOobleck, StableAudioDiTModel
27
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
28
+ from diffusers.schedulers import EDMDPMSolverMultistepScheduler
29
+ from diffusers.utils import (
30
+ logging,
31
+ replace_example_docstring,
32
+ )
33
+ import numpy as np
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
36
+ from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
37
+ from torch.cuda.amp import autocast, GradScaler
38
+
39
+ def check_and_print_non_float32_parameters(model):
40
+ non_float32_params = []
41
+ for name, param in model.named_parameters():
42
+ if param.dtype != torch.float32:
43
+ non_float32_params.append((name, param.dtype))
44
+
45
+ if non_float32_params:
46
+ print("Not all parameters are in float32!")
47
+ print("The following parameters are not in float32:")
48
+ for name, dtype in non_float32_params:
49
+ print(f"Parameter: {name}, Data Type: {dtype}")
50
+ else:
51
+ print("All parameters are in float32.")
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import scipy
58
+ >>> import torch
59
+ >>> import soundfile as sf
60
+ >>> from diffusers import StableAudioPipeline
61
+
62
+ >>> repo_id = "stabilityai/stable-audio-open-1.0"
63
+ >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
64
+ >>> pipe = pipe.to("cuda")
65
+
66
+ >>> # define the prompts
67
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
68
+ >>> negative_prompt = "Low quality."
69
+
70
+ >>> # set the seed for generator
71
+ >>> generator = torch.Generator("cuda").manual_seed(0)
72
+
73
+ >>> # run the generation
74
+ >>> audio = pipe(
75
+ ... prompt,
76
+ ... negative_prompt=negative_prompt,
77
+ ... num_inference_steps=200,
78
+ ... audio_end_in_s=10.0,
79
+ ... num_waveforms_per_prompt=3,
80
+ ... generator=generator,
81
+ ... ).audios
82
+
83
+ >>> output = audio[0].T.float().cpu().numpy()
84
+ >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
85
+ ```
86
+ """
87
+
88
+
89
+ class StableAudioPipeline(DiffusionPipeline):
90
+ r"""
91
+ Pipeline for text-to-audio generation using StableAudio.
92
+
93
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
94
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
95
+
96
+ Args:
97
+ vae ([`AutoencoderOobleck`]):
98
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
99
+ text_encoder ([`~transformers.T5EncoderModel`]):
100
+ Frozen text-encoder. StableAudio uses the encoder of
101
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
102
+ [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
103
+ projection_model ([`StableAudioProjectionModel`]):
104
+ A trained model used to linearly project the hidden-states from the text encoder model and the start and
105
+ end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
106
+ give the input to the transformer model.
107
+ tokenizer ([`~transformers.T5Tokenizer`]):
108
+ Tokenizer to tokenize text for the frozen text-encoder.
109
+ transformer ([`StableAudioDiTModel`]):
110
+ A `StableAudioDiTModel` to denoise the encoded audio latents.
111
+ scheduler ([`EDMDPMSolverMultistepScheduler`]):
112
+ A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
113
+ """
114
+
115
+ model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
116
+
117
+ def __init__(
118
+ self,
119
+ vae: AutoencoderOobleck,
120
+ text_encoder: T5EncoderModel,
121
+ projection_model: StableAudioProjectionModel,
122
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
123
+ transformer: StableAudioDiTModel,
124
+ scheduler: EDMDPMSolverMultistepScheduler,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ projection_model=projection_model,
132
+ tokenizer=tokenizer,
133
+ transformer=transformer,
134
+ scheduler=scheduler,
135
+ )
136
+ self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
137
+
138
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
139
+ def enable_vae_slicing(self):
140
+ r"""
141
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143
+ """
144
+ self.vae.enable_slicing()
145
+
146
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
147
+ def disable_vae_slicing(self):
148
+ r"""
149
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
150
+ computing decoding in one step.
151
+ """
152
+ self.vae.disable_slicing()
153
+
154
+ def encode_prompt(
155
+ self,
156
+ prompt,
157
+ device,
158
+ do_classifier_free_guidance,
159
+ negative_prompt=None,
160
+ prompt_embeds: Optional[torch.Tensor] = None,
161
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.LongTensor] = None,
163
+ negative_attention_mask: Optional[torch.LongTensor] = None,
164
+ ):
165
+ if prompt is not None and isinstance(prompt, str):
166
+ batch_size = 1
167
+ elif prompt is not None and isinstance(prompt, list):
168
+ batch_size = len(prompt)
169
+ else:
170
+ batch_size = prompt_embeds.shape[0]
171
+
172
+ if prompt_embeds is None:
173
+ # 1. Tokenize text
174
+ self.tokenizer.model_max_length = 512
175
+ text_inputs = self.tokenizer(
176
+ prompt,
177
+ padding="max_length",
178
+ max_length=self.tokenizer.model_max_length,
179
+ truncation=True,
180
+ return_tensors="pt",
181
+ )
182
+ text_input_ids = text_inputs.input_ids
183
+ attention_mask = text_inputs.attention_mask
184
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
185
+
186
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
187
+ text_input_ids, untruncated_ids
188
+ ):
189
+ removed_text = self.tokenizer.batch_decode(
190
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
191
+ )
192
+ logger.warning(
193
+ f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
194
+ f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
195
+ )
196
+
197
+ text_input_ids = text_input_ids.to(device)
198
+ attention_mask = attention_mask.to(device)
199
+
200
+ # 2. Text encoder forward
201
+ self.text_encoder.eval()
202
+ prompt_embeds = self.text_encoder(
203
+ text_input_ids,
204
+ attention_mask=attention_mask,
205
+ )
206
+ prompt_embeds = prompt_embeds[0]
207
+
208
+ if do_classifier_free_guidance and negative_prompt is not None:
209
+ uncond_tokens: List[str]
210
+ if type(prompt) is not type(negative_prompt):
211
+ raise TypeError(
212
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
213
+ f" {type(prompt)}."
214
+ )
215
+ elif isinstance(negative_prompt, str):
216
+ uncond_tokens = [negative_prompt]
217
+ elif batch_size != len(negative_prompt):
218
+ raise ValueError(
219
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
220
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
221
+ " the batch size of `prompt`."
222
+ )
223
+ else:
224
+ uncond_tokens = negative_prompt
225
+
226
+ # 1. Tokenize text
227
+ uncond_input = self.tokenizer(
228
+ uncond_tokens,
229
+ padding="max_length",
230
+ max_length=self.tokenizer.model_max_length,
231
+ truncation=True,
232
+ return_tensors="pt",
233
+ )
234
+
235
+ uncond_input_ids = uncond_input.input_ids.to(device)
236
+ negative_attention_mask = uncond_input.attention_mask.to(device)
237
+
238
+ # 2. Text encoder forward
239
+ self.text_encoder.eval()
240
+ negative_prompt_embeds = self.text_encoder(
241
+ uncond_input_ids,
242
+ attention_mask=negative_attention_mask,
243
+ )
244
+ negative_prompt_embeds = negative_prompt_embeds[0]
245
+
246
+ if negative_attention_mask is not None:
247
+ # set the masked tokens to the null embed
248
+ negative_prompt_embeds = torch.where(
249
+ negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
250
+ )
251
+
252
+ # 3. Project prompt_embeds and negative_prompt_embeds
253
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
254
+ # For classifier free guidance, we need to do two forward passes.
255
+ # Here we concatenate the negative and text embeddings into a single batch
256
+ # to avoid doing two forward passes
257
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds, prompt_embeds])
258
+ if attention_mask is not None and negative_attention_mask is None:
259
+ negative_attention_mask = torch.ones_like(attention_mask)
260
+ elif attention_mask is None and negative_attention_mask is not None:
261
+ attention_mask = torch.ones_like(negative_attention_mask)
262
+ if attention_mask is not None:
263
+ attention_mask = torch.cat([negative_attention_mask, attention_mask, attention_mask, attention_mask])
264
+
265
+ prompt_embeds = self.projection_model(
266
+ text_hidden_states=prompt_embeds,
267
+ ).text_hidden_states
268
+ if attention_mask is not None:
269
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
270
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
271
+
272
+ return prompt_embeds
273
+
274
+ def encode_duration(
275
+ self,
276
+ audio_start_in_s,
277
+ audio_end_in_s,
278
+ device,
279
+ do_classifier_free_guidance,
280
+ batch_size,
281
+ ):
282
+ audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
283
+ audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
284
+
285
+ if len(audio_start_in_s) == 1:
286
+ audio_start_in_s = audio_start_in_s * batch_size
287
+ if len(audio_end_in_s) == 1:
288
+ audio_end_in_s = audio_end_in_s * batch_size
289
+
290
+ # Cast the inputs to floats
291
+ audio_start_in_s = [float(x) for x in audio_start_in_s]
292
+ audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
293
+
294
+ audio_end_in_s = [float(x) for x in audio_end_in_s]
295
+ audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
296
+
297
+ projection_output = self.projection_model(
298
+ start_seconds=audio_start_in_s,
299
+ end_seconds=audio_end_in_s,
300
+ )
301
+ seconds_start_hidden_states = projection_output.seconds_start_hidden_states
302
+ seconds_end_hidden_states = projection_output.seconds_end_hidden_states
303
+
304
+ # For classifier free guidance, we need to do two forward passes.
305
+ # Here we repeat the audio hidden states to avoid doing two forward passes
306
+ if do_classifier_free_guidance:
307
+ seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
308
+ seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
309
+
310
+ return seconds_start_hidden_states, seconds_end_hidden_states
311
+
312
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
313
+ def prepare_extra_step_kwargs(self, generator, eta):
314
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
315
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
316
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
317
+ # and should be between [0, 1]
318
+
319
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
320
+ extra_step_kwargs = {}
321
+ if accepts_eta:
322
+ extra_step_kwargs["eta"] = eta
323
+
324
+ # check if the scheduler accepts generator
325
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
326
+ if accepts_generator:
327
+ extra_step_kwargs["generator"] = generator
328
+ return extra_step_kwargs
329
+
330
+ def check_inputs(
331
+ self,
332
+ prompt,
333
+ audio_start_in_s,
334
+ audio_end_in_s,
335
+ callback_steps,
336
+ negative_prompt=None,
337
+ prompt_embeds=None,
338
+ negative_prompt_embeds=None,
339
+ attention_mask=None,
340
+ negative_attention_mask=None,
341
+ initial_audio_waveforms=None,
342
+ initial_audio_sampling_rate=None,
343
+ ):
344
+ if audio_end_in_s < audio_start_in_s:
345
+ raise ValueError(
346
+ f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
347
+ )
348
+
349
+ if (
350
+ audio_start_in_s < self.projection_model.config.min_value
351
+ or audio_start_in_s > self.projection_model.config.max_value
352
+ ):
353
+ raise ValueError(
354
+ f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
355
+ f"is {audio_start_in_s}."
356
+ )
357
+
358
+ if (
359
+ audio_end_in_s < self.projection_model.config.min_value
360
+ or audio_end_in_s > self.projection_model.config.max_value
361
+ ):
362
+ raise ValueError(
363
+ f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
364
+ f"is {audio_end_in_s}."
365
+ )
366
+
367
+ if (callback_steps is None) or (
368
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
369
+ ):
370
+ raise ValueError(
371
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
372
+ f" {type(callback_steps)}."
373
+ )
374
+
375
+ if prompt is not None and prompt_embeds is not None:
376
+ raise ValueError(
377
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
378
+ " only forward one of the two."
379
+ )
380
+ elif prompt is None and (prompt_embeds is None):
381
+ raise ValueError(
382
+ "Provide either `prompt`, or `prompt_embeds`. Cannot leave"
383
+ "`prompt` undefined without specifying `prompt_embeds`."
384
+ )
385
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
386
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
387
+
388
+ if negative_prompt is not None and negative_prompt_embeds is not None:
389
+ raise ValueError(
390
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
391
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
392
+ )
393
+
394
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
395
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
396
+ raise ValueError(
397
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
398
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
399
+ f" {negative_prompt_embeds.shape}."
400
+ )
401
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
402
+ raise ValueError(
403
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
404
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
405
+ )
406
+
407
+ if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
408
+ raise ValueError(
409
+ "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
410
+ )
411
+
412
+ if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
413
+ raise ValueError(
414
+ f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
415
+ "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
416
+ )
417
+
418
+ def prepare_latents(
419
+ self,
420
+ batch_size,
421
+ num_channels_vae,
422
+ sample_size,
423
+ dtype,
424
+ device,
425
+ generator,
426
+ latents=None,
427
+ initial_audio_waveforms=None,
428
+ num_waveforms_per_prompt=None,
429
+ audio_channels=None,
430
+ ):
431
+ shape = (batch_size, num_channels_vae, sample_size)
432
+ if isinstance(generator, list) and len(generator) != batch_size:
433
+ raise ValueError(
434
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
435
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
436
+ )
437
+
438
+ if latents is None:
439
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
440
+ else:
441
+ latents = latents.to(device)
442
+
443
+ # scale the initial noise by the standard deviation required by the scheduler
444
+ latents = latents * self.scheduler.init_noise_sigma
445
+
446
+ # encode the initial audio for use by the model
447
+ if initial_audio_waveforms is not None:
448
+ # check dimension
449
+ if initial_audio_waveforms.ndim == 2:
450
+ initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
451
+ elif initial_audio_waveforms.ndim != 3:
452
+ raise ValueError(
453
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
454
+ )
455
+
456
+ audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
457
+ audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
458
+
459
+ # check num_channels
460
+ if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
461
+ initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
462
+ elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
463
+ initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
464
+
465
+ if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
466
+ raise ValueError(
467
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
468
+ )
469
+
470
+ # crop or pad
471
+ audio_length = initial_audio_waveforms.shape[-1]
472
+ if audio_length < audio_vae_length:
473
+ logger.warning(
474
+ f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
475
+ )
476
+ elif audio_length > audio_vae_length:
477
+ logger.warning(
478
+ f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
479
+ )
480
+
481
+ audio = initial_audio_waveforms.new_zeros(audio_shape)
482
+ audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
483
+
484
+ encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
485
+ encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
486
+ latents = encoded_audio + latents
487
+ return latents
488
+
489
+ @torch.no_grad()
490
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
491
+ def __call__(
492
+ self,
493
+ extracted_condition_audio = None,
494
+ extracted_condition = None,
495
+ prompt: Union[str, List[str]] = None,
496
+ audio_end_in_s: Optional[float] = None,
497
+ audio_start_in_s: Optional[float] = 0.0,
498
+ num_inference_steps: int = 100,
499
+ guidance_scale_text: float = 7.0,
500
+ guidance_scale_con: float = 2.0,
501
+ guidance_scale_audio: float = 2.0,
502
+ negative_prompt: Optional[Union[str, List[str]]] = None,
503
+ num_waveforms_per_prompt: Optional[int] = 1,
504
+ eta: float = 0.0,
505
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
506
+ latents: Optional[torch.Tensor] = None,
507
+ initial_audio_waveforms: Optional[torch.Tensor] = None,
508
+ initial_audio_sampling_rate: Optional[torch.Tensor] = None,
509
+ prompt_embeds: Optional[torch.Tensor] = None,
510
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
511
+ attention_mask: Optional[torch.LongTensor] = None,
512
+ negative_attention_mask: Optional[torch.LongTensor] = None,
513
+ return_dict: bool = True,
514
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
515
+ callback_steps: Optional[int] = 1,
516
+ output_type: Optional[str] = "pt",
517
+ ):
518
+ r"""
519
+ The call function to the pipeline for generation.
520
+
521
+ Args:
522
+ prompt (`str` or `List[str]`, *optional*):
523
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
524
+ audio_end_in_s (`float`, *optional*, defaults to 47.55):
525
+ Audio end index in seconds.
526
+ audio_start_in_s (`float`, *optional*, defaults to 0):
527
+ Audio start index in seconds.
528
+ num_inference_steps (`int`, *optional*, defaults to 100):
529
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
530
+ expense of slower inference.
531
+ guidance_scale (`float`, *optional*, defaults to 7.0):
532
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
533
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
534
+ negative_prompt (`str` or `List[str]`, *optional*):
535
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
536
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
537
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
538
+ The number of waveforms to generate per prompt.
539
+ eta (`float`, *optional*, defaults to 0.0):
540
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
541
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
542
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
543
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
544
+ generation deterministic.
545
+ latents (`torch.Tensor`, *optional*):
546
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
547
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
548
+ tensor is generated by sampling using the supplied random `generator`.
549
+ initial_audio_waveforms (`torch.Tensor`, *optional*):
550
+ Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
551
+ `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
552
+ corresponds to the number of prompts passed to the model.
553
+ initial_audio_sampling_rate (`int`, *optional*):
554
+ Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
555
+ prompt_embeds (`torch.Tensor`, *optional*):
556
+ Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
557
+ *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
558
+ argument.
559
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
560
+ Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
561
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
562
+ `negative_prompt` input argument.
563
+ attention_mask (`torch.LongTensor`, *optional*):
564
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
565
+ be computed from `prompt` input argument.
566
+ negative_attention_mask (`torch.LongTensor`, *optional*):
567
+ Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
568
+ return_dict (`bool`, *optional*, defaults to `True`):
569
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
570
+ plain tuple.
571
+ callback (`Callable`, *optional*):
572
+ A function that calls every `callback_steps` steps during inference. The function is called with the
573
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
574
+ callback_steps (`int`, *optional*, defaults to 1):
575
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
576
+ every step.
577
+ output_type (`str`, *optional*, defaults to `"pt"`):
578
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
579
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
580
+ model (LDM) output.
581
+
582
+ Examples:
583
+
584
+ Returns:
585
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
586
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
587
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
588
+ """
589
+ # 0. Convert audio input length from seconds to latent length
590
+ downsample_ratio = self.vae.hop_length
591
+
592
+ max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
593
+ if audio_end_in_s is None:
594
+ audio_end_in_s = max_audio_length_in_s
595
+
596
+ if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
597
+ raise ValueError(
598
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
599
+ )
600
+
601
+ waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
602
+ waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
603
+ waveform_length = int(self.transformer.config.sample_size) # * audio_end_in_s / 47.554
604
+ # waveform_length = 646
605
+ # 1. Check inputs. Raise error if not correct
606
+ self.check_inputs(
607
+ prompt,
608
+ audio_start_in_s,
609
+ audio_end_in_s,
610
+ callback_steps,
611
+ negative_prompt,
612
+ prompt_embeds,
613
+ negative_prompt_embeds,
614
+ attention_mask,
615
+ negative_attention_mask,
616
+ initial_audio_waveforms,
617
+ initial_audio_sampling_rate,
618
+ )
619
+
620
+ # 2. Define call parameters
621
+ if prompt is not None and isinstance(prompt, str):
622
+ batch_size = 1
623
+ elif prompt is not None and isinstance(prompt, list):
624
+ batch_size = len(prompt)
625
+ else:
626
+ batch_size = prompt_embeds.shape[0]
627
+
628
+ device = self._execution_device
629
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
630
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
631
+ # corresponds to doing no classifier free guidance.
632
+ do_classifier_free_guidance = True
633
+
634
+ # 3. Encode input prompt
635
+ prompt_embeds = self.encode_prompt(
636
+ prompt,
637
+ device,
638
+ do_classifier_free_guidance,
639
+ negative_prompt,
640
+ prompt_embeds,
641
+ negative_prompt_embeds,
642
+ attention_mask,
643
+ negative_attention_mask,
644
+ )
645
+
646
+ # Encode duration
647
+ seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
648
+ audio_start_in_s,
649
+ audio_end_in_s,
650
+ device,
651
+ do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
652
+ batch_size,
653
+ )
654
+
655
+ # Create text_audio_duration_embeds and audio_duration_embeds
656
+ text_audio_duration_embeds = torch.cat(
657
+ [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
658
+ )
659
+
660
+ audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
661
+
662
+ # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
663
+ # to concatenate it to the embeddings
664
+ if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
665
+ negative_text_audio_duration_embeds = torch.zeros_like(
666
+ text_audio_duration_embeds, device=text_audio_duration_embeds.device
667
+ )
668
+ text_audio_duration_embeds = torch.cat(
669
+ [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
670
+ )
671
+ audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
672
+ # if condition is not None:
673
+ # condition_conditioning = condition_model(condition)
674
+ # condition_no_conditioning = condition_model(torch.full_like(condition, fill_value=0))
675
+ # extracted_condition = torch.cat([condition_no_conditioning, condition_no_conditioning, condition_conditioning], dim=0)
676
+
677
+ bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
678
+ # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
679
+ text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
680
+ text_audio_duration_embeds = text_audio_duration_embeds.view(
681
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
682
+ )
683
+
684
+ audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
685
+ audio_duration_embeds = audio_duration_embeds.view(
686
+ bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
687
+ )
688
+
689
+ # 4. Prepare timesteps
690
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
691
+ timesteps = self.scheduler.timesteps
692
+
693
+ # 5. Prepare latent variables
694
+ num_channels_vae = self.transformer.config.in_channels
695
+ latents = self.prepare_latents(
696
+ batch_size * num_waveforms_per_prompt,
697
+ num_channels_vae,
698
+ waveform_length,
699
+ text_audio_duration_embeds.dtype,
700
+ device,
701
+ generator,
702
+ latents,
703
+ initial_audio_waveforms,
704
+ num_waveforms_per_prompt,
705
+ audio_channels=self.vae.config.audio_channels,
706
+ )
707
+
708
+ # 6. Prepare extra step kwargs
709
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
710
+
711
+ # 7. Prepare rotary positional embedding
712
+ rotary_embedding = get_1d_rotary_pos_embed(
713
+ self.rotary_embed_dim,
714
+ latents.shape[2] + audio_duration_embeds.shape[1],
715
+ use_real=True,
716
+ repeat_interleave_real=False,
717
+ )
718
+
719
+ # 8. Denoising loop
720
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
721
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
722
+ for i, t in enumerate(timesteps):
723
+ # expand the latents if we are doing classifier free guidance
724
+ latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
725
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
726
+ with autocast():
727
+ noise_pred = self.transformer(
728
+ latent_model_input,
729
+ t.unsqueeze(0),
730
+ encoder_hidden_states=text_audio_duration_embeds,
731
+ encoder_hidden_states_con=extracted_condition,
732
+ encoder_hidden_states_audio = extracted_condition_audio,
733
+ global_hidden_states=audio_duration_embeds,
734
+ rotary_embedding=rotary_embedding,
735
+ return_dict=False,
736
+ )[0]
737
+ # transformer_weight_dtype = next(self.transformer.parameters()).dtype
738
+ # print("transformer_weight_dtype",transformer_weight_dtype)
739
+ # noise_pred = noise_pred.half()
740
+
741
+ # perform guidance
742
+ if do_classifier_free_guidance:
743
+ noise_pred_uncond, noise_pred_text, noise_pred_both, noise_pred_both_audio = noise_pred.chunk(4)
744
+ noise_pred = noise_pred_uncond + guidance_scale_text * (noise_pred_text - noise_pred_uncond) + guidance_scale_con * (noise_pred_both - noise_pred_text) \
745
+ + guidance_scale_audio * (noise_pred_both_audio - noise_pred_both)
746
+ # print("guidance_scale_audio", guidance_scale_audio)
747
+ # if do_classifier_free_guidance:
748
+ # noise_pred_uncond, noise_pred_text, noise_pred_both, noise_pred_both_audio = noise_pred.chunk(4)
749
+ # noise_pred_uncond_no_mask, noise_pred_text_no_mask, noise_pred_both_no_mask, noise_pred_both_audio_no_mask = noise_pred_uncond[:,:,:323], noise_pred_text[:,:,:323], noise_pred_both[:,:,:323], noise_pred_both_audio[:,:,:323]
750
+ # noise_pred_no_mask = noise_pred_uncond_no_mask + 7.0 * (noise_pred_text_no_mask - noise_pred_uncond_no_mask) + guidance_scale_con * (noise_pred_both_no_mask - noise_pred_text_no_mask) \
751
+ # + 1.5 * (noise_pred_both_audio_no_mask - noise_pred_both_no_mask)
752
+ # noise_pred_uncond_mask, noise_pred_text_mask, noise_pred_both_mask, noise_pred_both_audio_mask = noise_pred_uncond[:,:,323:], noise_pred_text[:,:,323:], noise_pred_both[:,:,323:], noise_pred_both_audio[:,:,323:]
753
+ # noise_pred_mask = noise_pred_uncond_mask + 7.0 * (noise_pred_text_mask - noise_pred_uncond_mask) + guidance_scale_con * (noise_pred_both_mask - noise_pred_text_mask) \
754
+ # + 4.5 * (noise_pred_both_audio_mask - noise_pred_both_mask)
755
+ # noise_pred = torch.concat((noise_pred_no_mask, noise_pred_mask), dim=2)
756
+ # compute the previous noisy sample x_t -> x_t-1
757
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
758
+
759
+ # call the callback, if provided
760
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
761
+ progress_bar.update()
762
+ if callback is not None and i % callback_steps == 0:
763
+ step_idx = i // getattr(self.scheduler, "order", 1)
764
+ callback(step_idx, t, latents)
765
+
766
+ # 9. Post-processing
767
+ if not output_type == "latent":
768
+ with autocast():
769
+ audio = self.vae.decode(latents).sample
770
+ else:
771
+ return AudioPipelineOutput(audios=latents)
772
+
773
+ audio = audio[:, :, waveform_start:waveform_end]
774
+
775
+ if output_type == "np":
776
+ audio = audio.cpu().float().numpy()
777
+
778
+ self.maybe_free_model_hooks()
779
+
780
+ if not return_dict:
781
+ return (audio,)
782
+
783
+ return AudioPipelineOutput(audios=audio)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/fundwotsai2001/[email protected]
2
+ git+https://github.com/YianLai0327/madmom.git
3
+ torch
4
+ torchaudio
5
+ soundfile
6
+ accelerate
7
+ transformers==4.46.1
8
+ matplotlib
9
+ librosa
10
+ torchsde
11
+ gdown
12
+ wandb
13
+ gradio
utils/extract_conditions.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import numpy as np
3
+ from scipy.signal import savgol_filter
4
+ import librosa
5
+ import torch
6
+ import torchaudio
7
+ import scipy.signal as signal
8
+ from torchaudio import transforms as T
9
+ import torch
10
+ import torchaudio
11
+ import librosa
12
+ import numpy as np
13
+
14
+
15
+ def compute_melody_v2(stereo_audio: torch.Tensor) -> np.ndarray:
16
+ """
17
+ Args:
18
+ stereo_audio: torch.Tensor of shape (2, N), 其中 stereo_audio[0] 是左聲道,
19
+ stereo_audio[1] 是右聲道。
20
+ sr: 取樣率 (sampling rate)。
21
+ Returns:
22
+ c: np.ndarray of shape (8, T_frames),
23
+ 每一列代表: [L1, R1, L2, R2, L3, R3, L4, R4](按 frame 交錯),
24
+ 且每個值都 ∈ {1, 2, …, 128},對應 CQT 的頻率 bin。
25
+ """
26
+ audio, sr = torchaudio.load(stereo_audio)
27
+ # 1. 先針對左、右聲道分別計算 CQT (128 bins),回傳 cqt_db 形狀都是 (128, T_frames)
28
+ cqt_left = compute_music_represent(audio[0], sr) # shape: (128, T_frames)
29
+ cqt_right = compute_music_represent(audio[1], sr) # shape: (128, T_frames)
30
+
31
+ # 2. 取得時框 (frame) 數量
32
+ # 注意:librosa.cqt 的輸出 cqt_db 對應的「時框數」就是第二維度
33
+ T_frames = cqt_left.shape[1]
34
+
35
+ # 3. 預先配置輸出矩陣 c,dtype 用 int,shape = (8, T_frames)
36
+ c = np.zeros((8, T_frames), dtype=np.int32)
37
+
38
+ # 4. 逐一 frame 處理:對每個 frame 的 128 維度做 top-4
39
+ for j in range(T_frames):
40
+ # 4.1 取出當前時框的左、右聲道 CQT 能量(分貝值)
41
+ col_L = cqt_left[:, j] # shape: (128,)
42
+ col_R = cqt_right[:, j] # shape: (128,)
43
+
44
+ # 4.2 用 numpy.argsort 找到「前 4 大」的索引
45
+ # np.argsort 預設是從小到大排序,所以取最後 4 個,再反轉取大到小
46
+ idx4_L = np.argsort(col_L)[-4:][::-1] # 0-based, 長度=4
47
+ idx4_R = np.argsort(col_R)[-4:][::-1] # 0-based, 長度=4
48
+
49
+ # 4.3 轉成 1-based(因為題意寫 pixel ∈ {1,2,…,128})
50
+ idx4_L = idx4_L + 1 # 現在範圍是 1..128
51
+ idx4_R = idx4_R + 1
52
+
53
+ # 4.4 交錯填入 c 的第 j 欄
54
+ # 我們希望 c[:, j] = [L1, R1, L2, R2, L3, R3, L4, R4]
55
+ for k in range(4):
56
+ c[2 * k , j] = idx4_L[k]
57
+ c[2 * k + 1, j] = idx4_R[k]
58
+
59
+ return c[:,:4097]
60
+
61
+
62
+ def compute_music_represent(audio, sr):
63
+ filter_y = torchaudio.functional.highpass_biquad(audio, sr, 261.6)
64
+ fmin = librosa.midi_to_hz(0)
65
+ cqt_spec = librosa.cqt(y=filter_y.numpy(), fmin=fmin, sr=sr, n_bins=128, bins_per_octave=12, hop_length=512)
66
+ cqt_db = librosa.amplitude_to_db(np.abs(cqt_spec), ref=np.max)
67
+ return cqt_db
68
+
69
+ def keep_top4_pitches_per_channel(cqt_db):
70
+ """
71
+ cqt_db is assumed to have shape: (2, 128, time_frames).
72
+ We return a combined 2D array of shape (128, time_frames)
73
+ where only the top 4 pitch bins in each channel are kept
74
+ (for a total of up to 8 bins per time frame).
75
+ """
76
+ # Parse shapes
77
+ num_channels, num_bins, num_frames = cqt_db.shape
78
+
79
+ # Initialize an output array that combines both channels
80
+ # and has zeros everywhere initially
81
+ combined = np.zeros((num_bins, num_frames), dtype=cqt_db.dtype)
82
+
83
+ for ch in range(num_channels):
84
+ for t in range(num_frames):
85
+ # Find the top 4 pitch bins for this channel at frame t
86
+ # argsort sorts ascending; we take the last 4 indices for top 4
87
+ top4_indices = np.argsort(cqt_db[ch, :, t])[-4:]
88
+
89
+ # Copy their values into the combined array
90
+ # We add to it in case there's overlap between channels
91
+ combined[top4_indices, t] = 1
92
+ return combined
93
+ def compute_melody(input_audio):
94
+ # Initialize parameters
95
+ sample_rate = 44100
96
+
97
+ # Load audio file
98
+ wav, sr = torchaudio.load(input_audio)
99
+ if sr != sample_rate:
100
+ resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
101
+ wav = resample(wav)
102
+ # Truncate or pad the audio to 2097152 samples
103
+ target_length = 2097152
104
+ if wav.size(1) > target_length:
105
+ # Truncate the audio if it is longer than the target length
106
+ wav = wav[:, :target_length]
107
+ elif wav.size(1) < target_length:
108
+ # Pad the audio with zeros if it is shorter than the target length
109
+ padding = target_length - wav.size(1)
110
+ wav = torch.cat([wav, torch.zeros(wav.size(0), padding)], dim=1)
111
+ melody = compute_music_represent(wav, 44100)
112
+ melody = keep_top4_pitches_per_channel(melody)
113
+ return melody
114
+
115
+ def compute_dynamics(audio_file, hop_length=160, target_sample_rate=44100, cut=True):
116
+ """
117
+ Compute the dynamics curve for a given audio file.
118
+
119
+ Args:
120
+ audio_file (str): Path to the audio file.
121
+ window_length (int): Length of FFT window for computing the spectrogram.
122
+ hop_length (int): Number of samples between successive frames.
123
+ smoothing_window (int): Length of the Savitzky-Golay filter window.
124
+ polyorder (int): Polynomial order of the Savitzky-Golay filter.
125
+
126
+ Returns:
127
+ dynamics_curve (numpy.ndarray): The computed dynamic values in dB.
128
+ """
129
+ # Load audio file
130
+ waveform, original_sample_rate = torchaudio.load(audio_file)
131
+ if original_sample_rate != target_sample_rate:
132
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
133
+ waveform = resampler(waveform)
134
+ if cut:
135
+ waveform = waveform[:, :2097152]
136
+ # Ensure waveform has a single channel (e.g., select the first channel if multi-channel)
137
+ waveform = waveform.mean(dim=0, keepdim=True) # Mix all channels into one
138
+ waveform = waveform.clamp(-1, 1).numpy()
139
+
140
+ S = np.abs(librosa.stft(waveform, n_fft=1024, hop_length=hop_length))
141
+ mel_filter_bank = librosa.filters.mel(sr=target_sample_rate, n_fft=1024, n_mels=64, fmin=0, fmax=8000)
142
+ S = np.dot(mel_filter_bank, S)
143
+ energy = np.sum(S**2, axis=0)
144
+ dynamics_db = np.clip(energy, 1e-6, None)
145
+ dynamics_db = librosa.amplitude_to_db(energy, ref=np.max).squeeze(0)
146
+ smoothed_dynamics = savgol_filter(dynamics_db, window_length=279, polyorder=1)
147
+ # print(smoothed_dynamics.shape)
148
+ return smoothed_dynamics
149
+ def extract_melody_one_hot(audio_path,
150
+ sr=44100,
151
+ cutoff=261.2,
152
+ win_length=2048,
153
+ hop_length=256):
154
+ """
155
+ Extract a one-hot chromagram-based melody from an audio file (mono).
156
+
157
+ Parameters:
158
+ -----------
159
+ audio_path : str
160
+ Path to the input audio file.
161
+ sr : int
162
+ Target sample rate to resample the audio (default: 44100).
163
+ cutoff : float
164
+ The high-pass filter cutoff frequency in Hz (default: Middle C ~ 261.2 Hz).
165
+ win_length : int
166
+ STFT window length for the chromagram (default: 2048).
167
+ hop_length : int
168
+ STFT hop length for the chromagram (default: 256).
169
+
170
+ Returns:
171
+ --------
172
+ one_hot_chroma : np.ndarray, shape=(12, n_frames)
173
+ One-hot chromagram of the most prominent pitch class per frame.
174
+ """
175
+ # ---------------------------------------------------------
176
+ # 1. Load audio (Torchaudio => shape: (channels, samples))
177
+ # ---------------------------------------------------------
178
+ audio, in_sr = torchaudio.load(audio_path)
179
+
180
+ # Convert to mono by averaging channels: shape => (samples,)
181
+ audio_mono = audio.mean(dim=0)
182
+
183
+ # Resample if necessary
184
+ if in_sr != sr:
185
+ resample_tf = T.Resample(orig_freq=in_sr, new_freq=sr)
186
+ audio_mono = resample_tf(audio_mono)
187
+
188
+ # Convert torch.Tensor => NumPy array: shape (samples,)
189
+ y = audio_mono.numpy()
190
+
191
+ # ---------------------------------------------------------
192
+ # 2. Design & apply a high-pass filter (Butterworth, order=2)
193
+ # ---------------------------------------------------------
194
+ nyquist = 0.5 * sr
195
+ norm_cutoff = cutoff / nyquist
196
+ b, a = signal.butter(N=2, Wn=norm_cutoff, btype='high', analog=False)
197
+
198
+ # filtfilt expects shape (n_samples,) for 1D
199
+ y_hp = signal.filtfilt(b, a, y)
200
+
201
+ # ---------------------------------------------------------
202
+ # 3. Compute the chromagram (librosa => shape: (12, n_frames))
203
+ # ---------------------------------------------------------
204
+ chroma = librosa.feature.chroma_stft(
205
+ y=y_hp,
206
+ sr=sr,
207
+ n_fft=win_length, # Usually >= win_length
208
+ win_length=win_length,
209
+ hop_length=hop_length
210
+ )
211
+
212
+ # ---------------------------------------------------------
213
+ # 4. Convert chromagram to one-hot via argmax along pitch classes
214
+ # ---------------------------------------------------------
215
+ # pitch_class_idx => shape=(n_frames,)
216
+ pitch_class_idx = np.argmax(chroma, axis=0)
217
+
218
+ # Make a zero array of the same shape => (12, n_frames)
219
+ one_hot_chroma = np.zeros_like(chroma)
220
+
221
+ # For each frame (column in chroma), set the argmax row to 1
222
+ one_hot_chroma[pitch_class_idx, np.arange(chroma.shape[1])] = 1.0
223
+
224
+ return one_hot_chroma
225
+ def evaluate_f1_rhythm(input_timestamps, generated_timestamps, tolerance=0.07):
226
+ """
227
+ Evaluates precision, recall, and F1-score for beat/downbeat timestamp alignment.
228
+
229
+ Args:
230
+ input_timestamps (ndarray): 2D array of shape [n, 2], where column 0 contains timestamps.
231
+ generated_timestamps (ndarray): 2D array of shape [m, 2], where column 0 contains timestamps.
232
+ tolerance (float): Alignment tolerance in seconds (default: 70ms).
233
+
234
+ Returns:
235
+ tuple: (precision, recall, f1)
236
+ """
237
+ # Extract and sort timestamps
238
+ input_timestamps = np.asarray(input_timestamps)
239
+ generated_timestamps = np.asarray(generated_timestamps)
240
+
241
+ # If you only need the first column
242
+ if input_timestamps.size > 0:
243
+ input_timestamps = input_timestamps[:, 0]
244
+ input_timestamps.sort()
245
+ else:
246
+ input_timestamps = np.array([])
247
+
248
+ if generated_timestamps.size > 0:
249
+ generated_timestamps = generated_timestamps[:, 0]
250
+ generated_timestamps.sort()
251
+ else:
252
+ generated_timestamps = np.array([])
253
+
254
+ # Handle empty cases
255
+ # Case 1: Both are empty
256
+ if len(input_timestamps) == 0 and len(generated_timestamps) == 0:
257
+ # You could argue everything is correct since there's nothing to detect,
258
+ # but returning all zeros is a common convention.
259
+ return 0.0, 0.0, 0.0
260
+
261
+ # Case 2: No ground-truth timestamps, but predictions exist
262
+ if len(input_timestamps) == 0 and len(generated_timestamps) > 0:
263
+ # All predictions are false positives => tp=0, fp = len(generated_timestamps)
264
+ # => precision=0, recall is undefined (tp+fn=0), typically we treat recall=0
265
+ return 0.0, 0.0, 0.0
266
+
267
+ # Case 3: Ground-truth timestamps exist, but no predictions
268
+ if len(input_timestamps) > 0 and len(generated_timestamps) == 0:
269
+ # Everything in input_timestamps is a false negative => tp=0, fn = len(input_timestamps)
270
+ # => recall=0, precision is undefined (tp+fp=0), typically we treat precision=0
271
+ return 0.0, 0.0, 0.0
272
+
273
+ # If we get here, both arrays are non-empty
274
+ tp = 0
275
+ fp = 0
276
+
277
+ # Track matched ground-truth timestamps
278
+ matched_inputs = np.zeros(len(input_timestamps), dtype=bool)
279
+
280
+ for gen_ts in generated_timestamps:
281
+ # Calculate absolute differences to each reference timestamp
282
+ diffs = np.abs(input_timestamps - gen_ts)
283
+ # Find index of the closest input timestamp
284
+ min_diff_idx = np.argmin(diffs)
285
+
286
+ # Check if that difference is within tolerance and unmatched
287
+ if diffs[min_diff_idx] < tolerance and not matched_inputs[min_diff_idx]:
288
+ tp += 1
289
+ matched_inputs[min_diff_idx] = True
290
+ else:
291
+ fp += 1 # no suitable match found or closest was already matched
292
+
293
+ # Remaining unmatched input timestamps are false negatives
294
+ fn = np.sum(~matched_inputs)
295
+
296
+ # Compute precision, recall, f1
297
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
298
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
299
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
300
+
301
+ return precision, recall, f1
utils/feature_extractor.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class dynamics_extractor_full_stereo(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.conv1d_1 = nn.Conv1d(2, 16, kernel_size=3, padding=1, stride=2)
8
+ self.conv1d_2 = nn.Conv1d(16, 16, kernel_size=3, padding=1)
9
+ self.conv1d_3 = nn.Conv1d(16, 128, kernel_size=3, padding=1, stride=2)
10
+ self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
11
+ self.conv1d_5 = nn.Conv1d(128, 256, kernel_size=3, padding=1, stride=2)
12
+ def forward(self, x):
13
+ # original shape: (batchsize, 1, 8280)
14
+ # x = x.unsqueeze(1) # shape: (batchsize, 1, 8280)
15
+ x = self.conv1d_1(x) # shape: (batchsize, 16, 4140)
16
+ x = F.silu(x)
17
+ x = self.conv1d_2(x) # shape: (batchsize, 16, 4140)
18
+ x = F.silu(x)
19
+ x = self.conv1d_3(x) # shape: (batchsize, 128, 2070)
20
+ x = F.silu(x)
21
+ x = self.conv1d_4(x) # shape: (batchsize, 128, 2070)
22
+ x = F.silu(x)
23
+ x = self.conv1d_5(x) # shape: (batchsize, 192, 1035)
24
+ return x
25
+ class melody_extractor_full_mono(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.conv1d_1 = nn.Conv1d(128, 256, kernel_size=3, padding=0, stride=2)
29
+ self.conv1d_2 = nn.Conv1d(256, 256, kernel_size=3, padding=1)
30
+ self.conv1d_3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, stride=2)
31
+ self.conv1d_4 = nn.Conv1d(512, 512, kernel_size=3, padding=1)
32
+ self.conv1d_5 = nn.Conv1d(512, 768, kernel_size=3, padding=1)
33
+ def forward(self, x):
34
+ # original shape: (batchsize, 12, 1296)
35
+ x = self.conv1d_1(x)# shape: (batchsize, 64, 2048)
36
+ x = F.silu(x)
37
+ x = self.conv1d_2(x) # shape: (batchsize, 64, 2048)
38
+ x = F.silu(x)
39
+ x = self.conv1d_3(x) # shape: (batchsize, 128, 1024)
40
+ x = F.silu(x)
41
+ x = self.conv1d_4(x) # shape: (batchsize, 128, 1024)
42
+ x = F.silu(x)
43
+ x = self.conv1d_5(x) # shape: (batchsize, 768, 1024)
44
+ return x
45
+ class melody_extractor_mono(nn.Module):
46
+ def __init__(self):
47
+ super().__init__()
48
+ self.conv1d_1 = nn.Conv1d(128, 128, kernel_size=3, padding=0, stride=2)
49
+ self.conv1d_2 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
50
+ self.conv1d_3 = nn.Conv1d(192, 192, kernel_size=3, padding=1)
51
+ def forward(self, x):
52
+ # original shape: (batchsize, 12, 1296)
53
+ x = self.conv1d_1(x)# shape: (batchsize, 64, 2048)
54
+ x = F.silu(x)
55
+ x = self.conv1d_2(x) # shape: (batchsize, 64, 2048)
56
+ x = F.silu(x)
57
+ x = self.conv1d_3(x) # shape: (batchsize, 128, 1024)
58
+ return x
59
+
60
+ class melody_extractor_full_stereo(nn.Module):
61
+ def __init__(self):
62
+ super().__init__()
63
+ self.embed = nn.Embedding(num_embeddings=129, embedding_dim=48)
64
+
65
+ # Four Conv1d layers, each with kernel_size=3, padding=1:
66
+ self.conv1 = nn.Conv1d(384, 384, kernel_size=3, padding=1)
67
+ self.conv2 = nn.Conv1d(384, 768, kernel_size=3, padding=1)
68
+ self.conv3 = nn.Conv1d(768, 768, kernel_size=3, padding=1)
69
+
70
+ def forward(self, melody_idxs):
71
+ # melody_idxs: LongTensor of shape (B, 8, 4096)
72
+ B, eight, L = melody_idxs.shape # L == 4096
73
+
74
+ # 1) Embed:
75
+ # (B, 8, 4096) → (B, 8, 4096, 48)
76
+ embedded = self.embed(melody_idxs)
77
+
78
+ # 2) Permute & reshape → (B, 8*48, 4096) = (B, 384, 4096)
79
+ x = embedded.permute(0, 1, 3, 2) # (B, 8, 48, 4096)
80
+ x = x.reshape(B, eight * 48, L) # (B, 384, 4096)
81
+
82
+ # 3) Conv1 → (B, 384, 4096)
83
+ x = F.silu(self.conv1(x))
84
+
85
+ # 4) Conv2 → (B, 768, 4096)
86
+ x = F.silu(self.conv2(x))
87
+
88
+ # 5) Conv3 → (B, 768, 4096)
89
+ x = F.silu(self.conv3(x))
90
+
91
+ # Now x is (B, 1536, 4096) and can be sent on to whatever comes next
92
+ return x
93
+ class melody_extractor_stereo(nn.Module):
94
+ def __init__(self):
95
+ super().__init__()
96
+ self.embed = nn.Embedding(num_embeddings=129, embedding_dim=4)
97
+
98
+ # Four Conv1d layers, each with kernel_size=3, padding=1:
99
+ self.conv1 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
100
+ self.conv2 = nn.Conv1d(64, 64, kernel_size=3, padding=0, stride=2)
101
+ self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
102
+ self.conv4 = nn.Conv1d(128, 128, kernel_size=3, padding=1, stride=2)
103
+ self.conv5 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
104
+
105
+ def forward(self, melody_idxs):
106
+ # melody_idxs: LongTensor of shape (B, 8, 4096)
107
+ B, eight, L = melody_idxs.shape # L == 4096
108
+
109
+ # 1) Embed:
110
+ # (B, 8, 4096) → (B, 8, 4096, 4)
111
+ embedded = self.embed(melody_idxs)
112
+
113
+ # 2) Permute & reshape → (B, 8*4, 4096) = (B, 32, 4096)
114
+ x = embedded.permute(0, 1, 3, 2) # (B, 8, 4, 4096)
115
+ x = x.reshape(B, eight * 4, L) # (B, 32, 4096)
116
+
117
+ # 3) Conv1 → (B, 384, 4096)
118
+ x = F.silu(self.conv1(x))
119
+
120
+ # 4) Conv2 → (B, 768, 4096)
121
+ x = F.silu(self.conv2(x))
122
+
123
+ # 5) Conv3 → (B, 768, 4096)
124
+ x = F.silu(self.conv3(x))
125
+
126
+ x = F.silu(self.conv4(x))
127
+
128
+ x = F.silu(self.conv5(x))
129
+
130
+ # Now x is (B, 1536, 4096) and can be sent on to whatever comes next
131
+ return x
132
+
133
+ class dynamics_extractor(nn.Module):
134
+ def __init__(self):
135
+ super().__init__()
136
+ self.conv1d_1 = nn.Conv1d(1, 16, kernel_size=3, padding=1, stride=2)
137
+ self.conv1d_2 = nn.Conv1d(16, 16, kernel_size=3, padding=1)
138
+ self.conv1d_3 = nn.Conv1d(16, 128, kernel_size=3, padding=1, stride=2)
139
+ self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
140
+ self.conv1d_5 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
141
+ def forward(self, x):
142
+ # original shape: (batchsize, 1, 8280)
143
+ # x = x.unsqueeze(1) # shape: (batchsize, 1, 8280)
144
+ x = self.conv1d_1(x) # shape: (batchsize, 16, 4140)
145
+ x = F.silu(x)
146
+ x = self.conv1d_2(x) # shape: (batchsize, 16, 4140)
147
+ x = F.silu(x)
148
+ x = self.conv1d_3(x) # shape: (batchsize, 128, 2070)
149
+ x = F.silu(x)
150
+ x = self.conv1d_4(x) # shape: (batchsize, 128, 2070)
151
+ x = F.silu(x)
152
+ x = self.conv1d_5(x) # shape: (batchsize, 192, 1035)
153
+ return x
154
+ class rhythm_extractor(nn.Module):
155
+ def __init__(self):
156
+ super().__init__()
157
+ self.conv1d_1 = nn.Conv1d(2, 16, kernel_size=3, padding=1)
158
+ self.conv1d_2 = nn.Conv1d(16, 64, kernel_size=3, padding=1)
159
+ self.conv1d_3 = nn.Conv1d(64, 128, kernel_size=3, padding=1, stride=2)
160
+ self.conv1d_4 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
161
+ self.conv1d_5 = nn.Conv1d(128, 192, kernel_size=3, padding=1, stride=2)
162
+ def forward(self, x):
163
+ # original shape: (batchsize, 2, 3000)
164
+ x = self.conv1d_1(x)# shape: (batchsize, 64, 3000)
165
+ x = F.silu(x)
166
+ x = self.conv1d_2(x) # shape: (batchsize, 64, 3000)
167
+ x = F.silu(x)
168
+ x = self.conv1d_3(x) # shape: (batchsize, 128, 1500)
169
+ x = F.silu(x)
170
+ x = self.conv1d_4(x) # shape: (batchsize, 128, 1500)
171
+ x = F.silu(x)
172
+ x = self.conv1d_5(x) # shape:
173
+ return x
utils/stable_audio_dataset_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from typing import Tuple
6
+ import torchaudio
7
+ import torch.nn.functional as F
8
+ from torchaudio import transforms as T
9
+
10
+ def load_audio_file(filename, target_sr=44100, target_samples=2097152):
11
+ try:
12
+ audio, in_sr = torchaudio.load(filename)
13
+ # Resample if necessary
14
+ if in_sr != target_sr:
15
+ resampler = T.Resample(in_sr, target_sr)
16
+ audio = resampler(audio)
17
+ augs = torch.nn.Sequential(
18
+ PhaseFlipper(),
19
+ )
20
+ audio = augs(audio)
21
+ audio = audio.clamp(-1, 1)
22
+ encoding = torch.nn.Sequential(
23
+ Stereo(),
24
+ )
25
+ audio = encoding(audio)
26
+ # audio.shape is [channels, samples]
27
+ num_samples = audio.shape[-1]
28
+
29
+ # if num_samples < target_samples:
30
+ # # Pad if it's too short
31
+ # pad_amount = target_samples - num_samples
32
+ # # Zero-pad at the end (or randomly if you prefer)
33
+ # audio = F.pad(audio, (0, pad_amount))
34
+ # print(f"pad {pad_amount}")
35
+ # else:
36
+ audio = audio[:, :target_samples]
37
+ return audio
38
+ except RuntimeError:
39
+ print(f"Failed to decode audio file: {filename}")
40
+ return None
41
+ class PadCrop(nn.Module):
42
+ def __init__(self, n_samples, randomize=True):
43
+ super().__init__()
44
+ self.n_samples = n_samples
45
+ self.randomize = randomize
46
+
47
+ def __call__(self, signal):
48
+ n, s = signal.shape
49
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
50
+ end = start + self.n_samples
51
+ output = signal.new_zeros([n, self.n_samples])
52
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
53
+ return output
54
+
55
+ class PadCrop_Normalized_T(nn.Module):
56
+
57
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
58
+
59
+ super().__init__()
60
+
61
+ self.n_samples = n_samples
62
+ self.sample_rate = sample_rate
63
+ self.randomize = randomize
64
+
65
+ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
66
+
67
+ n_channels, n_samples = source.shape
68
+
69
+ # If the audio is shorter than the desired length, pad it
70
+ upper_bound = max(0, n_samples - self.n_samples)
71
+
72
+ # If randomize is False, always start at the beginning of the audio
73
+ offset = 0
74
+ if(self.randomize and n_samples > self.n_samples):
75
+ offset = random.randint(0, upper_bound)
76
+
77
+ # Calculate the start and end times of the chunk
78
+ t_start = offset / (upper_bound + self.n_samples)
79
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
80
+
81
+ # Create the chunk
82
+ chunk = source.new_zeros([n_channels, self.n_samples])
83
+
84
+ # Copy the audio into the chunk
85
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
86
+
87
+ # Calculate the start and end times of the chunk in seconds
88
+ seconds_start = math.floor(offset / self.sample_rate)
89
+ seconds_total = math.ceil(n_samples / self.sample_rate)
90
+
91
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
92
+ padding_mask = torch.zeros([self.n_samples])
93
+ padding_mask[:min(n_samples, self.n_samples)] = 1
94
+
95
+
96
+ return (
97
+ chunk,
98
+ offset,
99
+ offset + self.n_samples,
100
+ seconds_start,
101
+ seconds_total,
102
+ padding_mask
103
+ )
104
+
105
+ class PhaseFlipper(nn.Module):
106
+ "Randomly invert the phase of a signal"
107
+ def __init__(self, p=0.5):
108
+ super().__init__()
109
+ self.p = p
110
+ def __call__(self, signal):
111
+ return -signal if (random.random() < self.p) else signal
112
+
113
+ class Mono(nn.Module):
114
+ def __call__(self, signal):
115
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
116
+
117
+ class Stereo(nn.Module):
118
+ def __call__(self, signal):
119
+ signal_shape = signal.shape
120
+ # Check if it's mono
121
+ if len(signal_shape) == 1: # s -> 2, s
122
+ signal = signal.unsqueeze(0).repeat(2, 1)
123
+ elif len(signal_shape) == 2:
124
+ if signal_shape[0] == 1: #1, s -> 2, s
125
+ signal = signal.repeat(2, 1)
126
+ elif signal_shape[0] > 2: #?, s -> 2,s
127
+ signal = signal[:2, :]
128
+
129
+ return signal