Spaces:
Running
Running
File size: 13,246 Bytes
db12d23 c8d1052 f782dc8 c8d1052 84395d7 db12d23 84395d7 db12d23 84395d7 c8d1052 f782dc8 c8d1052 84395d7 f782dc8 84395d7 f782dc8 c8d1052 84395d7 c8d1052 1009e84 c8d1052 db12d23 c8d1052 84395d7 c8d1052 84395d7 db12d23 c8d1052 f782dc8 db12d23 c8d1052 84395d7 c8d1052 f782dc8 c8d1052 84395d7 f782dc8 c8d1052 f782dc8 c8d1052 84395d7 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 f782dc8 84395d7 f782dc8 c8d1052 f782dc8 c8d1052 84395d7 c8d1052 84395d7 c8d1052 f782dc8 c8d1052 8eb3166 c8d1052 84395d7 8eb3166 c8d1052 f782dc8 8eb3166 f782dc8 8eb3166 f782dc8 c8d1052 84395d7 c8d1052 f782dc8 f8a9f51 db12d23 c8d1052 8eb3166 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 84395d7 c8d1052 7622e4a 2313bae c8d1052 84395d7 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 2313bae c8d1052 84395d7 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 2313bae c8d1052 f782dc8 8eb3166 f782dc8 84395d7 f782dc8 84395d7 c8d1052 f782dc8 84395d7 f782dc8 84395d7 f782dc8 84395d7 f782dc8 84395d7 f782dc8 84395d7 f782dc8 8eb3166 84395d7 f782dc8 84395d7 f782dc8 84395d7 f782dc8 84395d7 c8d1052 f782dc8 c8d1052 f782dc8 f8a9f51 ecee7e2 db12d23 f782dc8 84395d7 c8d1052 f782dc8 c8d1052 f782dc8 c8d1052 f782dc8 84395d7 c8d1052 84395d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
import os
import io
import base64
import threading
import traceback
import gc
from typing import Optional
from flask import Flask, request, jsonify, send_from_directory
from PIL import Image
import numpy as np
import requests
import torch
# Set environment variables for CPU-only operation
os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
# Create writable fontconfig cache
os.makedirs("/tmp/.fontconfig", exist_ok=True)
os.makedirs("/tmp/.matplotlib", exist_ok=True)
# Limit torch threads
try:
torch.set_num_threads(4)
except Exception:
pass
import supervision as sv
from rfdetr import RFDETRSegPreview
app = Flask(__name__, static_folder="static", static_url_path="/")
# Checkpoint URL & local path
CHECKPOINT_URL = "https://huggingface.co/Subh775/Seg-Basil-rfdetr/resolve/main/checkpoint_best_total.pth"
CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
MODEL_LOCK = threading.Lock()
MODEL = None
def download_file(url: str, dst: str, chunk_size: int = 8192):
"""Download file if not exists"""
if os.path.exists(dst) and os.path.getsize(dst) > 0:
print(f"[INFO] Checkpoint already exists at {dst}")
return dst
print(f"[INFO] Downloading weights from {url} -> {dst}")
try:
r = requests.get(url, stream=True, timeout=180)
r.raise_for_status()
with open(dst, "wb") as fh:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
fh.write(chunk)
print("[INFO] Download complete.")
return dst
except Exception as e:
print(f"[ERROR] Download failed: {e}")
raise
def init_model():
"""Lazily initialize the RF-DETR model and cache it in global MODEL."""
global MODEL
with MODEL_LOCK:
if MODEL is not None:
print("[INFO] Model already loaded, returning cached instance")
return MODEL
try:
# Ensure checkpoint present
if not os.path.exists(CHECKPOINT_PATH):
print("[INFO] Checkpoint not found, downloading...")
download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
else:
print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
print("[INFO] Loading RF-DETR model (CPU mode)...")
MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
# Try to optimize for inference
try:
print("[INFO] Optimizing model for inference...")
MODEL.optimize_for_inference()
print("[INFO] Model optimization complete")
except Exception as e:
print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
print("[INFO] Model ready for inference")
return MODEL
except Exception as e:
print(f"[ERROR] Model initialization failed: {e}")
traceback.print_exc()
raise
def decode_data_url(data_url: str) -> Image.Image:
"""Decode data URL to PIL Image"""
if data_url.startswith("data:"):
_, b64 = data_url.split(",", 1)
data = base64.b64decode(b64)
else:
try:
data = base64.b64decode(data_url)
except Exception:
raise ValueError("Invalid image data")
return Image.open(io.BytesIO(data)).convert("RGB")
def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
"""Encode PIL Image to data URL"""
buf = io.BytesIO()
pil_img.save(buf, format=fmt, optimize=False)
buf.seek(0)
return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
def annotate_segmentation(image: Image.Image, detections: sv.Detections,
show_labels: bool = True, show_confidence: bool = True) -> Image.Image:
"""
Annotate image with segmentation masks using supervision library.
This matches the visualization from rfdetr_seg_infer.py script.
Args:
image: Input PIL Image
detections: Supervision Detections object
show_labels: Whether to show "Tulsi" label text
show_confidence: Whether to show confidence scores
"""
try:
# Define color palette
palette = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
])
# Calculate optimal text scale based on image resolution
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
print(f"[INFO] Creating annotators with text_scale={text_scale}")
# Create annotators
mask_annotator = sv.MaskAnnotator(color=palette)
polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
# Apply base annotations (masks and polygons always shown)
out = image.copy()
print("[INFO] Applying mask annotation...")
out = mask_annotator.annotate(out, detections)
print("[INFO] Applying polygon annotation...")
out = polygon_annotator.annotate(out, detections)
# Only add labels if at least one option is enabled
if show_labels or show_confidence:
label_annotator = sv.LabelAnnotator(
color=palette,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_position=sv.Position.CENTER_OF_MASS
)
# Create labels based on options
labels = []
for conf in detections.confidence:
label_parts = []
if show_labels:
label_parts.append("Tulsi")
if show_confidence:
label_parts.append(f"{float(conf):.2f}")
labels.append(" ".join(label_parts))
print(f"[INFO] Applying label annotation with {len(labels)} labels...")
out = label_annotator.annotate(out, detections, labels)
else:
print("[INFO] Skipping label annotation (both labels and confidence disabled)")
print("[INFO] Annotation complete")
return out
except Exception as e:
print(f"[ERROR] Annotation failed: {e}")
traceback.print_exc()
# Return original image if annotation fails
return image
@app.route("/", methods=["GET"])
def index():
"""Serve the static UI"""
index_path = os.path.join(app.static_folder or "static", "index.html")
if os.path.exists(index_path):
return send_from_directory(app.static_folder, "index.html")
return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
model_loaded = MODEL is not None
return jsonify({
"status": "healthy",
"model_loaded": model_loaded,
"checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
})
@app.route("/predict", methods=["POST"])
def predict():
"""
Accepts:
- multipart/form-data with file field "file"
- or JSON {"image": "<data:url...>", "conf": 0.05, "show_labels": true, "show_confidence": true}
Returns JSON:
{"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
"""
print("\n[INFO] ========== New prediction request ==========")
try:
print("[INFO] Initializing model...")
model = init_model()
print("[INFO] Model ready")
except Exception as e:
error_msg = f"Model initialization failed: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 500
# Parse input
img: Optional[Image.Image] = None
conf_threshold = 0.05
show_labels = True
show_confidence = True
# Check if file uploaded
if "file" in request.files:
file = request.files["file"]
print(f"[INFO] Processing uploaded file: {file.filename}")
try:
img = Image.open(file.stream).convert("RGB")
except Exception as e:
error_msg = f"Invalid uploaded image: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(request.form.get("conf", conf_threshold))
show_labels = request.form.get("show_labels", "true").lower() == "true"
show_confidence = request.form.get("show_confidence", "true").lower() == "true"
else:
# Try JSON payload
payload = request.get_json(silent=True)
if not payload or "image" not in payload:
return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
try:
print("[INFO] Decoding image from data URL...")
img = decode_data_url(payload["image"])
except Exception as e:
error_msg = f"Invalid image data: {e}"
print(f"[ERROR] {error_msg}")
return jsonify({"error": error_msg}), 400
conf_threshold = float(payload.get("conf", conf_threshold))
show_labels = payload.get("show_labels", True)
show_confidence = payload.get("show_confidence", True)
print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
print(f"[INFO] Display options - Labels: {show_labels}, Confidence: {show_confidence}")
# Optionally downscale large images to reduce memory usage
MAX_SIZE = 1024
if max(img.size) > MAX_SIZE:
w, h = img.size
scale = MAX_SIZE / float(max(w, h))
new_w, new_h = int(round(w * scale)), int(round(h * scale))
print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Run inference with no_grad for memory efficiency
try:
print("[INFO] Running inference...")
with torch.no_grad():
detections = model.predict(img, threshold=conf_threshold)
print(f"[INFO] Raw detections: {len(detections)} objects")
# Check if detections exist
if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
print("[INFO] No detections above threshold")
# Return original image
data_url = encode_pil_to_dataurl(img, fmt="PNG")
return jsonify({
"annotated": data_url,
"confidences": [],
"count": 0
})
print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
# Check if masks exist
if hasattr(detections, 'masks') and detections.masks is not None:
print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
else:
print("[WARN] No masks found in detections!")
# Annotate image using supervision library
print("[INFO] Starting annotation...")
annotated_pil = annotate_segmentation(img, detections, show_labels, show_confidence)
# Extract confidence scores
confidences = [float(conf) for conf in detections.confidence]
print(f"[INFO] Final confidences: {confidences}")
# Encode to data URL
print("[INFO] Encoding annotated image...")
data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
# Clean up
del detections
gc.collect()
print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
return jsonify({
"annotated": data_url,
"confidences": confidences,
"count": len(confidences)
})
except Exception as e:
error_msg = f"Inference failed: {e}"
print(f"[ERROR] {error_msg}")
traceback.print_exc()
return jsonify({"error": error_msg}), 500
if __name__ == "__main__":
print("\n" + "="*60)
print("Starting Tulsi Leaf Segmentation Server")
print("="*60 + "\n")
# Warm model in background thread
def warm():
try:
print("[INFO] Starting model warmup in background...")
init_model()
print("[INFO] β Model warmup complete - ready for predictions")
except Exception as e:
print(f"[ERROR] β Model warmup failed: {e}")
traceback.print_exc()
threading.Thread(target=warm, daemon=True).start()
# Run Flask app
app.run(
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
debug=False
) |