Quantcast
Channel: Machine Learning | Towards AI
Viewing all articles
Browse latest Browse all 819

Explainability for 3DResNet Classifier

$
0
0
Author(s): Shashwat (Shawn) Gupta Originally published on Towards AI. GradCAM is one of the simplest techniques to get explainability insights into model prediction. I was surprised to find that while there are many blogs on Medium about using GradCAM with ResNet, there aren’t any specifically for GradCAM with 3D images (eg. for ResNet3D); and almost none in Pytorch. Furthermore, most Github codes, inspired from 2D GradCAM, do incorrect implementation of GradCAMDetermined to fill this gap, I spent an entire night understanding the intricate details of the code and successfully wrote my own implementation. 2D Explainabilty by GradCAM. Source: author X-ray image from kaggle Chest X-ray dataset What This Code Does This code builds a ResNet3D model from scratch, which is a 3D version of the popular ResNet (Residual Network) used for image recognition. It incorporates GradCAM (Gradient-weighted Class Activation Mapping), a technique that helps visualize which parts of the input data the model focuses on when making decisions. The model processes NIfTI files (a common format for medical imaging data) listed in train.txt and test.txt. Instead of performing image segmentation, we modify the model to do classification by replacing the segmentation layer with a feedforward network (ffn) initialised using Xavier initialisation. Initially, only the new layers are trained while keeping the existing weights fixed. After a few training cycles (epochs), the entire network is fine-tuned. To speed up training, the code utilizes multiple GPUs (Graphics Processing Units) through Data-Parallelism, allowing the model to use all available GPUs efficiently. The train.txt, test.txt, and gradcam.txt files should contain the paths to the .nii.gz files and their corresponding class labels, separated by a space. For example: ./file1/a.nii.gz 1./file2/b.nii.gz 0 …. Imports and Parameters import torchnum_gpus = torch.cuda.device_count()print(f"Number of GPUs available: {num_gpus}")import torch.nn as nnimport torch.optim as optimimport torchfrom torch.utils.data import DataLoader, Datasetimport osimport numpy as npimport nibabelfrom scipy import ndimageimport timefrom scipy.ndimage import zoomimport warningsimport torch.nn.functional as Ffrom torch.autograd import Variableimport mathfrom functools import partial# Parameters (Command Line)n_epochs = 700epoch_unfreeze_all = 300data_root = './data'train_img_list = './data/train.txt'test_img_list = './data/test.txt'manual_seed = 1num_classes = 2 # Updated for classificationlearning_rate = 0.001num_workers = 4batch_size = 1save_intervals = 30input_D = 56input_H = 448input_W = 448resume_path = '' # Resume from this if it's a filemodel_depth = 10 # 10 | 18 | 34 | 50 | 101 | 152 | 200pretrain_path = f'pretrain/resnet_{model_depth}.pth'new_layer_names = ['fc'] # Updated to 'fc' for classification headgpu_id = [i for i in range(num_gpus)]model = 'resnet'resnet_shortcut = 'B' # A | B # A - Identity Matrix v B - Projection Matrixsave_folder = "./trails/models/{}_{}".format(model, model_depth)test_batch_size = 1test_num_workers = 4no_cuda = not torch.cuda.is_available()if not no_cuda and torch.cuda.device_count() > 0: pin_memory = True test_pin_memory = True # Set to True if using GPU print(f"Using GPU(s). Number of GPUs available: {torch.cuda.device_count()}")else: pin_memory = False test_pin_memory = False print("No GPU available, using CPU.") Model Description The ResNet (Residual Network) is a type of neural network that uses residual blocks to allow the network to learn more effectively. In this implementation, we define different types of blocks and layers to build the ResNet3D model tailored for classification tasks. def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): # 3x3x3 convolution with padding return nn.Conv3d( in_planes, out_planes, kernel_size=3, dilation=dilation, stride=stride, padding=dilation, bias=False)def downsample_basic_block(x, planes, stride, no_cuda=False): out = F.avg_pool3d(x, kernel_size=1, stride=stride) zero_pads = torch.Tensor( out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)).zero_() if not no_cuda: if isinstance(out.data, torch.cuda.FloatTensor): zero_pads = zero_pads.cuda() out = Variable(torch.cat([out.data, zero_pads], dim=1)) return outclass BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) self.bn1 = nn.BatchNorm3d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3x3(planes, planes, dilation=dilation) self.bn2 = nn.BatchNorm3d(planes) self.downsample = downsample self.stride = stride self.dilation = dilation def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return outclass Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm3d(planes) self.conv2 = nn.Conv3d( planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) self.bn2 = nn.BatchNorm3d(planes) self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm3d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return outclass ResNet(nn.Module): def __init__(self, block, layers, sample_input_D, sample_input_H, sample_input_W, num_classes, # Changed from num_seg_classes to num_classes shortcut_type='B', no_cuda=False): super(ResNet, self).__init__() self.inplanes = 64 self.no_cuda = no_cuda self.conv1 = nn.Conv3d( 1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) self.layer2 = self._make_layer( block, 128, layers[1], shortcut_type, stride=2) self.layer3 = self._make_layer( block, 256, layers[2], shortcut_type, stride=1, dilation=2) self.layer4 = self._make_layer( block, 512, layers[3], shortcut_type, stride=1, dilation=4) # placeholder for the gradients self.gradients = None # Remove or comment out the segmentation head # self.conv_seg = nn.Sequential( # ... # ) # Add a classification head self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) # Binary classification (2 classes) # Initialize weights for new layers self._initialize_weights() def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: if shortcut_type == 'A': downsample = partial( downsample_basic_block, planes=planes * block.expansion, stride=stride, no_cuda=self.no_cuda) else: downsample = nn.Sequential( nn.Conv3d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm3d(planes * block.expansion)) layers = [] layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation)) return nn.Sequential(*layers) def _initialize_weights(self): # Initialize weights for the new classification head using Xavier initialization nn.init.xavier_normal_(self.fc.weight) if self.fc.bias is not None: nn.init.constant_(self.fc.bias, 0) def forward(self, x,reg_hook=True): # Feature extraction x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = […]

Viewing all articles
Browse latest Browse all 819

Latest Images

Trending Articles



Latest Images