Skip to content

dvtailor/bayesian-data-attribution

Repository files navigation

A Bayesian Information-Theoretic Approach to Data Attribution

This is a PyTorch implementation of the following paper:

A Bayesian Information-Theoretic Approach to Data Attribution
Dharmesh Tailor, Nicolò Felicioni, Kamil Ciosek
29th International Conference on Artificial Intelligence and Statistics (AISTATS 2026)
Paper arxiv

Environment setup

The following has only been tested with CUDA 12.6, Python 3.11, and the specific package versions installed by the commands below. Other configurations may work but are untested.

conda create -n tda python=3.11
conda activate tda
# need to compile trak custom cuda kernel for projections
conda install nvidia::cuda-toolkit=12.6
conda install numpy scipy matplotlib jupyterlab jupyter_console scikit-learn
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
pip install traker[fast]
pip install kronfluence
pip install transformers evaluate datasets

Demo notebook

tda_cifar10_demo.ipynb provides a self-contained walkthrough of our method on CIFAR-10 (ResNet-9). It compares the EK-FAC influence function baseline (kronfluence) against three Bayesian information-theoretic variants (Info-Gain, Info-Gain variance approximation, and Info-Loss), identifying the top-5 most influential training examples for three test queries.

Reproducing paper figures and tables

Pre-computed results are included in the results/ directory, so paper figures and tables can be reproduced without re-running experiments.

Paper Script Notes
Figure 1 scripts_plotting/paper_ready/fig1_teaser_images.py
Figure 2 Not included; contact Dharmesh if interested
Figure 3 scripts_plotting/paper_ready/tda_all.py
Figure 4 scripts_plotting/paper_ready/backdoor_vis.py
Table 1 scripts_backdoor/analyze_backdoor_results.sh
Figure 5 scripts_plotting/paper_ready/coreset_cifar10.py
Figure 6 scripts_plotting/paper_ready/coreset_cifar10.py Set PREFER_CLASS_BALANCED = True
Figures 7, 8, 9 scripts_plotting/paper_ready/appendix_tda_vis.py

For example:

conda run -n tda python scripts_plotting/paper_ready/tda_all.py

Running experiments from scratch

Each experiment is organised as three files: a Python script that implements the experiment, a bash wrapper that parameterises it for a given method and task index, and an SGE template for cluster job-array submission. For local runs, use the bash wrapper directly; for cluster runs, adapt the SGE template to your environment.

All bash wrappers accept a 1-based task_id that indexes into the cross-product of seeds and query indices (or just seeds for the coreset experiment). Run a single task locally or loop over the full range to reproduce all results.

TDA brittleness — CIFAR-10 / Fashion-MNIST

Produces results/run_cifar10_tda_cls.jsonl and results/run_fmnist_tda_cls.jsonl.

# Single run (e.g. method=infogain, task 1 of 500, on Fashion-MNIST)
bash scripts_tda/tda_brittleness.sh infogain 1

# Change dataset to CIFAR-10
DATASET=cifar10 bash scripts_tda/tda_brittleness.sh infogain 1

# Full sweep on a cluster (5 seeds × 100 queries = 500 tasks)
# Adapt tda_brittleness.sge.sh to your cluster, then:
qsub scripts_tda/tda_brittleness.sge.sh infogain

Available methods: random, trak, trak_noq, infogain, infogain_approx, infoloss, kroninfluence, repsim, tracin.

TDA brittleness — BERT

Produces results/run_bert_rte_tda_cls.jsonl.

# Single run
bash scripts_tda/tda_brittleness_bert.sh infogain 1

# Full sweep (5 seeds × 50 queries = 250 tasks)
qsub scripts_tda/tda_brittleness_bert.sge.sh infogain

Coreset selection — CIFAR-10

Produces results/coreset_cifar10.jsonl.

# Single run (task_id maps directly to a seed, 1–5)
bash scripts_coreset/tda_coreset.sh infogain 1

# With class-balanced selection (for trak, kroninfluence, repsim, tracin)
bash scripts_coreset/tda_coreset.sh trak 1 true

# Full sweep (5 seeds = 5 tasks)
qsub scripts_coreset/tda_coreset.sge.sh infogain

Backdoor detection — CIFAR-10

Produces results/backdoor_cifar10.jsonl.

# Single run
bash scripts_backdoor/backdoor_cifar10.sh infogain 1

# Full sweep (5 seeds × 100 queries = 500 tasks)
qsub scripts_backdoor/backdoor_cifar10.sge.sh infogain

# Analyse results for a given method (produces Table 1 statistics)
bash scripts_backdoor/analyze_backdoor_results.sh infogain

Acknowledgements

This codebase builds on code from kronfluence [Grosse et al.], TRAK [Park et al.], and LogIX [Choe et al.].

Troubleshooting

Please open an issue in this repository or contact Dharmesh.

Citation

Please consider citing our conference paper

@inproceedings{tailor2026bayesian,
  title           = {{A Bayesian Information-Theoretic Approach to Data Attribution}},
  booktitle       = {Proceedings of the 29th International Conference on Artificial Intelligence and Statistics},
  author          = {Tailor, Dharmesh and Felicioni, Nicol\`o and Ciosek, Kamil},
  year            = {2026}
}

About

Code for 'A Bayesian Information-Theoretic Approach to Data Attribution' (AISTATS 2026)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors