Spaces:
Running
on
Zero
Running
on
Zero
Add TOKEN retrieval for Hugging Face API in setup function
Browse files- MuseControlLite_setup.py +8 -1
MuseControlLite_setup.py
CHANGED
|
@@ -27,6 +27,9 @@ from utils.stable_audio_dataset_utils import load_audio_file
|
|
| 27 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
import soundfile as sf
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
# For zero initialized 1D CNN in the attention processor
|
| 31 |
def zero_module(module):
|
| 32 |
for p in module.parameters():
|
|
@@ -507,6 +510,7 @@ class StableAudioAttnProcessor2_0_rotary_double(torch.nn.Module):
|
|
| 507 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 508 |
|
| 509 |
return hidden_states
|
|
|
|
| 510 |
def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
|
| 511 |
"""
|
| 512 |
Setup AP-adapter pipeline with attention processors and load checkpoints.
|
|
@@ -525,7 +529,10 @@ def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
|
|
| 525 |
else:
|
| 526 |
from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
|
| 527 |
attn_processor = StableAudioAttnProcessor2_0_rotary
|
| 528 |
-
pipe = StableAudioPipeline.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 529 |
pipe.scheduler.config.sigma_max = config["sigma_max"]
|
| 530 |
pipe.scheduler.config.sigma_min = config["sigma_min"]
|
| 531 |
transformer = pipe.transformer
|
|
|
|
| 27 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
import soundfile as sf
|
| 29 |
|
| 30 |
+
|
| 31 |
+
TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_TOKEN")
|
| 32 |
+
|
| 33 |
# For zero initialized 1D CNN in the attention processor
|
| 34 |
def zero_module(module):
|
| 35 |
for p in module.parameters():
|
|
|
|
| 510 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 511 |
|
| 512 |
return hidden_states
|
| 513 |
+
|
| 514 |
def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
|
| 515 |
"""
|
| 516 |
Setup AP-adapter pipeline with attention processors and load checkpoints.
|
|
|
|
| 529 |
else:
|
| 530 |
from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
|
| 531 |
attn_processor = StableAudioAttnProcessor2_0_rotary
|
| 532 |
+
pipe = StableAudioPipeline.from_pretrained(
|
| 533 |
+
"stabilityai/stable-audio-open-1.0",
|
| 534 |
+
torch_dtype=weight_dtype, token=TOKEN
|
| 535 |
+
)
|
| 536 |
pipe.scheduler.config.sigma_max = config["sigma_max"]
|
| 537 |
pipe.scheduler.config.sigma_min = config["sigma_min"]
|
| 538 |
transformer = pipe.transformer
|