Supercharge Image Classification with Transfer Learning



Original Source Here

The field of computer vision by itself has existed since probably the 1960s, where the initial key objective was to build an artificial system to mimic human visual capabilities. Deep learning models like CNNs have caused a resurgence in the last decade owing to a number of factors which include better algorithms, faster compute (GPUs), easy to use software and tools (TensorFlow, PyTorch) and more data availability.

Gradually CNNs have become large and complex but also good enough to be the goto solutions for most tasks related to computer vision (especially classification). Yet, not everyone (or every project) has the resources (or feasibility) to make use of a large GPU cluster to make use of these state of the art behemoth models. Enter Transfer Learning!

Without deviating too much, transfer learning is a method of leveraging such state of the art models (pre-trained on huge datasets) for specific use-cases without the hassles of worrying about preparing large datasets or access to the latest GPU setup on the planet. For an in-depth understanding of Transfer Learning, checkout [1] and [2]

Let’s dive into a hands-on example showcasing transfer learning in the context of image classification or categorisation. The objective here will be to take a few sample images of animals and see how some canned, pre-trained models fare in classifying these images. We will be picking up a couple of pre-trained state-of-the-art models based on complexity, to compare and contrast how they interpret the true category of the input images.

Methodology

The key objective here is to take a pre-trained model off-the-shelf and use it directly to predict the class of an input image. We focus on inference here to keep things simple without diving into how to train or fine-tune these models. Our methodology to tackle our key objective of image classification focuses on taking an input image, loading up a pre-trained model from TensorFlow Hub in Python and classifying the Top 5 probable classes of the input image. This workflow is depicted in figure 1.

Figure 1 Image Classification with Pre-trained CNNs. The figure depicts the top-5 class probabilities for the given input image using a pre-trained CNN model. Image Source: by author

Pre-trained Model Architectures

For our experiment, we will be leveraging two state-of-the-art pre-trained convolutional neural network (CNN) models, namely:

  • ResNet-50: This is a residual deep convolutional neural network (CNN) with a total of 50 layers focusing on the standard convolution and pooling layers a typical CNN has along with batch normalisation layers for regularisation. The novelty of these models include residual or skip connections. This model was trained on the standard ImageNet-1K dataset having a total of 1000 distinct classes.
  • BiT MultiClass ResNet-152 4x: This is Google’s latest state-of-the-art (SOTA) invention in the world of computer vision called Big Transfer, published on May, 2020. Here, they have built their flagship model architecture, a pre-trained ResNet-152 model (152 layers) but four times wider than the original model. This model uses group normalisation layers instead of batch normalisation for regularisation. The model was trained on the ImageNet-21K dataset[3] having a total of 21843 classes.

The foundational architecture behind both models is a convolutional neural network (CNN) which works on the principle of leveraging a multi-layered hierarchical architecture of several convolution and pooling layers with non-linear activation functions.

Convolutional neural networks

Let’s look at the essential components of CNN models. Typically, a convolutional neural network, more popularly known as CNN model consists of a layered architecture of several layers which include convolution, pooling and dense layers besides the input and output layers. A typical architecture is depicted in figure 2.

Figure 2 Architecture of a typical Convolutional Neural Network. This usually includes a stacked hierarchy of convolution and pooling layers. Image Source: by author

The CNN model leverages convolution and pooling layers to automatically extract different hierarchies of features, from very generic features like edges and corners to very specific features like the facial structure, whiskers and ears of the tiger depicted as an input image in figure 3. The feature maps are usually flattened using a flatten or global pooling operator to obtain a 1-dimensional feature vector. This vector is then sent as an input through a few fully-connected dense layers, and the output class is finally predicted using a softmax output layer.

The key objective of this multi-stage hierarchical architecture is to learn spatial hierarchies of patterns which are also translation invariant. This is possible through two main layers in the CNN architecture: the convolution and pooling layers.

Convolution Layers: The secret sauce of CNNs are its convolution layers! These layers are created by convolving multiple filters or kernels with patches of the input image, which help in extracting specific features from the input image automatically. Using a layered architecture of stacked convolution layers helps in learning spatial features with a certain hierarchy as depicted in figure 3.

Figure 3 Hierarchical feature maps extracted from convolutional layers. Each layer extracts relevant features for the input image. Shallower layers extract more generic features and deeper layers extract specific features pertaining to the given input image. Image Source: by author

While figure 3 provides a simplistic view of a CNN, the core methodology is true in the sense that coarse and generic features like edges and corners are extracted in initial convolution layers (to give feature maps). A combination of these features maps in deeper convolutional layers helps the CNN to learn more complex visual features like the mane, eyes, cheeks and nose. Finally the overall visual representation and concept of what a tiger looks like is built using a combination of these features.

Pooling Layers: We typically downsample the feature maps from the convolutional layers in the pooling layers using an aggregation operation like max, min or mean. Usually max-pooling is preferred, which means we take in patches of image pixels (e.g. a 2×2 patch) and reduce it to its maximum value (giving one pixel with the max value). Max-pooling is preferred because of its lower computation time as well as its ability to encode the enhanced aspects of the feature maps (by taking the maximal pixel values of image patches rather than the average). Pooling also helps in reducing overfitting, decreasing computation time and enables the CNN to learn translation-invariant features.

The ResNet architecture

Both of the pre-trained models we mentioned earlier are different variants of the ResNet CNN architecture. ResNet stands for Residual Networks, which introduced a novel concept of using residual or skip connections to build deeper neural network models without facing problems of vanishing gradients and model generalization ability. The typical architecture of a ResNet-50 has been simplified and depicted in figure 4.

Figure 4 ResNet-50 CNN architecture and its components. The key components include the convolution and identity block with residual (skip) connections. Image Source: by author

It is pretty clear that the ResNet-50 architecture consists of several stacked convolutional and pooling layers followed by a final global average pooling and a fully connected layer with 1000 units to make the final class prediction. This model also introduces the concept of batch-normalisation layers interspersed between layers to help with regularisation. The stacked conv and identity blocks are novel concepts introduced in the ResNet architecture which make use of residual or skip connections as seen in the detailed block diagrams in figure 4.

The whole idea of a skip connection (also known as residual or shortcut connections) is to not just stack layers but also directly connect the original input to the output of a few stacked layers as seen in figure 5 where the original input is added to the output from the conv or identity block. The purpose of using skip connections is to enable the capability to build deeper networks without facing problems like vanishing gradients and saturation of performance by allowing alternate paths for gradients to flow through the network. We see different variants of the ResNet architecture in figure 5.

Figure 5 Various ResNet Architectures. The figure indicates the various ResNet models based on the total layers present in the model. Image Source: by author

For our first pre-trained model we will use a ResNet-50 model which has been trained on the ImageNet-1k dataset with a multi-class classification task. Our second pre-trained model uses Google’s pre-trained Big Transfer Model for multi-label classification (BitM) which has variants based on ResNet 50, 101 and 152. The model we use is based on a variant of the ResNet-152 architecture which is 4 times wider.

Big Transfer (BiT) Pre-Trained Models

The Big Transfer Models (BiT) were trained and published by Google on May, 2020 as a part of their seminal research paper[4]. These pre-trained models are built on top of the basic ResNet architecture we discussed in the previous section with a few tricks and enhancements. The key focus of BigTransfer models including the following:

  • Upstream Training: Here we train large model architectures (e.g. ResNet) on large datasets (e.g. ImageNet-21k) with a long pre-training time and using concepts like Group Normalisation with Weight Standardisation, instead of Batch Normalisation. The general observation has been that GroupNorm with Weight Standardisation scales well to larger batch sizes as compared to BatchNorm.
  • Downstream Fine-tuning: Once the model is pre-trained, it can be fine-tuned and ‘adapted’ to any new dataset with relatively less number of samples. Google uses a hyperparameter heuristic called BiT-HyperRule where stochastic gradient descent (SGD) is used with an initial learning rate of 0.003 with a decay factor of 10 at 30%, 60% and 90% of the training steps.

In our following experiments, we will be using the BiTM-R152x4 model which is a pre-trained Big Transfer model based on Google’s flagship CNN model architecture of a ResNet-152 which is four times wider and trained to perform multi-label classification on the ImageNet-21k dataset.

Implementation

Let’s now use these pre-trained models to solve our objective of predicting the Top-5 classes of input images.

TIP: The supporting code notebooks are available in the GitHub repository at https://github.com/dipanjanS/transfer-learning-in-action

We start by loading up the specific dependencies for image processing, modeling and inference.

Code Listing : Import tensorflow and tensorflowhub

Do note that we use TensorFlow 2.x here which is the latest version at the time of writing this article. Since we will be directly using the pre-trained models for inference, we will need to know the class labels of the original ImageNet-1K and the ImageNet-21K datasets for the ResNet-50 and BiTM-R152x4 models respectively as depicted in listing 1.

Code Listing 1: View sample class labels from ImageNet 1k and ImageNet21k datasets

The next step would be to load up the two pre-trained models we discussed earlier from TensorFlow Hub.

Code Listing : Load pre-trained resnet and bit models

Once we have our pre-trained models ready, the next step would be to focus on building some specific utility functions which you can access from the notebook for this article. Just to get some perspective,

  • The preprocess_image(…) function helps us in pre-processing, shaping and scaling the input image pixel values between the range of 0–1.
  • The visualize_predictions(...) function takes in the pre-trained model, the class label mappings, the model type and the input image as inputs to visualise the top-5 predictions as a bar chart.

The ResNet-50 model directly gives the class probabilities as inputs but the BiTM-R152x4 model gives class logits as outputs which need to be converted to class probabilities. We can look at listing 2 which shows a section of the visualize_predictions(...)function which helps us achieve this.

Code Listing 2: visualize_prediction function which gets model probabilities for top 5 predictions

Remember that logits are basically the log-odds or unnormalised class probabilities and hence you need to compute the softmax of these logits to get to the normalised class probabilities which sum up to 1 as depicted in figure 6 which shows a sample neural network architecture with the logits and the class probabilities for a hypothetical 3-class classification problem.

Figure 6 Logits and Softmax values in a Neural Network. Image Source: by author

The softmax function basically squashes the logits using the transform depicted in figure 6 to give us the normalised class probabilities. Let’s now put our code to action! You can leverage these functions on any downloaded image using the sequence of steps depicted in listing 3 to visualise the Top-5 predictions of our two pre-trained models.

Code Listing 3: Analyse an input image and visualise model predictions

Voila! We have the Top-5 predictions from our two pre-trained models depicted in a nice visualisation in figure 7.

Figure 7 Prediction results on a Snow Leopard Image. Image Source: by author

It looks like both our models performed well, and as expected the BitM model is very specific and more accurate given it has been trained on over 21K classes with very specific animal species and breeds.

The ResNet-50 model has more inconsistencies as compared to the BiTM model with regard to predicting on animals of similar genus but slightly different species like tigers and lions as depicted in figure 8.

Figure 8 Correct vs. Incorrect predictions of the BitM and ResNet-50 models. Image Source: by author

Limitations and Possible Improvements

Another aspect to keep in mind is that these models are not exhaustive. They don’t cover each and every entity on this planet. This would be impossible to do considering data collection for this task itself would take centuries, if not forever! An example is showcased in figure 9 where our models try to predict a very specific dog breed, the Afghan Hound, from a given image.

Figure 9 Both our models struggle to predict an Afghan Hound. Image Source: by author

Based on the Top-5 predictions in figure 9 you can see that while our BiTM model actually get the right prediction, the prediction probability is very low indicating our model is not too confident (given that it probably hasn’t seen too many examples of this dog breed in its training data during the pre-training phase). This is where we can fine-tune and adapt our models to make them more tuned to our specific datasets and output labels and outcomes. This forms the basis of a number of applications where we can leverage pre-trained models and adapt these models to very different and novel problems. One such interesting use-case is covered in the article “supercharge your image search with transfer learning”

Summary

  • Image classification is one of the most popular and well researched applications of deep learning as well as transfer learning
  • Convolutional Neural Networks or CNNs are extremely powerful when it comes to extracting features from image inputs.
  • ResNets or Residual Neural Networks and their variants have proven performance on large multi-class datasets such as ImageNet.
  • The latest ResNet variant from Google, BiT model, is extremely powerful and provides state-of-the-art performance for image classification tasks
  • We easily leveraged pretrained ResNets from Tensorflow-Hub to understand the ability of such models to be transfer-learnt/fine-tuned on new datasets

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: