Skip to content

NightBirdNY/CNN_test

Repository files navigation

Fashion MNIST CNN Explorer

This is an educational project focused on building, training, and evaluating various Convolutional Neural Network (CNN) models for image classification on the Fashion MNIST dataset. The primary goal is to explore different CNN architectures and hyperparameter tuning techniques to understand their impact on model performance, particularly in managing overfitting.


📖 Project Overview

This repository documents an iterative process of model development:

  1. Baseline CNN Model: A simple, standard CNN to establish a performance baseline.
  2. Overfitting Model: A more complex model is built to intentionally demonstrate the effects of overfitting.
  3. Regularized Model: Techniques such as Dropout, L2 Regularization, and Early Stopping are applied to combat overfitting and improve generalization.

For each experiment, the script automatically generates:

  • A saved model file (.keras).
  • A training summary file (_summary.txt) detailing the hyperparameters, model architecture, and final performance metrics.
  • A plot (_plots.png) visualizing the training & validation accuracy and loss over epochs.

💾 Dataset

The project uses the Fashion MNIST dataset, which is a popular benchmark for image classification tasks.

  • Content: 70,000 grayscale images (28x28 pixels) of 10 different fashion categories.
  • Split: 60,000 images for training and 10,000 images for testing.
  • Classes: T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot.

⚙️ Setup and Usage

Prerequisites

  • Python 3.8+
  • An NVIDIA GPU with CUDA and cuDNN installed.
  • pip and venv for package management.

Installation & Running an Experiment

  1. Clone the repository:

    git clone [https://github.com/your-username/your-repository-name.git](https://github.com/your-username/your-repository-name.git)
    cd your-repository-name
  2. Create and activate a virtual environment:

    python3 -m venv .venv
    source .venv/bin/activate
  3. Install the required packages: (It's recommended to create a requirements.txt file using pip freeze > requirements.txt in your activated environment).

    pip install tensorflow[and-cuda] scikit-learn matplotlib numpy
  4. Run a model training script:

    python model_1_cnn.py
  5. Check the results: After the script finishes, check the newly created results/model_1_cnn/ directory for the saved model, summary text file, and plots.


🚀 Models and Experiments

This project includes the following models, each representing a step in the learning process:

model_1_cnn.py - Baseline Model

  • Architecture: A simple CNN with two convolutional blocks (32 filters -> 64 filters) followed by a dense classifier.
  • Purpose: To establish a solid baseline performance.
  • Result: Achieved a high validation accuracy (around 91%), proving the effectiveness of CNNs over simpler MLP models.

model_2_cnn.py - Overfitting Model

  • Architecture: A deeper and wider CNN with three convolutional blocks (32 -> 64 -> 128 filters) and a larger dense layer.
  • Purpose: To observe how increasing model capacity without proper regularization leads to overfitting.
  • Result: Showcased a classic overfitting pattern: training accuracy kept increasing while validation accuracy plateaued and validation loss started to increase.

model_3_cnn.py - Regularized Model

  • Architecture: Similar to model_2 but with architectural improvements and strong regularization.
  • Techniques Applied:
    • Added a MaxPooling2D layer to balance the architecture.
    • Increased Dropout rate.
    • Added L2 Kernel Regularization to the dense layer.
    • Implemented EarlyStopping to halt training when performance on the validation set stops improving.
  • Purpose: To effectively combat overfitting and find a model that generalizes well.

🛠️ Technologies Used

  • TensorFlow & Keras: For building and training the neural networks.
  • Scikit-learn: For splitting the data into training and validation sets.
  • NumPy: For numerical operations and data manipulation.
  • Matplotlib: For generating and saving the training plots.
  • Python 3: The core programming language.

About

An exploratory project for training and comparing various CNN models on the Fashion MNIST dataset using TensorFlow and Keras.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages