IgorSlinko's picture
Fix file upload and numpy compatibility
1427ba3
import json
import os
from pathlib import Path
import gradio as gr
from PIL import Image
from visualization.corners import Circle, Corners
from visualization.draw import visualize_corners
YOLO_MODEL = None
YOLO_CONF = 0.7
def setup_yolo():
global YOLO_MODEL, YOLO_CONF
if YOLO_MODEL is not None:
return True
weights_path = Path(__file__).parent / "yolo_weights" / "best.pt"
if not weights_path.exists():
return False
try:
from ultralytics import YOLO
YOLO_MODEL = YOLO(str(weights_path))
return True
except Exception as e:
print(f"Failed to load YOLO model: {e}")
return False
def predict_yolo(image):
if not setup_yolo():
return None, "YOLO model not available"
import tempfile
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
temp_path = f.name
try:
results = YOLO_MODEL.predict(
source=temp_path,
conf=YOLO_CONF / 2.0,
iou=0.01,
verbose=False,
)
r = results[0]
if len(r.boxes) < 4:
return None, f"Less than 4 corners detected ({len(r.boxes)})"
boxes = r.boxes.xyxy.cpu().numpy()
confs = r.boxes.conf.cpu().numpy()
top4_idx = confs.argsort()[-4:][::-1]
top4_boxes = [(int(boxes[i][0]), int(boxes[i][1]), int(boxes[i][2]), int(boxes[i][3])) for i in top4_idx]
corners = Corners.from_boxes(top4_boxes)
return corners, "OK"
finally:
os.unlink(temp_path)
def process_single_image(image):
yolo_corners, yolo_status = predict_yolo(image)
yolo_viz = None
if yolo_corners:
yolo_viz = visualize_corners(image.copy(), yolo_corners, color=(255, 100, 0, 255))
yolo_info = json.dumps(yolo_corners.to_dict(), indent=2)
else:
yolo_info = f"Not detected: {yolo_status}"
return yolo_viz, yolo_info
def process_images(images):
results = []
if images is None:
return results
for img_data in images:
if isinstance(img_data, str):
image = Image.open(img_data)
filename = Path(img_data).name
elif isinstance(img_data, Image.Image):
image = img_data
filename = "uploaded_image"
else:
try:
image = Image.open(img_data)
filename = "uploaded_image"
except Exception:
continue
yolo_viz, yolo_info = process_single_image(image)
results.append({
"filename": filename,
"original": image,
"yolo_viz": yolo_viz,
"yolo_info": yolo_info,
})
return results
def create_gallery_output(results):
yolo_images = []
for r in results:
if r["yolo_viz"]:
yolo_images.append((r["yolo_viz"], r["filename"]))
else:
yolo_images.append((r["original"], f"{r['filename']} (not detected)"))
yolo_info = "\n\n".join([f"**{r['filename']}**\n{r['yolo_info']}" for r in results])
return yolo_images, yolo_info
def predict(images):
if images is None or len(images) == 0:
return [], "No images uploaded"
results = process_images(images)
return create_gallery_output(results)
with gr.Blocks(title="Squircle Corners Prediction") as demo:
gr.Markdown("# Squircle Corners Prediction")
gr.HTML("""
<style>
#gallery { height: auto !important; max-height: none !important; }
</style>
""")
gr.Markdown("""
Upload iOS-style app icons to detect squircle corners using a trained YOLOv11 model.
The visualization shows detected corner circles and crop bounds.
""")
with gr.Row():
with gr.Column(scale=1):
input_images = gr.File(
label="Upload Images",
file_count="multiple",
file_types=["image"],
type="filepath"
)
predict_btn = gr.Button("Predict", variant="primary")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### YOLO Predictions")
yolo_gallery = gr.Gallery(
label="YOLO Results",
columns=3,
height="auto",
object_fit="contain",
elem_id="gallery"
)
yolo_info = gr.Markdown(label="YOLO Details")
predict_btn.click(
fn=predict,
inputs=[input_images],
outputs=[yolo_gallery, yolo_info],
api_name="predict"
)
if __name__ == "__main__":
demo.launch()