This repository provides a PyTorch implementation for generating synthetic time-series data using Conditional Generative Adversarial Networks (cGANs). The project serves as a practical exploration of generative modeling, offering a comparative analysis of two distinct neural network architectures: a standard feed-forward network and a recurrent LSTM-based network.
Both models are trained within the robust Wasserstein GAN with Gradient Penalty (WGAN-GP) framework to ensure stable and effective learning.
Generative Adversarial Networks (GANs) are powerful tools for data synthesis, but their application to sequential or time-series data presents unique challenges, including training instability and mode collapse. This project addresses these issues directly by:
- Implementing the WGAN-GP Framework: We leverage the Wasserstein distance as a loss function, which provides smoother and more reliable gradients compared to the Jensen-Shannon (JS) divergence used in traditional GANs.
- Enforcing the Lipschitz Constraint: A gradient penalty is applied to the critic's loss, a critical component that ensures the model adheres to the theoretical requirements of the Wasserstein distance.
- Comparing Architectures: The repository provides two distinct models to investigate how different architectures handle the task of time-series generation:
- A standard feed-forward cGAN with linear layers (
models.py). - An LSTM-based cGAN specifically designed to capture temporal dependencies in sequential data (
lstm_gan.py).
- A standard feed-forward cGAN with linear layers (
The core of this project is the WGAN-GP loss function. The critic (discriminator) is trained to maximize the following objective, which includes the gradient penalty:
The generator is trained to minimize:
Where:
-
$P_r$ ,$P_g$ , and$P_{\hat{x}}$ are the real, generated, and interpolated data distributions, respectively. -
$\lambda$ is the gradient penalty coefficient, set via the--gradient-penalty-lambda-termargument.
This formulation is implemented in the _calculate_gradient_penalty method within the gan_trainer_base.py script.
## How to Run
You can train the models either directly from the command line or by using the provided Jupyter Notebook.
### 1. Training from the Command Line (Recommended)
#### LSTM-based cGAN
Train the LSTM-based model using `train_gan.py`. This is ideal for capturing temporal patterns.
```bash
python train_gan.py \
--data-path './data/snl_dataset_444.npy' \
--gan-epochs 200 \
--batch-size 64 \
--n_classes 3 \
--generator-input-dim 128 \
--generator-learning-rate 0.0001 \
--discriminator-learning-rate 0.0001Train the standard model using train.py.
python train.py \
--data_directory './data/snl_data_76.npy' \
--label_directory './data/snl_label_76.npy' \
--class_size 3 \
--batch_size 128 \
--epochs 100For a full list of tunable hyperparameters, please see arg_parser.py.
- train_gan.py: Main training script for the LSTM-based cGAN.
- train.py: Main training script for the standard feed-forward cGAN.
- lstm_gan.py: Defines the LSTM-based Generator and Discriminator architectures.
- models.py: Defines the standard feed-forward Generator and Discriminator architectures.
- gan_trainer_base.py: A base class containing the shared WGAN-GP training loop, loss calculations, and evaluation logic.
- snl_cycledata.py: A PyTorch Dataset class for loading and preprocessing sequential data.
- arg_parser.py: Manages command-line arguments for setting hyperparameters.
- wgan-gp-script_run.ipynb: A Jupyter notebook for running and experimenting with the standard GAN.
- train_lstm_gan.ipynb: A Jupyter notebook for visualizing training results (loss curves, PCA, t-SNE).