Learning Deep Learning — MNIST with FastAI (Part 1)

Original Source Here

Learning Deep Learning — MNIST with FastAI (Part 1)

In this series of posts my goal is to document and illustrate my journey as I learn the art and science of “deep learning”. I know these posts will be useful to myself as I look back and reflect on how far I’ve come, and I hope they can be great starting points for others as well.

In Part 1 I’ll set the stage and we’ll look at solving a subset of the full MNIST problem. In Part 2 we’ll take what we learned and apply it to solve the whole thing. Let’s get to it!


Classifying MNIST, a dataset with images of handwritten numbers 0–9, is a common starting point for learning how to build neural nets. Some have compared it to the hello-world of programming, but I find this comparison quite unfair and misleading as even something as simple as MNIST image classification is significantly more complex than a few lines of a hello-world program.

Example images from the MNIST dataset.

However, it’s still a great starting point to learn the basics of deep learning, which is exactly what we’ll be doing here!


With a goal set before us, what tools do we use to accomplish it? There are a variety of popular deep learning frameworks, and the two major players mid-2021 seem to be TensorFlow and PyTorch. Then there are higher-level “wrapper” frameworks like Keras, PyTorch Lightning, and FastAI which abstract away some of the details and are great for beginners.

I chose FastAI after listening to Lex Fridman interview Jeremy Howard (co-founder of FastAI) in this podcast. It seemed like an especially great place to start because of the high-quality, free online course that FastAI offers. I highly recommend the videos and material if these posts whet your appetite.

Setup and Data

Learning can’t happen without experience and examples, and for deep learning that comes in the form of datasets. Often data collection and curation is a HUGE part of building a deep learning model, but fortunately FastAI makes this process very simple for common datasets.

The following Python code is meant to be run in a Jupyter Notebook, and I’ve combined multiple cells here for readability. If you want to run the code I recommend FastAI’s getting started section here. (I run my notebooks via Google Colab.)

!pip install -q fastbook
import fastbook
from fastai.vision.all import *
from fastbook import *
# Set the default image color map to grayscale
matplotlib.rc('image', cmap='Greys')
# Download the dataset
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
# Gathers paths to each image
three_paths_t = (path / 'train' / '3').ls()
three_paths_v = (path / 'valid' / '3').ls()
seven_paths_t = (path / 'train' / '7').ls()
seven_paths_v = (path / 'valid' / '7').ls()
# Let's look at one - GET FAMILIAR WITH YOUR DATA!!!
im3 = Image.open(three_paths_t[0])
Tada! It’s a three.

Wow, what did we just do?! Hopefully if you take your time to read the code it (along with the comments) will be fairly self-explanatory, but I’ll cover the highlights.

We downloaded a subset of the full MNIST dataset, in this case just 3’s and 7’s. Did you see how easy FastAI made this with untar_data(URLs.MNIST_SAMPLE)? That dataset comes with “train” and “valid” subfolders, each of which have subfolders for each number. It’s really nice that the data is so well organized for us. Normally you’d have to do all that work yourself.

After gathering the paths to all the images, we looked at an example. Looking at your data is critical to understand what you’re feeding your model. In this case we pulled the first “3” out of the training set and it looks correct.

# Convert paths to image data and stack them into a single tensor
def stack(paths):
return torch.stack([
tensor(Image.open(path)).float() / 255 for path in paths
sevens_train = stack(seven_paths_t)
sevens_valid = stack(seven_paths_v)
threes_train = stack(three_paths_t)
threes_valid = stack(three_paths_v)
# Examine the shape of our tensors
print(sevens_train.shape, sevens_valid.shape, threes_train.shape, threes_valid.shape)
# Look at more examples

torch.Size([6265, 28, 28]) torch.Size([1028, 28, 28]) torch.Size([6131, 28, 28]) torch.Size([1010, 28, 28])
Sure enough, a 7 and different 3 from before. Data looks correct.

Finally we transform our image paths into actual image data (pixel values). We print the size of the tensors (stacked up image data into a single 3-dimensional array) to make sure we get something sensible. In this case we see that we have roughly 6,000 of each number in the training set and 1,000 of each in the validation set. The images themselves are 28×28 pixels. Looking at a couple new examples, things look as expected.

Forming a Baseline

With our data downloaded and transformed, we could proceed to shove it into a deep learning model. But it’s important to pause here to answer the question: “What does success look like?”

We may have an idea of the level of accuracy we need, or we may not. While that’s good to keep in mind, there’s more to consider. A baseline helps us measure success by setting the expected lower limit of whatever metric we’re after (in this case, accuracy of the classification). With two buckets to sort images into, we can get 50% accuracy with random guessing, but we can do better than that.

The trick to forming a good baseline is to find some approach that’s easy to implement which gives better-than-average results. When working with image data, one of those approaches is to compare any individual image to the “average” image. The idea is that any given “3” will be more similar to the average “3” than the average “7”. Let’s try it.

# Build our "average" image by taking the mean over the 0-axis
mean_valid_7 = sevens_valid.mean(0)
mean_valid_3 = threes_valid.mean(0)
# You guessed it - LOOK AT YOUR DATA
Nice, blurry averages.

Notice that we averaged the validation images. We never want to compare the training data to itself — we should always compare the training data to a validation or test set. It’s less important here with the baseline, but getting this wrong with a neural net would completely ruin the model.

Ok, now we need a way to compare a single image to the averages. In fact, let’s do it all at once for every training image and average the results. Tensors are cool like that.

# Measure the average, absolute pixel difference
def abs_dist(a, b): return (a-b).abs().mean((-1, -2))
def is_3(img):
return abs_dist(img, mean_valid_3) < abs_dist(img, mean_valid_7)
# Create the baselines
baseline_3 = is_3(threes_train).float().mean().item()
baseline_7 = (1 - is_3(sevens_train).float()).mean().item()
# Look at the results
baseline_3, baseline_7, (baseline_3 + baseline_7) / 2

(0.8822377920150757, 0.9969672560691833, 0.9396025240421295)

WOW. This blew me away when I saw that a naïve approach of simply comparing to the average yields 88% accuracy for 3’s and 99.7% accuracy for 7’s (from which we can infer that people are more consistent with writing a “7” than a “3”).

Now we have our baseline: 94.0% between both 3’s and 7’s. This is way higher than I expected, and this perfectly illustrates why forming a baseline is important. Otherwise we might have hit 93% with our fancy model and been happy with ourselves doing worse than this simple method.

Can a Neural Net Do Better?

I sure hope so, or we’re wasting our time!

Now this lesson in the FastAI course digs into all the nitty-gritty details we’ll be skipping here. It covers Datasets, Dataloaders, loss functions, stochastic gradient descent (SGD), activation functions, model architecture, and more. While I could replicate that all here, why bother? Just go check it out! The full notebook they reference for this lesson is a treasure trove of knowledge, and any beginner serious about learning deep learning should check it out (ideally from the first FastAI lesson).

At the end of the day though, we’re using FastAI as a high-level abstraction library. So without further ado — MAGIC.

# Just a reminder about where our data comes from
path = untar_data(URLs.MNIST_SAMPLE)
# DataLoaders feed batches of images for training
dls = ImageDataLoaders.from_folder(
# Learners handle the training loop for us
learner = cnn_learner(
Hell yeah, that’s a LOT better than our baseline of 94%!

We did it. I warned you about magic — let’s break it down.

ImageDataLoaders.from_folder() is built to load our images from a certain file structure. We tell it the training images are in the train folder, and the validation images are in the valid folder. It knows to use further subfolders as labels for the images therein.

cnn_learner handles our training loop, but it needs a few things. First, we feed it our DataLoaders. resnet18 is the model architecture we want it to use (a common model good at working with image data). pretrained is False because we want to start training the model from scratch (leaving this True is used for transfer learning which is insanely powerful, but we aren’t covering it here). F.cross_entropy is the loss function we’re optimizing (it measures classification accuracy). And finally metrics tells the Learner what data points we care about seeing as it trains.

fit_one_cycle(1) does a single “cycle” of training using various learning rates, and voilà we have a 99.6% accuracy “3” and “7” classifier! That’s significantly better than our 94.0% baseline, and it only took 15 seconds to train as well.

Honestly that felt a little anticlimactic — it happened so fast. We worked harder on our baseline than on the nice model. Well, if anything this should illustrate the power of a good abstraction library like FastAI. It’s also a good lesson about where the hard work lies when developing a deep learning model — training a fancy model is just the tip of the iceberg.

Next Steps

Well, we still don’t have an MNIST digit classifier, just a “3” and “7” classifier. So obvious next steps would be to repeat this with the full dataset (baseline + model training). I’ll be tackling that in the next post.

So feel free to proceed to see how we can create a full 10-digit MNIST classifier, or step back and dig into the details more with FastAI’s course or the full notebook for this lesson.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: