AItool commited on
Commit
b6780bc
·
verified ·
1 Parent(s): 986380a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -34,33 +34,24 @@ def soft_alpha(mask_uint8, blur_radius=10):
34
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
35
 
36
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
37
- """
38
- image: PIL image uploaded
39
- evt: click coordinates from Gradio (x,y)
40
- """
41
  img_rgb = np.array(image.convert("RGB"))
42
  predictor.set_image(img_rgb)
43
 
44
- # SAM expects input points as numpy array [[x,y]]
45
  input_point = np.array([[evt.index[0], evt.index[1]]])
46
- input_label = np.array([1]) # 1 = foreground
47
 
48
- masks, scores, logits = predictor.predict(
49
  point_coords=input_point,
50
  point_labels=input_label,
51
  multimask_output=True
52
  )
53
 
54
- # Pick the highest score mask
55
  best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
56
-
57
- # Soft alpha
58
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
59
 
60
- # Crop to bounding box
61
  ys, xs = np.where(best_mask == 255)
62
  if len(xs) == 0 or len(ys) == 0:
63
- return None
64
  x0, x1 = xs.min(), xs.max()
65
  y0, y1 = ys.min(), ys.max()
66
  pad = int(max(img_rgb.shape[:2]) * 0.02)
@@ -71,16 +62,27 @@ def isolate_with_click(image: Image.Image, evt: gr.SelectData):
71
  fg_alpha = alpha[y0:y1+1, x0:x1+1]
72
 
73
  rgba = np.dstack((fg_rgb, (fg_alpha * 255).astype(np.uint8)))
74
- return Image.fromarray(rgba)
 
 
 
 
 
 
 
 
 
75
 
76
  # --- Gradio UI ---
77
  with gr.Blocks() as demo:
78
  gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
79
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
80
- out = gr.Image(type="pil", label="Isolated cutout (RGBA)")
81
- inp.select(isolate_with_click, inputs=[inp], outputs=out)
82
 
83
- # Add demo example at the bottom
 
 
84
  gr.Examples(
85
  examples=["demo.png"], # make sure demo.png is in your repo
86
  inputs=inp,
@@ -88,3 +90,4 @@ with gr.Blocks() as demo:
88
  )
89
 
90
  demo.launch(share=True)
 
 
34
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
35
 
36
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
 
 
 
 
37
  img_rgb = np.array(image.convert("RGB"))
38
  predictor.set_image(img_rgb)
39
 
 
40
  input_point = np.array([[evt.index[0], evt.index[1]]])
41
+ input_label = np.array([1]) # foreground
42
 
43
+ masks, scores, _ = predictor.predict(
44
  point_coords=input_point,
45
  point_labels=input_label,
46
  multimask_output=True
47
  )
48
 
 
49
  best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
 
 
50
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
51
 
 
52
  ys, xs = np.where(best_mask == 255)
53
  if len(xs) == 0 or len(ys) == 0:
54
+ return None, None
55
  x0, x1 = xs.min(), xs.max()
56
  y0, y1 = ys.min(), ys.max()
57
  pad = int(max(img_rgb.shape[:2]) * 0.02)
 
62
  fg_alpha = alpha[y0:y1+1, x0:x1+1]
63
 
64
  rgba = np.dstack((fg_rgb, (fg_alpha * 255).astype(np.uint8)))
65
+ cutout = Image.fromarray(rgba)
66
+
67
+ # Build overlay preview (purple tint on mask)
68
+ overlay = img_rgb.copy()
69
+ tint = np.array([180, 0, 180], dtype=np.uint8) # purple
70
+ sel = best_mask == 255
71
+ overlay[sel] = (0.6 * overlay[sel] + 0.4 * tint).astype(np.uint8)
72
+ overlay_img = Image.fromarray(overlay)
73
+
74
+ return cutout, overlay_img
75
 
76
  # --- Gradio UI ---
77
  with gr.Blocks() as demo:
78
  gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
79
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
80
+ out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
81
+ out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
82
 
83
+ inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
84
+
85
+ # Demo example at the bottom
86
  gr.Examples(
87
  examples=["demo.png"], # make sure demo.png is in your repo
88
  inputs=inp,
 
90
  )
91
 
92
  demo.launch(share=True)
93
+