This repository is the official release of the paper "FedRTS: Federated Robust Pruning via Combinatorial Thompson Sampling"
In this work, we propose Federated Robust pruning via combinatorial Thompson Sampling (FedRTS), a novel framework designed to derive robust spase models for cross-device FL, as illustrated in below figure. FedRTS introduces a Thompson Sampling-based Adjustment (TSAdj) mechanism to address the shortcomings of existing methods. First, TSAdj leverages farsighted probability distributions, including prior information from unseen clients, to mitigate the impact of partial client participation. Second, it employs probabilistic decisions based on stable, comprehensive information rather than relying solely on aggregated data, ensuring a stable model topology despite data heterogeneity. Third, FedRTS minimizes communication costs by requiring clients to upload only the indices of the top gradients for adjustment.
![]() |
|---|
| Overview of the proposed FedRTS |
conda create -n fedrts python=3.10
conda activate fedrts
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install -c anaconda mpi4py
pip install -r requirements.txtWe use three computer vision datasets CIFAR-10, CINIC-10, SVHN, and a large natural language processing dataset, TinyStories, for all experiments in this paper.
To get the computer vision datasets, e.g., CIFAR-10, just run following command:
cd data/cifar10
sh download_cifar10.shYou do not need to manually download the TinyStories dataset. We fetch it directly from HuggingFace. As long as the transformers package is properly installed, the TinyStories dataset will be automatically downloaded when you run the NLP task.
You should decide how many processes you plan to use, the number of processes is client_num_per_round + 1 (where client_num_per_round is the number of clients selected per round and the additional process is for the server), you need to set the correct number of processes in gpu_mapping.yaml.
For example, if client_num_per_round is set to 10, you should configure 11 processes in gpu_mapping.yaml like this:
mapping_default:
four_gpu: [3, 3, 3, 2]CUDA_VISIBLE_DEVICES=0,1,2,3 sh run.sh resnet18 cifar10 100 10 500 5 0.5 0.001 --delta_T 10 --partition_alpha 0.5 --adjust_alpha 0.2[.] is mandatory arguments, {.} is the optional arguments.
CUDA_VISIBLE_DEVICES=[gpus] sh run.sh [model] [dataset] [client_num_in_total] \
[client_num_per_round] [comm_round] [epochs] [target_density] [initial_lr] \
{--delta_T , --T_end, --partition_alpha, --num_eval, --frequency_of_the_test}where
[gpus] specifies which GPUs to use.
[model] is the name of the model.
[dataset] is the name of the dataset.
[client_num_in_total] is the total number of clients.
[client_num_per_round] is the number of clients selected per round.
[comm_round] is the number of communication rounds.
[epochs] is the number of local epochs.
[target_density] is the target density for the sparse model.
[initial_lr] is the initial learning rate.
{--delta_T} is the interval rounds between two adjustment round.
{--T_end} is the end round number for adjustment round.
{--partition_alpha} refers to the partition alpha, higher partition alpha makes lower degree of data heterogeneity
{--num_eval}is the number data samples for validation, -1 means using the whole testing dataset.
{--frequency_of_the_test} the frequency to test/validate the performance during the training, using num_eval data samples
More detailed arguments can be found in main.py.
@article{huang2025fedrts,
title={Fedrts: Federated robust pruning via combinatorial thompson sampling},
author={Huang, Hong and Yang, Hai and Chen, Yuan and Ye, Jiaxun and Wu, Dapeng},
journal={arXiv preprint arXiv:2501.19122},
year={2025}
}
