MLP-Mixer: An all-MLP Architecture for Vision



Original Source Here

MLP-Mixer: An all-MLP Architecture for Vision

Back to the good old MLP?

Introduction

Researchers from the Google Research & Google Brain teams released a new architecture MLP-Mixer which is not featuring CNN or Attention Layers but still can achieve a performance that is on-par with architectures like ViT (Vision Transformer), BiT (Big Transfer), HaloNet, and NF-Net. It does so with a 3x speedup in training and scales excellently with the amount of training data. The authors have put together this architecture because they feel we can achieve excellent results without the all-important CNN and Attention layers which are computation hungry. In this process, they also demonstrate a higher throughput/second/core and how well the model scales for bigger datasets better than most of the current state-of-the-art.

Background

Until this paper was published, CNNs and Self-Attention were the go-to mechanisms for Image Classification and Computer Vision in general. This was showcased in the recent state-of-the-art Vision Transformer (ViT) which was successful in using a Transformer with Attention layers applied for Computer Vision tasks. While these operations are very computationally intensive taking multiple days to train on TPUs, the authors propose something which is a fundamental concept in Deep Learning — The MLP (Multi-layered Perceptron)! The architecture relies on a variety of matrix multiplications repeatedly on spatial and feature channels.

How is that possible? Didn’t we move on from MLPs, CNNs, Residual CNNs i.e. ResNets, DenseNets, NF-Nets, ViT, and so on…? Now, the authors are saying come back to MLP? Yes, you heard them right,

While convolutions and attention are both sufficient for good performance, neither of them are necessary

These are huge claims, let us understand why this architecture worked and their claims are right!

Architecture

Coming to the MLP-Mixer model, quite a few things are going on. We have the input in form of patches, just like how it is for ViT. Then there is a Mixer Layer wherein two operations are happening, the fully connected layers with GELU activations, and finally a linear classification head. Skip-connections and regularization like Dropout and Layer Norm also make their way into the architecture.

MLP-Mixer architecture. Image Credits — MLP-Mixer paper

As we see at the bottom of the architecture, the input to the network is in form of image patches. These patches are projected linearly into an H-dimension latent representation (which is hidden) and passed on to the Mixer layer. One thing to note here is the H value is independent of the number of patches or patch sizes which enabled the network to grow linearly instead of quadratically in the case of ViT. This resulted in reduced computational parameters and a higher throughput of about 120 images/sec/core which is almost 3x than ViT’s 32 images/sec/core.

Now, things get interesting. Referring to the top part of the image which is showcasing the Mixer layer, the patches when projected are converted into a table-like form which we will call X. The patches are layered out on top of each other.

The Mixer layer consists of 2 MLP blocks. The first block (token-mixing MLP block) is acting on the transpose of X, i.e. columns of the linear projection table (X). Every row is having the same channel information for all the patches. This is fed to a block of 2 Fully Connected layers. This block holds the same type of information — Identify features in the image across patches i.e aggregate all channels where this feature occurs. The weights are shared here in the MLP 1 layers shown in the image.

The second block (channel-mixing MLP block) is acting on the rows after another transpose of X. Here, each patch is taken and apply computation across all channels of the patch. This is looking for features only in that patch and associating it with the channel, whereas in token-mixing block it is searching for features in all the channels.

This architecture is a unique CNN as quoted by the authors,

Our architecture can be seen as a unique CNN, which uses (1×1) convolutions for channel mixing, and single-channel depth-wise convolutions for token mixing. However, the converse is not true as CNNs are not special cases of Mixer.

Also, the architecture uses Layer Norm which is often found in Transformer architectures. Below image illustrates Layer Norm vs BatchNorm for a quick understanding. More on Layer Norm here. There are skip connections included with the GELU activation for non-linearity and dropout for regularization.

Layer Normalization vs Batch Normalization. Image Credits — PowerNorm Paper

Benchmarks and Results

Comparison of Mixer with other models across datasets. Image Credits — MLP-Mixer paper

As we see in the table, the performance in terms of ImageNet top-1 validation accuracy is quite good when compared with other SOTAs. But the real deal for MLP-Mixer is two-fold — Throughput of images/sec/core and TPU training time. And the Mixer is far ahead of ViT in terms of throughput of images/sec/core with a score of 105 to the ViT’s 32 (ImageNet) and scores a good 40 images/sec/core against the ViT’s 15 images/sec/core (JFT-300).

The ViT outperforms Mixer when it comes to TPU training time on ImageNet but not JFT-300. So, as we observe more numbers on the table, we can say that the “Mixer-MLP is out there and challenging the other SOTAs. Not necessarily outperforming them, but yeah still very much competitive”.

Coming to the part where “It scales with more data”, the below graph illustrates the performance of Mixer, ViT, and BiT when the training set size increases.

Training subset size vs ImageNet top-1 val accuracy. Image Credits — MLP-Mixer paper

The Mixer is closely matched at the right side of the graph which is indicating 5-shot ImageNet top-1 accuracy. But we need to observe the curves from the left side where the training subset size starts from 10M and goes all the way to 300M. In all the cases, Mixer has a better curve scaling and doesn’t flat-line the curve at any point in time. This means, it is not getting saturated with more training data, whereas ViT and BiT models are saturating (accuracy not improving) once they pass the 100M training data subset size.

Also, in the other full-blown comparison of all variants of the model with other SOTAs considering the pre-training aspect which is correlating to the performance, Mixer benefits when there is more data. Smaller the pre-training data, larger is its gap to SOTAs in terms of validation accuracy. make sure to have a look at this in the paper. Also in the paper is a section very interesting — How MLP-Mixer’s design idea can be traced back to CNNs & Transformers?. That’s why make sure to give the paper a read for the ablations, related works, different experiment setups, visualization of the weights, and things that did not work out!

Conclusion

The MLP-Mixer is right there on top in various important metrics like throughput and training time. It is not far behind in terms of ImageNet top-1 validation accuracy with a score of 87.78 when pre-trained on JFT-300. The network scales excellently when compared to ViT & BiT for the amount of training data provided. However, there is one point to ponder — Inductive Bias. According to the training subset size vs ImageNet top-1 val accuracy graph in the paper, ViT flattens out at the 300M mark whereas Mixer-MLP is still on the rise. What might happen if the scale of data is 600M? Keep a close eye on the role of pre-training dataset sizes and the inductive biases. Having said that, one more interesting comment by the authors is — “Whether this design holds good for NLP?”. Time will tell.

References

[1] MLP-Mixer: https://arxiv.org/pdf/2105.01601.pdf

[2] ViT: https://arxiv.org/pdf/2010.11929

[3] BiT: https://arxiv.org/pdf/1912.11370

[4] HaloNet: https://arxiv.org/pdf/2103.12731

[5] Layer Norm: https://arxiv.org/pdf/2003.07845

[6] NF-Nets: https://arxiv.org/pdf/2102.06171v1.pdf

[7] ImageNet: http://www.image-net.org/

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: