Skip to content

cuburt/varima-nn

Repository files navigation

VARIMA-NN: Multivariate Time Series Forecasting Engine

VARIMA-NN (Vector AutoRegression Integrated Moving Average - Neural Network) is a robust machine learning system designed to forecast multiple correlated time-series variables simultaneously.

Unlike traditional ARIMA models, this system uses a Multi-Layer Perceptron (MLP) neural network to learn complex non-linear relationships between variables (e.g., how Oil Prices affect Retail Sales and Transactions).

🚀 Key Features

  • Multivariate Forecasting (VAR): Predicts multiple targets (Oil, Sales, Transactions, Holidays) at once.
  • Recursive Inference: Uses its own predictions as inputs for future time steps to generate long-horizon forecasts.
  • Decoupled Architecture: * Backend: High-performance FastAPI server for inference.
    • Frontend: Interactive Plotly Dash web application.
  • Smart Preprocessing: Automatic stationarity checks (Differencing), volatility scaling, and normalization.
  • Dynamic Uncertainty: Calculates confidence intervals (bounds) based on model noise (Sigma).

📂 Project Structure

varima-nn/
├── config.yaml             # Central configuration (Data paths, Model params)
├── train.py                # Training pipeline (Preprocessing -> Training -> Saving)
├── process.py              # Shared logic for Feature Engineering & Inference
├── fastapi-server.py       # API Backend (Serves the model)
└── prepare_data.py         # ETL script to merge Kaggle datasets
data/                   # Data storage
├── oil.csv
├── train.csv
├── transactions.csv
├── holidays_events.csv
└── unified_data.csv    # Generated by prepare_data.py
model_storage/          # Saved models (.sav files)
app.py                  # Dashboard Frontend (Consumes the API)

🛠️ Installation

Install Dependencies:

pip install pandas numpy scikit-learn statsmodels joblib pyyaml fastparquet uvicorn fastapi dash plotly requests scikit-optimize

📊 Data Preparation

This project uses the Store Sales - Time Series Forecasting dataset.

  1. Download Data:
  • Get train.csv, oil.csv, transactions.csv, and holidays_events.csv from Kaggle.
  • Place them inside the data/ folder.
  1. Run the ETL Script: This merges all files, handles missing values, and aggregates sales/transactions to a daily global level.
python prepare_data.py

Output: data/unified_data.csv


🧠 Training the Model

To train the Neural Network, run:

python train.py

What happens?

  1. Loads unified_data.csv.
  2. Preprocesses data (Normalization, Volatility Scaling).
  3. Creates Lag features (e.g., t-1, t-7, t-30).
  4. Splits into Train/Test sets to prevent leakage.
  5. Optimizes Hyperparameters (using BayesSearchCV).
  6. Saves the model to model_storage/model-new.sav.

🌐 Running the System

This system uses a Client-Server architecture. You need two terminal windows.

1. Start the API Server (Backend)

This exposes the model via REST API endpoints.

python fastapi-server.py
  • URL: http://localhost:8080
  • Docs: http://localhost:8080/docs

2. Start the Dashboard (Frontend)

This runs the UI that visualizes the forecasts.

python app.py
  • URL: http://localhost:8050

⚙️ Configuration (config.yaml)

You can tune the entire system without changing code by editing config.yaml.

data:
  targets: ["oil_price", "total_sales", "total_transactions", "is_holiday"]
  primary_target: "total_sales" # Default view in Dashboard

preprocessing:
  train_size: 0.95      # Use 95% of data for training
  lag_size: 0.20        # Look back window size (20% of history)
  volatility_range: "M" # Scale volatility by Month (M) or Day (D)

training:
  max_iter: 2000        # Max training epochs
  bayes_n_iter: 1       # Number of hyperparameter combinations to try (1 = Fast)

inference:
  method: "recursive"   # "recursive" (Best for trends) or "batched"
  confidence_z_score: 1.0 # 1.0 = 68% Confidence, 1.96 = 95% Confidence

🔌 API Usage

You can interact with the model programmatically using curl or Postman.

Endpoint: POST /forecast

Payload:

{
    "range_start": "2017-09-01",
    "range_end": "2017-09-15",
    "target_feature": "total_sales"
}

Response:

{
    "success": true,
    "target": "total_sales",
    "forecast": {
        "2017-09-01": {
            "forecast": 12500.5,
            "upper_bound": 13100.2,
            "lower_bound": 11900.8
        }
        ...
    }
}

About

Multivariable Forecast with Vector ARIMA + MLP

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors