AItool commited on
Commit
20cd6c7
·
verified ·
1 Parent(s): b6780bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -14
app.py CHANGED
@@ -33,12 +33,37 @@ def soft_alpha(mask_uint8, blur_radius=10):
33
  blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
34
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
 
 
 
 
37
  img_rgb = np.array(image.convert("RGB"))
38
  predictor.set_image(img_rgb)
39
 
40
- input_point = np.array([[evt.index[0], evt.index[1]]])
41
- input_label = np.array([1]) # foreground
 
 
42
 
43
  masks, scores, _ = predictor.predict(
44
  point_coords=input_point,
@@ -46,12 +71,28 @@ def isolate_with_click(image: Image.Image, evt: gr.SelectData):
46
  multimask_output=True
47
  )
48
 
49
- best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
 
 
 
 
 
 
 
 
 
 
 
50
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
51
 
 
52
  ys, xs = np.where(best_mask == 255)
53
  if len(xs) == 0 or len(ys) == 0:
54
- return None, None
 
 
 
 
55
  x0, x1 = xs.min(), xs.max()
56
  y0, y1 = ys.min(), ys.max()
57
  pad = int(max(img_rgb.shape[:2]) * 0.02)
@@ -61,30 +102,28 @@ def isolate_with_click(image: Image.Image, evt: gr.SelectData):
61
  fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
62
  fg_alpha = alpha[y0:y1+1, x0:x1+1]
63
 
64
- rgba = np.dstack((fg_rgb, (fg_alpha * 255).astype(np.uint8)))
 
65
  cutout = Image.fromarray(rgba)
66
 
67
- # Build overlay preview (purple tint on mask)
68
- overlay = img_rgb.copy()
69
- tint = np.array([180, 0, 180], dtype=np.uint8) # purple
70
- sel = best_mask == 255
71
- overlay[sel] = (0.6 * overlay[sel] + 0.4 * tint).astype(np.uint8)
72
- overlay_img = Image.fromarray(overlay)
73
 
74
  return cutout, overlay_img
75
 
76
  # --- Gradio UI ---
77
  with gr.Blocks() as demo:
78
- gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
79
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
80
  out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
81
  out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
82
 
 
83
  inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
84
 
85
- # Demo example at the bottom
86
  gr.Examples(
87
- examples=["demo.png"], # make sure demo.png is in your repo
88
  inputs=inp,
89
  label="Try with demo image"
90
  )
 
33
  blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
34
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
35
 
36
+ def make_overlay(img_rgb: np.ndarray, mask_uint8: np.ndarray) -> Image.Image:
37
+ """Return a purple-tinted overlay on the selected region."""
38
+ # Ensure 0/255 mask and uint8 image
39
+ mask = (mask_uint8 > 127).astype(np.uint8)
40
+ overlay = img_rgb.copy()
41
+
42
+ # Create a purple layer
43
+ purple = np.zeros_like(img_rgb, dtype=np.uint8)
44
+ purple[..., 0] = 180 # R
45
+ purple[..., 1] = 0 # G
46
+ purple[..., 2] = 180 # B
47
+
48
+ # Blend only on selected pixels
49
+ sel = mask.astype(bool)
50
+ blended = cv2.addWeighted(overlay[sel], 0.6, purple[sel], 0.4, 0)
51
+ overlay[sel] = blended
52
+
53
+ return Image.fromarray(overlay)
54
+
55
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
56
+ # Guard: if no image or no click event, return original image and a subtle info overlay
57
+ if image is None or evt is None:
58
+ return None, None
59
+
60
  img_rgb = np.array(image.convert("RGB"))
61
  predictor.set_image(img_rgb)
62
 
63
+ # SAM expects input points as numpy array [[x,y]]
64
+ x, y = evt.index # (x, y) from Gradio click
65
+ input_point = np.array([[x, y]], dtype=np.float32)
66
+ input_label = np.array([1], dtype=np.int32) # 1 = foreground
67
 
68
  masks, scores, _ = predictor.predict(
69
  point_coords=input_point,
 
71
  multimask_output=True
72
  )
73
 
74
+ # If SAM didn't return masks, show original plus a faint marker
75
+ if masks is None or len(masks) == 0:
76
+ # Create a simple marker overlay to indicate click
77
+ overlay = img_rgb.copy()
78
+ cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
79
+ return None, Image.fromarray(overlay)
80
+
81
+ # Pick the highest score mask
82
+ best_idx = int(np.argmax(scores))
83
+ best_mask = masks[best_idx].astype(np.uint8) * 255
84
+
85
+ # Soft alpha for RGBA cutout
86
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
87
 
88
+ # Crop to bounding box (with small pad)
89
  ys, xs = np.where(best_mask == 255)
90
  if len(xs) == 0 or len(ys) == 0:
91
+ # If mask is empty, still return overlay with click marker
92
+ overlay = img_rgb.copy()
93
+ cv2.circle(overlay, (int(x), int(y)), 12, (180, 0, 180), thickness=3)
94
+ return None, Image.fromarray(overlay)
95
+
96
  x0, x1 = xs.min(), xs.max()
97
  y0, y1 = ys.min(), ys.max()
98
  pad = int(max(img_rgb.shape[:2]) * 0.02)
 
102
  fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
103
  fg_alpha = alpha[y0:y1+1, x0:x1+1]
104
 
105
+ # Compose RGBA correctly once
106
+ rgba = np.dstack((fg_rgb, (fg_alpha * 255.0).astype(np.uint8)))
107
  cutout = Image.fromarray(rgba)
108
 
109
+ # Build and return the purple overlay on the original image
110
+ overlay_img = make_overlay(img_rgb, best_mask)
 
 
 
 
111
 
112
  return cutout, overlay_img
113
 
114
  # --- Gradio UI ---
115
  with gr.Blocks() as demo:
116
+ gr.Markdown("### SAM Object Isolation\nUpload an image (or pick the demo), then click on the object to isolate it. The right panel shows a purple overlay of the mask.")
117
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
118
  out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
119
  out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
120
 
121
+ # Click-to-segment
122
  inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
123
 
124
+ # Demo example at the bottom that populates the upload image
125
  gr.Examples(
126
+ examples=["demo.png"], # ensure demo.png is in your repo
127
  inputs=inp,
128
  label="Try with demo image"
129
  )