Skip to content

msudars/Sequential-Data-Synthesis-GAN

Repository files navigation

A Framework for Sequential Data Synthesis

This repository provides a PyTorch implementation for generating synthetic time-series data using Conditional Generative Adversarial Networks (cGANs). The project serves as a practical exploration of generative modeling, offering a comparative analysis of two distinct neural network architectures: a standard feed-forward network and a recurrent LSTM-based network.

Both models are trained within the robust Wasserstein GAN with Gradient Penalty (WGAN-GP) framework to ensure stable and effective learning.


Project Overview

Generative Adversarial Networks (GANs) are powerful tools for data synthesis, but their application to sequential or time-series data presents unique challenges, including training instability and mode collapse. This project addresses these issues directly by:

  • Implementing the WGAN-GP Framework: We leverage the Wasserstein distance as a loss function, which provides smoother and more reliable gradients compared to the Jensen-Shannon (JS) divergence used in traditional GANs.
  • Enforcing the Lipschitz Constraint: A gradient penalty is applied to the critic's loss, a critical component that ensures the model adheres to the theoretical requirements of the Wasserstein distance.
  • Comparing Architectures: The repository provides two distinct models to investigate how different architectures handle the task of time-series generation:
    • A standard feed-forward cGAN with linear layers (models.py).
    • An LSTM-based cGAN specifically designed to capture temporal dependencies in sequential data (lstm_gan.py).

Mathematical Framework

The core of this project is the WGAN-GP loss function. The critic (discriminator) is trained to maximize the following objective, which includes the gradient penalty:

$$L_{D} = \mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}_{x \sim P_r}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(|\nabla_{\hat{x}} D(\hat{x})|_2 - 1)^2]$$

The generator is trained to minimize:

$$L_{G} = -\mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})]$$

Where:

  • $P_r$, $P_g$, and $P_{\hat{x}}$ are the real, generated, and interpolated data distributions, respectively.
  • $\lambda$ is the gradient penalty coefficient, set via the --gradient-penalty-lambda-term argument.

This formulation is implemented in the _calculate_gradient_penalty method within the gan_trainer_base.py script.


## How to Run

You can train the models either directly from the command line or by using the provided Jupyter Notebook.

### 1. Training from the Command Line (Recommended)

#### LSTM-based cGAN
Train the LSTM-based model using `train_gan.py`. This is ideal for capturing temporal patterns.

```bash
python train_gan.py \
    --data-path './data/snl_dataset_444.npy' \
    --gan-epochs 200 \
    --batch-size 64 \
    --n_classes 3 \
    --generator-input-dim 128 \
    --generator-learning-rate 0.0001 \
    --discriminator-learning-rate 0.0001

Standard Feed-Forward cGAN

Train the standard model using train.py.

python train.py \
    --data_directory './data/snl_data_76.npy' \
    --label_directory './data/snl_label_76.npy' \
    --class_size 3 \
    --batch_size 128 \
    --epochs 100

For a full list of tunable hyperparameters, please see arg_parser.py.


Code Structure

  • train_gan.py: Main training script for the LSTM-based cGAN.
  • train.py: Main training script for the standard feed-forward cGAN.
  • lstm_gan.py: Defines the LSTM-based Generator and Discriminator architectures.
  • models.py: Defines the standard feed-forward Generator and Discriminator architectures.
  • gan_trainer_base.py: A base class containing the shared WGAN-GP training loop, loss calculations, and evaluation logic.
  • snl_cycledata.py: A PyTorch Dataset class for loading and preprocessing sequential data.
  • arg_parser.py: Manages command-line arguments for setting hyperparameters.
  • wgan-gp-script_run.ipynb: A Jupyter notebook for running and experimenting with the standard GAN.
  • train_lstm_gan.ipynb: A Jupyter notebook for visualizing training results (loss curves, PCA, t-SNE).

About

An exploration of GAN architectures for sequential data synthesis using the Wasserstein GAN with Gradient Penalty framework in PyTorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors