Original Source Here
Implementing RepVGG in PyTorch
Make your CNN >100x faster
Hello There!! Today we’ll see how to implement RepVGG in PyTorch proposed in RepVGG: Making VGG-style ConvNets Great Again
Code is here, an interactive version of this article can be downloaded from here.
Let’s get started!
The paper proposed a new architecture that can be tuned after training to make it faster on modern hardware. And by faster I mean lighting fast, this idea was used by Apple’s MobileOne model.
Single vs Multi Branch Models
A lot of recent models use multi-branching, where the input is passed through different layers and then aggregated somehow (usually with addition).
This is great because it makes the multi-branch model an implicit ensemble of numerous shallower models. More specifically, the model can be interpreted as an ensemble of 2^n models since every block branches the flow into two paths.
Unfortunately, multi-branch models consume more memory and are slower than single-branch ones. Let’s create a classic ResNetBlock
to see why (check out my article about ResNet in PyTorch).
Storing the residual
double memory consumption. This is also shown in the following image from the paper
The authors noticed that the multi branch architecture is useful only at train time. Thus, if we can have a way to remove it at test time we can improve the model speed and memory consumption.
From Multi Branches to Single Branch
Consider the following situation, you have two branches composed of two 3x3
convs
torch.Size([1, 8, 5, 5])
Now, we can create one conv, let’s call it conv_fused
such that conv_fused(x) = conv1(x) + conv2(x)
. Very easily, we can just sum up the weight
s and the bias
of the two convs! Thus we only need to run one conv
instead of two.
Let’s see how much faster it is!
conv1(x) + conv2(x) tooks 0.000421s
conv_fused(x) tooks 0.000215s
Almost 50%
less (keep in mind this is a very naive benchmark, we will see a better one later on)
Fuse Conv and BatchNorm
In modern network architectures, BatchNorm
is used as a regularization layer after a convolution block. We may want to fuse them together, so create a conv such that conv_fused(x) = batchnorm(conv(x))
. The idea is to change the weights of conv
in order to incorporate the shifting and scaling from BatchNorm
.
The paper explains it as follows:
The code is the following:
Let’s see if it works
yes, we fused a Conv2d
and a BatchNorm2d
layer. There is also an article from PyTorch about this
So our goal is to fuse all the branches in one single conv, making the network faster!
The author proposed a new type of block, called RepVGG
. Similar to ResNet, it has a shortcut but it also has an identity connection (or better branch).
In PyTorch:
Reparametrization
We have one 3x3
conv->bn
, one 1x1
conv-bn
and (sometimes) one batchnorm
(the identity branch). We want to fused them together to create one single conv_fused
such that conv_fused
= 3x3conv-bn(x) + 1x1conv-bn(x) + bn(x)
or if we don’t have an identity connection, conv_fused
= 3x3conv-bn(x) + 1x1conv-bn(x)
.
Let’s go step by step. To create conv_fused
we have to:
- fuse the
3x3conv-bn(x)
into one3x3conv
1x1conv-bn(x)
, then convert it to a3x3conv
- convert the identity,
bn
, to a3x3conv
- add all the three
3x3conv
s
Summarized by the image below:
The first step it’s easy, we can use get_fused_bn_to_conv_state_dict
on RepVGGBlock.block
(the main 3x3 conv-bn
).
The second step is similar, get_fused_bn_to_conv_state_dict
on RepVGGBlock.shortcut
(the 1x1 conv-bn
). Then we pad each kernel of the fused 1x1
by 1
in each dimension creating a 3x3
.
The identity bn
is trickier. We need to create a 3x3
conv
that will act as an identity function and then use get_fused_bn_to_conv_state_dict
to fuse it with the identity bn
. This can be done by having 1
in the center of the corresponding kernel for that corresponding channel.
Recall that a conv’s weight is a tensor of in_channels, out_channels, kernel_h, kernel_w
. If we want to create an identity conv, such that conv(x) = x
, we need to have one single 1
for that channel.
For example:
https://gist.github.com/c499d53431d243e9fc811f394f95aa05
torch.Size([2, 2, 3, 3])
Parameter containing:
tensor([[[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]], [[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], [[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]]], requires_grad=True)
See, we created a Conv
that acts like an identity function.
Now, putting everything together, this step is formally called reparametrization
Finally, let’s define a RepVGGFastBlock
. It’s only composed by a conv + relu
and add a to_fast
method to RepVGGBlock
to quickly create the correct RepVGGFastBlock
RepVGG
Let’s define RepVGGStage
(collection of blocks) and RepVGG
with a handy switch_to_fast
method that will swith to the fast block in-place:
Let’s test it out!
I’ve created a benchmark inside benchmark.py
, running the model on my gtx 1080ti with different batch sizes and this is the result:
The model has two layers per stage, four stages, and widths of 64, 128, 256, 512
.
In their paper, they scale these values by some amount (called a
and b
) and they used grouped convs. Since we are more interested in the reparametrization part, I skip them.
Yeah, so basically the reparametrization model is on a different scaled time compared to the vanilla one. Wow!
Let me copy and paste the dataframe I used to store the benchmark
You can see that the default model (multi branch) tooks 1.45
s for a batch_size=128
while the parametrized one (fast) only took 0.0134
s.That is 108x 🚀🚀🚀
.
Conclusions
Conclusions In this article we have seen, step by step, how to create RepVGG; a blazing fast model using a clever reparameterization technique.
This technique can be ported to other architecture as well.
Thank you for reading it!
👉 Implementing SegFormer in PyTorch
👉 Implementing ConvNext in PyTorch
Francesco
AI/ML
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot