Generative Adversarial Network
1. Introduction
Generative adversarial network (GAN) is a framework for learning to represent an estimate $p_{\text{model}}$ of the distribution $p_{\text{data}}$ from its samples. The standard GAN estimates $p_{\text{data}}$ implicitly by being able to generate samples from $p_{\text{model}}$. In contrast, explicit estimation requires a defined density function $p_{\text{model}}(x; \theta)$.
The main idea behind how GANs learn is through a game between a generator and a discriminator. The generator’s goal is to learn to generate samples from the same distribution as the training data. The discriminator’s goal is to learn to discriminate between fake generated samples and true training samples. Both players optimize the same value function $V(D, G)$, but only with respect to their own parameters (from $D$ or $G$). The solution to this game is referred to as a local differential Nash equilibria, a point in the parameter space where all neighboring points have less than or equal value. After training, the generator is used to generate samples that may aesthetically appear similar to samples from the training data. The following post is based on the original GAN paper as well as the NIPS 2016 tutorial on GANs [1, 2].
2. Method
Let the generator $G$ and discriminator $D$ be parameterized by neural networks. GANs learn by having $G$ and $D$ play the following minimax game
\[\min_{G}\max_{D}V(D, G) = \min_{G}\max_{D}\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))],\]where $p_{z}(z)$ is a prior distribution of input noise variables. The generator learns a mapping from $z$ to the data space. The discriminator $D(x) \in [0, 1]$ outputs a probability that an input sample comes from $p_{\text{data}}$ rather than the generator’s distribution $p_{g}$. This game is a zero-sum game because if $\max_{D}V(D, G)$ is re-expressed as minimizing a cost $\min_{D}-V(D, G) = \min_{D}J^{(D)}$, then $\min_{G}V(D, G) = \min_{G}J^{(G)}$ implies $J^{(G)} = -J^{(D)}$. Zero-sum games are alternatively called minimax games.
Intuitively, \(\max_{D}\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]\) maximizes the log-probability of the discriminator being correct on real samples, and \(\max_{D}\mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D)]\) maximizes the log-probability of being correct on fake samples, because \(\arg\max_{D \in [0, 1]}\log(1 - D) = 0\). For $G$, \(\min_{G}V(D, G)\) amounts to minimizing the log-probability of the discriminator being correct ($D(G(z)) = 1$) since \(\arg\min_{D \in [0, 1]}\log(1 - D) = 1\).
2.1 Optimal Discriminator
For a given $G$, $\arg\max_{D}V(D, G) = \frac{p_{\text{data}}}{p_{\text{data}} + p_{g}}$. This is determined by solving for $D$ such that
\[\frac{\partial}{\partial D}V(D, G) = 0\]for every given $x$. To start, rewrite $V(D, G)$ in terms of $p_{\text{data}}$ and $p_{g}$
\[\begin{align*} V(D, G) &= \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))]\\ &= \int_{x \in \mathcal{X}}\log D(x) p_{\text{data}}(x)dx + \int_{z \in \mathcal{Z}}\log(1 - D(G(z)))p_{z}(z)dz\\ &= \int_{x \in \mathcal{X}}\log D(x) p_{\text{data}}(x)dx + \int_{x \in \mathcal{X}}\log(1 - D(x))p_{g}(x)dx\\ &= \int_{x \in \mathcal{X}}\log D(x) p_{\text{data}}(x) + \log(1 - D(x))p_{g}(x)dx, \end{align*}\]where in the third step, we apply the law of the unconscious statistician with \(\mathbb{E}_{z}[G(z)] = \mathbb{E}_{x}[x]\) for $x = G(z)$. Then
\[\begin{align*} \frac{\partial V(D, G)}{\partial D(x)} = \frac{p_{\text{data}}(x)}{D(x)} - \frac{p_{g}(x)}{1 - D(x)} = 0 \Rightarrow D^{*}(x) = \frac{p_{\text{data}}(x)}{p_{g}(x) + p_{\text{data}}(x)}. \end{align*}\]2.2 Optimal Generator
Define $C(G) = \max_{D}V(G, D)$, then global minimum of $C(G)$ is achieved if and only if $p_{g} = p_{\text{data}}$. At that point, $C(G)$ achieves the value $-\log 4$. We will first show that $\min C(G) = -\log 4$. Start by adding and subtracting $-\log 4$ to $C(G) = V(G, D^{*})$
\[\begin{align*} C(G) &= -\log 4 +2\log 2 + \mathbb{E}_{x \sim p_{\text{data}}(x)}\left[\log \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right] + \mathbb{E}_{x \sim p_{g}}\left[\log \frac{p_{g}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right]\\ &= -\log 4 + \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log 2] + \mathbb{E}_{x \sim p_{g}}[\log 2] + \mathbb{E}_{x \sim p_{\text{data}}(x)}\left[\log \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right]\\ &\quad + \mathbb{E}_{x \sim p_{g}}\left[\log \frac{p_{g}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right]\\ &= -\log 4 + \mathbb{E}_{x \sim p_{\text{data}}(x)}\left[\log \frac{2p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right] + \mathbb{E}_{x \sim p_{g}}\left[\log \frac{2p_{g}(x)}{p_{\text{data}}(x) + p_{g}(x)}\right]\\ &= -\log 4 + KL\left(p_{\text{data}}(x) \bigg\lvert \bigg\lvert \frac{p_{\text{data}}(x) + p_{g}(x)}{2}\right) + KL\left(p_{g}(x) \bigg\lvert \bigg\lvert \frac{p_{\text{data}}(x) + p_{g}(x)}{2} \right)\\ &= -\log 4 + \underbrace{2 JSD(p_{\text{data}} \lvert \lvert p_{g})}_{\geq 0}, \end{align*}\]since Jensen-Shannon divergence $JSD(p_{\text{data}} \lvert \lvert p_{g}) \geq 0$, $\min_{G}C(G) = -\log 4$.
To prove that $\min_{G}C(G) = -\log 4 \Leftrightarrow p_{g} = p_{\text{data}}$. When $p_{g} = p_{\text{data}}$, $JSD(p_{\text{data}} \lvert\lvert p_{g}) = 0$, the minimum value of $C(G) = -\log 4$ is achieved. When $C(G) = -\log 4$, this implies $JSD(p_{\text{data}} \lvert\lvert p_{g}) = 0$ and $p_{\text{data}} = p_{g}$.
In other words, an optimal generator learns the distribution of the data completely.
2.3 Heuristic, non-saturating game
The value function $V(D, G)$ is useful for theoretical analysis but does not perform well in practice. In the initial training iterations, the discriminator tends to reject generator samples with high confidence, and this causes the generator’s gradient to vanish. Concretely, when $D$ rejects samples from $G(z)$ with high confidence, $D(G(z)) \approx 0$. This leads to \(\log (1) = 0 \Rightarrow \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))] = 0\), so the gradient with respect to generator weights will be zero.
The heuristic game addresses this by having the generator minimize \(-\mathbb{E}_{z \sim p_{z}(z)}\log D(G(z))\) instead, which corresponds to maximizing the log-probability of the discriminator being mistaken. In this version of the game, the game is no longer zero-sum.
3. Implementation
We will implement a GAN to generate realistic-looking images. Since the input data are images, the generator and discriminator are chosen to be parameterized by convolutional neural networks. This type of GAN is called the deep convolutional generative adversarial network (DCGAN) [3]. Our implementation of a DCGAN in PyTorch 1.5.1 will largely following the one presented by this resource. The diagram below illustrates the architecture of a DCGAN.
Following the resource, we will also apply the DCGAN to the LWF Face Dataset, which consists of 13,000 images of faces. First download the dataset.
import os
import wget
import tarfile
lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'
data_path = 'data'
wget.download(lfw_url, data_path + "/lfw.tgz")
# extract compressed files
with tarfile.open(data_path + "/lfw.tgz") as tar:
tar.extractall(path = data_path)
'data/lfw.tgz'
We first create a Dataset
class that loads and returns this data. Following Radford and Chintala [3], we normalize the pixel values to $[-1, 1]$, the range of the tanh activation function.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from PIL import Image
import pickle
import matplotlib.pyplot as plt
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch import utils
import torchvision.transforms as transforms
from torchsummary import summary
import torchvision.utils as vutils
# Ignore warnings
import warnings
import math
warnings.filterwarnings("ignore")
class LWFDataset(Dataset):
def __init__(self, data_dir, width = 64, height = 64):
# load data
self.width = width
self.height = height
self.images = []
for path, _, fnames in os.walk(data_dir):
for fname in fnames:
if not fname.endswith('.jpg'):
continue
filepath = os.path.join(path, fname)
img = plt.imread(filepath)
self.images.append(img)
self.images = np.array(self.images)
self.toPIL = transforms.Compose([transforms.ToPILImage()])
self.toTensor = transforms.Compose([transforms.ToTensor()])
def transform(self, img):
img = self.toPIL(img)
img = transforms.functional.resize(img, (self.height, self.width))
img = np.array(img)
#img = np.transpose(np.array(img), (2, 0, 1))
# normalize to [-1, 1]
img = img.astype(np.float32) / 127.5 - 1
# if image is greyscale, repeat 3 times to get RGB image.
if img.shape[0] == 1:
img = np.tile(img, (3, 1, 1))
return img
def __len__(self):
#return size of dataset
return len(self.images)
def __getitem__(self, idx):
#apply transforms and return with label
image = self.transform(self.images[idx])
return self.toTensor(image)
def getUntransformedImage(self, idx):
return self.images[idx]
Download dataset.
lwfDataset = LWFDataset('data/lfw-deepfunneled')
dataloader = torch.utils.data.DataLoader(lwfDataset, batch_size=128, shuffle=True)
Following the architecture guidelines for stable DCGANs
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
- Use batch normalization in both the generator and the discriminator.
- Remove fully connected hidden layers for deeper architectures.
- Use ReLU activation in generator for all layers except for the output, which uses Tanh (so that output range equals input range of normalized pixel values).
- Use LeakyReLU activation in the discriminator for all layers.
We will first implement the generator. The first layer is a “project and reshape” layer, which we implement as a linear, fully-connected layer, to project the input $z \in \mathbb{R}^{100 \times 1 \times 1}$ into a vector in $\mathbb{R}^{8192}$ space, which then gets reshaped into a tensor of shape $\mathbb{R}^{512 \times 4 \times 4}$. The remaining layers are 2D transpose convolution layers for projection into higher dimensional spaces.
class Generator(nn.Module):
def __init__(self, z = 100, imgDim = 64, nc = 3):
super(Generator, self).__init__()
self.imgDim = imgDim
self.z = z
self.nc = nc
self.__initialize_graph__()
self.__initialize_weights__()
def __initialize_graph__(self):
self.main = nn.Sequential(
# transpose Conv1
nn.ConvTranspose2d(in_channels=self.imgDim * 8, out_channels=self.imgDim * 4,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.imgDim * 4),
nn.ReLU(inplace=True),
# transpose Conv2
nn.ConvTranspose2d(in_channels=self.imgDim * 4, out_channels=self.imgDim * 2,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.imgDim * 2),
nn.ReLU(inplace=True),
# transpose Conv3
nn.ConvTranspose2d(in_channels=self.imgDim * 2, out_channels=self.imgDim,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.imgDim),
nn.ReLU(inplace=True),
# transpose Conv4
nn.ConvTranspose2d(in_channels=self.imgDim, out_channels=self.nc,
kernel_size=4, stride=2, padding=1, bias=False),
# G(z)
nn.Tanh()
)
self.project = nn.Linear(in_features=self.z, out_features=self.imgDim * 128, bias=True)
def __initialize_weights__(self):
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, mean=0, std=0.2)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1, 0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# project and reshape
x = torch.flatten(x, start_dim=1)
x = self.project(x)
x = x.view(x.shape[0], 512, 4, 4)
# transpose convolutions
return self.main(x)
Create an instance of generator and view summary of it.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
summary(netG, input_size=(100, 1, 1))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 8192] 827,392
ConvTranspose2d-2 [-1, 256, 8, 8] 2,097,152
BatchNorm2d-3 [-1, 256, 8, 8] 512
ReLU-4 [-1, 256, 8, 8] 0
ConvTranspose2d-5 [-1, 128, 16, 16] 524,288
BatchNorm2d-6 [-1, 128, 16, 16] 256
ReLU-7 [-1, 128, 16, 16] 0
ConvTranspose2d-8 [-1, 64, 32, 32] 131,072
BatchNorm2d-9 [-1, 64, 32, 32] 128
ReLU-10 [-1, 64, 32, 32] 0
ConvTranspose2d-11 [-1, 3, 64, 64] 3,072
Tanh-12 [-1, 3, 64, 64] 0
================================================================
Total params: 3,583,872
Trainable params: 3,583,872
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 2.88
Params size (MB): 13.67
Estimated Total Size (MB): 16.55
----------------------------------------------------------------
We implement the discriminator with the following features:
- Strided convolution in place of pooling to allow network to learn its own pooling function.
- Batch normalization and leaky ReLU to promote healthy gradient flow.
class Discriminator(nn.Module):
def __init__(self, imgDim=64, nc=3, ndf=64):
super(Discriminator, self).__init__()
self.imgDim = imgDim
self.nc = nc
self.ndf = ndf
self.__initialize_graph__()
self.__initialize_weights__()
def __initialize_graph__(self):
self.main = nn.Sequential(
# Conv1
nn.Conv2d(in_channels=self.nc, out_channels=self.ndf, kernel_size=4,
stride=2, padding=1, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# Conv2
nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.ndf * 2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# Conv3
nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.ndf * 4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# Conv4
nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(num_features=self.ndf * 8),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# Conv5
nn.Conv2d(in_channels=self.ndf * 8, out_channels=1,
kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def __initialize_weights__(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, mean=0, std=0.2)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1, 0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x):
return self.main(x)
Create an instance of discriminator and view summary of it.
netD = Discriminator().to(device)
summary(netD, input_size=(3, 64, 64))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 32, 32] 3,072
LeakyReLU-2 [-1, 64, 32, 32] 0
Conv2d-3 [-1, 128, 16, 16] 131,072
BatchNorm2d-4 [-1, 128, 16, 16] 256
LeakyReLU-5 [-1, 128, 16, 16] 0
Conv2d-6 [-1, 256, 8, 8] 524,288
BatchNorm2d-7 [-1, 256, 8, 8] 512
LeakyReLU-8 [-1, 256, 8, 8] 0
Conv2d-9 [-1, 512, 4, 4] 2,097,152
BatchNorm2d-10 [-1, 512, 4, 4] 1,024
LeakyReLU-11 [-1, 512, 4, 4] 0
Conv2d-12 [-1, 1, 1, 1] 8,192
Sigmoid-13 [-1, 1, 1, 1] 0
================================================================
Total params: 2,765,568
Trainable params: 2,765,568
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 2.31
Params size (MB): 10.55
Estimated Total Size (MB): 12.91
----------------------------------------------------------------
Now we implement the training procedure. Following recommendation from the 2016 GAN tutorial [2], we implement one-sided label smoothing, which prevents extreme extrapolation behavior. In other words, label smoothing prevents the discriminator from predicting extremely large logits. This helps the discriminator converge to $D^{*}(x) = \frac{p_{\text{data}}(x)}{p_{g}(x) + p_{\text{data}}(x)}$.
# Initialize binary cross entropy function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(128, 100, 1, 1)
# Establish convention for real and fake labels during training
real_label = 0.9
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
A few notes on implementation details
- For some reason $z \sim \mathcal{N}(0, I)$ performs better compared to $z_{i} \sim Unif(0, 1)$ for $i = 1, \dots ,100$.
- When computing discriminator gradient, call
.detach()
for generated images so that gradient with respect to generator weights will not be computed.
# Training Loop
img_list = []
G_losses = []
D_losses = []
iters = 0
# Lists to keep track of progress
num_epochs = 100
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
b_size = data.size()[0]
label = torch.full((b_size,), real_label, dtype=torch.float)
# Forward pass real batch through D
output = netD(data).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
# Update G
optimizerG.step()
# Output training stats
if i % 5 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Sample training output messages
Starting Training Loop...
[0/20][0/104] Loss_D: 1.4444 Loss_G: 5.3994 D(x): 0.9483 D(G(z)): 0.4281
[0/20][5/104] Loss_D: 1.3202 Loss_G: 3.9553 D(x): 0.9328 D(G(z)): 0.3406
[0/20][10/104] Loss_D: 1.0306 Loss_G: 4.3411 D(x): 0.8991 D(G(z)): 0.2811
[0/20][15/104] Loss_D: 1.0606 Loss_G: 2.9521 D(x): 0.8990 D(G(z)): 0.3044
[0/20][20/104] Loss_D: 1.1041 Loss_G: 2.1576 D(x): 0.6295 D(G(z)): 0.1640
[0/20][25/104] Loss_D: 1.6851 Loss_G: 1.2120 D(x): 0.3495 D(G(z)): 0.0336
[0/20][30/104] Loss_D: 1.3260 Loss_G: 1.7409 D(x): 0.4221 D(G(z)): 0.0491
[0/20][35/104] Loss_D: 1.0421 Loss_G: 2.4594 D(x): 0.5918 D(G(z)): 0.1079
[0/20][40/104] Loss_D: 1.1397 Loss_G: 4.1079 D(x): 0.5686 D(G(z)): 0.0549
[0/20][45/104] Loss_D: 0.7310 Loss_G: 5.0861 D(x): 0.8848 D(G(z)): 0.1409
4. Results
Visualize the progression of the generated images from the same input noise $z$.
import matplotlib.animation as animation
from IPython.display import HTML
fig = plt.figure(figsize=(20,10))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())