Creating Conditional GANs for Class-Specific Image Generation: Hands-On Implementation with PyTorch



Key Takeaways

  • Conditional GANs (cGANs) give you control over AI image generation by adding a "condition," such as a class label, to guide the creative process.
  • The Generator uses this label to create a specific image, and the Discriminator checks if the image is both realistic and matches its label.
  • By feeding the same label into both networks, you can build a model in PyTorch that generates specific images on command, like a handwritten "7" or "9."

You’ve probably seen those eerie AI-generated faces that look almost real, but not quite. What you might not know is that the early AI models that created them were like abstract artists throwing paint at a canvas. They had no idea what they were painting.

Ask a vanilla Generative Adversarial Network (GAN) to draw a cat, and it might give you a blurry mess with a dozen eyes. Why? Because it learned the features of animals, but not the concept of a "cat." It had no control.

That lack of control is infuriating. What’s the point of a powerful image generator if you can't tell it what to generate?

This is where Conditional GANs (cGANs) change the game entirely. They introduce one simple, yet revolutionary, idea: giving the AI a "clue." Today, I’m going to walk you through exactly how to build one from scratch using PyTorch. We're moving from chaotic artist to digital puppet master.

From Chaos to Control: Why Conditional GANs?

The Limitation of Vanilla GANs

I remember my first time training a standard GAN on the MNIST dataset of handwritten digits. It was magical! It started with random noise and slowly began spitting out things that vaguely resembled numbers.

But here’s the problem: I couldn’t ask it for a "7." I just had to keep running it and hope a "7" would pop out. It's like having a brilliant chef who only cooks random dishes.

Introducing the 'Condition': The Core Idea of cGANs

The "condition" is the magic ingredient. It’s a piece of extra information we feed into the GAN to guide the creation process. This information is usually a class label.

Instead of just telling the Generator, "Hey, make an image from this random noise," we now say, "Hey, make an image that looks like a digit 4 using this random noise."

We also upgrade the Discriminator. It no longer just asks, "Is this image real or fake?" It now asks two questions: 1. Is this image of a handwritten digit realistic? 2. Does this image actually match the label it claims to be (e.g., does this look like a 4)?

This forces the entire system to not only create realistic images but also to make sure they correspond to the correct class.

What We'll Build: A Class-Specific MNIST Digit Generator

By the end of this tutorial, we'll have a cGAN that can do something a vanilla GAN can't: generate a specific handwritten digit on command. We'll be able to ask for a "9," and it will give us a "9." This is the first step toward true controllable, creative AI.

Prerequisites and Environment Setup

Before we dive in, let’s get our digital workshop ready. I’m assuming you have a basic Python environment set up.

Installing PyTorch and Torchvision

If you don't have them already, open your terminal and get the essentials. I highly recommend using a version with CUDA support if you have a compatible GPU—it will speed things up dramatically.

pip install torch torchvision torchaudio

Importing Essential Libraries

Let's get all our imports out of the way at the top of our script. It’s just good practice.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

Setting Up Device Configuration (CPU/GPU)

I always start my PyTorch scripts with this little block of code. It automatically detects if a GPU is available and uses it; otherwise, it falls back to the CPU. Trust me, your future self will thank you for this.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Architecting the Conditional Generator

The Generator is our creative artist. Its job is to take a blob of random noise (the latent vector) and sculpt it into an image. But now, we're giving it a blueprint: the class label.

The Role of the Generator

In a cGAN, the Generator's input is a combination of two things: 1. Latent Noise (z): The source of randomness and variation. 2. Class Label (c): The condition that dictates what to create.

Integrating the Class Label with an Embedding Layer

How do we combine a number (like 5) with a 100-dimensional noise vector? A much more powerful method is to use an nn.Embedding layer. This layer learns a dense vector representation for each class.

The label "5" gets turned into its own unique, learnable vector, which we then concatenate with the noise vector. This gives the model a much richer signal to work with.

Code: Building the Generator nn.Module

Here’s what our Generator looks like in PyTorch. It takes the combined noise and label embedding and uses ConvTranspose2d layers to upsample it into a 28x28 image.

class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            # Input is latent_dim + num_classes (for the embedding)
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh() # Tanh brings the output to a [-1, 1] range
        )

    def forward(self, z, labels):
        # Concatenate label embedding and latent noise
        c = self.label_emb(labels)
        gen_input = torch.cat((z, c), -1)
        img = self.model(gen_input)
        # Reshape to the image size
        img = img.view(img.size(0), *self.img_shape)
        return img

Architecting the Conditional Discriminator

The Discriminator is our art critic. In a standard GAN, it's a simple binary classifier: real or fake. Here, it gets a promotion and a much tougher job.

The Discriminator's Dual Role: Realism and Correctness

Our new Discriminator has to judge an image based on both its quality and its adherence to the provided label. It needs to look at a picture and its supposed label (e.g., "7") and decide if it's a convincing, real-looking "7," or a fake.

Conditioning the Discriminator on the Class Label

Just like the Generator, we need to give the Discriminator the class label. We'll use an embedding layer again and concatenate this embedding with the flattened image data before passing it through the network.

Code: Building the Discriminator nn.Module

The Discriminator is essentially a standard image classifier, but it takes the label embedding as an extra input. It outputs a single value (a probability) indicating how "real" it thinks the input image-label pair is.

class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            # Input is image size + num_classes (for the embedding)
            nn.Linear(int(np.prod(img_shape)) + num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid() # Sigmoid squashes output to a [0, 1] probability
        )

    def forward(self, img, labels):
        # Flatten image and concatenate with label embedding
        img_flat = img.view(img.size(0), -1)
        c = self.label_emb(labels)
        d_in = torch.cat((img_flat, c), -1)
        validity = self.model(d_in)
        return validity

Data Loading and Preprocessing

Garbage in, garbage out. Let's make sure we're feeding our models clean, well-prepared data. We'll use the classic MNIST dataset.

Loading the MNIST Dataset

Torchvision makes this ridiculously easy. We'll download the dataset if it's not already on our machine.

Defining Transformations

The Tanh function in our Generator outputs values from -1 to 1, so we need to normalize our real images to match that range.

Creating the DataLoader

The DataLoader handles batching, shuffling, and loading the data in parallel. It's a non-negotiable for any serious PyTorch project.

# Hyperparameters
latent_dim = 100
num_classes = 10
img_shape = (1, 28, 28) # MNIST is 1 channel, 28x28
batch_size = 64

# Configure data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])

mnist_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

dataloader = DataLoader(
    mnist_dataset,
    batch_size=batch_size,
    shuffle=True
)

The Training Loop: A Step-by-Step Implementation

This is where the magic happens. The training loop is an intricate dance between the Generator and the Discriminator.

Defining Hyperparameters, Models, and Optimizers

First, let's set everything up: learning rate, loss function, our models, and their respective optimizers. We use separate optimizers because we train them in alternating steps.

# Hyperparameters
lr = 0.0002
epochs = 200 # I recommend at least 50 to see decent results

# Initialize models
generator = Generator(latent_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

The Core Loop: Training the Discriminator

For each batch of data, we perform two major steps. First, we train the Discriminator.

  1. Real Images: We show it a batch of real images with their real labels and teach it to output 1 (for "real").
  2. Fake Images: We generate a batch of fake images with random labels, show them to the Discriminator, and teach it to output 0 (for "fake").
  3. We add the losses from these two steps and perform a single backward pass to update the Discriminator's weights.
# Inside the training loop... for epoch in range(epochs): for i, (imgs, labels) in enumerate(dataloader):
# Adversarial ground truths
valid = torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)

# Configure input
real_imgs = imgs.to(device)
labels = labels.to(device)

# --- Train Discriminator ---
optimizer_D.zero_grad()

# Real images loss
real_pred = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(real_pred, valid)

# Fake images loss
z = torch.randn(batch_size, latent_dim).to(device)
gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
gen_imgs = generator(z, gen_labels)

fake_pred = discriminator(gen_imgs.detach(), gen_labels) # Use .detach()!
d_fake_loss = adversarial_loss(fake_pred, fake)

# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()

Crucial point: Notice gen_imgs.detach(). This is vital. We treat the generated images as fixed inputs because we don't want gradients flowing back into the Generator just yet.

The Core Loop: Training the Generator

Now it's the Generator's turn. Its goal is to fool the Discriminator.

  1. We generate a new batch of fake images and their labels.
  2. We pass them through the Discriminator.
  3. We calculate the Generator's loss by comparing the Discriminator's output against a tensor of 1s.
  4. We perform a backward pass to update the Generator's weights.
# --- Train Generator ---
optimizer_G.zero_grad()

# Generate a batch of images
gen_imgs = generator(z, gen_labels)

# Generator's loss
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)

g_loss.backward()
optimizer_G.step()

Logging Losses and Saving Sample Images

It's essential to monitor the training process. Printing losses and saving sample images helps you visually check if the Generator is actually learning or just producing noise.

# At the end of each epoch or every N batches:
print(
    f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
    f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
)
# Code to save sample images would go here

Putting Our cGAN to the Test: Generating Specific Digits

After training, it's time for the fun part!

Loading the Trained Generator

If you saved your model weights, you can load them back in. For now, we'll just use the generator we have in memory.

Creating a Function to Generate Images by Class

Let's write a simple helper function that takes a number (0-9) and the quantity of images we want, and returns a grid of generated digits.

def generate_digit(digit, num_images=10):
    generator.eval() # Set model to evaluation mode
    with torch.no_grad():
        # Prepare latent noise and labels
        z = torch.randn(num_images, latent_dim).to(device)
        labels = torch.LongTensor([digit] * num_images).to(device)

        # Generate images
        gen_imgs = generator(z, labels)

        # Move to CPU and arrange in a grid
        gen_imgs = gen_imgs.cpu().numpy()

        fig, axes = plt.subplots(1, num_images, figsize=(10, 2))
        for i, ax in enumerate(axes):
            ax.imshow(gen_imgs[i, 0, :, :], cmap='gray')
            ax.axis('off')
        plt.suptitle(f"Generated Images for Digit: {digit}")
        plt.show()

Visualizing the Results: A Grid of Generated Digits

Now, let's call our function and see the fruits of our labor.

# Generate 10 images of the digit '7'
generate_digit(7, num_images=10)

# Generate 10 images of the digit '3'
generate_digit(3, num_images=10)

You should see two rows of images: one containing various styles of the digit "7" and the other for "3." It worked! We now have explicit control over our AI artist.

Conclusion and Next Steps

Recap of What We Achieved

We went from theory to a fully functional Conditional GAN. We took a standard GAN architecture and injected a "condition" (the class label) into both the Generator and the Discriminator.

The result is a model that generates specific, targeted content on command. This is a fundamental building block for more advanced generative AI.

Ideas for Improvement: Different Datasets (Fashion-MNIST, CIFAR-10)

MNIST is great for learning, but the real fun begins with more complex datasets. * Fashion-MNIST: Try this exact same code on the Fashion-MNIST dataset. Can you generate specific items of clothing like "t-shirt" or "boot"? * CIFAR-10: This is a bigger challenge. You'll need to adjust the image shape (img_shape = (3, 32, 32)) and likely use a more complex, convolution-based architecture.

Further Reading and Resources

The world of GANs is massive and evolving at a breakneck pace. Look into architectures like AC-GANs, StyleGANs, and Diffusion Models (the new kid on the block that powers models like DALL-E 2 and Midjourney).

The power to control what AI creates is one of the most exciting frontiers in tech. Now you have the tools to start building it yourself. Happy generating



Recommended Watch

📺 Pytorch Conditional GAN Tutorial
📺 Machine Learning | Conditional Generative Adversarial Nets in 100 lines of PyTorch code

💬 Thoughts? Share in the comments below!

Comments