Original Source Here
Consistent Semi-Supervised, Explainable Multi-Tasking for Medical Imaging
MultiMix: Sparingly Supervised, Extreme Multitask Learning from Medical Images
In this article, I will discuss a new semi-supervised, multi-tasking medical imaging method called MultiMix, by Ayaan Haque (me), Abdullah-Al-Zubaer Imran, Adam Wang, and Demetri Terzopoulos. Our paper was accepted to ISBI 2021 in the full-paper track and was presented at the conference in April. The extension of our paper with improved results was published in the MELBA Journal as well. This article will cover a review of the methods, results, and a short code review. The code is available here.
MultiMix performs joint semi-supervised classification and segmentation by employing a confidence-based augmentation strategy and a novel saliency bridge module which provides explainability for the joint tasks. Deep learning-based models, when fully-supervised can be efficient in performing complex image analysis tasks, but this performance relies heavily upon the availability of large labeled datasets. Especially in the medical imaging domain, labels are expensive, time-consuming, and prone to observer variations. As a result, semi-supervised learning, which allows for learning from limited quantities of labeled data has been investigated as an alternative to supervised counterparts.
Moreover, learning multiple tasks within the same model further improves model generalizability. Additionally, Multi-tasking allows for shared representation learning between tasks while requiring fewer parameters and less compute, making models more efficient and less prone to overfitting.
Our extensive experimentation with varied quantities of labeled data and multi-source data proves the effectiveness of our method. Moreover, we also present both in-domain and cross-domain evaluations across the tasks to showcase the potential of our model to adapt to challenging generalization scenarios, which is a challenging but important task for medical imaging methods.
Learning-based medical imaging has grown in recent years mostly because of the growth of deep learning. However, the fundamental problem of deep learning always lingers, which is that they require lots of labeled data to be efficient. This unfortunately is an even larger problem in the medical imaging domain, as collecting large datasets and annotations can be difficult because they require domain expertise, are expensive, time-consuming, and hard to organize in centralized datasets. Moreover, generalization is a key problem in the medical imaging domain, as images from different sources can be significantly different both qualitatively and quantitatively, making the process of model-building difficult if we want to achieve strong performance in multiple domains. We hope to address these fundamental problems with a few key approaches which are centered around semi-supervised and multi-task learning.
What is Semi-Supervised Learning?
To address the limited labeled data problem, semi-supervised learning (SSL) has gained lots of attention as a promising alternative. In semi-supervised learning, unlabeled examples are leveraged in combination with labeled examples to maximize information gains. There has been lots of research in semi-supervised learning, both general and medical-domain specific. I won’t discuss these methods in detail, but here is a list of prominent methods to refer to if you are interested [1, 2, 3, 4].
Another solution to address limited-sample learning is to use data from multiple sources, as this increases the number of samples in data as well as the diversity of data. However, doing so is challenging and requires specific training methods, but if done correctly, it can be very impactful.
What is Multi-Task Learning?
Multi-Task Learning (MTL) has been researched for improving the generalizability of many models. Multi-task learning is defined as optimizing more than one loss in a single model such that multiple related tasks are performed through shared representation learning. Jointly training multiple tasks within a model improves the generalizability of the model as each of the tasks regularizes one another. In addition, assuming training data come from different distributions for different tasks with limited annotations, multi-tasking can be useful in such scenarios for learning in a scarcely-supervised manner. Combining multi-tasking with semi-supervised learning can increase performance and yield success in these two tasks. Accomplishing these two tasks at once can be extremely beneficial, as instead of having medically trained professionals, a single deep learning model can accomplish both tasks with remarkable accuracy.
Regarding related work in the medical domain, I won’t get too detailed on the methods, but here is a list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. However, the primary limitations of these works are that they don’t use data from multiple sources, limiting their generalizability, as well as most methods being only single-task methods.
Thus, we propose a novel, better-generalized multi-tasking model called MultiMix, incorporating confidence-based augmentation and saliency bridge module, to jointly learn diagnostic classification and anatomical structure segmentation from multi-source data. A saliency map enables the analysis of model predictions through the visualization of meaningful visual features. A saliency map can be produced in a few ways, most notably by computing the gradient of the class score from the input image. While any deep learning model can be investigated for better explainability through the saliency map, to our knowledge a saliency bridge between the two shared tasks in a single model has not yet been explored.
Let’s begin by defining our problem. We use two datasets for training, one for segmentation and one for classification. For segmentation data, we can use the notation Xs and Y, which are the images and the segmentation masks respectively. For classification data, we can use the notation Xc and C, which are the images and class labels.
Regarding our model architecture, we use a baseline U-Net architecture, which is a commonly used segmentation architecture using an encoder-decoder framework. The encoder functions similarly to a standard CNN. To perform multi-tasking with a U-Net, we branch off from the encoder with pooling and fully-connected layers to get a final classification output.
For our proposed classification method, we leverage data augmentation and pseudo-labeling. Inspired by , we take an unlabeled image and perform two separate augmentations. First, an unlabeled image is weakly augmented, and from that weakly augmented version of the image, a pseudo-label is assumed based on the prediction from the current state of the model. This is why the method is semi-supervised, but we will talk more about the pseudo-labeling process a little later.
Secondly, the same exact unlabeled image is augmented strongly, and a loss is calculated with the pseudo-label from the weakly augmented image and the strongly augmented image itself. Essentially, we are teaching the model to map the weakly augmented image to the strongly augmented image, and this forces the model to learn the fundamental underlying features required for diagnostic classification. Augmenting images twice also maximizes the potential knowledge gains from a sole image. This also helps improve generalization, as if the model is forced to learn just the most important parts of the image, it will be able to overcome the differences that appear in images due to different domains.
Regarding augmentations, we use conventional augmentations for weakly augmented images such as horizontal flipping and slight rotations. The strong augmentation strategy is much more interesting. We create a pool of unconventional, strong augmentations, and apply a random number of the augmentations to any given image. These augmentations are quite distortive, such as include cropping, autocontrast, brightness, contrast, equalize, identity, rotation, sharpness, shearing, and more. By applying any amount of these, we create an extremely wide variety of images, which is especially important when dealing with low-sample datasets. We ultimately found that this augmentation strategy is quite important for strong performance.
Now let’s go back and discuss the pseudo-labeling process. So once the weak augmentations are converted to pseudo-labels, we only use them Note that this image-label if the confidence with which the model generates the pseudo-label is above a tuned threshold, which prevents the model from learning from incorrect and poor labels. This results in the curriculum for free effect, as when the predictions are less confident at the beginning, the model learns mainly from the labeled data. The model becomes more confident with the generation of labels for the unlabeled images and as a result, the model becomes more efficient. This also is a really important feature in terms of increasing performance.
Now let’s take a look at the loss function. The classification loss can be modeled by the following equation:
where L-sub-l is the supervised loss, c-hat-l is the classification prediction, c-l is the label, lambda is the unsupervised classification weight, L-sub-u is the unsupervised loss, c-hat-s is the predictions on strongly augmented images, argmax(c-hat-w) are the pseudo-labels from weakly augmented images, and t is the pseudo-labeling threshold.
That essentially summarizes the classification method, so now let’s move on to the segmentation method.
For the segmentation, the predictions are made through the encoder-decoder architecture with skip connections, which is quite straightforward. Our main contribution for segmentation is the incorporation of a saliency bridge module to bridge between the two tasks, as seen in the figure above. We generate saliency maps based on the model predicted classes, using the gradients from the encoder extended to the classification branch. The entire process is shown above, but essentially a saliency map highlights which parts of the image the model is using to classify the image for pneumonia. When visualized, they end up looking similar to a segmentation map, making it a perfect addition for a segmentation bridge.
While we don’t know if the segmentation images represent pneumonia, the generated maps highlight the lungs, creating images at the final segmentation resolution. Thus, when the class prediction of an image is produced and visualized with the saliency map, it somewhat resembles the lung mask. We hypothesize that these saliency maps can be used to guide the segmentation during the decoder phase, yielding improved segmentation while learning from limited labeled data.
In MultiMix, the generated saliency maps are concatenated with the input images, downsampled, and added to the feature maps input to the first decoder stage. Concatenation with the input image allows for a stronger connection between the two tasks and improves the effectiveness of the bridge module because of the context it provides. Adding both the input image and the saliency map provides the decoder with more context and information which can be really important when dealing with low-sample data.
Now let’s discuss training and loss. For the labeled samples, we conventionally calculate the segmentation loss using dice loss between the reference lung mask and predicted segmentation.
Since we don’t have the segmentation masks for the unlabeled segmentation samples, we can’t directly calculate the segmentation loss for them. So, to do so, we compute the KL divergence between segmentation predictions for labeled and unlabeled examples. This penalizes the model from making predictions that are increasingly different from those of the labeled data, which helps the model fit more appropriately to the unlabeled data. While this is an indirect method of computing loss, it still allows for the model to learn a lot from the unlabeled segmentation data.
Regarding loss, our segmentation loss can be written as:
Where alpha is the segmentation loss weight in comparison to classification, y-hat-l is the labeled segmentation predictions, y-l is the corresponding masks, beta is the unsupervised segmentation weight, and y-hat-u is the unlabeled segmentation predictions.
Our model is trained on the combined objective of the classification and segmentation loss. Now that we have discussed the loss, that wraps up both the segmentation methods as well as the entire methods section.
The models were trained and tested for classification and segmentation tasks, and the data for each task are from two different sources: a pneumonia detection dataset which we will call CheX , and the Japanese Society of Radiological Technology, or JSRT , for classification and segmentation respectively. When we refer to in-domain datasets, these are the two datasets.
We importantly validated the models on two external datasets, one for each task. We used the Montgomery County chest X-rays, or MCU , and a subset of the NIH chest X-ray dataset, which we will refer to as NIHX . The diversity of the sources poses a significant challenge for our model, as the image quality, size, proportion of normal and abnormal images, and the disparity in the intensity distributions of the four datasets are all quite varying. The figures below show the differences in intensity distribution along with examples of the images from each dataset. All 4 of the datasets have the CC BY 4.0 license.
We conducted a lot of experiments with varying quantities of labeled data on multiple datasets, both in and cross-domain.
To preface the results, we used multiple baselines in our tests, as we have a baseline for each addition to our model. We start with a barebones U-Net and a standard classifier (enc), which is the encoder feature extractor with dense layers. We then combined the two for our baseline multi-tasking model (UMTL). We also used an encoder with the semi-supervised method (EncSSL), a multi-tasking model with the saliency bridge (UMTLS), and the multi-tasking model with the saliency bridge and the proposed semi-supervised method (UMTLS-SSL), which is basically MultiMix without KL divergence for semi-supervised segmentation. Then we of course have MultiMix.
In terms of training, we trained on multiple levels of labeled datasets. For classification, we used 100, 1000, and all the labels, and for segmentation, we used 10, 50, and all of the labels. For our results, we will use the notation: model-seglabels-classlabels (e.g. MultiMix-10–100). For evaluation, we used accuracy (Acc) and F1 scores (F1-N and F1-P) for classification, and for segmentation, we used Dice Similarity (DS), Jaccard Similarity Score (JS), structural similarity index measure (SSIM), average Hausdorff distance (HD), precision (P), and recall (R).
The following figure is a table of the performance of MultiMix against multiple baselines. The best fully-supervised scores are underlined and the best semi-supervised scores are bolded.
The table shows how the model performance is improved with the subsequent inclusion of each of the novel components. For the classification task, our confidence-based augmentation approach for semi-supervised learning significantly improved performance compared to the baseline models. Even with the minimum labeled data for each task, our MultiMix-10–100 outperforms the fully-supervised baseline encoder in terms of accuracy. For segmentation, the inclusion of the saliency bridge module yields large improvements over the baseline U-Net and UMTL models. Even with the minimum segmentation labels, we can see a 30% performance gain over its counterparts, proving the effectiveness of our proposed MultiMix model.
We focus heavily on the importance of generalization, and our results show our model is able to generalize quite well. MultiMix consistently performs well against in domains with improved generalizability in both tasks. As seen in the table, the performance of MultiMix is as promising as the in-domain ones. MultiMix achieved better scores in the classification task over all the baseline models. Due to the significant differences in the NIHX and CheX datasets as discussed earlier, the scores are not as good as the in-domain results. It does indeed performs better than the other models, however.
The next figure is a box plot displaying the consistency of our segmentation results on both in-domain and cross-domain evaluations. We display the dice scores of our models for each image in the dataset against each other. From the plot, we can see that MultiMix is the strongest model compared to the baselines.
The final figure we will discuss is visualizations of the segmentation predictions of our model. We show a predicted boundary for each of the proposed segmentation task additions against the ground truth at varied labeled data, for both in- and cross-domain. The figure shows strong agreement with MultiMix’s boundary predictions against the ground truth boundary, especially when compared to the baseline. Especially for cross-domain, MultiMix is the best by a substantial margin, showing our strong ability to generalize.
Now that we have covered the methods and the results, let’s get into the code. I will mostly go over the model architecture and the training loop, as those are the main areas of contribution. Note that the code is written in PyTorch and Python.
Let’s start by checking out our convolutional blocks.
Each block is a double-convolutional block. We start with a 2d convolutional layer with a kernel size of 3, and then we use an instance normalization layer and an activation function of LeakyReLU with a negative slope of 0.2. We then repeat this sequence again to finish the convolution block.
Now let’s take a look at the saliency bridge.
This code is used just to generate the saliency map. We first pass in the inputs, the encoder, and the optimizer. We then create a copy of the images to ensure the gradients of the image are not modified. We then set the input require_grad to true and set the encoder to eval mode. We then get the feature maps and output of the encoder so we can generate the saliency maps. We first get the maximum index of the classification outputs and then use the .backward() function to collect the gradients. We then get the saliency map by collecting the gradients with the .abs() function. Importantly, we have to zero the gradients of the optimizer because using the backward computes the gradients, and this can be problematic when updating the parameters of the model.
Now that we have covered the components of the architecture, let’s put it all together and check out the entire architecture.
We split the model into separate encoder and decoder modules, and combine them in the MultiMix class. For the encoder, we use double_conv blocks upscaling by a factor of 2 each time. Looking at the forward function, we save the feature maps after each convolution block, which are used for skip connections between the encoder and decoder, and we use max-pooling layers to deconstruct the image. We then add the classification branch for multi-tasking using the average pooling layer and a dense layer to get the final classification output (outC). We return all the feature maps as well as the classification prediction for the decoder to use.
Then in the decoder, we use convolution layers which reduce the feature maps and we use upsampling layers to reconstruct the image. The forward function is where all the magic happens. We start by concatenating and stacking the saliency map with the original image. Then we downsample the input so that it can be concatenated in the first convolutional block, along with the skip connection. For the next convolutional blocks, we just perform standard deconvolution and skip connections to get the final output (out).
Once we have built the model, we can build our training loop. This is a quite lengthy and intimidating block of code, so don’t worry, we will break it down.
Before we discuss the loop, note that we have left out a lot of the methods and training loop to simplify it.
If we start at line 52, we start by combining all the training datasets, including the supervised segmentation training set, the unlabeled segmentation training set, the supervised classification training set, the weakly augmented classification set, and the strongly augmented classification set. The latter two have the same exact images but are just augmented at different levels. The next set of lines is simply just basic splicing and combination of the data so that all the data is sent through the model uniformly.
Once we pass all the inputs into the model, we pass them all to the calc_loss function. In the calc_loss function, we start by getting the basic supervised classification and segmentation loss (dice and lossClassifier). We use dice loss for segmentation and cross-entropy for classification.
For semi-supervised classification, we begin by passing the weakly augmented image predictions through a softmax function to get the probabilities, and we use the torch.max function to get the label. Then we use the .ge function to only keep the predictions that are above the confidence threshold, which is an important factor discussed in the methods. Then we compute the unsupervised classification loss (lossUnlabeled).
Lastly, we compute the KL divergence using the labeled and unlabeled segmentation predictions (kl_seg). Once all the computations are complete, we combine them into a single loss calculation by summing all the losses after they are multiplied by their respective weights (lambda, alpha, beta). Once this is passed back to the main train loop, we simply compute the gradients with loss.backward() and update the parameters of the model with optimizer.step().
That concludes the code review section. We didn’t go over the augmentation and data-processing portion as it is quite tedious. If you are interested, check out the full code at the following repo: https://github.com/ayaanzhaque/MultiMix
Conclusion and Thoughts:
In this blog post, we explain MultiMix, a novel sparingly supervised, multitask learning model for jointly learning classification and segmentation tasks. Through the incorporation of consistency augmentation and a novel saliency bridge module for better explainability, MultiMix performs improved and consistent pneumonia detection and lung segmentation even when trained on limited-labeled data and multi-source data. Our extensive experimentation using four different chest X-ray datasets truly demonstrates the effectiveness of MultiMix both in in-domain and cross-domain evaluations, for either task. Our future work will focus on further improving MultiMix’s cross-domain performance, especially for classification. We are currently preparing a full journal submission with more results and extensions of the work.
Doing this work was really exciting for me. As a high schooler, I am grateful for having the opportunity to work with qualified and experienced research faculty to do cutting-edge research. The entire process was quite challenging for me, as I had little experience on how to write formal research papers and conduct proper and convincing experiments. Even coding and building the actual additions took quite a bit of time. I was still getting familiar with PyTorch, but working on this project was very fun and exciting, and I learned so much about deep learning and medical imaging. I am thrilled about the conference with the chance to meet fellow researchers and learn about the new research in the field, and I am confident that our future work will also see the same success as this project. Thank you for reading.
If you found any part of this blog or the paper interesting, please consider citing:
title = "Generalized Multi-Task Learning from Substantially Unlabeled Multi-Source Medical Image Data",
authors = "Haque, Ayaan and Imran, Abdullah-Al-Zubaer and Wang, Adam and Terzopoulos, Demetri",
journal = "Machine Learning for Biomedical Imaging",
volume = "1",
issue = "October 2021 issue",
year = "2021"
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot