PyTorch Multi-Weight Support API Makes Transfer Learning Trivial Again

Original Source Here

The Old Ways: A Journey on Foot

To have a solid understanding of what is happening, we will examine the old ways first. We won’t train a model, but we’ll do almost everything else:

  • Load a pre-trained version of a neural network architecture
  • Preprocess the dataset
  • Use the neural network to get the predictions on a test set
  • Use the metadata of your dataset to get a human-readable result

The following snippet summarizes what you need to do to tick all the boxes of the list above:

In this example, you first load the ResNet neural network architecture. You set the pretrained flag to True to tell PyTorch that you do not want it to initialize the model’s weights randomly. Instead, it should use the weights obtained by training the model on the ImageNet dataset.

Then, you define and initialize a composition of data transformations. So, before feeding the image to the model, PyTorch will:

  • Resize the image to be 224x224
  • Transform it to a Tensor
  • Convert each value of the Tensor to be of type float
  • Normalize it using a set of given means and standard deviations

Next, you are ready to process the image and pass it through the neural network layers to get your output. This step is the most straightforward one.

Finally, you want to print the result in a human-readable way. This means that you do not want to print that your model predicted that this image is class 4. That does not mean anything to you or the users of your application. You want to print that your model predicted that the image displays a dog with 95% confidence.

To do that, you first need to load your metadata file containing the names of the classes and pinpoint the right one for your prediction.

I admit that this script is not too much work; however, it has two drawbacks:

  1. A minor change in how you process your test dataset can lead to errors that are hard to debug. Remember, you should transform your test dataset the way you also processed your training dataset. If you don’t know how you did that or someone else ran the training procedure, you’re screwed.
  2. You must always carry a metadata file with you. Again, any tampering with this file can lead to unexpected results. These are errors that can have you bumping your head on the keyboard for days.

Let’s now see how a new PyTorch API makes all this better.

Moving Forward

As we saw before, you have two sore points to address: (i) always process your train and test subsets the same way, (ii) eliminate the need to carry a separate metadata file.

Let’s see how a new PyTorch API addresses these challenges with an example:

Immediately you see that the script is significantly smaller. But this is not the point; the script is smaller because you did not have to define a preprocess function from scratch, and you did not have to load any metadata file.

In line 9, you see something new. Instead of telling PyTorch that you need a pre-trained version of ResNet, first, you instantiate a new weights object and then use this to instantiate the model. Inside this weights object, you can find the transformations applied to the data during training and the dataset’s metadata. This is awesome!

Another observation is that you can easily choose the weights to preload now. For example, you could choose to load a different set of weights with minor changes to your code:

# New weights with accuracy 80.674%
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)

Or ask for the ones that yielded the best result on ImageNet:

# Best available weights (currently alias for ImageNet1K_V2)
model = resnet50(weights=ResNet50_Weights.default)


Fine-tuning a Deep Learning (DL) model has never been more straightforward. Modern DL frameworks like TensorFlow and PyTorch make this a trivial task.

However, there are still some pitfalls to avoid. Most notably, you should always process your train and test subsets the same way and eliminate the need to carry a separate metadata file.

Also, what happens if you need to use a different set of weights as your starting point, or you already have a set of weights, and you want to share them in some central repository? Is there an easy way to achieve that?

As you’ve seen, the new PyTorch multi-weight support API covers all these challenges. You can experiment with it by installing the nightly version of PyTorch and provide feedback on this GitHub issue.

About the Author

My name is Dimitris Poulopoulos, and I’m a machine learning engineer working for Arrikto. I have designed and implemented AI and software solutions for major clients such as the European Commission, Eurostat, IMF, the European Central Bank, OECD, and IKEA.

If you are interested in reading more posts about Machine Learning, Deep Learning, Data Science, and DataOps, follow me on Medium, LinkedIn, or @james2pl on Twitter.

Opinions expressed are solely my own and do not express the views or opinions of my employer.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: