Demystified: Wasserstein GANs (WGAN)

What is the Wasserstein distance? What is the intuition behind using Wasserstein distance to train GANs? How is it implemented?

Aadhithya Sankar
Towards Data Science

--

Fig. 1: Optimal discriminator and critic when learning to differentiate two Gaussians[1].

In this article we will read about Wasserstein GANs. Specifically we will focus on the following: i) What is Wasserstein distance?, ii) Why use it? iii) How do we use it to train GANs?

Wasserstein Distance

The Wasserstein distance (Earth Mover’s distance) is a distance metric between two probability distributions on a given metric space. Intuitively, it can be seen as the minimum work needed to transform one distribution to another, where work is defined as the product of mass of the distribution that has to be moved and the distance to be moved. Mathematically, it is defined as:

Eq. 1: Wasserstein Distance between distributions P_r and P_g.

In Eq. 1, Π(P_r ,P_g ) is the set of all joint distributions over x and y such that the marginal distributions are equal to P_r and P_g. γ(x, y) can be seen as the amount of mass that must be moved from x to y to transform P_r to P_g[1]. The Wasserstein distance is then the cost of the optimal transport plan.

Wasserstein Distance vs. Jensen-Shannon Divergence

The original GAN objective is shown to be the minimisation of the Jensen-Shannon Divergence[2]. The JS Divergence is defined as:

Eq. 2: JS Divergence between P_r and P_g. P_m = (P_r + P_g)/2

Compared to JS, Wasserstein distance has the following advantages:

  • Wasserstein Distance is continuous and almost differentiable everywhere, which allows us to train the model to optimality.
  • JS Divergence locally saturates as the discriminator gets better, thus the gradients becomes zero and vanishes.
  • Wasserstein distance is a meaningful metric, i.e, it converges to 0 as the distributions get close to each other and diverges as they get farther away.
  • Wasserstein Distance as objective function is more stable than using JS divergence. The mode collapse problem is also mitigated when using Wasserstein distance as the objective function.

From Fig. 1 we see clearly that the the optimal GAN discriminator saturates and results in vanishing gradients while the WGAN critic optimising the Wasserstein distance has stable gradients throughout.

For mathematical proof and a more detailed look into this topic, check the paper out here!

Wasserstein GAN

Now that it can be clearly seen that optimising for Wasserstein Distance makes more sense than optimising JS Divergence, it is also to be noted that the Wasserstein Distance defined in Eq.1 is highly intractable[3] as we cannot possibly compute the infimum(greatest lower bound) over all γ ∈Π(Pr ,Pg ). However, from the Kantorovich-Rubinstein duality we have,

Eq. 3: Wasserstein Distance under 1-Lipschitz condition.

here we have W(P_r, P_g) as the supremum(Lowest upper bound) over all 1-Lipschitz functions f: X → R.

K-Lipschitz continuity: given 2 metric spaces (X, d_X) and (Y, d_Y), the transformation function f: X → Y is K-Lipschitz Continuous if

Eq. 3: K-Lipschitz Continuity.

where d_X and d_Y are distance functions in their respective metric spaces. When a function is K-Lipschitz, from Eq. 2, we end up with K ∙ W(P_r, P_g).

Now, if we have a family of parameterised functions {f_w} where w∈W that are K-Lipschitz continuous, we can have

Eq. 4

i.e., the w∈W maximises Eq. 4 gives the Wasserstein distance multiplied by a constant.

WGAN Critic

To this effect the WGAN introduces a critic instead of the discriminator we’ve come to know with GANs. The critic network is similar in design to a discriminator network, but predicts the Wasserstein distance by optimising to find w* that will maximise Eq 4. To that end, the objective function of the critic is as follows:

Eq. 5: Critic Objective Function.

Here, to enforce Lipschitz continuity on the function f, the authors resort to restricting the weights w to a compact space. This is done by clamping the weights to a small range([-1e-2, 1e-2] in the paper[1]).

The difference between the discriminator and the critic is that the discriminator is trained to correctly identify the samples from P_r from samples from P_g, the critic estimates the Wasserstein distance between P_r and P_g.

Here is the python code to train the Critic.

WGAN Generator Objective

Naturally, the objective of the generator is to minimise the Wasserstein distance between P_r and P_g. The generator tries to find θ* that minimizes the Wasserstein distance between P_g and P_r. To that end, the objective function of the generator is as follows:

Eq. 6: Generator Objective Function.

Here, again the main difference between a WGAN generator and the standard generator is that the WGAN generator tries to minimise the Wasserstein distance between P_r and P_g while the standard generator tries to fool the discriminator with the generated images.

Here is the python code to train the generator:

Training Results

Fig. 2: Early results from WGAN training[3].

Fig.2 shows some early results from training the WGAN. Please note that the images in Fig. 2 are early results, the training was stopped as soon as it was confirmed that the model was training as expected.

Code

The complete implementation of the Wasserstein GAN can be found here[3].

Conclusion

WGANs offer much stable training and a meaningful training objective. This article introduced and gave an intuitive explanation about what Wasserstein Distance is, the advantages Wasserstein Distance has over Jensen-Shannon divergence used by the standard GAN and how Wasserstein distance is used to train the WGAN. We also saw code snippets to train the Critic and generator and ample outputs from early satges of training the model. Although WGANs have many advantages over the standard GAN, the authors of the WGAN paper clearly acknowledge that weight clipping is not the optimal way to enforce Lipschitz Continuity[1]. To fix this they propose the Wasserstein GAN with Gradient Penalty[4], which we will discuss in a later post.

If you liked this check the next post in this series which talks about WGAN-GP!

References

[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. “Wasserstein generative adversarial networks.” International conference on machine learning. PMLR, 2017.

[2] Goodfellow, Ian, et al. “Generative adversarial networks.” Communications of the ACM 63.11 (2020): 139–144.

[3] gan-zoo-pytorch (https://github.com/aadhithya/gan-zoo-pytorch).

[4] Gulrajani, Ishaan, et al. “Improved training of wasserstein gans.” arXiv preprint arXiv:1704.00028 (2017).

--

--

MSc. Informatics @ TU Munich. Specialised in Deep Learning for CV and Medical imaging.