SAM Brain Tumor Segmentation Model

This model is a fine-tuned Segment Anything Model (SAM) for brain tumor segmentation from medical imaging data. It was trained using a simulated dataset of 2D slices derived from 3D NIfTI (.nii.gz) images and their corresponding segmentation masks.

Model Description

The original SAM model is a powerful general-purpose image segmentation model. This fine-tuned version specializes in identifying brain tumors, leveraging the prompt-based segmentation capabilities of SAM. The model is prompted with bounding boxes around the tumor regions (derived from ground truth masks during training) to generate precise segmentation masks.

Training Details

  • Base Model: facebook/sam-vit-base
  • Dataset: Simulated 2D axial slices from 3D NIfTI images, normalized to 0-1 range.
  • Image Preprocessing: Grayscale images were duplicated across 3 channels to match SAM's expected input. Bounding box prompts were generated from ground truth masks.
  • Loss Functions: Binary Cross-Entropy (BCE) Loss and Dice Loss.
  • Optimizer: AdamW with a learning rate of 1e-5.
  • Epochs: 5
  • Average Dice Score on Validation Set: 0.9756 (on simulated data)

Usage

To use this model for inference, you can load it with the transformers library and provide an image along with a bounding box prompt for the region of interest. The model will then predict a segmentation mask.

from transformers import SamModel, SamProcessor
from PIL import Image
import torch
import numpy as np

# Load the fine-tuned model and processor
processor = SamProcessor.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
model = SamModel.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Example: Create a dummy image (replace with your actual medical image)
# This should be a 2D grayscale image, then converted to 3 channels.
# For a real image, load it and ensure it's normalized 0-1 and uint8 or float.
image_size = 256 # Example size
dummy_image_data = np.random.rand(image_size, image_size) * 255
dummy_image = Image.fromarray(dummy_image_data.astype(np.uint8)).convert("RGB")

# Example: Define a bounding box for the tumor region (x_min, y_min, x_max, y_max)
# In a real scenario, this bounding box would be provided by an expert or a detection model.
input_boxes = [[100, 100, 200, 200]] # Example bounding box coordinates

# Preprocess the image and bounding box
inputs = processor(dummy_image, input_boxes=input_boxes, return_tensors="pt").to(device)

# Perform inference
with torch.no_grad():
    outputs = model(**inputs, multimask_output=False)

# Post-process the predicted mask
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

# The output `masks` is a list of dictionaries. Each dict contains 'segmentation'.
# For simplicity, let's take the first mask (assuming multimask_output=False)
predicted_mask = masks[0]['segmentation'].squeeze().numpy() # Shape (H, W)

print("Predicted mask shape:", predicted_mask.shape)
# You can visualize 'predicted_mask' using matplotlib or other image libraries.
# For example:
# import matplotlib.pyplot as plt
# plt.imshow(predicted_mask, cmap='gray')
# plt.title('Predicted Segmentation Mask')
# plt.show()

Inference Endpoint Configuration (Optional)

If you wish to deploy this model as an Inference Endpoint on Hugging Face, here's a sample configuration you might use in your README.md (or directly in the UI):

widget:
- src: "app.py"
  example_title: "Brain Tumor Segmentation Example"
  inputs:
  - filename: "image.png"
    image: https://huggingface.co/datasets/huggingface/sample-images/resolve/main/segmentation_image_input.png
    input_boxes: [[100, 100, 200, 200]]

--- # Optional section for specific endpoint settings

parameters:
  do_normalize: false # Assuming inputs are already normalized 0-1
  do_rescale: false   # Assuming inputs are already scaled correctly
  multimask_output: false # For single best mask output

# Example of specific hardware/software config for advanced users
# inference:
#   accelerator: cuda
#   container: pytorch_latest
#   hardware: gpu_small
#   task: image-segmentation

Note: The example image and input_boxes in the YAML configuration are placeholders. For a real medical image endpoint, you would provide a relevant example image and a bounding box corresponding to a tumor within that image.

Limitations

  • The model was fine-tuned on a simulated dataset. Its performance on real, diverse clinical data may vary and needs further rigorous validation.
  • The model relies on a bounding box prompt. Its accuracy is highly dependent on the quality and precision of the provided bounding box.
  • Currently, the model handles 2D slices. Adaptation for full 3D volume segmentation would require further development.

Future Work

  • Evaluate and fine-tune the model on large, real-world medical imaging datasets (e.g., BraTS, TCIA).
  • Explore methods for automatic bounding box generation for tumor regions.
  • Extend the model to handle 3D medical images directly.
  • Implement quantitative metrics (e.g., IoU, Hausdorff Distance) during evaluation with real data.
Downloads last month
17
Safetensors
Model size
93.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support