English
flux-test2 / handler.py
refoundd's picture
Update handler.py
0f93b31 verified
Raw
History Blame Contribute Delete
3.84 kB
import os
from typing import Any, Dict, Union
from PIL import Image
import torch
from diffusers import FluxPipeline
from huggingface_inference_toolkit.logging import logger
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from torchao.quantization import autoquant
import time
import gc
# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")
import torch._dynamo
torch._dynamo.config.suppress_errors = False # for debugging
class EndpointHandler:
def __init__(self, path=""):
self.pipe = FluxPipeline.from_pretrained(
"NoMoreCopyrightOrg/flux-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
self.pipe.enable_vae_slicing()
self.pipe.enable_vae_tiling()
self.pipe.transformer.fuse_qkv_projections()
self.pipe.vae.fuse_qkv_projections()
self.pipe.transformer.to(memory_format=torch.channels_last)
self.pipe.vae.to(memory_format=torch.channels_last)
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
self.pipe.transformer = torch.compile(
self.pipe.transformer, mode="max-autotune-no-cudagraphs",
)
self.pipe.vae = torch.compile(
self.pipe.vae, mode="max-autotune-no-cudagraphs",
)
self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
gc.collect()
torch.cuda.empty_cache()
start_time = time.time()
print("Start warming-up pipeline")
self.pipe("Hello world!") # Warm-up for compiling
end_time = time.time()
time_taken = end_time - start_time
print(f"Time taken: {time_taken:.2f} seconds")
self.record=0
def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
try:
logger.info(f"Received incoming request with {data=}")
if "inputs" in data and isinstance(data["inputs"], str):
prompt = data.pop("inputs")
elif "prompt" in data and isinstance(data["prompt"], str):
prompt = data.pop("prompt")
else:
raise ValueError(
"Provided input body must contain either the key `inputs` or `prompt` with the"
" prompt to use for the image generation, and it needs to be a non-empty string."
)
if prompt=="get_queue":
return self.record
parameters = data.pop("parameters", {})
num_inference_steps = parameters.get("num_inference_steps", 28)
width = parameters.get("width", 1024)
height = parameters.get("height", 1024)
#guidance_scale = parameters.get("guidance_scale", 3.5)
guidance_scale = parameters.get("guidance", 3.5)
# seed generator (seed cannot be provided as is but via a generator)
seed = parameters.get("seed", 0)
generator = torch.manual_seed(seed)
self.record+=1
start_time = time.time()
result = self.pipe( # type: ignore
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
end_time = time.time()
time_taken = end_time - start_time
print(f"Time taken: {time_taken:.2f} seconds")
self.record-=1
return result
except Exception as e:
print(e)
return None