Image Classification
Transformers
Safetensors
total_classifier
feature-extraction
radiology
ct
organ
classification
custom_code
Instructions to use ianpan/total-classifier with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ianpan/total-classifier with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-classification", model="ianpan/total-classifier", trust_remote_code=True) pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ianpan/total-classifier", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import cv2 | |
| import glob | |
| import numpy as np | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from transformers import PreTrainedModel | |
| from timm import create_model | |
| from .configuration import TotalClassifierConfig | |
| from .label2index import label2index | |
| _PYDICOM_AVAILABLE = False | |
| try: | |
| from pydicom import dcmread | |
| _PYDICOM_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| pass | |
| _PANDAS_AVAILABLE = False | |
| try: | |
| import pandas as pd | |
| _PANDAS_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| pass | |
| class RNNHead(nn.Module): | |
| def __init__( | |
| self, | |
| rnn_type: str, | |
| rnn_num_layers: int, | |
| rnn_dropout: float, | |
| feature_dim: int, | |
| linear_dropout: float, | |
| num_classes: int, | |
| ): | |
| super().__init__() | |
| self.rnn = getattr(nn, rnn_type)( | |
| input_size=feature_dim, | |
| hidden_size=feature_dim // 2, | |
| num_layers=rnn_num_layers, | |
| dropout=rnn_dropout, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) | |
| self.dropout = nn.Dropout(linear_dropout) | |
| self.linear = nn.Linear(feature_dim, num_classes) | |
| def convert_seq_and_mask_to_packed_sequence( | |
| seq: torch.Tensor, mask: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| assert seq.shape[0] == mask.shape[0] | |
| lengths = mask.sum(1) | |
| seq = nn.utils.rnn.pack_padded_sequence( | |
| seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False | |
| ) | |
| return seq | |
| def forward( | |
| self, x: torch.Tensor, mask: torch.Tensor | None = None | |
| ) -> torch.Tensor: | |
| skip = x | |
| if mask is not None: | |
| # convert to PackedSequence | |
| L = x.shape[1] | |
| x = self.convert_seq_and_mask_to_packed_sequence(x, mask) | |
| x, _ = self.rnn(x) | |
| if mask is not None: | |
| # convert back to tensor | |
| x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0] | |
| x = x + skip | |
| return self.linear(self.dropout(x)) | |
| class TotalClassifierModel(PreTrainedModel): | |
| config_class = TotalClassifierConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.image_size = config.image_size | |
| self.backbone = create_model( | |
| model_name=config.backbone, | |
| pretrained=False, | |
| num_classes=0, | |
| global_pool="", | |
| features_only=True, | |
| in_chans=config.in_chans, | |
| ) | |
| self.cnn_dropout = nn.Dropout(p=config.cnn_dropout) | |
| self.head = RNNHead( | |
| rnn_type=config.rnn_type, | |
| rnn_num_layers=config.rnn_num_layers, | |
| rnn_dropout=config.rnn_dropout, | |
| feature_dim=config.feature_dim, | |
| linear_dropout=config.linear_dropout, | |
| num_classes=config.num_classes, | |
| ) | |
| self.label2index = label2index | |
| self.index2label = {v: k for k, v in self.label2index.items()} | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| return_logits: bool = False, | |
| return_as_dict: bool = False, | |
| return_as_list: bool = False, | |
| return_as_df: bool = False, | |
| threshold: float = 0.5, # only used for return_as_list=True | |
| ) -> torch.Tensor: | |
| if return_as_df: | |
| assert ( | |
| _PANDAS_AVAILABLE | |
| ), "`return_as_df=True` requires pandas to be installed" | |
| # x.shape = (b, n, c, h, w) | |
| b, n, c, h, w = x.shape | |
| # x = rearrange(x, "b n c h w -> (b n) c h w") | |
| x = x.reshape(b * n, c, h, w) | |
| x = self.normalize(x) | |
| # avg pooling | |
| features = self.backbone(x) | |
| # take last feature map | |
| features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1) | |
| features = self.cnn_dropout(features) | |
| # features = rearrange(features, "(b n) d -> b n d", b=b, n=n) | |
| features = features.reshape(b, n, -1) | |
| logits = self.head(features, mask=mask) | |
| if return_logits: | |
| # return raw logits | |
| return logits | |
| probas = logits.sigmoid() | |
| if return_as_dict or return_as_df: | |
| # list of dictionaries | |
| batch_list = [] | |
| for i in range(probas.shape[0]): | |
| dict_for_batch = {} | |
| probas_i = probas[i] | |
| for each_class in range(probas_i.shape[1]): | |
| dict_for_batch[self.index2label[each_class]] = probas_i[ | |
| :, each_class | |
| ] | |
| if return_as_df: | |
| batch_list.append( | |
| pd.DataFrame( | |
| {k: v.cpu().numpy() for k, v in dict_for_batch.items()} | |
| ) | |
| ) | |
| else: | |
| batch_list.append(dict_for_batch) | |
| return batch_list | |
| if return_as_list: | |
| # returns list of list of lists of strings | |
| # innermost list - list of strings for each organ present based on threshold | |
| # inner list - list of above for each slice | |
| # outer list - list of above for each batch element (studies) | |
| batch_list = [] | |
| # probas.shape = (batch_size, num_slices, num_classes) | |
| for i in range(probas.shape[0]): | |
| probas_i = probas[i] | |
| # probas_i.shape = (num_slices, num_classes) | |
| list_for_batch = [] | |
| for each_slice in range(probas_i.shape[0]): | |
| for each_class in range(probas_i.shape[1]): | |
| list_for_batch.append( | |
| [ | |
| self.index2label[each_class] | |
| for each_class in range(probas_i.shape[1]) | |
| if probas_i[each_slice, each_class] >= threshold | |
| ] | |
| ) | |
| batch_list.append(list_for_batch) | |
| return batch_list | |
| return probas | |
| def normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| # [0, 255] -> [-1, 1] | |
| mini, maxi = 0.0, 255.0 | |
| x = (x - mini) / (maxi - mini) | |
| x = (x - 0.5) * 2.0 | |
| return x | |
| def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]: | |
| # applying windowing to CT | |
| lower, upper = WL - WW // 2, WL + WW // 2 | |
| x = np.clip(x, lower, upper) | |
| x = (x - lower) / (upper - lower) | |
| return (x * 255.0).astype("uint8") | |
| def validate_windows_type(windows): | |
| assert isinstance(windows, tuple) or isinstance(windows, list) | |
| if isinstance(windows, tuple): | |
| assert len(windows) == 2 | |
| assert [isinstance(_, int) for _ in windows] | |
| elif isinstance(windows, list): | |
| assert all([isinstance(_, tuple) for _ in windows]) | |
| assert all([len(_) == 2 for _ in windows]) | |
| assert all([isinstance(__, int) for _ in windows for __ in _]) | |
| def determine_dicom_orientation(ds) -> int: | |
| iop = ds.ImageOrientationPatient | |
| # Calculate the direction cosine for the normal vector of the plane | |
| normal_vector = np.cross(iop[:3], iop[3:]) | |
| # Determine the plane based on the largest component of the normal vector | |
| abs_normal = np.abs(normal_vector) | |
| if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: | |
| return 0 # sagittal | |
| elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: | |
| return 1 # coronal | |
| else: | |
| return 2 # axial | |
| def load_image_from_dicom( | |
| self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None | |
| ) -> np.ndarray: | |
| # windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH) | |
| # or list of tuples if wishing to generate multi-channel image using | |
| # > 1 window | |
| if not _PYDICOM_AVAILABLE: | |
| raise Exception("`pydicom` is not installed") | |
| dicom = dcmread(path) | |
| array = dicom.pixel_array.astype("float32") | |
| m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept) | |
| array = array * m + b | |
| if windows is None: | |
| return array | |
| self.validate_windows_type(windows) | |
| if isinstance(windows, tuple): | |
| windows = [windows] | |
| arr_list = [] | |
| for WL, WW in windows: | |
| arr_list.append(self.window(array.copy(), WL, WW)) | |
| array = np.stack(arr_list, axis=-1) | |
| if array.shape[-1] == 1: | |
| array = np.squeeze(array, axis=-1) | |
| return array | |
| def is_valid_dicom( | |
| ds, | |
| fname: str = "", | |
| sort_by_instance_number: bool = False, | |
| exclude_invalid_dicoms: bool = False, | |
| ) -> bool: | |
| attributes = [ | |
| "pixel_array", | |
| "RescaleSlope", | |
| "RescaleIntercept", | |
| ] | |
| if sort_by_instance_number: | |
| attributes.append("InstanceNumber") | |
| else: | |
| attributes.append("ImagePositionPatient") | |
| attributes.append("ImageOrientationPatient") | |
| attributes_present = [hasattr(ds, attr) for attr in attributes] | |
| valid = all(attributes_present) | |
| if not valid and not exclude_invalid_dicoms: | |
| raise Exception( | |
| f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}" | |
| ) | |
| return valid | |
| def most_common_element(lst): | |
| return max(set(lst), key=lst.count) | |
| def center_crop_or_pad_borders(image, size): | |
| height, width = image.shape[:2] | |
| new_height, new_width = size | |
| if new_height < height: | |
| # crop top and bottom | |
| crop_top = (height - new_height) // 2 | |
| crop_bottom = height - new_height - crop_top | |
| image = image[crop_top:-crop_bottom] | |
| elif new_height > height: | |
| # pad top and bottom | |
| pad_top = (new_height - height) // 2 | |
| pad_bottom = new_height - height - pad_top | |
| image = np.pad( | |
| image, | |
| ((pad_top, pad_bottom), (0, 0)), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| if new_width < width: | |
| # crop left and right | |
| crop_left = (width - new_width) // 2 | |
| crop_right = width - new_width - crop_left | |
| image = image[:, crop_left:-crop_right] | |
| elif new_width > width: | |
| # pad left and right | |
| pad_left = (new_width - width) // 2 | |
| pad_right = new_width - width - pad_left | |
| image = np.pad( | |
| image, | |
| ((0, 0), (pad_left, pad_right)), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| return image | |
| def load_stack_from_dicom_folder( | |
| self, | |
| path: str, | |
| windows: tuple[int, int] | list[tuple[int, int]] | None = None, | |
| dicom_extension: str = ".dcm", | |
| sort_by_instance_number: bool = False, | |
| exclude_invalid_dicoms: bool = False, | |
| fix_unequal_shapes: str = "crop_pad", | |
| return_sorted_dicom_files: bool = False, | |
| ) -> np.ndarray | tuple[np.ndarray, list[str]]: | |
| if not _PYDICOM_AVAILABLE: | |
| raise Exception("`pydicom` is not installed") | |
| dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}")) | |
| if len(dicom_files) == 0: | |
| raise Exception( | |
| f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`" | |
| ) | |
| dicoms = [dcmread(f) for f in dicom_files] | |
| dicoms = [ | |
| (d, dicom_files[idx]) | |
| for idx, d in enumerate(dicoms) | |
| if self.is_valid_dicom( | |
| d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms | |
| ) | |
| ] | |
| # handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True | |
| # by only including valid DICOM filenames | |
| dicom_files = [_[1] for _ in dicoms] | |
| dicoms = [_[0] for _ in dicoms] | |
| slices = [dcm.pixel_array.astype("float32") for dcm in dicoms] | |
| shapes = np.stack([s.shape for s in slices], axis=0) | |
| if not np.all(shapes == shapes[0]): | |
| unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True) | |
| standard_shape = tuple(unique_shapes[np.argmax(counts)]) | |
| print( | |
| f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}" | |
| ) | |
| if fix_unequal_shapes == "crop_pad": | |
| slices = [ | |
| self.center_crop_or_pad_borders(s, standard_shape) | |
| if s.shape != standard_shape | |
| else s | |
| for s in slices | |
| ] | |
| elif fix_unequal_shapes == "resize": | |
| slices = [ | |
| cv2.resize(s, standard_shape) if s.shape != standard_shape else s | |
| for s in slices | |
| ] | |
| slices = np.stack(slices, axis=0) | |
| # find orientation | |
| orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms] | |
| # use most common | |
| orientation = self.most_common_element(orientation) | |
| # sort using ImagePositionPatient | |
| # orientation is index to use for sorting | |
| if sort_by_instance_number: | |
| positions = [float(d.InstanceNumber) for d in dicoms] | |
| else: | |
| positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms] | |
| indices = np.argsort(positions) | |
| slices = slices[indices] | |
| # rescale | |
| m, b = ( | |
| [float(d.RescaleSlope) for d in dicoms], | |
| [float(d.RescaleIntercept) for d in dicoms], | |
| ) | |
| m, b = self.most_common_element(m), self.most_common_element(b) | |
| slices = slices * m + b | |
| if windows is not None: | |
| self.validate_windows_type(windows) | |
| if isinstance(windows, tuple): | |
| windows = [windows] | |
| arr_list = [] | |
| for WL, WW in windows: | |
| arr_list.append(self.window(slices.copy(), WL, WW)) | |
| slices = np.stack(arr_list, axis=-1) | |
| if slices.shape[-1] == 1: | |
| slices = np.squeeze(slices, axis=-1) | |
| if return_sorted_dicom_files: | |
| return slices, [dicom_files[idx] for idx in indices] | |
| return slices | |
| def preprocess( | |
| self, | |
| x: np.ndarray, | |
| mode: str = "2d", | |
| torchify: bool = True, | |
| add_batch_dim: bool = False, | |
| device: str | torch.device | None = None, | |
| ) -> np.ndarray: | |
| if device is not None: | |
| assert torchify, "`torchify` must be `True` if specifying `device`" | |
| mode = mode.lower() | |
| if mode == "2d": | |
| x = cv2.resize(x, self.image_size) | |
| if x.ndim == 2: | |
| x = x[:, :, np.newaxis] | |
| elif mode == "3d": | |
| x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0) | |
| if x.ndim == 3: | |
| x = x[:, :, :, np.newaxis] | |
| if torchify: | |
| if x.ndim == 3: | |
| x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w") | |
| elif x.ndim == 4: | |
| x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w") | |
| if add_batch_dim: | |
| if torchify: | |
| x = x.unsqueeze(0) | |
| else: | |
| x = x[np.newaxis] | |
| if device is not None: | |
| x = x.to(device) | |
| return x | |
| def crop_single_plane( | |
| self, | |
| x: np.ndarray, | |
| device: str | torch.device, | |
| organ: str | list[str], | |
| threshold: float = 0.5, | |
| buffer: float | int = 0, | |
| speed_up: str | None = None, | |
| ) -> np.ndarray: | |
| num_slices = x.shape[0] | |
| if speed_up is not None: | |
| assert speed_up in ["fast", "faster", "fastest"] | |
| if speed_up == "fast": | |
| # 75% of slices | |
| reduce_num_slices = 3 * num_slices // 4 | |
| elif speed_up == "faster": | |
| # 50% of slices | |
| reduce_num_slices = num_slices // 2 | |
| elif speed_up == "fastest": | |
| # 33% of slices | |
| reduce_num_slices = num_slices // 3 | |
| indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int) | |
| x = x[indices] | |
| x = self.preprocess(x, mode="3d") | |
| x = torch.from_numpy(x) | |
| x = rearrange(x, "n h w c -> n c h w").float().to(device) | |
| x = rearrange(x, "n c h w -> 1 n c h w") | |
| if x.size(2) > 1: | |
| # if multi-channel, take mean | |
| x = x.mean(2, keepdim=True) | |
| organ_cls = self.forward(x)[0] | |
| if speed_up is not None: | |
| # organ_cls.shape = (num_slices, num_classes) | |
| organ_cls = ( | |
| F.interpolate( | |
| organ_cls.transpose(1, 0).unsqueeze(0), | |
| size=(num_slices,), | |
| mode="linear", | |
| ) | |
| .squeeze(0) | |
| .transpose(1, 0) | |
| ) | |
| assert organ_cls.shape[0] == num_slices | |
| slices = [] | |
| for each_organ in organ: | |
| slices.append( | |
| torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0] | |
| ) | |
| slices = torch.cat(slices) | |
| slice_min, slice_max = slices.min().item(), slices.max().item() | |
| if buffer > 0: | |
| if isinstance(buffer, float): | |
| # % buffer | |
| diff = slice_max - slice_min | |
| buf = int(buffer * diff) | |
| else: | |
| # absolute slice buffer | |
| buf = buffer | |
| slice_min = max(0, slice_min - buf) | |
| slice_max = min(num_slices - 1, slice_max + buf) | |
| return slice_min, slice_max | |
| def crop( | |
| self, | |
| x: np.ndarray, | |
| organ: str | list[str], | |
| crop_dims: int | list[int] = 0, | |
| device: str | torch.device | None = None, | |
| raw_hu: bool = False, | |
| threshold: float = 0.5, | |
| buffer: float | int = 0, | |
| speed_up: str | None = None, | |
| ) -> ( | |
| np.ndarray | |
| | tuple[np.ndarray, list[int]] | |
| | tuple[np.ndarray, list[int], list[int]] | |
| ): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| assert isinstance(x, np.ndarray) | |
| assert x.ndim in { | |
| 3, | |
| 4, | |
| }, f"x should be a 3D or 4D array, but got {x.ndim} dimensions" | |
| if raw_hu: | |
| # if input is in Hounsfield units, apply soft tissue window | |
| x = self.window(x, WL=50, WW=400) | |
| x0 = x | |
| if not isinstance(organ, list): | |
| organ = [organ] | |
| if not isinstance(crop_dims, list): | |
| crop_dims = [crop_dims] | |
| assert max(crop_dims) <= 2 | |
| assert min(crop_dims) >= 0 | |
| if isinstance(buffer, float): | |
| # percentage of cropped axis dimension | |
| assert buffer < 1 | |
| if 0 in crop_dims: | |
| smin0, smax0 = self.crop_single_plane( | |
| x0, device, organ, threshold, buffer, speed_up | |
| ) | |
| else: | |
| smin0, smax0 = 0, x0.shape[0] | |
| if 1 in crop_dims: | |
| # swap plane | |
| x = x0.swapaxes(1, 0) | |
| smin1, smax1 = self.crop_single_plane( | |
| x, device, organ, threshold, buffer, speed_up | |
| ) | |
| else: | |
| smin1, smax1 = 0, x0.shape[1] | |
| if 2 in crop_dims: | |
| # swap plane | |
| x = x0.swapaxes(2, 0) | |
| smin2, smax2 = self.crop_single_plane( | |
| x, device, organ, threshold, buffer, speed_up | |
| ) | |
| else: | |
| smin2, smax2 = 0, x0.shape[2] | |
| return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1] | |