A PyTorch framework for continuous learning that automatically detects concept drift in data streams and adapts models through JVP regularized retraining.
The pipeline runs on a changing data stream and loops through these stages:
- Evaluate the current model on stream batches.
- Aggregate monitored metrics at a configured interval.
- Run a drift detector on the aggregated metric.
- If drift is detected, pause monitoring and run a continual-learning update loop.
- Resume monitoring on the updated model and continue until stream limits are reached.
Core modules:
src/main.py: entry pointsrc/config/configuration.py: TOML/env/CLI config assemblysrc/driver/continuous_monitor.py: monitoring + drift loopsrc/training/continuous_trainer.py: CL training loopsrc/training/updater/: CL update strategiessrc/drift_detection/: detectors and detector factoryexamples/: concrete model harness implementations
Requires Python >=3.13,<3.15 and Poetry.
poetry installFrom the project root:
poetry run python -m src.main --config examples/mnist/mnist.toml
poetry run python -m src.main --config examples/cifar/cifar10_vit.tomlCurrently, we support two metrics logging backends: Weights & Biases (WandB) and MLflow. You can configure the desired backend in the config file's logging section. To disable logging, you can set the logging section to none to disable logging. Alternatively, you can set the logging choice via command line arguments, for example:
poetry run python -m src.main --config examples/mnist/mnist.toml --set logging.backend=mlflow --set logging.experiment_name="My Experiment"
# To view results for MLflow, run `mlflow ui` in another terminal and navigate to http://localhost:5000Currently the mnist example sets the logging to wandb in the toml config file. The other examples do not set any metric for the logging backend, which defaults to wandb.
Primary sections in config TOML:
[model][data][train][drift_detection][continual_learning](optional but recommended)[visualization](optional)
Top-level fields commonly used:
seeddevicemulti_gpuverbosity
Override precedence:
- Base TOML (
--config) - Environment overrides prefixed with
APP_ - CLI overrides via repeated
--set key=value
Example override:
poetry run python -m src.main \
--config examples/mnist/mnist.toml \
--set drift_detection.detector_name=\"KSWINDetector\" \
--set train.max_iter=200Detailed docs are in docs/:
docs/README.mddocs/model_harness.mddocs/drift_detectors.mddocs/continuous_learning.md
poetry run pytest
poetry run ruff check .
poetry run mypy .Platform-specific deployment guides:
- Builds the
DummyCNN_MNISTmodel defined insrc/model/DummyCNN_MNIST.py, a cross-entropy loss, and an Adam optimizer. - Loads the MNIST training split, stacks the tensors, and iterates over 10 tasks (digits 0–9). Each task applies random rotation and translation to encourage continual adaptation.
- Maintains replay buffers (
memory_image,memory_label, etc.) so past samples remain available for rehearsal while training new tasks. - Calls
CL(...)to assemble task-specific dataloaders and drive theOne_task_CLloop. The loop trains for five epochs, records loss/accuracy metrics, and prints periodic progress reports. - Computes sensitivity scores with
src/validation/validation_utils/return_scoreafter each task; you can repurpose these values for analysis or adaptive triggers.
- Change the number of epochs by editing
n_epochinsideCL. - Adjust replay/adversarial update counts through the
paramsdictionaries inOne_task_CLandutil.update_CL_. - Experiment with different transforms or task definitions by modifying
data.py. - Update batch sizes by changing the
batch_sizeparameter used when constructing the dataloaders.
Training logs report the task id, training/test accuracy, and replay-memory accuracy every five epochs. Accuracy is computed via test(...) on both the current task and the accumulated memory set.
Platform-specific deployment guides: