Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Diff-MN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Diff-MN

Time series generation (TSG) is widely used across domains, yet most existing methods assume regular sampling and fixed output resolutions. These assumptions are often violated in practice, where observations are irregular and sparse, while downstream applications require continuous and high-resolution TS.

Although Neural Controlled Differential Equation (NCDE) is promising for modeling irregular TS, it is constrained by a single dynamics function, tightly coupled optimization, and limited ability to adapt learned dynamics to newly generated samples from the generative model.

We propose Diff-MN, a continuous TSG framework that enhances NCDE with a Mixture-of-Experts (MoE) dynamics function and a decoupled architectural design for dynamics-focused training.

To further enable NCDE to generalize to newly generated samples, Diff-MN employs a diffusion model to parameterize the NCDE temporal dynamics parameters (MoE weights), i.e.,
jointly learn the distribution of TS data and MoE weights. This design allows sample-specific NCDE parameters to be generated for continuous TS generation.

Experiments on ten public and synthetic datasets demonstrate that Diff-MN consistently outperforms strong baselines on both irregular-to-regular and irregular-to-continuous TSG tasks.



## Environment

Install the environment using the YAML file: `./environment_diffmn.yml`:

```bash
conda env create -f environment_diffmn.yml --force --no-deps
```

## Data
Stocks and Energy data are located in `./datasets`. Sine, MuJoCo, polynomial datasets are generated and the scripts are included in `./datasets` folder.

`utils_data.py` provides functions for loading data in both regular and irregular settings. In particular, irregular data are preprocessed using the Python class `TimeDataset_irregular`, which may take some time to run. Once preprocessing is complete, the processed data are saved in the `./datasets` directory for future use.


## Reproducing the paper results
By setting the time series length and missing values within the script, the results in the paper can be reproduced:


**Step 1:** The initial MoE NeuralCDE can be trained by `run_irregular_moencde.py`.

---

**Step 2:** Parameterizing the MoE weights can be achieved by jointly training the TS samples and their corresponding MoE weights through script `run_diffmn_diffsuion.py`.

---

**Step 3:** Through Step 2, we generate new samples along with their corresponding MoE weights. These weights are then fed into the pretrained MoE Neural CDE to perform continuous time series generation for each new sample. Finally, the refined high-frequency continuous time series are obtained using the script `run_irregular_moencde_continues.py`, providing richer temporal information and improving the accuracy of downstream tasks.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion MNTSG/configs/MNTSG.yaml → Diff-MN/configs/DiffMN.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ lightning:


# image_logger:
# target: utils_mntsg.callback_utils.TSLogger
# target: utils_diffmn.callback_utils.TSLogger
# params:
# batch_frequency: 2000
# max_images: 8
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: mntsg
name: diffmn
channels:
- defaults
dependencies:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def prepare_data(self,) -> None:
data_path = f'./logs_irgen_moencde_final_ECG_womoe_baseline/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen=24epochs=30'
data_path_fix = f'./logs_irgen_moencde_final_ECG/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen=24epochs=30'
else:
data_path=f'./logs_irgen_moencde_mntsg/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'
data_path=f'./logs_irgen_moencde_diffmn/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'
data_path_fix = f'./logs_irgen_moencde_final/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'

else:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import time

from pytorch_lightning.trainer import Trainer
from utils_mntsg.cli_utils import get_parser
from utils_mntsg.init_utils import init_model_data_trainer
from utils_mntsg.test_utils import test_model_with_diffcde
from utils_diffmn.cli_utils import get_parser
from utils_diffmn.init_utils import init_model_data_trainer
from utils_diffmn.test_utils import test_model_with_diffcde


if __name__ == "__main__":
sys.argv = [
"--base", "configs/MNTSG.yaml",
"--base", "configs/Diff-MN.yaml",
"--gpus", "0,",
"--logdir", "./logs/",
"-sl", "28",
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import logging
from utils_mntsg.utils_data_continues import TimeDataset_irregular
from utils_diffmn.utils_data_continues import TimeDataset_irregular
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, mod
"save_dir": logdir,
"offline": opt.debug,
"id": f"{nowname}_{now}",
"project": "MNTSG",
"project": "Diff-MN",
}
}
}
Expand Down Expand Up @@ -280,7 +280,7 @@ def prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, mod
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "utils_mntsg.callback_utils.SetupCallback",
"target": "utils_diffmn.callback_utils.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
Expand All @@ -298,7 +298,7 @@ def prepare_trainer_configs(nowname, logdir, opt, lightning_config, ckptdir, mod
}
},
"cuda_callback": {
"target": "utils_mntsg.callback_utils.CUDACallback"
"target": "utils_diffmn.callback_utils.CUDACallback"
},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_args_with_defaults():

# 模拟命令行参数
default_args = [
"--base", "configs/MNTSG.yaml",
"--base", "configs/Diff-MN.yaml",
"--gpus", "6",
"--logdir", "./logs/",
"-sl", "168",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

from pytorch_lightning import Trainer
from omegaconf import OmegaConf
from utils_mntsg.cli_utils import nondefault_trainer_args
from utils_mntsg.callback_utils import prepare_trainer_configs
from utils_diffmn.cli_utils import nondefault_trainer_args
from utils_diffmn.callback_utils import prepare_trainer_configs
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from pathlib import Path
import datetime
from utils_mntsg.cli_utils import nondefault_trainer_args
from utils_diffmn.cli_utils import nondefault_trainer_args

# data_root = os.environ['DATA_ROOT']
data_root='./data'
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import controldiffeq
import pathlib
from utils_mntsg.dataset_polynomial import generate_polynomial_dataset
from utils_diffmn.dataset_polynomial import generate_polynomial_dataset
from utils_kovae import datautils
PROJECT_DIR = pathlib.Path(__file__).resolve().parent.parent

Expand Down
File renamed without changes.
32 changes: 0 additions & 32 deletions MNTSG/README.md

This file was deleted.

Loading