Subh775's picture
Update app.py
f782dc8 verified
raw
history blame
12 kB
import os
import io
import base64
import threading
import traceback
import gc
from typing import Optional
from flask import Flask, request, jsonify, send_from_directory
from PIL import Image
import numpy as np
import requests
import torch
# Set environment variables for CPU-only operation
os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
# Create writable fontconfig cache
os.makedirs("/tmp/.fontconfig", exist_ok=True)
os.makedirs("/tmp/.matplotlib", exist_ok=True)
# Limit torch threads
try:
torch.set_num_threads(4)
except Exception:
pass
import supervision as sv
from rfdetr import RFDETRSegPreview
app = Flask(__name__, static_folder="static", static_url_path="/")
# Checkpoint URL & local path
CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
MODEL_LOCK = threading.Lock()
MODEL = None
def download_file(url: str, dst: str, chunk_size: int = 8192):
"""Download file if not exists"""
if os.path.exists(dst) and os.path.getsize(dst) > 0:
print(f"[INFO] Checkpoint already exists at {dst}")
return dst
print(f"[INFO] Downloading weights from {url} -> {dst}")
try:
r = requests.get(url, stream=True, timeout=180)
r.raise_for_status()
with open(dst, "wb") as fh:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
fh.write(chunk)
print("[INFO] Download complete.")
return dst
except Exception as e:
print(f"[ERROR] Download failed: {e}")
raise
def init_model():
"""Lazily initialize the RF-DETR model and cache it in global MODEL."""
global MODEL
with MODEL_LOCK:
if MODEL is not None:
print("[INFO] Model already loaded, returning cached instance")
return MODEL
try:
# Ensure checkpoint present
if not os.path.exists(CHECKPOINT_PATH):
print("[INFO] Checkpoint not found, downloading...")
download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
else:
print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
print("[INFO] Loading RF-DETR model (CPU mode)...")
MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
# Try to optimize for inference
try:
print("[INFO] Optimizing model for inference...")
MODEL.optimize_for_inference()
print("[INFO] Model optimization complete")
except Exception as e:
print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
print("[INFO] Model ready for inference")
return MODEL
except Exception as e:
print(f"[ERROR] Model initialization failed: {e}")
traceback.print_exc()
raise
def decode_data_url(data_url: str) -> Image.Image:
"""Decode data URL to PIL Image"""
if data_url.startswith("data:"):
_, b64 = data_url.split(",", 1)
data = base64.b64decode(b64)
else:
try:
data = base64.b64decode(data_url)
except Exception:
raise ValueError("Invalid image data")
return Image.open(io.BytesIO(data)).convert("RGB")
def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
"""Encode PIL Image to data URL"""
buf = io.BytesIO()
pil_img.save(buf, format=fmt, optimize=False)
buf.seek(0)
return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
def annotate_segmentation(image: Image.Image, detections: sv.Detections) -> Image.Image:
"""
Annotate image with segmentation masks using supervision library.
This matches the visualization from rfdetr_seg_infer.py script.
"""
try:
# Define color palette
palette = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
])
# Calculate optimal text scale based on image resolution
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
print(f"[INFO] Creating annotators with text_scale={text_scale}")
# Create annotators
mask_annotator = sv.MaskAnnotator(color=palette)
polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
label_annotator = sv.LabelAnnotator(
color=palette,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_position=sv.Position.CENTER_OF_MASS
)
# Create labels with confidence scores
labels = [
f"Tulsi {float(conf):.2f}"
for conf in detections.confidence
]
print(f"[INFO] Annotating {len(labels)} detections")
# Apply annotations step by step
out = image.copy()
print("[INFO] Applying mask annotation...")
out = mask_annotator.annotate(out, detections)
print("[INFO] Applying polygon annotation...")
out = polygon_annotator.annotate(out, detections)
print("[INFO] Applying label annotation...")
out = label_annotator.annotate(out, detections, labels)
print("[INFO] Annotation complete")
return out
except Exception as e:
print(f"[ERROR] Annotation failed: {e}")
traceback.print_exc()
# Return original image if annotation fails
return image
@app.route("/", methods=["GET"])
def index():
"""Serve the static UI"""
index_path = os.path.join(app.static_folder or "static", "index.html")
if os.path.exists(index_path):
return send_from_directory(app.static_folder, "index.html")
return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
model_loaded = MODEL is not None
return jsonify({
"status": "healthy",
"model_loaded": model_loaded,
"checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
})
@app.route("/predict", methods=["POST"])
def predict():
"""
Accepts:
- multipart/form-data with file field "file"
- or JSON {"image": "<data:url...>", "conf": 0.25}
Returns JSON:
{"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
"""
print("\n[INFO] ========== New prediction request ==========")
try:
print("[INFO] Initializing model...")
model = init_model()
print("[INFO] Model ready")
except Exception as e:
error_msg = f"Model initialization failed: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 500
# Parse input
img: Optional[Image.Image] = None
conf_threshold = 0.25
# Check if file uploaded
if "file" in request.files:
file = request.files["file"]
print(f"[INFO] Processing uploaded file: {file.filename}")
try:
img = Image.open(file.stream).convert("RGB")
except Exception as e:
error_msg = f"Invalid uploaded image: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(request.form.get("conf", conf_threshold))
else:
# Try JSON payload
payload = request.get_json(silent=True)
if not payload or "image" not in payload:
return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
try:
print("[INFO] Decoding image from data URL...")
img = decode_data_url(payload["image"])
except Exception as e:
error_msg = f"Invalid image data: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(payload.get("conf", conf_threshold))
print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
# Optionally downscale large images to reduce memory usage
MAX_SIZE = 1024
if max(img.size) > MAX_SIZE:
w, h = img.size
scale = MAX_SIZE / float(max(w, h))
new_w, new_h = int(round(w * scale)), int(round(h * scale))
print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Run inference with no_grad for memory efficiency
try:
print("[INFO] Running inference...")
with torch.no_grad():
detections = model.predict(img, threshold=conf_threshold)
print(f"[INFO] Raw detections: {len(detections)} objects")
# Check if detections exist
if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
print("[INFO] No detections above threshold")
# Return original image
data_url = encode_pil_to_dataurl(img, fmt="PNG")
return jsonify({
"annotated": data_url,
"confidences": [],
"count": 0
})
print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
# Check if masks exist
if hasattr(detections, 'masks') and detections.masks is not None:
print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
else:
print("[WARN] No masks found in detections!")
# Annotate image using supervision library
print("[INFO] Starting annotation...")
annotated_pil = annotate_segmentation(img, detections)
# Extract confidence scores
confidences = [float(conf) for conf in detections.confidence]
print(f"[INFO] Final confidences: {confidences}")
# Encode to data URL
print("[INFO] Encoding annotated image...")
data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
# Clean up
del detections
gc.collect()
print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
return jsonify({
"annotated": data_url,
"confidences": confidences,
"count": len(confidences)
})
except Exception as e:
error_msg = f"Inference failed: {e}"
print(f"[ERROR] {error_msg}")
traceback.print_exc()
return jsonify({"error": error_msg}), 500
if __name__ == "__main__":
print("\n" + "="*60)
print("Starting Tulsi Leaf Segmentation Server")
print("="*60 + "\n")
# Warm model in background thread
def warm():
try:
print("[INFO] Starting model warmup in background...")
init_model()
print("[INFO] βœ“ Model warmup complete - ready for predictions")
except Exception as e:
print(f"[ERROR] βœ— Model warmup failed: {e}")
traceback.print_exc()
threading.Thread(target=warm, daemon=True).start()
# Run Flask app
app.run(
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
debug=False
)