import json import os from pathlib import Path import gradio as gr from PIL import Image from visualization.corners import Circle, Corners from visualization.draw import visualize_corners YOLO_MODEL = None YOLO_CONF = 0.7 def setup_yolo(): global YOLO_MODEL, YOLO_CONF if YOLO_MODEL is not None: return True weights_path = Path(__file__).parent / "yolo_weights" / "best.pt" if not weights_path.exists(): return False try: from ultralytics import YOLO YOLO_MODEL = YOLO(str(weights_path)) return True except Exception as e: print(f"Failed to load YOLO model: {e}") return False def predict_yolo(image): if not setup_yolo(): return None, "YOLO model not available" import tempfile with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: image.save(f.name) temp_path = f.name try: results = YOLO_MODEL.predict( source=temp_path, conf=YOLO_CONF / 2.0, iou=0.01, verbose=False, ) r = results[0] if len(r.boxes) < 4: return None, f"Less than 4 corners detected ({len(r.boxes)})" boxes = r.boxes.xyxy.cpu().numpy() confs = r.boxes.conf.cpu().numpy() top4_idx = confs.argsort()[-4:][::-1] top4_boxes = [(int(boxes[i][0]), int(boxes[i][1]), int(boxes[i][2]), int(boxes[i][3])) for i in top4_idx] corners = Corners.from_boxes(top4_boxes) return corners, "OK" finally: os.unlink(temp_path) def process_single_image(image): yolo_corners, yolo_status = predict_yolo(image) yolo_viz = None if yolo_corners: yolo_viz = visualize_corners(image.copy(), yolo_corners, color=(255, 100, 0, 255)) yolo_info = json.dumps(yolo_corners.to_dict(), indent=2) else: yolo_info = f"Not detected: {yolo_status}" return yolo_viz, yolo_info def process_images(images): results = [] if images is None: return results for img_data in images: if isinstance(img_data, str): image = Image.open(img_data) filename = Path(img_data).name elif isinstance(img_data, Image.Image): image = img_data filename = "uploaded_image" else: try: image = Image.open(img_data) filename = "uploaded_image" except Exception: continue yolo_viz, yolo_info = process_single_image(image) results.append({ "filename": filename, "original": image, "yolo_viz": yolo_viz, "yolo_info": yolo_info, }) return results def create_gallery_output(results): yolo_images = [] for r in results: if r["yolo_viz"]: yolo_images.append((r["yolo_viz"], r["filename"])) else: yolo_images.append((r["original"], f"{r['filename']} (not detected)")) yolo_info = "\n\n".join([f"**{r['filename']}**\n{r['yolo_info']}" for r in results]) return yolo_images, yolo_info def predict(images): if images is None or len(images) == 0: return [], "No images uploaded" results = process_images(images) return create_gallery_output(results) with gr.Blocks(title="Squircle Corners Prediction") as demo: gr.Markdown("# Squircle Corners Prediction") gr.HTML(""" """) gr.Markdown(""" Upload iOS-style app icons to detect squircle corners using a trained YOLOv11 model. The visualization shows detected corner circles and crop bounds. """) with gr.Row(): with gr.Column(scale=1): input_images = gr.File( label="Upload Images", file_count="multiple", file_types=["image"], type="filepath" ) predict_btn = gr.Button("Predict", variant="primary") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### YOLO Predictions") yolo_gallery = gr.Gallery( label="YOLO Results", columns=3, height="auto", object_fit="contain", elem_id="gallery" ) yolo_info = gr.Markdown(label="YOLO Details") predict_btn.click( fn=predict, inputs=[input_images], outputs=[yolo_gallery, yolo_info], api_name="predict" ) if __name__ == "__main__": demo.launch()