import copy import os import subprocess import time from typing import Dict, List, Optional, Tuple import spaces import gradio as gr import soundfile as sf import torch from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite from config_inference import get_config # Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100) TOTAL_AUDIO_SECONDS = 2097152 / 44100 DEFAULT_CONFIG = get_config() DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else "" OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs") CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"] CHECKPOINT_EXPECTED = [ "./checkpoints/woSDD-all/model_3.safetensors", "./checkpoints/woSDD-all/model_1.safetensors", "./checkpoints/woSDD-all/model_2.safetensors", "./checkpoints/woSDD-all/model.safetensors", ] os.makedirs(OUTPUT_ROOT, exist_ok=True) def ensure_checkpoints() -> None: """Download checkpoints with gdown if they are missing.""" if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED): return os.makedirs("checkpoints", exist_ok=True) try: subprocess.run( ["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"], check=True, ) except Exception as exc: # pylint: disable=broad-except # Do not crash the space on startup; inference will surface an error later if checkpoints are missing. print(f"[warn] Checkpoint download failed: {exc}") ensure_checkpoints() class ModelCache: """Lazy loader for heavy pipelines and condition extractors.""" def __init__(self) -> None: self.cache: Dict[Tuple, Dict] = {} def get(self, config: Dict) -> Dict: key = ( tuple(sorted(config["condition_type"])), config["weight_dtype"], float(config["ap_scale"]), config["apadapter"], ) if key in self.cache: return self.cache[key] weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32 if config["apadapter"]: condition_extractors, transformer_ckpt = initialize_condition_extractors(config) pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda") payload = { "pipe": pipe, "condition_extractors": condition_extractors, "weight_dtype": weight_dtype, "mode": "musecontrol", } else: from diffusers import StableAudioPipeline pipe = StableAudioPipeline.from_pretrained( "stabilityai/stable-audio-open-1.0", torch_dtype=weight_dtype, ).to("cuda") payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"} self.cache[key] = payload return payload model_cache = ModelCache() def _build_base_config() -> Dict: return copy.deepcopy(DEFAULT_CONFIG) def _create_run_dir() -> str: run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}") os.makedirs(run_dir, exist_ok=True) return run_dir def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]: if seed is None or seed == "": return None try: seed_int = int(seed) except (TypeError, ValueError): return None generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") return generator.manual_seed(seed_int) def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]: condition_type = condition_type or [] if "melody_stereo" in condition_type and any( choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono") ): raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.") return condition_type @spaces.GPU def run_inference( prompt_text: str, condition_audio: Optional[str], condition_type: Optional[List[str]], use_musecontrol: bool, no_text: bool, negative_text_prompt: str, guidance_scale_text: float, guidance_scale_con: float, guidance_scale_audio: float, denoise_step: int, weight_dtype: str, ap_scale: float, sigma_min: float, sigma_max: float, audio_mask_start: float, audio_mask_end: float, musical_mask_start: float, musical_mask_end: float, seed: Optional[float], ): condition_type = _validate_condition_choices(condition_type) config = _build_base_config() config.update( { "text": [prompt_text or ""], "audio_files": [condition_audio or ""], "apadapter": use_musecontrol, "no_text": bool(no_text), "negative_text_prompt": negative_text_prompt or "", "guidance_scale_text": float(guidance_scale_text), "guidance_scale_con": float(guidance_scale_con), "guidance_scale_audio": float(guidance_scale_audio), "denoise_step": int(denoise_step), "weight_dtype": weight_dtype, "ap_scale": float(ap_scale), "sigma_min": float(sigma_min), "sigma_max": float(sigma_max), "audio_mask_start_seconds": float(audio_mask_start or 0), "audio_mask_end_seconds": float(audio_mask_end or 0), "musical_attribute_mask_start_seconds": float(musical_mask_start or 0), "musical_attribute_mask_end_seconds": float(musical_mask_end or 0), "show_result_and_plt": False, } ) config["condition_type"] = condition_type if config["apadapter"]: if not condition_type: raise gr.Error("Select at least one condition type when using MuseControlLite.") if not condition_audio: raise gr.Error("Upload an audio file for conditioning.") if not os.path.exists(condition_audio): raise gr.Error("Condition audio file not found.") run_dir = _create_run_dir() config["output_dir"] = run_dir generator = _seed_to_generator(seed) try: models = model_cache.get(config) pipe = models["pipe"].to("cuda") pipe.enable_attention_slicing() pipe.scheduler.config.sigma_min = config["sigma_min"] pipe.scheduler.config.sigma_max = config["sigma_max"] prompt_for_model = "" if config["no_text"] else (prompt_text or "") with torch.no_grad(): if config["apadapter"]: final_condition, final_condition_audio = process_musical_conditions( config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe ) waveform = pipe( extracted_condition=final_condition, extracted_condition_audio=final_condition_audio, prompt=prompt_for_model, negative_prompt=config["negative_text_prompt"], num_inference_steps=config["denoise_step"], guidance_scale_text=config["guidance_scale_text"], guidance_scale_con=config["guidance_scale_con"], guidance_scale_audio=config["guidance_scale_audio"], num_waveforms_per_prompt=1, audio_end_in_s=TOTAL_AUDIO_SECONDS, generator=generator, ).audios output = waveform[0].T.float().cpu().numpy() sr = pipe.vae.sampling_rate else: audio = pipe( prompt=prompt_for_model, negative_prompt=config["negative_text_prompt"], num_inference_steps=config["denoise_step"], guidance_scale=config["guidance_scale_text"], num_waveforms_per_prompt=1, audio_end_in_s=TOTAL_AUDIO_SECONDS, generator=generator, ).audios output = audio[0].T.float().cpu().numpy() sr = pipe.vae.sampling_rate generated_path = os.path.join(run_dir, "generated.wav") sf.write(generated_path, output, sr) status_lines = [ f"Run directory: `{run_dir}`", f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}", f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}", f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]", ] if config["apadapter"]: status_lines.append( f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}" ) if generator is not None: status_lines.append(f"Seed: {int(seed)}") status_md = "\n".join(f"- {line}" for line in status_lines) return generated_path, status_md except gr.Error: raise except Exception as err: # pylint: disable=broad-except raise gr.Error(f"Generation failed: {err}") from err EXAMPLES = [ [ "Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting", "melody_condition_audio/49_piano.mp3", ["melody_stereo"], True, False, "", 7.0, 1.5, 1.0, 50, "fp16", 1.0, 0.3, 500, 0, 0, 0, 0, 42, ], [ "fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.", "melody_condition_audio/610_bass.mp3", ["melody_mono", "dynamics", "rhythm"], True, False, "", 7.0, 1.5, 1.0, 50, "fp16", 1.0, 0.3, 500, 0, 0, 0, 0, 7, ], ] def build_interface() -> gr.Blocks: with gr.Blocks(title="MuseControlLite") as demo: gr.Markdown( """ ## MuseControlLite demo UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation. """ ) with gr.Row(): prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT) use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True) no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False) condition_audio = gr.Audio( label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"] ) condition_type = gr.CheckboxGroup( CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", []) ) with gr.Accordion("Advanced controls", open=False): negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", "")) with gr.Row(): guidance_scale_text = gr.Slider( minimum=0.0, maximum=12.0, value=DEFAULT_CONFIG["guidance_scale_text"], step=0.1, label="Guidance scale (text)", ) guidance_scale_con = gr.Slider( minimum=0.0, maximum=5.0, value=DEFAULT_CONFIG["guidance_scale_con"], step=0.1, label="Guidance scale (conditions)", ) guidance_scale_audio = gr.Slider( minimum=0.0, maximum=5.0, value=DEFAULT_CONFIG["guidance_scale_audio"], step=0.1, label="Guidance scale (audio)", ) with gr.Row(): denoise_step = gr.Slider( minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps" ) weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype") ap_scale = gr.Slider( minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale" ) with gr.Row(): sigma_min = gr.Slider( minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min" ) sigma_max = gr.Slider( minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max" ) seed = gr.Number(label="Seed (optional)", precision=0) with gr.Row(): audio_mask_start = gr.Number( label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"] ) audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"]) with gr.Row(): musical_mask_start = gr.Number( label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"] ) musical_mask_end = gr.Number( label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"] ) generate_btn = gr.Button("Generate", variant="primary") generated_audio = gr.Audio(label="Generated audio", type="filepath") status = gr.Markdown(label="Run details") generate_btn.click( fn=run_inference, inputs=[ prompt, condition_audio, condition_type, use_musecontrol, no_text, negative_prompt, guidance_scale_text, guidance_scale_con, guidance_scale_audio, denoise_step, weight_dtype, ap_scale, sigma_min, sigma_max, audio_mask_start, audio_mask_end, musical_mask_start, musical_mask_end, seed, ], outputs=[generated_audio, status], ) gr.Examples( examples=EXAMPLES, inputs=[ prompt, condition_audio, condition_type, use_musecontrol, no_text, negative_prompt, guidance_scale_text, guidance_scale_con, guidance_scale_audio, denoise_step, weight_dtype, ap_scale, sigma_min, sigma_max, audio_mask_start, audio_mask_end, musical_mask_start, musical_mask_end, seed, ], label="Quick start examples (click to populate the form)", ) return demo if __name__ == "__main__": demo = build_interface() demo.launch()