You Can Just Put an Endpoint Penalty on Your Wasserstein GAN

2024 February 3

When training a Wasserstein GAN, there is a very important constraint that the discriminator network must be a Lipschitz-continuous function. Roughly we can think of this as saying that the output of the function can't change too fast with respect to position, and this change must be bounded by some constant K. If the discriminator function is given by f(x): ℝ
then we can write the Lipschitz condition for the discriminator as:
|f(x) - f(y)| ≤ K |x - y|
Usually this is implemented as a gradient penalty. People will take a gradient (higher order, since the loss already has a gradient in it) of this loss (for K = 1):
𝓛  = (|∇f(x)| - 1)
In this expression x is sampled as x = αx
 + (1-α)x
, a random mixture of a real and a generated data point.
But this is complicated to implement, involving a higher order gradient. It turns out we can also just impose the Lipschitz condition directly, via the following penalty:
l(xy) = ReLU(
|f(x) - f(y)|
K |x - y|
 - 1)
Except to prevent issues where we're maybe sometimes dividing by zero, we throw in an ε = 10
and a reweighting factor of |x - y|
(not sure if that is fully necessary, but the intuition is that making sure the Lipschitz condition is enforced for points at large separation is the most important thing).
l(xy) = ReLU(
|f(x) - f(y)|
K |x - y| + ε
 - 1) |x - y|
For the overall loss, we compare all pairwise distances between real data and generated data and a random mixture of them. Probably it improves things to add 1 or two more random mixtures in, but I'm not sure and haven't tried it.
𝓛 = l(x
) + l(x
x) + l(xx
In any case, this seems to work decently well (tried on mnist), so it might be a simpler alternative to gradient penalty. I also used instance noise, which as pointed out here, is amazingly good for preventing mode collapse and just generally makes training easier. So yeah, instance noise is great and you should use it. And if you really don't want to figure out how to do higher order gradients in pytorch for your WGAN, you've still got options.