manoskary commited on
Commit
ef862da
·
1 Parent(s): c096349

Add TOKEN retrieval for Hugging Face API in setup function

Browse files
Files changed (1) hide show
  1. MuseControlLite_setup.py +8 -1
MuseControlLite_setup.py CHANGED
@@ -27,6 +27,9 @@ from utils.stable_audio_dataset_utils import load_audio_file
27
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
  import soundfile as sf
29
 
 
 
 
30
  # For zero initialized 1D CNN in the attention processor
31
  def zero_module(module):
32
  for p in module.parameters():
@@ -507,6 +510,7 @@ class StableAudioAttnProcessor2_0_rotary_double(torch.nn.Module):
507
  hidden_states = hidden_states / attn.rescale_output_factor
508
 
509
  return hidden_states
 
510
  def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
511
  """
512
  Setup AP-adapter pipeline with attention processors and load checkpoints.
@@ -525,7 +529,10 @@ def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
525
  else:
526
  from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
527
  attn_processor = StableAudioAttnProcessor2_0_rotary
528
- pipe = StableAudioPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", torch_dtype=weight_dtype)
 
 
 
529
  pipe.scheduler.config.sigma_max = config["sigma_max"]
530
  pipe.scheduler.config.sigma_min = config["sigma_min"]
531
  transformer = pipe.transformer
 
27
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
  import soundfile as sf
29
 
30
+
31
+ TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_TOKEN")
32
+
33
  # For zero initialized 1D CNN in the attention processor
34
  def zero_module(module):
35
  for p in module.parameters():
 
510
  hidden_states = hidden_states / attn.rescale_output_factor
511
 
512
  return hidden_states
513
+
514
  def setup_MuseControlLite(config, weight_dtype, transformer_ckpt):
515
  """
516
  Setup AP-adapter pipeline with attention processors and load checkpoints.
 
529
  else:
530
  from pipeline.stable_audio_multi_cfg_pipe import StableAudioPipeline
531
  attn_processor = StableAudioAttnProcessor2_0_rotary
532
+ pipe = StableAudioPipeline.from_pretrained(
533
+ "stabilityai/stable-audio-open-1.0",
534
+ torch_dtype=weight_dtype, token=TOKEN
535
+ )
536
  pipe.scheduler.config.sigma_max = config["sigma_max"]
537
  pipe.scheduler.config.sigma_min = config["sigma_min"]
538
  transformer = pipe.transformer