Open In Colab

Autoencoders, Variational Autoencoders and Conditional Varaiational Autoencoders#

Autoencoders (AEs) are neural networks designed for unsupervised learning, primarily used for dimensionality reduction and feature extraction. They consist of an encoder, which compresses input data into a lower-dimensional latent space, and a decoder, which reconstructs the original input from this compressed representation. Standard autoencoders can learn meaningful representations but often struggle with generating diverse, high-quality samples.

Variational Autoencoders (VAEs) improve upon AEs by introducing a probabilistic framework, ensuring that the learned latent space follows a structured distribution (typically Gaussian). Instead of mapping inputs to fixed points in the latent space, VAEs learn distributions, allowing for smooth interpolation and controlled sampling. This makes them powerful for generative tasks, as they can generate novel data points that resemble the training distribution.

Conditional Variational Autoencoders (cVAEs) extend VAEs by incorporating additional information, such as labels or descriptors, into the encoding and decoding process. This enables the generation of data conditioned on specific attributes, making cVAEs particularly useful in chemistry for designing molecules with desired properties.

These models form the backbone of modern generative approaches in machine learning, with applications in molecular design, materials discovery, and reaction prediction.

Further Reading on Autoencoders, VAEs, and cVAEs#

Autoencoders#

Variational Autoencoders (VAEs)#

Conditional Variational Autoencoders (cVAEs)#

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

The Dataset#

In this notebook, we will use the mnist-digits dataset. It is simpler than the mnist-fashion dataset, allowing us to use only two latent features so that we can conveniently visualise and examine the encodings distribution in the latent space.

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
mnist_train = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

Take a look at the size of the data.#

We have 60k (28, 28) iamges over 10 separate classes.

# Print dataset info
print(f'Training dataset size: {len(mnist_train)}')
print(f'Image size (pixels per sample): {mnist_train.data.shape[1] * mnist_train.data.shape[2]}')
print(f'Number of classes: {len(mnist_train.classes)}')

Dataset visualisation#

Look at some of the samples from the data#

def plot_images(images, title="Images"):
    fig, axes = plt.subplots(4, 4, figsize=(5, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].squeeze().reshape(28, 28), cmap='gray')
        ax.axis('off')
    plt.suptitle(title)
    plt.show()

# Visualize some training images
images, _ = next(iter(train_loader))
plot_images(images[:16], "Training Data Samples")

Autoencoder#

Why we need a VAE? To answer this question, let us start with an ordinary AE and see what is unsatisfactory when we use it to generate new images.

Model Structure#

The autoencoder consists of two main components: an encoder and a decoder.

Encoder#

  • The encoder compresses the input image into a latent representation.

  • It consists of:

    • A fully connected layer (fc1) that transforms the input into a hidden representation.

    • A final fully connected layer (fc_latent) that reduces the hidden representation into a 2D latent space.

Decoder#

  • The decoder reconstructs the original image from the latent representation.

  • It consists of:

    • A fully connected layer (fc_decode) that expands the latent representation back to the hidden dimension.

    • A final fully connected layer (fc_out) that reconstructs the original image using a sigmoid activation.

Loss Function#

The autoencoder is trained using the binary cross-entropy (BCE) loss:

\( \mathcal{L}(x, \hat{x}) = - \sum x \log(\hat{x}) + (1 - x) \log(1 - \hat{x}) \)

where:

  • \(x\) is the original input image.

  • \(\hat{x}\) is the reconstructed image.

  • The loss measures how well the reconstructed image matches the original.

This approach helps the autoencoder learn meaningful features from the data and reconstruct images accurately.

Training Process#

  • The model is trained using the Adam optimizer with a learning rate of \(10^{-3}\).

  • The loss is minimized over multiple epochs using a batch size of 128.

  • The encoder’s learned representations can be visualized in a 2D latent space to analyze digit clustering.

Visualization#

  1. Reconstruction: Displays original and reconstructed images to assess model performance.

  2. Latent Space: Plots the encoded representations to understand digit distributions.

# Define the Autoencoder
class Autoencoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
        super(Autoencoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_latent = nn.Linear(hidden_dim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_latent(h)

    def decode(self, z):
        h = F.relu(self.fc_decode(z))
        return torch.sigmoid(self.fc_out(h))

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

Training#

Set up the training loop#

Using this code template, set up the code to train for 20 epochs with a latent dimensionality of 2. Use a learning rate of 0.001

# Initialize the model, optimizer, and train
epochs = 
latent_dim = 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = Autoencoder(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(autoencoder.parameters(), lr=)
# Initialize the model, optimizer, and train
epochs = 20
latent_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = Autoencoder(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3)
# Loss function
def loss_function(recon_x, x):
    return F.binary_cross_entropy(recon_x, x, reduction='sum')

Run training#

# Training loop
for epoch in range(epochs):
    autoencoder.train()
    train_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        recon_x = autoencoder(x)
        loss = loss_function(recon_x, x)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}")

Visulaise the results#

Reconstructed outputs#

We first plot a series of input images and their outputs.

# Visualize original and reconstructed images
def visualize_reconstructions():
    autoencoder.eval()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            recon_x = autoencoder(x)
            break
    x = x.cpu().view(-1, 28, 28)[:10]
    recon_x = recon_x.cpu().view(-1, 28, 28)[:10]
    
    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):
        axes[0, i].imshow(x[i], cmap="gray")
        axes[1, i].imshow(recon_x[i], cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].axis("off")
    plt.show()

visualize_reconstructions()

Visualise the latent space#

We can look at the latent space and see how the different digits are distributed. Note that although this was technically an unsupervised learning exercise, we do have access to labels, but they are never used in training.

# Visualize latent space
def visualize_latent_space():
    autoencoder.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            z = autoencoder.encode(x)
            embeddings.append(z.cpu())
            labels.append(y.cpu())
    
    embeddings = torch.cat(embeddings).numpy()
    labels = torch.cat(labels).numpy()
    
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels, cmap="tab10", alpha=0.5)
    plt.colorbar(scatter, label="Digit Label")
    plt.title("Latent Space Visualization")
    plt.xlabel("z1")
    plt.ylabel("z2")
    plt.show()

visualize_latent_space()

Visualise Digit distributions#

For a differnt perspective we can plot the distribtion of the latent dimensions of the different digits.

# KDE Plot for each digit in subplots
def visualize_kde_plots():
    autoencoder.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            z = autoencoder.encode(x)
            embeddings.append(z.cpu())
            labels.append(y.cpu())
    
    embeddings = torch.cat(embeddings).numpy()
    labels = torch.cat(labels).numpy()
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()
    for digit in range(10):
        idx = labels == digit
        sns.kdeplot(x=embeddings[idx, 0], y=embeddings[idx, 1], fill=True, ax=axes[digit])
        axes[digit].set_title(f"Digit {digit}")
        axes[digit].set_xlabel("z1")
        axes[digit].set_ylabel("z2")
    plt.tight_layout()
    plt.show()


visualize_kde_plots()

Regularity of the latent space#

Both the scatter plot and the histogram plots show that the data distributions in the latent space are rather irregular. Some of the digits have very wide distributions (such as 1 and 7) and some very narrow distributions (such as 2 and 3).

Remember that our goal of training this AE is neither dimensionality reduction nor denoising but to generate new images out of the original dataset. Image generation is done by the decoder, taking the latent representation (X and Y in the plots) as the input. An irregular latent space makes image generation less controllable and robust. Taking our case for example, two shortcomings are likely to emerge:

  1. Controllability: sampling the entire latent space, we will generate much more of the widely distributed digits than the narrowly distributed ones; instead, if we limit the range of the latent space, we will loss some characteristics of the widely distributed ones;

  2. Robustness: images that do not resemble any of the digits will be generated by the gaps between the distributions of the digits; such gaps increase with the range of the latent space.

Variational Autoencoder#

Overfitting is the essential reason behind an irregular latent space of a naive AE, that is, the neural networks for encoding and decoding try their best to fit the data from end to end without caring about how the latent space is organised with respect to the original data. A VAE can regularise the latent space by imposing additional distributional properties on the latent space.

The following figure summarises the two extensions from an AE to a VAE:

  1. Unlike a naive AE that encodes an input data \(x\) as a single point \(z\) in the latent space, a VAE encodes it as a normal distribution \(\mathcal{N}(\mu, \sigma)\), and the latent representation \(z\) is sampled from this distribution and then passed to the decoder;

  2. An AE only minimises the reconstruction error \(\lVert x-x'\rVert^2\) to fit the data, whereas a VAE minimises the sum of the reconstruction error and the KL divergence (Kullback–Leibler divergence) between the latent distribution \(\mathcal{N}(\mu, \sigma)\) and the standard normal distribution \(\mathcal{N}(0, 1)\).

How does a VAE regularise the latent space? The loss function provides a straightforward answer: in addition to fitting the data by minimising the reconstruction error, it also drags the latent distribution to a standard normal distribution. The final model is a trade-off between the two effects. Also, because each input image is encoded as a Gaussian blob instead of a single point, the gaps in the latent space can be filled by such blurring so that meaningless decodings can be largely avoided.

ae-vae.png

📚 Theory for VAEs

Reveal / Hide

Derivation of ELBO#

Posterior Probability \(p(z|x)\), which can be expressed as:

\begin{eqnarray} p(z|x)&=&\frac{p(x|z)p(z)}{p(x)}\nonumber\ &=& \frac{p(x|z)p(z)}{\int p(x|z)p(z)} \end{eqnarray}

where \(\int p(x|z)p(z)\), which is the marginal, can be intractable and cannot be computed directly. One way to compute the overall solution \(p(z|x)\) is using Monte Carlo methods (such as sampling). The method used in this notebook (and the underlying VAE paper) is variational inference.

The idea is to identify another proxy distribution \(q(z|x)\) that reasonably approximates \(p(z|x)=p(x|z)p(z)\). i.e. if the KL-divergence between two pdfs, \(q(x)\) and \(p(z|x)\) is denoted by

\[\mathrm{KL}(q(x)||p(z|x))\]

it can be minimized by selecting an alternative pdf \(q(z|x)\), which is a good proxy for \(p(z|x)\). But

\begin{eqnarray} \mathrm{KL}(q(z|x)||p(z|x)) &=& -\int q(z|x)\log\frac{p(z|x)}{q(z|x)} dz\nonumber\ &=& -\int q(z|x)\log\frac{p(x|z)p(z)}{p(x)q(z|x)} dz\nonumber\ &=& -\int q(z|x)\log\frac{p(x|z)p(z)}{q(z|x)}dz + \int_{z} q(z|x)\log p(x)dz \nonumber\ &=& -\int q(z|x)\log\frac{p(x|z)p(z)}{q(z|x)} + \log p(x)\int_{z} q(z|x)dz\nonumber\
&=& -\int q(z|x)\log\frac{p(x|z)p(z)}{q(z|x)}dz + \log p(x)\nonumber\ &=& -\int q(z|x)\log\frac{p(z)}{q(z|x)}dz -\int q(z|x)\log{p(x|z)}dz + \log p(x) \end{eqnarray}

Given that \(\mathrm{KL}\left(q(z|x)||p(z|x)\right)\geq 0\),

\begin{eqnarray} -\int q(z|x)\log\frac{p(z)}{q(z|x)}dz -\int q(z|x)\log{p(x|z)}dz + \log p(x) &\geq& 0 \ \log p(x) &\geq& \int q(z|x)\log\frac{p(z)}{q(z|x)}dz + \int q(z|x)\log{p(x|z)}dz\ \log p(x) &\geq& - \mathrm{KL}(q(z|x)||p(z)) + \int q(z|x)\log p(x|z)dz \nonumber\ \log p(x) &\geq& - \mathrm{KL}(q(z|x)||p(z)) + \mathbb{E}_{q(z|x)}\left[\log p(x|z)\right] \nonumber\ \end{eqnarray}

This is the variational lower-bound, or the evidence of lower bound (ELBO). This remains as the objective function for the VAE. However, frameworks like TensorFlow or PyTorch need a loss function to be minimized. Maximising the log likelihood of the model evidence \(p(x)\) is same as minimizing the \(-\log p(x)\). The first term of the ELBO, namely, \(\mathrm{KL}(q(z|x)||p(z))\) is the regularising term and constrains the posterior distribution. The second term of the ELBO models the reconstruction loss.

Now, this leaves fair bit of freedom on the choice of the prior \(p(z)\). Let’s assume:

\[ p(z)={\cal N}(\mu_p, \sigma_p^2) \]

and

\[ q(z|x)={\cal N}(\mu_q, \sigma_q^2) \]

Thus,

\[ p(z)=\frac{1}{\sqrt{2\pi\sigma_p^2}}\exp\left(\frac{(x-\mu_p)^2}{2\sigma_p^2}\right) \]

and

\[ q(z|x)=\frac{1}{\sqrt{2\pi\sigma_q^2}}\exp\left(\frac{(x-\mu_q)^2}{2\sigma_q^2}\right) \]

The direct derivation of \(\mathrm{KL}(q(z|x)||p(z))\) will give (with some simplifications)

\[ -\mathrm{KL}(q(z|x)||p(z)) = \log\frac{\sigma_q}{\sigma_p} - \frac{\left(\log\sigma_q^2-(\mu_p-\mu_q)^2\right)}{2\sigma_p^2} +\frac{1}{2} \]

By fixing the prior distribution \(p(z)={\cal N}(0,1^2)\),

\[ -\mathrm{KL}(q(z|x)||p(z)) = \frac{1}{2}\left[ 1 + \log\sigma_q^2 - \sigma_q^2 -\mu_q^2\right] \]

Hence, the new ELBO is

\[ \frac{1}{2}\left[ 1 + \log\sigma_q^2 - \sigma_q^2 -\mu_q^2\right] + \mathbb{E}_{q(z|x)}\left[\log p(x|z)\right] \]

Let \(J, B\) and \(\cal{L}\) be the dimension of the latent space, and the batch size over which the sampling is done. The loss function we need to minimise (from the point of implementation) is

\[ {\cal L} = - \sum_{j=1}^J \frac{1}{2}\Bigl[ 1 + \log\sigma_j^2 - \sigma_j^2 -\mu_j^2\Bigr] - \frac{1}{B}\sum_{l}\mathbb{E}_{q(z|x_i)}\left[\log p(x_i|z^{(i,l)})\right] \]

This can be observed in the code implementation below (see function implementation loss_function below)

Reparameterisation#

A valid reparameterization would be

\[ z = \mu+\sigma\epsilon \]

where \(\epsilon\) is an auxiliary noise variable \(\epsilon\sim{\cal{N}}(0, 1)\), which actually enables the reparameterization technique. Although it is possible to use \(\sigma\) or more specifically \(\sigma^2\), working on log scales improves the stability. i.e.

\begin{eqnarray} p &=& \log(\sigma^2)\ &=& 2 \log(\sigma) \end{eqnarray}

To get the log standard deviation, \(\log(\sigma)\), \begin{eqnarray} \log(\sigma) &=& p/2 \ \label{eqn:log_sigma} \end{eqnarray}

and hence

\[ \sigma = \exp^{p/2} \]

The resulting estimator (or the loss function) becomes (see Page 5 of Auto-Encoding Variational Bayes Paper),

\[ -\text{KLD} = \frac{1}{2}\sum_{j=1}^{J}(1+\log(\sigma_j^i)^2 - (\mu_j^i)^2 -(\sigma_j^i)^2) \]

It is important to see that the KL divergence can be computed and differentiated without estimation. This is a very remarkable thing (no esimtation!).

The \(\boldsymbol{\epsilon}\) must be sampled from a zero-mean, unit-variance Gaussian distribution, and should be of the same size as \(\boldsymbol{\sigma}\).


Model structure#

Encoder#

Maps the input data to a latent space by computing the mean (μ) and log-variance (log(σ²)) using fully connected layers.

  • A fully connected layer (fc1) maps the input (input_dim = 784) to a hidden dimension (hidden_dim = 400).

  • Two separate fully connected layers (fc_mu and fc_logvar) transform the hidden representation into the mean (μ) and log-variance (logvar) of the latent space (latent_dim = 2).

Decoder#

Reconstructs the input from the latent representation z.

  • A fully connected layer (fc_decode) maps z back to the hidden dimension (hidden_dim = 400).

  • A second fully connected layer (fc_out) transforms the hidden representation back to the original input dimension (input_dim = 784).

  • The final output is passed through a sigmoid activation to constrain pixel values between 0 and 1.

Loss function#

The loss function is the VAE loss discussed above.

Training Process#

  • The model is trained using the Adam optimizer with a learning rate of \(10^{-3}\).

  • The loss is minimized over multiple epochs using a batch size of 128.

  • The encoder’s learned representations can be visualized in a 2D latent space to analyze digit clustering.

Visualization#

  1. Reconstruction: Displays original and reconstructed images to assess model performance.

  2. Latent Space: Plots the encoded representations to understand digit distributions.

# Define the Variational Autoencoder
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_decode(z))
        return torch.sigmoid(self.fc_out(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

Training#

The setup for the VAE is very similar to the AE that we did earlier.

Set up the training loop#

Based on the code that you used earlier for the AE - set up the trainig parameters for the VAE.

Click to expand code
# Initialize the model, optimizer, and train 
epochs = 20 
latent_dim = 2 
vae = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
# Initialize the model, optimizer, and train
epochs = 20
latent_dim = 2
vae = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

Run training#

# Training loop
for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        recon_x, mu, logvar = vae(x)
        loss = loss_function(recon_x, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}")

Visualise the results#

Reconstructed plots#

# Visualize original and reconstructed images
def visualize_reconstructions():
    vae.eval()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            recon_x, _, _ = vae(x)
            break
    x = x.cpu().view(-1, 28, 28)[:10]
    recon_x = recon_x.cpu().view(-1, 28, 28)[:10]
    
    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):
        axes[0, i].imshow(x[i], cmap="gray")
        axes[1, i].imshow(recon_x[i], cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].axis("off")
    plt.show()

visualize_reconstructions()

Visualise the latent space#

# KDE Plot for each digit in subplots
def visualize_kde_plots():
    vae.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            mu, _ = vae.encode(x)
            embeddings.append(mu.cpu())
            labels.append(y.cpu())
    
    embeddings = torch.cat(embeddings).numpy()
    labels = torch.cat(labels).numpy()
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()
    for digit in range(10):
        idx = labels == digit
        sns.kdeplot(x=embeddings[idx, 0], y=embeddings[idx, 1], fill=True, ax=axes[digit])
        axes[digit].set_title(f"Digit {digit}")
        axes[digit].set_xlabel("z1")
        axes[digit].set_ylabel("z2")
    plt.tight_layout()
    plt.show()

# Run visualization
visualize_kde_plots()

Conditional Variational Autoencoder (CVAE) Explanation#

Overview#

A Conditional Variational Autoencoder (CVAE) is a type of generative model that extends the Variational Autoencoder (VAE) by incorporating labels during training and generation. This allows the model to generate specific outputs based on given labels.

In our implementation, the CVAE is trained on the MNIST dataset, using digit labels to guide the latent space representation.

Model Structure#

The CVAE consists of three main components:

1. Encoder#

The encoder maps an input image \(x\) and its corresponding one-hot encoded label \( y \) into a latent space:

  • The input image \(x \) (flattened 28×28 pixels) is concatenated with the label \(y\).

  • This combined input is passed through a fully connected encoder network.

  • The network produces mean \(\mu\) and log variance \(\log \\sigma^2\)for the latent representation.

2. Latent Space Sampling (Reparameterization Trick)#

To enable backpropagation, the model uses the reparameterization trick: \( z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \) where \(z\) is a sampled latent vector that represents the compressed information.

3. Decoder#

The decoder reconstructs the image using the sampled latent vector \(z\) and the label \(y\):

  • The latent vector \(z\) is concatenated with the label \(y\).

  • This combined input is passed through a decoder network.

  • The network outputs a reconstructed image.

Loss Function#

The CVAE is trained using a combination of two loss terms:

  1. Reconstruction Loss: Measures how accurately the reconstructed image resembles the original. \( \mathcal{L}_{\text{recon}} = \sum x \log(\hat{x}) + (1 - x) \log(1 - \hat{x}) \)

  2. KL Divergence Loss: Encourages the learned latent distribution to be close to a normal distribution. \( \mathcal{L}_{\text{KL}} = -\frac{1}{2} \sum (1 + \log \sigma^2 - \mu^2 - \sigma^2) \)

The total loss is: \( \mathcal{L} = \mathcal{L}_{\text{recon}} + \mathcal{L}_{\text{KL}} \)

Code Explanation#

1. Model Definition (CVAE)#

  • encode(x, y): Encodes the input and label into latent space (produces \(\mu\) and \(\log \sigma^2\)).

  • reparameterize(mu, logvar): Samples from the latent space.

  • decode(z, y): Reconstructs the image from the latent vector and label.

  • forward(x, y): Full forward pass through encoder, reparameterization, and decoder.

2. Training (train_cvae)#

  • Uses Adam optimizer to minimize the total loss.

  • Runs for multiple epochs, updating weights via backpropagation.

3. Conditional Image Generation (generate_from_prompt)#

  • Allows users to generate images by specifying a digit.

  • Uses a random latent vector combined with a one-hot encoded label.

Summary#

  • The CVAE can generate conditional outputs, making it more powerful than a standard VAE.

  • The latent space visualization helps understand how different digits are encoded.

  • The generation function allows creating specific digits on demand.

This implementation provides a flexible way to explore generative modeling using deep learning. 🚀

 
class CVAE(nn.Module):
    def __init__(self, feature_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, c): # Q(z|x, c)
        '''
        x: (bs, feature_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([x, c], 1) # (bs, feature_size+class_size)
        h1 = self.elu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z, c): # P(x|z, c)
        '''
        z: (bs, latent_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([z, c], 1) # (bs, latent_size+class_size)
        h3 = self.elu(self.fc3(inputs))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 28*28), c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar
# Convert labels to one-hot encoding
def one_hot_encoding(labels, num_classes=10):
    return torch.eye(num_classes, device=device)[labels]

# Loss function
def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div
# Initialize the model, optimizer, and train
epochs = 20
latent_dim = 20
cvae = CVAE(feature_size=28*28, latent_size=latent_dim, class_size=10).to(device)
optimizer = optim.Adam(cvae.parameters(), lr=5e-4)

# Training loop
for epoch in range(epochs):
    cvae.train()
    train_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_one_hot = one_hot_encoding(y).to(device)
        optimizer.zero_grad()
        recon_x, mu, logvar = cvae(x, y_one_hot)
        loss = loss_function(recon_x, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}")
# Conditional generation of new outputs based on prompts
def generate_from_prompt(digit, num_samples=5):
    cvae.eval()
    fig, axes = plt.subplots(1, num_samples, figsize=(10, 2))
    with torch.no_grad():
        for i in range(num_samples):
            y = torch.zeros(1, 10).to(device)
            y[0, digit] = 1  # One-hot encode the desired digit
            z = torch.randn(1, 20).to(device)  # Random latent vector
            generated = cvae.decode(z, y).cpu().view(28, 28)
            axes[i].imshow(generated, cmap="gray")
            axes[i].axis("off")
    plt.suptitle(f"Generated Samples for Digit {digit}")
    plt.show()
generate_from_prompt(9)

You can always try to improve the qulaity of the generated data by training for longer.