Original Source Here
Demystified: Wasserstein GAN with Gradient Penalty
What is Gradient Penalty? Why is it better than gradient clipping? How to implement Gradient Penalty?
In this post we will look into Wasserstein GANs with Gradient Penalty. While the original Wasserstein GAN improves training stability, there still are cases where it generate poor samples or fails to converge. Recap that the cost function of WGAN is:
where f is 1-Lipschitz continuous. The issues with WGAN arises mainly because of the weight clipping method used to enforce Lipschitz continuity on the critic. WGAN-GP replaces weight clipping with a constraint on the gradient norm of the critic to enforce Lipschitz continuity. This allows for more stable training of the network than WGAN and requires very little hyper-parameter tuning. WGAN-GP and this post builds on top of Wasserstein GANs, which has already been discussed in a previous post in the Demystified Series. Check the post below to learn about WGAN.
The optimal 1-Lipschitz function that is differentiable, f* that minimises Eq. 1 has unit gradient norm almost everywhere under ℙr and ℙg.
ℙr and ℙg are the real and fake distributions respectively. Proof for statement 1 can be found in .
Issues with Gradient Clipping
Using Weight clipping to enforce the k-Lipschitz constraint leads to the critic learning very simple functions.
From Statement 1, we know that the gradient norm of the optimal critic is 1 almost everywhere in both ℙr and ℙg. In the weight clipping set up, the critic tries to attain its maximum gradient norm k, and ends up learning simple functions.
Fig. 2 shows this effect. The critics are trained to convergence with fixed generated distribution(ℙg) as real distribution(ℙr)+ unit gaussian noise. We can clearly see that Critic trained with weight clipping ends up learning simple functions and fails to capture higher moments whereas the critic trained with Gradient Penalty does not suffer from this issue.
Exploding and Vanishing Gradients
The interaction between the weight constraint and the loss function makes training of WGAN difficult and leads to exploding or vanishing gradients.
This can be clearly seen in Fig.1 (left) where the weights of the critic explodes or vanishes for different clipping values. Fig. 1 (right) also shows that the gradient clipping pushes weights of the critic to the two extreme clipping values. On the other hand, the critic trained with Gradient Penalty does not suffer from such issues.
The idea of Gradient Penalty is to enforce a constraint such that the gradients of the critic’s output w.r.t the inputs to have unit norm (Statement 1).
The authors propose a soft version of this constraint with penalty on the gradient norm on the samples x̂ ∈ ℙx̂. The new objective is
In Eq. 2 the terms to the left of the sum is the original critic loss and the term to the right of the sum is the gradient penalty.
ℙx̂ is the distribution obtained by uniformly sampling along a straight line between the real and generated distributions ℙr and ℙg. This is done because the optimal critic has straight lines with unit gradient norm between the samples coupled from ℙr and ℙg.
λ, the penalty coefficient is used to weight the gradient penalty term. In the paper, the authors set λ=10 for all their experiments.
Batch Normalisation is not used in the critic anymore because batch norm maps a batch of inputs to a batch of outputs. In our case we want to be able to find gradients of each output w.r.t their respective inputs.
The implementation of gradient penalty is shown below.
The code to train the WGAN-GP model can be found here:
Fig.3 shows some early results from training the WGAN-GP. Please note that the images in Fig. 3 are early results, the training was stopped as soon as it was confirmed that the model was training as expected. The model was not trained to convergence.
Wasserstein GANs offer much needed stability in training Generative Adversarial Networks. However, the use of gradient clipping leads of various issues such as exploding and vanishing gradients, etc. The Gradient Penalty constraint does not suffer from these issues and therefore allows for easier optimisation and convergence compared to the original WGAN. This post looked at these issues, introduced the Gradient Penalty constraint and also showed how to implement Gradient Penalty using PyTorch. Finally the code to train WGAN-GP model along with some early stage outputs were provided.
If you liked this post, consider following the author, Aadhithya Sankar.
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot