Original Source Here
Building a classifier on your own dataset
What we see as images on our screens are interpreted by the computer as matrices of decimal numbers. So, in order to load the images into a way that the computer can understand fastai provides couple of classes: the
DataBlock and the
DataLoader. It is essential to understand these two classes in detail to be able to feed our data to the model and let it do it’s weight updates and come up with a function that approximates the relation between inputs (images) to the outputs (targets).
First then let’s understand what a
DataBlock is a template which tells the dataloader which will finally use this datablock, the following five things.
- blocks: What are our inputs and outputs. For image classification, we are mapping an image to a category. So our input will be an
ImageBlockand our output will be a
CategoryBlock. For text classification, we would do a text-category mapping so it will have
TextBlockas the input and
CategoryBlockas the output. If we wish to have multiple inputs, we can do that by passing a tuple/list as the first element of the blocks tuple; same for output. For eg. if we wish to pass two images and see if they’re same or different in some aspect, we can have
blocks = ([ImageBlock, ImageBlock], CategoryBlock).
- get_items: We have to tell how to get the input. fastai has a prebuilt function
get_image_fileswhich gets all the images in a folder and it’s sub folders (well, most common
.jpgformats etc. are considered.
.webpetc. aren’t) into a list.
- splitter: Any Machine Learning model given enough data can perfectly approximate a set of inputs to a set of outputs by memorizing it. However it will not perform well during runtime because it has not learnt the underlying concept and simply memorized the connection of input to output. So this basically specifies the way of splitting. For that fastai has it’s own set of splitters which could be used. The most common is the
RandomSplitterwhich takes a validation percentage and splits that much percentage into the validation set and remaining into the training set.
If you have a folder structure like this, i.e. your train and validation folders are explicitly defined in the structure shown below, you can choose to use the
- bugs bunny
- donald duck
- mickey mouse
- bugs bunny
- donald duck
- mickey mouse
By default the
GrandparentSplitter assumes your training set is called
train and your validation set is called
valid. If that’s not the case you can instantiate grandparent splitter by explicitly specifying the names of your folders respectively.
- get_y: This is to specify how do you want your model to decipher the category which should be assigned to the input in the ImageBlock. If you have a folder structure like above, you know that in the path of the items, the penultimate item in the path is always the class name, when that’s the case you can use the
parent_labelfunction which is provided in fastai to specify that’s where we need to pick the targets from.
- item_tfms: When we download images from the web, we observe that not all of them are of the same resolution. Some are thumbnails, some are gigantic poster sized images, some are reasonably good resolution images etc. But when we feed our data to the model we need to make sure they’re all the same size because we need to be able to batch them together and send it to the model for training. This is where this item_tfms comes in handy. If you want to manipulate your input before actually sending it to the model, you can do it here. We just use a
Resizeto basically ensure that what goes in is consistent. Also there’s several ways to resize a picture like
squish, pad, random, crop etc.but with this default method we do a
CenterCropfor resizing the images.
Now, we’re finally ready to define our datablock below.
This template can now be reused to create new templates or to define a
Dataloader. A dataloader is an actual mechanism unlike a datablock which can be used to pass data to a model. We can create a dataloader out of the datablock defined above as
We need to specify the
batch size which basically means how many images can the model see at once. By default it is selected but it’s better to specify based on your own hardware requirements as larger batch sizes might not fit onto smaller gpus. So be sure to check that out.
Then we can see the data as it’s loaded by using the
show_batch method native to the Dataloader as follows.
dls = cartoons.dataloaders(path, bs = 32)
Deep Learning models are data hungry. The more data you throw at them, in most cases, the better they learn. Now that we have some data, we can create copies of that data by changing elements like the camera angle, brightness, warp etc. and this technique is called data augmentation i.e. virtually increasing the data by creating copies of it.
By copies, I mean, these copies are created only during run-time.
batch_tfms short for batch transforms are carried out on a GPU (graphics card) and then the input is directly fed to the model at run time; so you don’t have to augment data beforehand as you pass it to the model (you can totally do it if you want to but it’s not necessary…)
In the above snippet we can see how augmentation of a same image can produce different images like flipping, random cropping, brightness manipulation etc. as shown in the second row.
Since we now have all the elements together, we can build our very own cartoon classifier by creating a model and passing it the data with the dataloaders created above. Doing that just takes 2 simple lines of code
# Train a cnn model to classify these 4 different cartoons from one another
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(8, freeze_epochs = 3)
We first create a
cnn_learner using the
resnet18 model architecture which is a function having some weights that get updated based on a performance metric on the data relayed to it by
We then use the
learn.fine_tune method by freezing the epochs to some fixed number of epochs and then running by unfreezing for remaining epochs and training the model. On the fastbook slack, there’s a really beautiful answer by Henrik here explaining the importance of fine tuning for a pretrained model.
As we see by the end of 11 epochs, we’re at 75% accuracy with just close to 200 images per class with so much variance in within the class. Not bad at all, isn’t it :)? especially considering such few lines of code that we’ve written…
Now, it’s time to evaluate the model i.e. identify where did it mispredict and which ones was our model struggling so much with?
Confusion Matrix a very common tool which is utilized to assess the performance of classification models. It is a grid of values with rows having actual classes and columns having predicted classes and each cell is populated with the number of elements which actually belonged to an actual row class and were predicted as a predicted column class. This article nicely explains the confusion matrix concept in depth which I found helpful.
For our case, we can use fastai’s interpretor to compute a confusion matrix as follows.
We can see from the matrix that
bugs bunny have been very nicely classified looking rowwise since those are the actual values. However
donald duck and
mickey mouse seem very hard to classify since for those categories, few counts are along the respective diagonal and most of them are scattered across other classes.
If we look at the images which are misclassified there’s some semblance of shinchan in the second image (shinchan always wears a red shirt and khaki pants and mickey’s attire has similar pants and boots respectively). In the first image, with mickey mouse there’s another character who has long ears and is tall like bugs bunny. So our model has quite some scope for improvement but it’s alright for now.
Finally, we can save this model weights to a file by calling an
.export method which creates a file named export.pkl in same directory as the notebook. This contains all the necessary elements for making inference at the run time.
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot