Subh775 commited on
Commit
f782dc8
·
verified ·
1 Parent(s): 2f068aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -259
app.py CHANGED
@@ -1,205 +1,9 @@
1
- # import os
2
- # import io
3
- # import base64
4
- # import tempfile
5
- # import threading
6
- # from PIL import Image, ImageDraw, ImageFont
7
- # import numpy as np
8
- # from flask import Flask, request, jsonify, send_from_directory
9
- # import requests
10
-
11
- # # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
12
- # os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
-
14
- # # --- model import (ensure rfdetr package is available in requirements) ---
15
- # try:
16
- # from rfdetr import RFDETRSegPreview
17
- # except Exception as e:
18
- # raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
19
-
20
- # app = Flask(__name__, static_folder="static", static_url_path="/")
21
-
22
- # # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
23
- # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
24
- # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
25
-
26
- # MODEL_LOCK = threading.Lock()
27
- # MODEL = None
28
-
29
- # def download_file(url: str, dst: str):
30
- # if os.path.exists(dst):
31
- # return dst
32
- # print(f"[INFO] Downloading weights from {url} ...")
33
- # r = requests.get(url, stream=True, timeout=60)
34
- # r.raise_for_status()
35
- # with open(dst, "wb") as fh:
36
- # for chunk in r.iter_content(chunk_size=8192):
37
- # if chunk:
38
- # fh.write(chunk)
39
- # print("[INFO] Download complete.")
40
- # return dst
41
-
42
- # def init_model():
43
- # global MODEL
44
- # with MODEL_LOCK:
45
- # if MODEL is None:
46
- # # Ensure model checkpoint
47
- # try:
48
- # download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
49
- # except Exception as e:
50
- # print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
51
- # # continue; model may fallback to default weights
52
- # print("[INFO] Loading RF-DETR model (CPU mode)...")
53
- # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
54
- # try:
55
- # MODEL.optimize_for_inference()
56
- # except Exception:
57
- # # optimization may fail on CPU or if not implemented; ignore
58
- # pass
59
- # print("[INFO] Model ready.")
60
- # return MODEL
61
-
62
- # @app.route("/")
63
- # def index():
64
- # return send_from_directory("static", "index.html")
65
-
66
- # def decode_data_url(data_url: str) -> Image.Image:
67
- # if data_url.startswith("data:"):
68
- # header, b64 = data_url.split(",", 1)
69
- # data = base64.b64decode(b64)
70
- # return Image.open(io.BytesIO(data)).convert("RGB")
71
- # else:
72
- # # assume plain base64 or path
73
- # data = base64.b64decode(data_url)
74
- # return Image.open(io.BytesIO(data)).convert("RGB")
75
-
76
- # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
77
- # buf = io.BytesIO()
78
- # pil_img.save(buf, format=fmt)
79
- # b = base64.b64encode(buf.getvalue()).decode("ascii")
80
- # return f"data:image/{fmt.lower()};base64,{b}"
81
-
82
- # def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45):
83
- # """
84
- # masks: either list of HxW bool arrays or numpy array (N,H,W)
85
- # confidences: list of floats
86
- # Returns annotated PIL image and list of kept confidences and count.
87
- # """
88
- # base = pil_img.convert("RGBA")
89
- # W, H = base.size
90
-
91
- # # Normalize masks to N,H,W
92
- # if masks is None:
93
- # return base, []
94
-
95
- # if isinstance(masks, list):
96
- # masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
97
- # else:
98
- # masks_arr = np.asarray(masks)
99
- # # masks might be (H,W,N) -> transpose
100
- # if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
101
- # masks_arr = masks_arr.transpose(2, 0, 1)
102
-
103
- # # create overlay
104
- # overlay = Image.new("RGBA", (W, H), (0,0,0,0))
105
- # draw = ImageDraw.Draw(overlay)
106
-
107
- # kept_confidences = []
108
- # for i in range(masks_arr.shape[0]):
109
- # conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
110
- # if conf < threshold:
111
- # continue
112
- # mask = masks_arr[i].astype(np.uint8) * 255
113
- # mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
114
- # # create colored mask image
115
- # color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
116
- # # put alpha using mask
117
- # color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
118
- # overlay = Image.alpha_composite(overlay, color_layer)
119
- # kept_confidences.append(conf)
120
-
121
- # # composite
122
- # annotated = Image.alpha_composite(base, overlay)
123
-
124
- # # add confidence text (show highest kept confidence)
125
- # if len(kept_confidences) > 0:
126
- # best = max(kept_confidences)
127
- # draw = ImageDraw.Draw(annotated)
128
- # try:
129
- # # Try to use a builtin font
130
- # font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
131
- # except Exception:
132
- # font = ImageFont.load_default()
133
- # text = f"Confidence: {best:.2f}"
134
- # # draw background box for text
135
- # tw, th = draw.textsize(text, font=font)
136
- # pad = 8
137
- # draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
138
- # draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
139
- # return annotated.convert("RGB"), kept_confidences
140
-
141
- # @app.route("/predict", methods=["POST"])
142
- # def predict():
143
- # payload = request.get_json(force=True)
144
- # if not payload or "image" not in payload:
145
- # return jsonify({"error": "Missing image"}), 400
146
- # conf = float(payload.get("conf", 0.25))
147
-
148
- # # ensure model ready
149
- # model = init_model()
150
-
151
- # # decode image
152
- # try:
153
- # pil = decode_data_url(payload["image"])
154
- # except Exception as e:
155
- # return jsonify({"error": f"Invalid image: {e}"}), 400
156
-
157
- # # perform prediction (model.predict expects PIL image)
158
- # try:
159
- # detections = model.predict(pil, threshold=0.0) # we filter using conf manually
160
- # except Exception as e:
161
- # return jsonify({"error": f"Inference failure: {e}"}), 500
162
-
163
- # # extract masks and confidences
164
- # masks = getattr(detections, "masks", None)
165
- # confidences = []
166
- # # attempt to read per-instance confidence
167
- # try:
168
- # confidences = [float(x) for x in getattr(detections, "confidence", [])]
169
- # except Exception:
170
- # # fallback: attempt attribute 'scores' or 'scores_' or generate ones
171
- # confidences = []
172
- # try:
173
- # confidences = [float(x) for x in getattr(detections, "scores", [])]
174
- # except Exception:
175
- # confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
176
-
177
- # # overlay mask with pink-red color
178
- # mask_color = (255, 77, 166) # pinkish
179
- # annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
180
-
181
- # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
182
- # return jsonify({
183
- # "annotated": data_url,
184
- # "confidences": kept_conf,
185
- # "count": len(kept_conf)
186
- # })
187
-
188
- # if __name__ == "__main__":
189
- # # warm up model on startup (non-blocking)
190
- # try:
191
- # init_model()
192
- # except Exception as e:
193
- # print("Model init warning:", e)
194
- # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
-
196
-
197
-
198
  import os
199
  import io
200
  import base64
201
  import threading
202
  import traceback
 
203
  from typing import Optional
204
 
205
  from flask import Flask, request, jsonify, send_from_directory
@@ -211,13 +15,21 @@ import torch
211
  # Set environment variables for CPU-only operation
212
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
213
  os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
 
214
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
215
  os.environ.setdefault("OMP_NUM_THREADS", "4")
216
  os.environ.setdefault("MKL_NUM_THREADS", "4")
217
  os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
218
 
 
 
 
 
219
  # Limit torch threads
220
- torch.set_num_threads(4)
 
 
 
221
 
222
  import supervision as sv
223
  from rfdetr import RFDETRSegPreview
@@ -238,14 +50,18 @@ def download_file(url: str, dst: str, chunk_size: int = 8192):
238
  print(f"[INFO] Checkpoint already exists at {dst}")
239
  return dst
240
  print(f"[INFO] Downloading weights from {url} -> {dst}")
241
- r = requests.get(url, stream=True, timeout=120)
242
- r.raise_for_status()
243
- with open(dst, "wb") as fh:
244
- for chunk in r.iter_content(chunk_size=chunk_size):
245
- if chunk:
246
- fh.write(chunk)
247
- print("[INFO] Download complete.")
248
- return dst
 
 
 
 
249
 
250
 
251
  def init_model():
@@ -253,28 +69,31 @@ def init_model():
253
  global MODEL
254
  with MODEL_LOCK:
255
  if MODEL is not None:
 
256
  return MODEL
257
  try:
258
  # Ensure checkpoint present
259
- try:
 
260
  download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
261
- except Exception as e:
262
- print("[WARN] Failed to download checkpoint:", e)
263
- if not os.path.exists(CHECKPOINT_PATH):
264
- raise
265
 
266
  print("[INFO] Loading RF-DETR model (CPU mode)...")
267
  MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
268
 
269
  # Try to optimize for inference
270
  try:
 
271
  MODEL.optimize_for_inference()
 
272
  except Exception as e:
273
- print("[WARN] optimize_for_inference() skipped/failed:", e)
274
 
275
- print("[INFO] Model ready.")
276
  return MODEL
277
- except Exception:
 
278
  traceback.print_exc()
279
  raise
280
 
@@ -295,7 +114,8 @@ def decode_data_url(data_url: str) -> Image.Image:
295
  def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
296
  """Encode PIL Image to data URL"""
297
  buf = io.BytesIO()
298
- pil_img.save(buf, format=fmt)
 
299
  return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
300
 
301
 
@@ -304,38 +124,53 @@ def annotate_segmentation(image: Image.Image, detections: sv.Detections) -> Imag
304
  Annotate image with segmentation masks using supervision library.
305
  This matches the visualization from rfdetr_seg_infer.py script.
306
  """
307
- # Define color palette
308
- palette = sv.ColorPalette.from_hex([
309
- "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
310
- "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
311
- ])
312
-
313
- # Calculate optimal text scale based on image resolution
314
- text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
315
-
316
- # Create annotators
317
- mask_annotator = sv.MaskAnnotator(color=palette)
318
- polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
319
- label_annotator = sv.LabelAnnotator(
320
- color=palette,
321
- text_color=sv.Color.BLACK,
322
- text_scale=text_scale,
323
- text_position=sv.Position.CENTER_OF_MASS
324
- )
325
-
326
- # Create labels with class IDs and confidence scores
327
- labels = [
328
- f"Tulsi {float(conf):.2f}"
329
- for conf in detections.confidence
330
- ]
331
-
332
- # Apply annotations
333
- out = image.copy()
334
- out = mask_annotator.annotate(out, detections)
335
- out = polygon_annotator.annotate(out, detections)
336
- out = label_annotator.annotate(out, detections, labels)
337
-
338
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
 
341
  @app.route("/", methods=["GET"])
@@ -344,7 +179,18 @@ def index():
344
  index_path = os.path.join(app.static_folder or "static", "index.html")
345
  if os.path.exists(index_path):
346
  return send_from_directory(app.static_folder, "index.html")
347
- return jsonify({"message": "RF-DETR Segmentation API is running."})
 
 
 
 
 
 
 
 
 
 
 
348
 
349
 
350
  @app.route("/predict", methods=["POST"])
@@ -356,10 +202,16 @@ def predict():
356
  Returns JSON:
357
  {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
358
  """
 
 
359
  try:
 
360
  model = init_model()
 
361
  except Exception as e:
362
- return jsonify({"error": f"Model initialization failed: {e}"}), 500
 
 
363
 
364
  # Parse input
365
  img: Optional[Image.Image] = None
@@ -368,10 +220,13 @@ def predict():
368
  # Check if file uploaded
369
  if "file" in request.files:
370
  file = request.files["file"]
 
371
  try:
372
  img = Image.open(file.stream).convert("RGB")
373
  except Exception as e:
374
- return jsonify({"error": f"Invalid uploaded image: {e}"}), 400
 
 
375
  conf_threshold = float(request.form.get("conf", conf_threshold))
376
  else:
377
  # Try JSON payload
@@ -379,31 +234,37 @@ def predict():
379
  if not payload or "image" not in payload:
380
  return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
381
  try:
 
382
  img = decode_data_url(payload["image"])
383
  except Exception as e:
384
- return jsonify({"error": f"Invalid image data: {e}"}), 400
 
 
385
  conf_threshold = float(payload.get("conf", conf_threshold))
386
 
 
 
387
  # Optionally downscale large images to reduce memory usage
388
  MAX_SIZE = 1024
389
  if max(img.size) > MAX_SIZE:
390
  w, h = img.size
391
  scale = MAX_SIZE / float(max(w, h))
392
  new_w, new_h = int(round(w * scale)), int(round(h * scale))
 
393
  img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
394
- print(f"[INFO] Resized image to {new_w}x{new_h}")
395
 
396
  # Run inference with no_grad for memory efficiency
397
  try:
 
398
  with torch.no_grad():
399
  detections = model.predict(img, threshold=conf_threshold)
400
 
401
- print(f"[INFO] Detected {len(detections)} objects")
402
 
403
  # Check if detections exist
404
- if len(detections) == 0:
405
  print("[INFO] No detections above threshold")
406
- # Return original image with message
407
  data_url = encode_pil_to_dataurl(img, fmt="PNG")
408
  return jsonify({
409
  "annotated": data_url,
@@ -411,15 +272,33 @@ def predict():
411
  "count": 0
412
  })
413
 
 
 
 
 
 
 
 
 
 
414
  # Annotate image using supervision library
 
415
  annotated_pil = annotate_segmentation(img, detections)
416
 
417
  # Extract confidence scores
418
  confidences = [float(conf) for conf in detections.confidence]
 
419
 
420
  # Encode to data URL
 
421
  data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
422
 
 
 
 
 
 
 
423
  return jsonify({
424
  "annotated": data_url,
425
  "confidences": confidences,
@@ -427,19 +306,25 @@ def predict():
427
  })
428
 
429
  except Exception as e:
 
 
430
  traceback.print_exc()
431
- return jsonify({"error": f"Inference failed: {e}"}), 500
432
 
433
 
434
  if __name__ == "__main__":
 
 
 
 
435
  # Warm model in background thread
436
  def warm():
437
  try:
438
- print("[INFO] Starting model warmup...")
439
  init_model()
440
- print("[INFO] Model warmup complete")
441
  except Exception as e:
442
- print(f"[ERROR] Model warmup failed: {e}")
443
  traceback.print_exc()
444
 
445
  threading.Thread(target=warm, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import io
3
  import base64
4
  import threading
5
  import traceback
6
+ import gc
7
  from typing import Optional
8
 
9
  from flask import Flask, request, jsonify, send_from_directory
 
15
  # Set environment variables for CPU-only operation
16
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
17
  os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
18
+ os.environ.setdefault("FONTCONFIG_FILE", "/etc/fonts/fonts.conf")
19
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
20
  os.environ.setdefault("OMP_NUM_THREADS", "4")
21
  os.environ.setdefault("MKL_NUM_THREADS", "4")
22
  os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
23
 
24
+ # Create writable fontconfig cache
25
+ os.makedirs("/tmp/.fontconfig", exist_ok=True)
26
+ os.makedirs("/tmp/.matplotlib", exist_ok=True)
27
+
28
  # Limit torch threads
29
+ try:
30
+ torch.set_num_threads(4)
31
+ except Exception:
32
+ pass
33
 
34
  import supervision as sv
35
  from rfdetr import RFDETRSegPreview
 
50
  print(f"[INFO] Checkpoint already exists at {dst}")
51
  return dst
52
  print(f"[INFO] Downloading weights from {url} -> {dst}")
53
+ try:
54
+ r = requests.get(url, stream=True, timeout=180)
55
+ r.raise_for_status()
56
+ with open(dst, "wb") as fh:
57
+ for chunk in r.iter_content(chunk_size=chunk_size):
58
+ if chunk:
59
+ fh.write(chunk)
60
+ print("[INFO] Download complete.")
61
+ return dst
62
+ except Exception as e:
63
+ print(f"[ERROR] Download failed: {e}")
64
+ raise
65
 
66
 
67
  def init_model():
 
69
  global MODEL
70
  with MODEL_LOCK:
71
  if MODEL is not None:
72
+ print("[INFO] Model already loaded, returning cached instance")
73
  return MODEL
74
  try:
75
  # Ensure checkpoint present
76
+ if not os.path.exists(CHECKPOINT_PATH):
77
+ print("[INFO] Checkpoint not found, downloading...")
78
  download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
79
+ else:
80
+ print(f"[INFO] Using existing checkpoint at {CHECKPOINT_PATH}")
 
 
81
 
82
  print("[INFO] Loading RF-DETR model (CPU mode)...")
83
  MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
84
 
85
  # Try to optimize for inference
86
  try:
87
+ print("[INFO] Optimizing model for inference...")
88
  MODEL.optimize_for_inference()
89
+ print("[INFO] Model optimization complete")
90
  except Exception as e:
91
+ print(f"[WARN] optimize_for_inference() skipped/failed: {e}")
92
 
93
+ print("[INFO] Model ready for inference")
94
  return MODEL
95
+ except Exception as e:
96
+ print(f"[ERROR] Model initialization failed: {e}")
97
  traceback.print_exc()
98
  raise
99
 
 
114
  def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
115
  """Encode PIL Image to data URL"""
116
  buf = io.BytesIO()
117
+ pil_img.save(buf, format=fmt, optimize=False)
118
+ buf.seek(0)
119
  return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
120
 
121
 
 
124
  Annotate image with segmentation masks using supervision library.
125
  This matches the visualization from rfdetr_seg_infer.py script.
126
  """
127
+ try:
128
+ # Define color palette
129
+ palette = sv.ColorPalette.from_hex([
130
+ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
131
+ "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
132
+ ])
133
+
134
+ # Calculate optimal text scale based on image resolution
135
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
136
+
137
+ print(f"[INFO] Creating annotators with text_scale={text_scale}")
138
+
139
+ # Create annotators
140
+ mask_annotator = sv.MaskAnnotator(color=palette)
141
+ polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
142
+ label_annotator = sv.LabelAnnotator(
143
+ color=palette,
144
+ text_color=sv.Color.BLACK,
145
+ text_scale=text_scale,
146
+ text_position=sv.Position.CENTER_OF_MASS
147
+ )
148
+
149
+ # Create labels with confidence scores
150
+ labels = [
151
+ f"Tulsi {float(conf):.2f}"
152
+ for conf in detections.confidence
153
+ ]
154
+
155
+ print(f"[INFO] Annotating {len(labels)} detections")
156
+
157
+ # Apply annotations step by step
158
+ out = image.copy()
159
+ print("[INFO] Applying mask annotation...")
160
+ out = mask_annotator.annotate(out, detections)
161
+ print("[INFO] Applying polygon annotation...")
162
+ out = polygon_annotator.annotate(out, detections)
163
+ print("[INFO] Applying label annotation...")
164
+ out = label_annotator.annotate(out, detections, labels)
165
+
166
+ print("[INFO] Annotation complete")
167
+ return out
168
+
169
+ except Exception as e:
170
+ print(f"[ERROR] Annotation failed: {e}")
171
+ traceback.print_exc()
172
+ # Return original image if annotation fails
173
+ return image
174
 
175
 
176
  @app.route("/", methods=["GET"])
 
179
  index_path = os.path.join(app.static_folder or "static", "index.html")
180
  if os.path.exists(index_path):
181
  return send_from_directory(app.static_folder, "index.html")
182
+ return jsonify({"message": "RF-DETR Segmentation API is running.", "status": "ready"})
183
+
184
+
185
+ @app.route("/health", methods=["GET"])
186
+ def health():
187
+ """Health check endpoint"""
188
+ model_loaded = MODEL is not None
189
+ return jsonify({
190
+ "status": "healthy",
191
+ "model_loaded": model_loaded,
192
+ "checkpoint_exists": os.path.exists(CHECKPOINT_PATH)
193
+ })
194
 
195
 
196
  @app.route("/predict", methods=["POST"])
 
202
  Returns JSON:
203
  {"annotated": "<data:image/png;base64,...>", "confidences": [..], "count": N}
204
  """
205
+ print("\n[INFO] ========== New prediction request ==========")
206
+
207
  try:
208
+ print("[INFO] Initializing model...")
209
  model = init_model()
210
+ print("[INFO] Model ready")
211
  except Exception as e:
212
+ error_msg = f"Model initialization failed: {e}"
213
+ print(f"[ERROR] {error_msg}")
214
+ return jsonify({"error": error_msg}), 500
215
 
216
  # Parse input
217
  img: Optional[Image.Image] = None
 
220
  # Check if file uploaded
221
  if "file" in request.files:
222
  file = request.files["file"]
223
+ print(f"[INFO] Processing uploaded file: {file.filename}")
224
  try:
225
  img = Image.open(file.stream).convert("RGB")
226
  except Exception as e:
227
+ error_msg = f"Invalid uploaded image: {e}"
228
+ print(f"[ERROR] {error_msg}")
229
+ return jsonify({"error": error_msg}), 400
230
  conf_threshold = float(request.form.get("conf", conf_threshold))
231
  else:
232
  # Try JSON payload
 
234
  if not payload or "image" not in payload:
235
  return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
236
  try:
237
+ print("[INFO] Decoding image from data URL...")
238
  img = decode_data_url(payload["image"])
239
  except Exception as e:
240
+ error_msg = f"Invalid image data: {e}"
241
+ print(f"[ERROR] {error_msg}")
242
+ return jsonify({"error": error_msg}), 400
243
  conf_threshold = float(payload.get("conf", conf_threshold))
244
 
245
+ print(f"[INFO] Image size: {img.size}, Confidence threshold: {conf_threshold}")
246
+
247
  # Optionally downscale large images to reduce memory usage
248
  MAX_SIZE = 1024
249
  if max(img.size) > MAX_SIZE:
250
  w, h = img.size
251
  scale = MAX_SIZE / float(max(w, h))
252
  new_w, new_h = int(round(w * scale)), int(round(h * scale))
253
+ print(f"[INFO] Resizing image from {w}x{h} to {new_w}x{new_h}")
254
  img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
 
255
 
256
  # Run inference with no_grad for memory efficiency
257
  try:
258
+ print("[INFO] Running inference...")
259
  with torch.no_grad():
260
  detections = model.predict(img, threshold=conf_threshold)
261
 
262
+ print(f"[INFO] Raw detections: {len(detections)} objects")
263
 
264
  # Check if detections exist
265
+ if len(detections) == 0 or not hasattr(detections, 'confidence') or len(detections.confidence) == 0:
266
  print("[INFO] No detections above threshold")
267
+ # Return original image
268
  data_url = encode_pil_to_dataurl(img, fmt="PNG")
269
  return jsonify({
270
  "annotated": data_url,
 
272
  "count": 0
273
  })
274
 
275
+ print(f"[INFO] Detections have {len(detections.confidence)} confidence scores")
276
+ print(f"[INFO] Confidence range: {min(detections.confidence):.3f} - {max(detections.confidence):.3f}")
277
+
278
+ # Check if masks exist
279
+ if hasattr(detections, 'masks') and detections.masks is not None:
280
+ print(f"[INFO] Masks present: shape={np.array(detections.masks).shape if hasattr(detections.masks, '__len__') else 'unknown'}")
281
+ else:
282
+ print("[WARN] No masks found in detections!")
283
+
284
  # Annotate image using supervision library
285
+ print("[INFO] Starting annotation...")
286
  annotated_pil = annotate_segmentation(img, detections)
287
 
288
  # Extract confidence scores
289
  confidences = [float(conf) for conf in detections.confidence]
290
+ print(f"[INFO] Final confidences: {confidences}")
291
 
292
  # Encode to data URL
293
+ print("[INFO] Encoding annotated image...")
294
  data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
295
 
296
+ # Clean up
297
+ del detections
298
+ gc.collect()
299
+
300
+ print(f"[INFO] ========== Prediction complete: {len(confidences)} leaves detected ==========\n")
301
+
302
  return jsonify({
303
  "annotated": data_url,
304
  "confidences": confidences,
 
306
  })
307
 
308
  except Exception as e:
309
+ error_msg = f"Inference failed: {e}"
310
+ print(f"[ERROR] {error_msg}")
311
  traceback.print_exc()
312
+ return jsonify({"error": error_msg}), 500
313
 
314
 
315
  if __name__ == "__main__":
316
+ print("\n" + "="*60)
317
+ print("Starting Tulsi Leaf Segmentation Server")
318
+ print("="*60 + "\n")
319
+
320
  # Warm model in background thread
321
  def warm():
322
  try:
323
+ print("[INFO] Starting model warmup in background...")
324
  init_model()
325
+ print("[INFO] Model warmup complete - ready for predictions")
326
  except Exception as e:
327
+ print(f"[ERROR] Model warmup failed: {e}")
328
  traceback.print_exc()
329
 
330
  threading.Thread(target=warm, daemon=True).start()