AItool commited on
Commit
986380a
·
verified ·
1 Parent(s): 982580e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -18,11 +18,13 @@ import torch
18
  from segment_anything import sam_model_registry, SamPredictor
19
 
20
  # --- CONFIG ---
 
21
  SAM_MODEL_TYPE = "vit_h"
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  BLUR_RADIUS = 10
24
  # --------------
25
 
 
26
  sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
27
  sam.to(device=DEVICE)
28
  predictor = SamPredictor(sam)
@@ -32,21 +34,30 @@ def soft_alpha(mask_uint8, blur_radius=10):
32
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
33
 
34
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
 
 
 
 
35
  img_rgb = np.array(image.convert("RGB"))
36
  predictor.set_image(img_rgb)
37
 
 
38
  input_point = np.array([[evt.index[0], evt.index[1]]])
39
- input_label = np.array([1])
40
 
41
- masks, scores, _ = predictor.predict(
42
  point_coords=input_point,
43
  point_labels=input_label,
44
  multimask_output=True
45
  )
46
 
 
47
  best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
 
 
48
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
49
 
 
50
  ys, xs = np.where(best_mask == 255)
51
  if len(xs) == 0 or len(ys) == 0:
52
  return None
@@ -65,17 +76,15 @@ def isolate_with_click(image: Image.Image, evt: gr.SelectData):
65
  # --- Gradio UI ---
66
  with gr.Blocks() as demo:
67
  gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
68
-
69
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
70
  out = gr.Image(type="pil", label="Isolated cutout (RGBA)")
71
-
72
  inp.select(isolate_with_click, inputs=[inp], outputs=out)
73
 
74
- # Examples section at the bottom
75
  gr.Examples(
76
- examples=["demo.png"],
77
  inputs=inp,
78
  label="Try with demo image"
79
  )
80
 
81
- demo.launch()
 
18
  from segment_anything import sam_model_registry, SamPredictor
19
 
20
  # --- CONFIG ---
21
+ SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
22
  SAM_MODEL_TYPE = "vit_h"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  BLUR_RADIUS = 10
25
  # --------------
26
 
27
+ # Load SAM once
28
  sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
29
  sam.to(device=DEVICE)
30
  predictor = SamPredictor(sam)
 
34
  return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
35
 
36
  def isolate_with_click(image: Image.Image, evt: gr.SelectData):
37
+ """
38
+ image: PIL image uploaded
39
+ evt: click coordinates from Gradio (x,y)
40
+ """
41
  img_rgb = np.array(image.convert("RGB"))
42
  predictor.set_image(img_rgb)
43
 
44
+ # SAM expects input points as numpy array [[x,y]]
45
  input_point = np.array([[evt.index[0], evt.index[1]]])
46
+ input_label = np.array([1]) # 1 = foreground
47
 
48
+ masks, scores, logits = predictor.predict(
49
  point_coords=input_point,
50
  point_labels=input_label,
51
  multimask_output=True
52
  )
53
 
54
+ # Pick the highest score mask
55
  best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
56
+
57
+ # Soft alpha
58
  alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
59
 
60
+ # Crop to bounding box
61
  ys, xs = np.where(best_mask == 255)
62
  if len(xs) == 0 or len(ys) == 0:
63
  return None
 
76
  # --- Gradio UI ---
77
  with gr.Blocks() as demo:
78
  gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
 
79
  inp = gr.Image(type="pil", label="Upload image", interactive=True)
80
  out = gr.Image(type="pil", label="Isolated cutout (RGBA)")
 
81
  inp.select(isolate_with_click, inputs=[inp], outputs=out)
82
 
83
+ # Add demo example at the bottom
84
  gr.Examples(
85
+ examples=["demo.png"], # make sure demo.png is in your repo
86
  inputs=inp,
87
  label="Try with demo image"
88
  )
89
 
90
+ demo.launch(share=True)