Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import subprocess | |
| import sys | |
| import io | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| from diffusers import Flux2Pipeline, Flux2Transformer2DModel | |
| import requests | |
| from PIL import Image | |
| import base64 | |
| from huggingface_hub import InferenceClient | |
| # Install spaces if needed | |
| try: | |
| import spaces | |
| except ImportError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "spaces==0.43.0"]) | |
| import spaces | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| # Hugging Face token for gated repo authentication | |
| HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGING_FACE_HUB_TOKEN")) | |
| hf_client = ( | |
| InferenceClient( | |
| api_key=HF_TOKEN, | |
| ) | |
| if HF_TOKEN | |
| else None | |
| ) | |
| VLM_MODEL = "baidu/ERNIE-4.5-VL-424B-A47B-Base-PT" | |
| SYSTEM_PROMPT_TEXT_ONLY = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. | |
| Guidelines: | |
| 1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs. | |
| 2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context. | |
| 3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish. | |
| Output only the revised prompt and nothing else.""" | |
| SYSTEM_PROMPT_WITH_IMAGES = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). | |
| Rules: | |
| - Single instruction only, no commentary | |
| - Use clear, analytical language (avoid "whimsical," "cascading," etc.) | |
| - Specify what changes AND what stays the same (face, lighting, composition) | |
| - Reference actual image elements | |
| - Turn negatives into positives ("don't change X" → "keep X") | |
| - Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels") | |
| - Keep content PG-13 | |
| Output only the final instruction in plain text and nothing else.""" | |
| def remote_text_encoder(prompts): | |
| from gradio_client import Client | |
| client = Client("multimodalart/mistral-text-encoder") | |
| result = client.predict(prompt=prompts, api_name="/encode_text") | |
| # Load returns a tensor, usually on CPU by default | |
| prompt_embeds = torch.load(result[0]) | |
| return prompt_embeds | |
| # Load model | |
| repo_id = "black-forest-labs/FLUX.2-dev" | |
| print("Loading Flux.2 model...") | |
| dit = Flux2Transformer2DModel.from_pretrained( | |
| repo_id, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN, | |
| ) | |
| pipe = Flux2Pipeline.from_pretrained( | |
| repo_id, | |
| text_encoder=None, | |
| transformer=dit, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN, | |
| ) | |
| pipe.to(device) | |
| # Pull pre-compiled Flux2 Transformer blocks from HF hub for ZeroGPU | |
| print("Loading pre-compiled blocks for ZeroGPU...") | |
| spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3") | |
| def image_to_data_uri(img): | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| def upsample_prompt_logic(prompt, image_list): | |
| """Upsample prompt using VLM if available""" | |
| if not hf_client: | |
| return prompt | |
| try: | |
| if image_list and len(image_list) > 0: | |
| # Image + Text Editing Mode | |
| system_content = SYSTEM_PROMPT_WITH_IMAGES | |
| # Construct user message with text and images | |
| user_content = [{"type": "text", "text": prompt}] | |
| for img in image_list: | |
| data_uri = image_to_data_uri(img) | |
| user_content.append( | |
| {"type": "image_url", "image_url": {"url": data_uri}} | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_content}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| else: | |
| # Text Only Mode | |
| system_content = SYSTEM_PROMPT_TEXT_ONLY | |
| messages = [ | |
| {"role": "system", "content": system_content}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| completion = hf_client.chat.completions.create( | |
| model=VLM_MODEL, messages=messages, max_tokens=1024 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"Upsampling failed: {e}") | |
| return prompt | |
| def update_dimensions_from_image(image_list): | |
| """Update width/height sliders based on uploaded image aspect ratio. | |
| Keeps one side at 1024 and scales the other proportionally, with both sides as multiples of 8. | |
| """ | |
| if image_list is None or len(image_list) == 0: | |
| return 1024, 1024 # Default dimensions | |
| # Get the first image to determine dimensions | |
| img = image_list[0][0] # Gallery returns list of tuples (image, caption) | |
| img_width, img_height = img.size | |
| aspect_ratio = img_width / img_height | |
| if aspect_ratio >= 1: # Landscape or square | |
| new_width = 1024 | |
| new_height = int(1024 / aspect_ratio) | |
| else: # Portrait | |
| new_height = 1024 | |
| new_width = int(1024 * aspect_ratio) | |
| # Round to nearest multiple of 8 | |
| new_width = round(new_width / 8) * 8 | |
| new_height = round(new_height / 8) * 8 | |
| # Ensure within valid range (minimum 256, maximum 1024) | |
| new_width = max(256, min(1024, new_width)) | |
| new_height = max(256, min(1024, new_height)) | |
| return new_width, new_height | |
| # Updated duration function to match generate_image arguments (including progress) | |
| def get_duration( | |
| prompt_embeds, | |
| image_list, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| num_images = 0 if image_list is None else len(image_list) | |
| step_duration = 1 + 0.8 * num_images | |
| return max(65, num_inference_steps * step_duration + 10) | |
| def generate_image( | |
| prompt_embeds, | |
| image_list, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # Move embeddings to GPU only when inside the GPU decorated function | |
| prompt_embeds = prompt_embeds.to(device) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| pipe_kwargs = { | |
| "prompt_embeds": prompt_embeds, | |
| "image": image_list, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "generator": generator, | |
| "width": width, | |
| "height": height, | |
| } | |
| # Progress bar for the actual generation steps | |
| if progress: | |
| progress(0, desc="Starting generation...") | |
| image = pipe(**pipe_kwargs).images[0] | |
| return image | |
| def infer( | |
| prompt, | |
| input_images=None, | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=30, | |
| guidance_scale=4.0, | |
| prompt_upsampling=False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Prepare image list (convert None or empty gallery to None) | |
| image_list = None | |
| if input_images is not None and len(input_images) > 0: | |
| image_list = [] | |
| for item in input_images: | |
| image_list.append(item[0]) | |
| # 1. Upsampling (Network bound - No GPU needed) | |
| final_prompt = prompt | |
| if prompt_upsampling: | |
| progress(0.05, desc="Upsampling prompt...") | |
| final_prompt = upsample_prompt_logic(prompt, image_list) | |
| print(f"Original Prompt: {prompt}") | |
| print(f"Upsampled Prompt: {final_prompt}") | |
| # 2. Text Encoding (Network bound - No GPU needed) | |
| progress(0.1, desc="Encoding prompt...") | |
| # This returns CPU tensors | |
| prompt_embeds = remote_text_encoder(final_prompt) | |
| # 3. Image Generation (GPU bound) | |
| progress(0.3, desc="Waiting for GPU...") | |
| image = generate_image( | |
| prompt_embeds, | |
| image_list, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| progress, | |
| ) | |
| return image, seed | |
| examples = [ | |
| ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"], | |
| ["An astronaut riding a green horse"], | |
| ["A delicious ceviche cheesecake slice"], | |
| [ | |
| "Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088" | |
| ], | |
| [ | |
| "Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo" | |
| ], | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1200px; | |
| } | |
| .gallery-container img { | |
| object-fit: contain; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown( | |
| """# FLUX.2 [dev] Text-to-Image | |
| FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)] | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=2, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| scale=3, | |
| ) | |
| run_button = gr.Button("Run", scale=1, variant="primary") | |
| with gr.Accordion("Input image(s) (optional)", open=False): | |
| gr.Markdown("Upload images for editing or combining") | |
| input_images = gr.Gallery( | |
| label="Input Image(s)", | |
| type="pil", | |
| columns=3, | |
| rows=1, | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt_upsampling = gr.Checkbox( | |
| label="Prompt Upsampling", | |
| value=False, | |
| info="Automatically enhance the prompt using a VLM (requires HF_TOKEN)", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=30, | |
| info="More steps = higher quality but slower", | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=4.0, | |
| info="How closely to follow the prompt", | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="Result", show_label=False) | |
| gr.Examples(examples=examples, inputs=[prompt], cache_examples=False) | |
| # Auto-update dimensions when images are uploaded | |
| input_images.upload( | |
| fn=update_dimensions_from_image, inputs=[input_images], outputs=[width, height] | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| input_images, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| prompt_upsampling, | |
| ], | |
| outputs=[result, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=css) | |