ZIT-Controlnet / app.py
Alexander Bagus
22
83cd1bb
raw
history blame
7.03 kB
import gradio as gr
import numpy as np
import random
import json
import spaces
import torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel
from transformers import AutoTokenizer, Qwen3ForCausalLM
from diffusers import AutoencoderKL
from image_utils import get_image_latent, scale_image
# from videox_fun.utils.utils import get_image_latent
# MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo
MODEL_LOCAL = "models/Z-Image-Turbo/"
# curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors
TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
weight_dtype = torch.bfloat16
# load transformer
transformer = ZImageControlTransformer2DModel.from_pretrained(
MODEL_LOCAL,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
transformer_additional_kwargs={
"control_layers_places": [0, 5, 10, 15, 20, 25],
"control_in_dim": 16
},
).to(torch.bfloat16)
if TRANSFORMER_LOCAL is not None:
print(f"From checkpoint: {TRANSFORMER_LOCAL}")
if TRANSFORMER_LOCAL.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(TRANSFORMER_LOCAL)
else:
state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# load ZImageControlPipeline
vae = AutoencoderKL.from_pretrained(
MODEL_LOCAL,
subfolder="vae"
).to(weight_dtype)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_LOCAL, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=True,
)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
pipe = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
pipe.transformer = transformer
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers,
"zerogpu-aoti/Z-Image", variant="fa3")
@spaces.GPU
def inference(
prompt,
input_image,
image_scale=1.0,
control_context_scale = 0.75,
seed=42,
randomize_seed=True,
guidance_scale=1.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
# process image
if input_image is None:
print("Error: input_image is empty.")
return None
input_image, width, height = scale_image(input_image, image_scale)
control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0]
# generation
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
generator=generator,
guidance_scale=guidance_scale,
control_image=control_image,
num_inference_steps=num_inference_steps,
control_context_scale=control_context_scale,
).images[0]
return image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with open('static/data.json', 'r') as file:
data = json.load(file)
examples = data['examples']
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row(equal_height=True):
with gr.Column():
input_image = gr.Image(
height=290, sources=['upload', 'clipboard'],
image_mode='RGB',
# elem_id="image_upload",
type="pil", label="Upload")
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", variant="primary")
with gr.Column():
output_image = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
image_scale = gr.Slider(
label="Image scale",
minimum=0.5,
maximum=2.0,
step=0.1,
value=1.0,
)
control_context_scale = gr.Slider(
label="Control context scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.75,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=2.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[input_image, prompt])
gr.HTML(read_file("static/footer.html"))
gr.on(
triggers=[run_button.click, prompt.submit],
fn=inference,
inputs=[
prompt,
input_image,
image_scale,
control_context_scale,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed],
).then(
)
if __name__ == "__main__":
demo.launch(mcp_server=True)