Model Description
This model is a deep learning-based MRI anatomy classification system built using a ResNet18 architecture and trained on medical imaging data of InfoBay.AI.
The training pipeline processes MRI images from multiple anatomical regions, applies preprocessing and normalization, and trains a convolutional neural network to classify scans into predefined categories.
This approach demonstrates how structured image data can be leveraged to build efficient medical imaging classification systems for research and automation tasks.
Training Pipeline
The complete pipeline used for training is as follows:
Raw MRI Images → Image Preprocessing → Normalization → Dataset Labeling → ResNet18 Training
- Data Source: MRI datasets (Brain, Cervical Spine, Whole Spine, Pelvic, Shoulder)
- Preprocessing: Grayscale conversion, resizing (224×224)
- Normalization: Pixel scaling to [-1, 1]
- Labeling: Manual anatomical classification
- Model Training: ResNet18 for multi-class classification
Key Insight
This model demonstrates that even a relatively small dataset of medical images can be used to train an effective deep learning classifier.
It validates the ability of structured MRI datasets to support:
- Medical image classification
- Automated dataset organization
- Preprocessing pipelines in radiology AI
- Computer vision applications in healthcare
Dataset Split
- Train/Test Split: 80% / 20%
- Split Strategy: Random sampling
- Number of Classes: 5
Training Hyperparameters
- Number of Epochs: 20
- Batch Size: 16
- Learning Rate: 1e-5
- Optimizer: Adam
- Loss Function: Cross-Entropy Loss
- Input Size: 224 × 224
Model Performance
The model demonstrates stable performance on internal validation data:
- Accuracy: ~70% (depends on data distribution)
Classification Labels
| Class ID | Label |
|---|---|
| 0 | Brain |
| 1 | Cervical Spine |
| 2 | Whole Spine |
| 3 | Pelvic |
| 4 | Right Shoulder |
Usage
Install Dependencies
pip install torch torchvision pillow numpy
Load and Run Model
import torch
import torchvision.models as models
import json
from PIL import Image
import numpy as np
# Load classes
with open("classes.json") as f:
classes = json.load(f)
# Recreate model
model = models.resnet18(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(classes))
# Load weights
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# Preprocess image
def preprocess(image_path):
img = Image.open(image_path).convert("L")
img = img.resize((224, 224))
img = np.array(img)
img = (img / 255.0 - 0.5) / 0.5
img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float()
return img
# Predict
img = preprocess("test_image.png")
with torch.no_grad():
output = model(img)
pred = torch.argmax(output, 1).item()
print("Prediction:", classes[pred])
Considerations
This model is trained on MRI image dataset from the InfoBay.AI is provided for research and evaluation purposes. The dataset contains a larger collection of high-quality MRI Images. For access to the full dataset or enterprise licensing inquiries, please visit our website InfoBay.AI or contact us directly.
Ph: (91) 8303174762
Email: vipul@infobay.ai
- Downloads last month
- 148
