Implementing RepVGG in PyTorch



Original Source Here

Photo by Alberto Restifo on Unsplash

Implementing RepVGG in PyTorch

Make your CNN >100x faster

Hello There!! Today we’ll see how to implement RepVGG in PyTorch proposed in RepVGG: Making VGG-style ConvNets Great Again

Code is here, an interactive version of this article can be downloaded from here.

Let’s get started!

The paper proposed a new architecture that can be tuned after training to make it faster on modern hardware. And by faster I mean lighting fast, this idea was used by Apple’s MobileOne model.

Image by Xiaohan Ding, Xiangyu Zhang, Ningning Ma, Jungong Han, Guiguang Ding, Jian Sun

Single vs Multi Branch Models

A lot of recent models use multi-branching, where the input is passed through different layers and then aggregated somehow (usually with addition).

Image by the Author

This is great because it makes the multi-branch model an implicit ensemble of numerous shallower models. More specifically, the model can be interpreted as an ensemble of 2^n models since every block branches the flow into two paths.

Unfortunately, multi-branch models consume more memory and are slower than single-branch ones. Let’s create a classic ResNetBlock to see why (check out my article about ResNet in PyTorch).

Storing the residual double memory consumption. This is also shown in the following image from the paper

The authors noticed that the multi branch architecture is useful only at train time. Thus, if we can have a way to remove it at test time we can improve the model speed and memory consumption.

From Multi Branches to Single Branch

Consider the following situation, you have two branches composed of two 3x3 convs

torch.Size([1, 8, 5, 5])

Now, we can create one conv, let’s call it conv_fused such that conv_fused(x) = conv1(x) + conv2(x). Very easily, we can just sum up the weights and the bias of the two convs! Thus we only need to run one conv instead of two.

Let’s see how much faster it is!

conv1(x) + conv2(x) tooks 0.000421s
conv_fused(x) tooks 0.000215s

Almost 50% less (keep in mind this is a very naive benchmark, we will see a better one later on)

Fuse Conv and BatchNorm

In modern network architectures, BatchNorm is used as a regularization layer after a convolution block. We may want to fuse them together, so create a conv such that conv_fused(x) = batchnorm(conv(x)). The idea is to change the weights of conv in order to incorporate the shifting and scaling from BatchNorm.

The paper explains it as follows:

The code is the following:

Let’s see if it works

yes, we fused a Conv2d and a BatchNorm2d layer. There is also an article from PyTorch about this

So our goal is to fuse all the branches in one single conv, making the network faster!

The author proposed a new type of block, called RepVGG. Similar to ResNet, it has a shortcut but it also has an identity connection (or better branch).

Image by Xiaohan Ding, Xiangyu Zhang, Ningning Ma, Jungong Han, Guiguang Ding, Jian Sun

In PyTorch:

Reparametrization

We have one 3x3 conv->bn, one 1x1 conv-bn and (sometimes) one batchnorm (the identity branch). We want to fused them together to create one single conv_fused such that conv_fused = 3x3conv-bn(x) + 1x1conv-bn(x) + bn(x) or if we don’t have an identity connection, conv_fused = 3x3conv-bn(x) + 1x1conv-bn(x).

Let’s go step by step. To create conv_fused we have to:

  • fuse the 3x3conv-bn(x) into one 3x3conv
  • 1x1conv-bn(x), then convert it to a 3x3conv
  • convert the identity, bn, to a 3x3conv
  • add all the three 3x3convs

Summarized by the image below:

Image by Xiaohan Ding, Xiangyu Zhang, Ningning Ma, Jungong Han, Guiguang Ding, Jian Sun

The first step it’s easy, we can use get_fused_bn_to_conv_state_dict on RepVGGBlock.block (the main 3x3 conv-bn).

The second step is similar, get_fused_bn_to_conv_state_dict on RepVGGBlock.shortcut (the 1x1 conv-bn). Then we pad each kernel of the fused 1x1 by 1 in each dimension creating a 3x3.

The identity bn is trickier. We need to create a 3x3 conv that will act as an identity function and then use get_fused_bn_to_conv_state_dict to fuse it with the identity bn. This can be done by having 1 in the center of the corresponding kernel for that corresponding channel.

Recall that a conv’s weight is a tensor of in_channels, out_channels, kernel_h, kernel_w. If we want to create an identity conv, such that conv(x) = x, we need to have one single 1 for that channel.

For example:

https://gist.github.com/c499d53431d243e9fc811f394f95aa05

torch.Size([2, 2, 3, 3])
Parameter containing:
tensor([[[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]]], requires_grad=True)

See, we created a Conv that acts like an identity function.

Now, putting everything together, this step is formally called reparametrization

Finally, let’s define a RepVGGFastBlock. It’s only composed by a conv + relu

and add a to_fast method to RepVGGBlock to quickly create the correct RepVGGFastBlock

RepVGG

Let’s define RepVGGStage (collection of blocks) and RepVGG with a handy switch_to_fast method that will swith to the fast block in-place:

Let’s test it out!

I’ve created a benchmark inside benchmark.py, running the model on my gtx 1080ti with different batch sizes and this is the result:

The model has two layers per stage, four stages, and widths of 64, 128, 256, 512.

In their paper, they scale these values by some amount (called a and b) and they used grouped convs. Since we are more interested in the reparametrization part, I skip them.

Image by the Author

Yeah, so basically the reparametrization model is on a different scaled time compared to the vanilla one. Wow!

Let me copy and paste the dataframe I used to store the benchmark

You can see that the default model (multi branch) tooks 1.45s for a batch_size=128 while the parametrized one (fast) only took 0.0134s.That is 108x 🚀🚀🚀

.

Conclusions

Conclusions In this article we have seen, step by step, how to create RepVGG; a blazing fast model using a clever reparameterization technique.

This technique can be ported to other architecture as well.

Thank you for reading it!

👉 Implementing SegFormer in PyTorch

👉 Implementing ConvNext in PyTorch

Francesco

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: