| import json |
| import os |
| from pathlib import Path |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from visualization.corners import Circle, Corners |
| from visualization.draw import visualize_corners |
|
|
| YOLO_MODEL = None |
| YOLO_CONF = 0.7 |
|
|
|
|
| def setup_yolo(): |
| global YOLO_MODEL, YOLO_CONF |
|
|
| if YOLO_MODEL is not None: |
| return True |
|
|
| weights_path = Path(__file__).parent / "yolo_weights" / "best.pt" |
| if not weights_path.exists(): |
| return False |
|
|
| try: |
| from ultralytics import YOLO |
| YOLO_MODEL = YOLO(str(weights_path)) |
| return True |
| except Exception as e: |
| print(f"Failed to load YOLO model: {e}") |
| return False |
|
|
|
|
| def predict_yolo(image): |
| if not setup_yolo(): |
| return None, "YOLO model not available" |
|
|
| import tempfile |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: |
| image.save(f.name) |
| temp_path = f.name |
|
|
| try: |
| results = YOLO_MODEL.predict( |
| source=temp_path, |
| conf=YOLO_CONF / 2.0, |
| iou=0.01, |
| verbose=False, |
| ) |
|
|
| r = results[0] |
| if len(r.boxes) < 4: |
| return None, f"Less than 4 corners detected ({len(r.boxes)})" |
|
|
| boxes = r.boxes.xyxy.cpu().numpy() |
| confs = r.boxes.conf.cpu().numpy() |
|
|
| top4_idx = confs.argsort()[-4:][::-1] |
| top4_boxes = [(int(boxes[i][0]), int(boxes[i][1]), int(boxes[i][2]), int(boxes[i][3])) for i in top4_idx] |
|
|
| corners = Corners.from_boxes(top4_boxes) |
| return corners, "OK" |
| finally: |
| os.unlink(temp_path) |
|
|
|
|
| def process_single_image(image): |
| yolo_corners, yolo_status = predict_yolo(image) |
|
|
| yolo_viz = None |
|
|
| if yolo_corners: |
| yolo_viz = visualize_corners(image.copy(), yolo_corners, color=(255, 100, 0, 255)) |
| yolo_info = json.dumps(yolo_corners.to_dict(), indent=2) |
| else: |
| yolo_info = f"Not detected: {yolo_status}" |
|
|
| return yolo_viz, yolo_info |
|
|
|
|
| def process_images(images): |
| results = [] |
|
|
| if images is None: |
| return results |
|
|
| for img_data in images: |
| if isinstance(img_data, str): |
| image = Image.open(img_data) |
| filename = Path(img_data).name |
| elif isinstance(img_data, Image.Image): |
| image = img_data |
| filename = "uploaded_image" |
| else: |
| try: |
| image = Image.open(img_data) |
| filename = "uploaded_image" |
| except Exception: |
| continue |
|
|
| yolo_viz, yolo_info = process_single_image(image) |
|
|
| results.append({ |
| "filename": filename, |
| "original": image, |
| "yolo_viz": yolo_viz, |
| "yolo_info": yolo_info, |
| }) |
|
|
| return results |
|
|
|
|
| def create_gallery_output(results): |
| yolo_images = [] |
|
|
| for r in results: |
| if r["yolo_viz"]: |
| yolo_images.append((r["yolo_viz"], r["filename"])) |
| else: |
| yolo_images.append((r["original"], f"{r['filename']} (not detected)")) |
|
|
| yolo_info = "\n\n".join([f"**{r['filename']}**\n{r['yolo_info']}" for r in results]) |
|
|
| return yolo_images, yolo_info |
|
|
|
|
| def predict(images): |
| if images is None or len(images) == 0: |
| return [], "No images uploaded" |
|
|
| results = process_images(images) |
| return create_gallery_output(results) |
|
|
|
|
| with gr.Blocks(title="Squircle Corners Prediction") as demo: |
| gr.Markdown("# Squircle Corners Prediction") |
| gr.HTML(""" |
| <style> |
| #gallery { height: auto !important; max-height: none !important; } |
| </style> |
| """) |
| gr.Markdown(""" |
| Upload iOS-style app icons to detect squircle corners using a trained YOLOv11 model. |
| |
| The visualization shows detected corner circles and crop bounds. |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_images = gr.File( |
| label="Upload Images", |
| file_count="multiple", |
| file_types=["image"], |
| type="filepath" |
| ) |
| predict_btn = gr.Button("Predict", variant="primary") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### YOLO Predictions") |
| yolo_gallery = gr.Gallery( |
| label="YOLO Results", |
| columns=3, |
| height="auto", |
| object_fit="contain", |
| elem_id="gallery" |
| ) |
| yolo_info = gr.Markdown(label="YOLO Details") |
|
|
| predict_btn.click( |
| fn=predict, |
| inputs=[input_images], |
| outputs=[yolo_gallery, yolo_info], |
| api_name="predict" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|