Generative Adversarial Network for MNIST Dataset (digit generation) using TensorFlow

- Project Overview
- Installation
- Dataset
- Model Architecture
- Training
- Results
- Future Improvements
- Contributing
- License
This project demostrates how to implement and train a basic Generative Adversarial Network (GAN) using MNIST dataset of handwritten digits. The GAN consists of two neural networks:
- Generator : This network generates fake images based on random noise (latent vectors).
- Discriminator : This network tries to distinguish between real images (MNIST dataset) and fake images generated by the Generator.
Both networks are trained in such a manner that Generator tries to fool the Discriminator, and the Discriminator learns to improves its accuracy. By end of training, the Generator produces realistic handwritten digits that are indistinguishable from real MNIST images.
-
'Clone the Repository'
First, clone this repository to your local machine:
git clone https://github.com/adiManethia/GAN_from_Scratch.git cd GAN_from_Scratch -
'Install Dependencies'
Install the required dependencies using pip:
pip install -r requirements.txt
Alternatively, you can install the dependencies manually.
We use the MNIST dataset of handwritten digits, which contain 60,000 training images and 10,000 test images. Each image is 28x28 pixels in grayscale. TensorFlow automatically downloads the MNIST datset through its datasets API.
The images are normalized to the range [-1,1], for better performance with the tanh activation function used in the Generator.
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()-
'Generator'
The Generator network takes a random noise vector (latent vector) as input and generates a fake MNIST image. It uses series of
Dense,BatchNormalization, andConv2DTransposelayers to upsample the noise vector into a realistic image.key Layers:
- Dense layer to expand noise to 7x7x256.
- Conv2DTranspose layers with strides to upsample the image to 28x28.
- LeakyReLU and
tanhfor activation.
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU())
-
'Discriminator'
This network takes an image (fake or real) as input and outputs a binary classification (0/1). It uses a series of
Conv2Dlayers followed by aFlattenand aDenselayer.Key Layers:
- conv2D layers for downsampling.
- LeakyReLU activation.
- Dropout for regularization to prevent overfitting.
model.add(layers.Conv2D(64, (5,5), strides=(2,2), padding='same')) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3))
The GAN is trained for a fixed number of epochs. In each epoch:
- Generator: It produces fake images from random noise vectors.
- Discriminator: it attemps to classify real images (from MNISt datset) and fake images (from the Generator)
- Loss Function:
- Generator loss: Measures how well the Generator can fool the Discriminator.
- Discriminator loss: Measures how well the Discriminator can distinguish between real and fake images.
Both the Generator and Discriminator are optimized using 'Adam' optimizer with learning rates lr=0.0001 for stable training.
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)Checkpoints are saved during training to allow the model to be restored and avoid losing progress.
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer, discriminator_optimizer = discriminator_optimizer, generator = generator, discriminator = discriminator)After training for several epochs (100 in this case), the Generator is able to produce images that resemble real handwritten digits. Below are some example outputs generated by the model after 50, 100, and 200 epochs:
-
Train the GAN:
-
Generate Images:
Here, we will load the trained model from checkpoints and generate new handwritten digits.
generate_images(generator_trained)
-
Create a GIF: We can create a GIF of the generated images over time. Here is an example -

Some potential future improvements for the project include:
- Training the GAN on larger, more complex datasets.
- Experimenting with different GAN architectures.
- Adding evaluation metrics to assess the quality of generated images.
- Deploying the model as an interactive web application using Streamlit or Flask.
Contributions are welcome! Please fork the repository and submit a pull request for any changes or improvement you'd like to make. Please follow the steps outlined below.
- Fork the project
- Create your feature branch (git checkout -b feature/NewFeature)
- Commit your changes (git commit -m 'Add some NewFeature')
- Push to the branch (git push origin feature/NewFeature)
- Open a pull request
This project is licensed under the MIT License. See the LICENSE file for detail.
This README file serves as a detailed guide for the GAN project, providing a comprehensive explanation for anyone looking to understand the project structure, setup, and usage. You can customize it further based on your specific needs, results, and repository structure!

