From 24f2cdad358294f7b1d34633ffeaf43d97bc08e1 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 20:57:49 -0400 Subject: [PATCH 01/39] Added MSE metric for time series forecasting models. --- plato/clients/strategies/defaults.py | 9 +++++++- plato/servers/fedavg.py | 33 +++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/plato/clients/strategies/defaults.py b/plato/clients/strategies/defaults.py index a7a54c86d..376acd6a6 100644 --- a/plato/clients/strategies/defaults.py +++ b/plato/clients/strategies/defaults.py @@ -339,7 +339,14 @@ async def train(self, context: ClientContext) -> tuple[Any, Any]: if context.sio is not None: await context.sio.disconnect() - if hasattr(Config().trainer, "target_perplexity"): + metric_name = getattr( + getattr(context.trainer, "testing_strategy", None), + "metric_name", + "accuracy", + ) + if metric_name == "mse": + LOGGER.info("[%s] Test MSE: %.6f", context, accuracy) + elif hasattr(Config().trainer, "target_perplexity"): LOGGER.info("[%s] Test perplexity: %.2f", context, accuracy) else: LOGGER.info("[%s] Test accuracy: %.2f%%", context, 100 * accuracy) diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index e3418eba6..ad9afe947 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -240,16 +240,25 @@ async def _process_reports(self): self.callback_handler.call_event("on_weights_aggregated", self, self.updates) # Testing the global model accuracy + trainer = self.require_trainer() + metric_name = getattr( + getattr(trainer, "testing_strategy", None), "metric_name", "accuracy" + ) + if hasattr(Config().server, "do_test") and not Config().server.do_test: - # Compute the average accuracy from client reports + # Compute the average metric from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy - ) + if metric_name == "mse": + logging.info( + "[%s] Average client MSE: %.6f.", self, self.accuracy + ) + else: + logging.info( + "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy + ) else: # Testing the updated model directly at the server logging.info("[%s] Started model testing.", self) - trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) # Extract CORE evaluation results if available (Nanochat CORE evaluation) @@ -260,7 +269,7 @@ async def _process_reports(self): core_results = trainer.context.state["nanochat_core_results"] self._core_metric = core_results.get("core_metric", self.accuracy) - # If CORE benchmark was run via a Nanochat testing strategy, report the specialized CORE metric instead of the generic 'Global model accuracy' label. + # Log with metric-appropriate label and format core_metric = getattr(self, "_core_metric", None) if core_metric is not None: @@ -269,6 +278,10 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) + elif metric_name == "mse": + logging.info( + fonts.colourize(f"[{self}] Global model MSE: {self.accuracy:.6f}\n") + ) elif hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( @@ -345,6 +358,14 @@ def get_logged_items(self) -> dict: if hasattr(self, "_core_metric"): logged["core_metric"] = self._core_metric + metric_name = getattr( + getattr(getattr(self, "trainer", None), "testing_strategy", None), + "metric_name", + "accuracy", + ) + if metric_name and metric_name != "accuracy": + logged[metric_name] = self.accuracy + return logged @staticmethod From 8232690d0f3995bf9dd302c7bb057aafbe91dfb1 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 21:08:52 -0400 Subject: [PATCH 02/39] Added config file for TimesFM model with EV availability forecasting. --- configs/TimeSeries/timesfm_ev_charging.toml | 88 +++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 configs/TimeSeries/timesfm_ev_charging.toml diff --git a/configs/TimeSeries/timesfm_ev_charging.toml b/configs/TimeSeries/timesfm_ev_charging.toml new file mode 100644 index 000000000..426e0d1fa --- /dev/null +++ b/configs/TimeSeries/timesfm_ev_charging.toml @@ -0,0 +1,88 @@ +# Federated Learning with TimesFM for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm_ev" +model_path = "models/timeseries/timesfm_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "timesfm" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximizes training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" From a19e5acf5bb3de55e52091110b82a79a1e58a766 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 22:20:04 -0400 Subject: [PATCH 03/39] ruff format . . --- .../model_search/fedrlnas/Darts/architect.py | 10 +- .../model_search/fedrlnas/Darts/operations.py | 6 +- .../pfedrlnas/DARTS/Darts/architect.py | 10 +- .../pfedrlnas/DARTS/Darts/operations.py | 6 +- .../NASViT/misc/attentive_nas_eval.py | 4 +- .../models/attentive_nas_dynamic_model.py | 8 +- plato/algorithms/fedavg.py | 37 +- plato/datasources/lerobot.py | 27 +- plato/models/smolvla.py | 4 +- plato/servers/fedavg.py | 4 +- plato/trainers/huggingface.py | 328 +++++++++++++----- plato/trainers/lerobot.py | 8 +- plato/trainers/lr_schedulers.py | 5 +- tests/algorithms/test_fedavg_algorithm.py | 8 +- tests/test_config_loader.py | 2 - tests/trainers/test_lerobot_trainer.py | 2 +- 16 files changed, 313 insertions(+), 156 deletions(-) diff --git a/examples/model_search/fedrlnas/Darts/architect.py b/examples/model_search/fedrlnas/Darts/architect.py index 22660d3a8..721bb5f53 100644 --- a/examples/model_search/fedrlnas/Darts/architect.py +++ b/examples/model_search/fedrlnas/Darts/architect.py @@ -140,10 +140,12 @@ def _parse(weights): weight_matrix = weights[start:end].copy() edges = sorted( range(i + 2), - key=lambda x, wm=weight_matrix: -max( - wm[x][k] - for k in range(len(wm[x])) - if k != PRIMITIVES.index("none") + key=lambda x, wm=weight_matrix: ( + -max( + wm[x][k] + for k in range(len(wm[x])) + if k != PRIMITIVES.index("none") + ) ), )[:2] for j in edges: diff --git a/examples/model_search/fedrlnas/Darts/operations.py b/examples/model_search/fedrlnas/Darts/operations.py index 48bba9a03..92f4ff99f 100644 --- a/examples/model_search/fedrlnas/Darts/operations.py +++ b/examples/model_search/fedrlnas/Darts/operations.py @@ -11,9 +11,9 @@ 3, stride=stride, padding=1, count_include_pad=False ), "max_pool_3x3": lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), - "skip_connect": lambda C, stride, affine: Identity() - if stride == 1 - else FactorizedReduce(C, C, affine=affine), + "skip_connect": lambda C, stride, affine: ( + Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine) + ), "sep_conv_3x3": lambda C, stride, affine: SepConv( C, C, 3, stride, 1, affine=affine ), diff --git a/examples/model_search/pfedrlnas/DARTS/Darts/architect.py b/examples/model_search/pfedrlnas/DARTS/Darts/architect.py index f9e5d5e6e..108cae151 100644 --- a/examples/model_search/pfedrlnas/DARTS/Darts/architect.py +++ b/examples/model_search/pfedrlnas/DARTS/Darts/architect.py @@ -193,10 +193,12 @@ def _parse(weights): weight_matrix = weights[start:end].copy() edges = sorted( range(i + 2), - key=lambda x, wm=weight_matrix: -max( - wm[x][k] - for k in range(len(wm[x])) - if k != PRIMITIVES.index("none") + key=lambda x, wm=weight_matrix: ( + -max( + wm[x][k] + for k in range(len(wm[x])) + if k != PRIMITIVES.index("none") + ) ), )[:2] for j in edges: diff --git a/examples/model_search/pfedrlnas/DARTS/Darts/operations.py b/examples/model_search/pfedrlnas/DARTS/Darts/operations.py index 48bba9a03..92f4ff99f 100644 --- a/examples/model_search/pfedrlnas/DARTS/Darts/operations.py +++ b/examples/model_search/pfedrlnas/DARTS/Darts/operations.py @@ -11,9 +11,9 @@ 3, stride=stride, padding=1, count_include_pad=False ), "max_pool_3x3": lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), - "skip_connect": lambda C, stride, affine: Identity() - if stride == 1 - else FactorizedReduce(C, C, affine=affine), + "skip_connect": lambda C, stride, affine: ( + Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine) + ), "sep_conv_3x3": lambda C, stride, affine: SepConv( C, C, 3, stride, 1, affine=affine ), diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py index 00b2281e7..e8b6a44f7 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/misc/attentive_nas_eval.py @@ -77,8 +77,8 @@ def validate( top5_list.append(acc5) head_dim = 8 - func = ( - lambda x: x[0] ** 2 + func = lambda x: ( + x[0] ** 2 * ( x[1] ** 2 * 6 + x[1] ** 2 * 8 diff --git a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py index b31694e0b..54db8a1e9 100644 --- a/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py +++ b/examples/model_search/pfedrlnas/VIT/nasvit_wrapper/NASViT/models/attentive_nas_dynamic_model.py @@ -434,8 +434,8 @@ def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flo return cfg def _sample_active_subnet(self, min_net=False, max_net=False): - sample_cfg = ( - lambda candidates, sample_min, sample_max: min(candidates) + sample_cfg = lambda candidates, sample_min, sample_max: ( + min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates)) ) @@ -461,8 +461,8 @@ def _sample_active_subnet(self, min_net=False, max_net=False): def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False): cfg = copy.deepcopy(cfg) - pick_another = ( - lambda x, candidates: x + pick_another = lambda x, candidates: ( + x if len(candidates) == 1 else random.choice([v for v in candidates if v != x]) ) diff --git a/plato/algorithms/fedavg.py b/plato/algorithms/fedavg.py index e1e0b3d97..7328b3c8f 100644 --- a/plato/algorithms/fedavg.py +++ b/plato/algorithms/fedavg.py @@ -22,13 +22,13 @@ class Algorithm(base.Algorithm): def _as_state_mapping(weights: Any, context: str) -> Mapping[str, torch.Tensor]: """Validate and cast a state-dict-like payload.""" if not isinstance(weights, Mapping): - raise TypeError(f"{context} must be a mapping of parameter names to tensors.") + raise TypeError( + f"{context} must be a mapping of parameter names to tensors." + ) return weights @staticmethod - def _to_transport_tensor( - tensor: torch.Tensor, tensor_name: str - ) -> torch.Tensor: + def _to_transport_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor: """ Convert a tensor to a wire-safe representation for payload transport. @@ -88,7 +88,9 @@ def _compute_tensor_delta( if baseline_weight.dtype == torch.bool: return current_casted.to(torch.int8) - baseline_weight.to(torch.int8) - if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + if torch.is_floating_point(baseline_weight) or torch.is_complex( + baseline_weight + ): return current_casted.to(baseline_weight.dtype) - baseline_weight return current_casted.to(torch.int64) - baseline_weight.to(torch.int64) @@ -111,7 +113,9 @@ def _apply_tensor_delta( delta_integral = delta.to(torch.int8) return (baseline_weight.to(torch.int8) + delta_integral).ne(0) - if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + if torch.is_floating_point(baseline_weight) or torch.is_complex( + baseline_weight + ): return baseline_weight + delta.to(baseline_weight.dtype) if torch.is_floating_point(delta): @@ -156,7 +160,9 @@ def _estimate_payload_size_bytes(weights: Mapping[str, torch.Tensor]) -> int: size_bytes += tensor.numel() * tensor.element_size() return size_bytes - def _assert_payload_size(self, weights: Mapping[str, torch.Tensor], source: str) -> None: + def _assert_payload_size( + self, weights: Mapping[str, torch.Tensor], source: str + ) -> None: """Enforce an optional payload-size safeguard.""" limit_mb = self._resolve_payload_limit_mb() if limit_mb is None: @@ -175,10 +181,15 @@ def _resolve_adapter_parameter_names( ) -> list[str] | None: """Resolve parameter names to exchange for adapter-only finetuning.""" finetune_mode = getattr(target_model, "plato_finetune_mode", None) - if not isinstance(finetune_mode, str) or finetune_mode.strip().lower() != "adapter": + if ( + not isinstance(finetune_mode, str) + or finetune_mode.strip().lower() != "adapter" + ): return None - trainable_names_attr = getattr(target_model, "plato_trainable_parameter_names", None) + trainable_names_attr = getattr( + target_model, "plato_trainable_parameter_names", None + ) names_from_attr = ( [ name @@ -224,7 +235,9 @@ def compute_weight_deltas( unknown_keys = set(weight_mapping).difference(baseline_mapping) if unknown_keys: unknown = ", ".join(sorted(unknown_keys)) - raise KeyError(f"Received weights include unexpected parameter(s): {unknown}.") + raise KeyError( + f"Received weights include unexpected parameter(s): {unknown}." + ) delta = OrderedDict() for name, current_weight in weight_mapping.items(): @@ -308,7 +321,9 @@ def load_weights(self, weights): unknown_keys = set(weights_mapping).difference(current_state) if unknown_keys: unknown = ", ".join(sorted(unknown_keys)) - raise KeyError(f"Inbound weights include unexpected parameter(s): {unknown}.") + raise KeyError( + f"Inbound weights include unexpected parameter(s): {unknown}." + ) merged_state = OrderedDict(current_state.items()) for name, incoming_tensor in weights_mapping.items(): diff --git a/plato/datasources/lerobot.py b/plato/datasources/lerobot.py index f6d632ff0..a2a7df4b7 100644 --- a/plato/datasources/lerobot.py +++ b/plato/datasources/lerobot.py @@ -96,8 +96,8 @@ def _import_lerobot() -> tuple[Any, Any]: except ImportError as exc: # pragma: no cover - exercised without robotics extra. raise ImportError( "LeRobot datasource requires optional robotics dependencies. " - "Install them with \"uv sync --extra robotics\" before using " - '"data.datasource = \"LeRobot\"". ' + 'Install them with "uv sync --extra robotics" before using ' + '"data.datasource = "LeRobot"". ' ) from exc return LeRobotDataset, LeRobotDatasetMetadata @@ -362,7 +362,9 @@ def _resolve_task_name(row: Mapping[str, Any], tasks_lookup: Any) -> str | None: return None -def _resolve_episode_tasks(metadata: Any, episodes: Sequence[int]) -> dict[int, str | None]: +def _resolve_episode_tasks( + metadata: Any, episodes: Sequence[int] +) -> dict[int, str | None]: episode_tasks = {episode: None for episode in episodes} episode_rows = _episode_rows(getattr(metadata, "episodes", None)) tasks_lookup = _to_plain(getattr(metadata, "tasks", None)) @@ -473,10 +475,14 @@ def _resolve_episode_split( episode_set = set(int(episode) for episode in all_episodes) if explicit_train is None and explicit_test is None: - return _split_episodes(all_episodes, episode_tasks, train_ratio, seed, task_aware) + return _split_episodes( + all_episodes, episode_tasks, train_ratio, seed, task_aware + ) train_episodes = [ - int(episode) for episode in (explicit_train or []) if int(episode) in episode_set + int(episode) + for episode in (explicit_train or []) + if int(episode) in episode_set ] test_episodes = [ int(episode) @@ -569,7 +575,9 @@ def _resolve_total_clients(config: Any) -> int: return total_clients -def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: +def _filter_constructor_kwargs( + dataset_cls: Any, kwargs: Mapping[str, Any] +) -> dict[str, Any]: try: signature = inspect.signature(dataset_cls.__init__) except (TypeError, ValueError): @@ -582,9 +590,7 @@ def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> d if accepts_var_kwargs: return dict(kwargs) - valid_parameters = { - name for name in signature.parameters.keys() if name != "self" - } + valid_parameters = {name for name in signature.parameters.keys() if name != "self"} filtered = {key: value for key, value in kwargs.items() if key in valid_parameters} dropped = sorted(set(kwargs.keys()) - set(filtered.keys())) @@ -646,8 +652,7 @@ def __init__(self, client_id: int = 0, **kwargs): repo_id = str(dataset_cfg.pop("repo_id", "")).strip() if not repo_id: raise ValueError( - "LeRobot datasource requires " - '"parameters.dataset.repo_id" to be set.' + 'LeRobot datasource requires "parameters.dataset.repo_id" to be set.' ) train_split_raw = dataset_cfg.pop("train_split", _DEFAULT_TRAIN_SPLIT) diff --git a/plato/models/smolvla.py b/plato/models/smolvla.py index dd570d8d3..bf89f91c9 100644 --- a/plato/models/smolvla.py +++ b/plato/models/smolvla.py @@ -273,7 +273,9 @@ def get(model_name: str | None = None, **kwargs: Any) -> nn.Module: setattr(policy, "plato_policy_path", policy_path) setattr(policy, "plato_finetune_mode", finetune_mode) setattr(policy, "plato_adapter_patterns", tuple(adapter_patterns)) - setattr(policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"]) + setattr( + policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"] + ) setattr(policy, "plato_trainable_parameter_count", trainable_count) setattr( policy, diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ad9afe947..112e71bc1 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -249,9 +249,7 @@ async def _process_reports(self): # Compute the average metric from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) if metric_name == "mse": - logging.info( - "[%s] Average client MSE: %.6f.", self, self.accuracy - ) + logging.info("[%s] Average client MSE: %.6f.", self, self.accuracy) else: logging.info( "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index e9196d752..aa9855a5e 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -5,13 +5,14 @@ HuggingFace data handling through strategy objects instead of overriding `load_model`/`save_model` hooks. +Supports both text/NLP models and time series models (e.g., PatchTSMixer, TimesFM). """ import logging import math import os from collections.abc import Iterable, Sequence -from typing import Any, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -39,6 +40,7 @@ TrainingContext, TrainingStepStrategy, ) +from plato.utils.timeseries_utils import is_timeseries_model class HuggingFaceBatch(dict): @@ -79,6 +81,23 @@ def __call__( return HuggingFaceBatch(batch), labels +class TimeSeriesCollateWrapper: + """Collator for time series data (PatchTSMixer format).""" + + def __call__( + self, examples: Iterable[dict] + ) -> tuple[HuggingFaceBatch, torch.Tensor | None]: + """ + Collate time series examples into batches. + + Expected format: {"past_values": tensor, "future_values": tensor} + """ + batch = default_data_collator(list(examples)) + labels = batch.get("future_values", None) + + return HuggingFaceBatch(batch), labels + + def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): """ Resolve a loss tensor from HuggingFace model outputs. @@ -110,8 +129,10 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): raise ValueError("HuggingFace model did not return a tensor loss.") logits = getattr(outputs, "logits", None) + if logits is None: + logits = getattr(outputs, "prediction_outputs", None) # PatchTSMixer if logits is None and isinstance(outputs, dict): - logits = outputs.get("logits") + logits = outputs.get("logits") or outputs.get("prediction_outputs") if logits is None and isinstance(outputs, tuple) and len(outputs) > 0: logits = outputs[0] @@ -133,6 +154,13 @@ def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): logits = logits.to(labels.device) if labels.device != logits.device else logits labels = labels.to(logits.device) + # Check if this is a regression task (shapes match) -> use MSE + # Time series: logits (batch, pred_len, channels), labels (batch, pred_len, channels) + # Text generation: logits (batch, seq_len, vocab_size), labels (batch, seq_len) + if logits.shape == labels.shape: + return F.mse_loss(logits, labels) + + # Text generation with causal LM -> use cross-entropy vocab_size = logits.size(-1) if logits.ndim > 2: shift_logits = logits[..., :-1, :].contiguous() @@ -196,12 +224,23 @@ def training_step( optimizer.zero_grad() batch_inputs = dict(examples) - if labels is not None: + + is_timeseries = ( + "past_values" in batch_inputs and "future_values" in batch_inputs + ) + + if not is_timeseries and labels is not None: batch_inputs["labels"] = labels batch_inputs.setdefault("return_dict", True) outputs = model(**batch_inputs) - labels_tensor = batch_inputs.get("labels") + + # For time series, get labels from batch_inputs, otherwise from labels argument + labels_tensor = ( + batch_inputs.get("future_values") + if is_timeseries + else batch_inputs.get("labels") + ) loss = _resolve_hf_loss(outputs, labels_tensor) loss_for_backward = loss.div(accum_steps) if accum_steps > 1 else loss @@ -291,10 +330,21 @@ def finalize(self, model, optimizer, context: TrainingContext): class HuggingFaceTestingStrategy(TestingStrategy): - """Evaluates HuggingFace models and reports perplexity based on loss.""" + """Evaluates HuggingFace models (text: perplexity, time series: MSE).""" - def __init__(self, collate_fn: HuggingFaceCollateWrapper): + def __init__(self, collate_fn, is_timeseries=False): self.collate_fn = collate_fn + self.is_timeseries = is_timeseries + + @property + def metric_name(self) -> str: + """Return the name of the metric this strategy computes.""" + if self.is_timeseries: + return "mse" # For time series models, using mean squared error. + elif hasattr(Config().trainer, "target_perplexity"): + return "perplexity" + else: + return "accuracy" def test_model(self, model, config, testset, sampler, context: TrainingContext): batch_size = config.get("batch_size", 1) @@ -324,41 +374,108 @@ def test_model(self, model, config, testset, sampler, context: TrainingContext): model.eval() context.state["eval_loader"] = data_loader - total_loss = 0.0 - total_weight = 0 - - with torch.no_grad(): - for batch_inputs, labels in data_loader: - batch_inputs = batch_inputs.to(context.device) - if labels is not None: - labels = labels.to(context.device) - batch_inputs["labels"] = labels - - batch_inputs.setdefault("return_dict", True) - outputs = model(**batch_inputs) - loss = _resolve_hf_loss(outputs, labels) - - if labels is not None: - weight = labels.ne(-100).sum().item() - if weight == 0: - continue - else: - weight = 1 - - total_loss += loss.item() * weight - total_weight += weight + if self.is_timeseries: + total_loss = 0.0 + total_samples = 0 + channel_indices = getattr( + Config().trainer, "prediction_channel_indices", None + ) + if channel_indices is not None: + try: + channel_indices = list(channel_indices) + except TypeError: + channel_indices = None + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["future_values"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + + preds = getattr(outputs, "prediction_outputs", None) + if preds is None: + preds = getattr(outputs, "logits", None) + if preds is None and isinstance(outputs, dict): + preds = outputs.get("prediction_outputs") or outputs.get( + "logits" + ) + + if preds is None: + loss = getattr(outputs, "loss", None) + if loss is None and isinstance(outputs, dict): + loss = outputs.get("loss") + if loss is None: + continue + batch_loss = loss + else: + if labels is None: + continue + preds = preds.to(labels.device) + labels_for_loss = labels + if channel_indices is not None: + if preds.shape[-1] != len(channel_indices): + preds = preds[..., channel_indices] + if labels_for_loss.shape[-1] != len(channel_indices): + labels_for_loss = labels_for_loss[..., channel_indices] + batch_loss = F.mse_loss(preds, labels_for_loss) + + batch_size = ( + batch_inputs["past_values"].size(0) + if "past_values" in batch_inputs + else 1 + ) + total_loss += batch_loss.item() * batch_size + total_samples += batch_size - model.train() - context.state.pop("eval_loader", None) + model.train() + context.state.pop("eval_loader", None) - if total_weight == 0: - return float("inf") + if total_samples == 0: + return float("inf") - avg_loss = total_loss / total_weight - try: - return math.exp(avg_loss) - except OverflowError: - return float("inf") + # Return MSE + return total_loss / total_samples + else: + # Text/NLP: compute perplexity + total_loss = 0.0 + total_weight = 0 + + with torch.no_grad(): + for batch_inputs, labels in data_loader: + batch_inputs = batch_inputs.to(context.device) + if labels is not None: + labels = labels.to(context.device) + batch_inputs["labels"] = labels + + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + loss = _resolve_hf_loss(outputs, labels) + + if labels is not None: + weight = labels.ne(-100).sum().item() + if weight == 0: + continue + else: + weight = 1 + + total_loss += loss.item() * weight + total_weight += weight + + model.train() + context.state.pop("eval_loader", None) + + if total_weight == 0: + return float("inf") + + avg_loss = total_loss / total_weight + try: + return math.exp(avg_loss) + except OverflowError: + return float("inf") def _split_callback_types( @@ -416,8 +533,6 @@ def on_train_step_end(self, trainer, config, batch, loss, **kwargs): class Trainer(ComposableTrainer): """Composable HuggingFace trainer built on Plato's strategy API.""" - training_args: TrainingArguments - def __init__(self, model=None, callbacks=None): hf_callbacks, plato_callbacks = _split_callback_types(callbacks) @@ -427,66 +542,83 @@ def __init__(self, model=None, callbacks=None): self._hf_control = TrainerControl() self._hf_steps_per_epoch: int | None = None - parser = HfArgumentParser(cast(Any, TrainingArguments)) - (training_args,) = parser.parse_args_into_dataclasses( + parser = HfArgumentParser(TrainingArguments) + (self.training_args,) = parser.parse_args_into_dataclasses( args=[ "--output_dir=" + Config.params["checkpoint_path"], "--report_to=none", ] ) - self.training_args = cast(TrainingArguments, training_args) - model_name = Config().trainer.model_name - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } - self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + model_name = getattr(Config().trainer, "model_name", "") + model_type = getattr(Config().trainer, "model_type", None) - cache_dir = Config().params["data_path"] - use_fast_tokenizer = True - revision = "main" - auth_token = getattr( - getattr(Config(), "parameters", None), "huggingface_token", None + # Detect if this is a time series model + self._is_timeseries = is_timeseries_model( + model_name=model_name, model_type=model_type ) - if "llama" in model_name: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) - else: - self.tokenizer = LlamaTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) - else: - if isinstance(auth_token, str) and auth_token: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - use_auth_token=auth_token, - ) + if self._is_timeseries: + logging.info( + "Detected time series model (type: %s, name: %s)", + model_type, + model_name, + ) + + self.config = None + if not self._is_timeseries: + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } + self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + + self.tokenizer = None + if not self._is_timeseries: + cache_dir = Config().params["data_path"] + use_fast_tokenizer = True + revision = "main" + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) + + if "llama" in model_name: + if isinstance(auth_token, str) and auth_token: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = LlamaTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) else: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - config=self.config, - cache_dir=cache_dir, - use_fast=use_fast_tokenizer, - revision=revision, - ) + if isinstance(auth_token, str) and auth_token: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + use_auth_token=auth_token, + ) + else: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + config=self.config, + cache_dir=cache_dir, + use_fast=use_fast_tokenizer, + revision=revision, + ) grad_accum_steps = getattr(Config().trainer, "gradient_accumulation_steps", 1) try: @@ -494,7 +626,15 @@ def __init__(self, model=None, callbacks=None): except (TypeError, ValueError): grad_accum_steps = 1 self._gradient_accumulation_steps = max(grad_accum_steps, 1) - self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + + # Choose collator based on model type + if self._is_timeseries: + self._collate_wrapper = TimeSeriesCollateWrapper() + logging.info("Using TimeSeriesCollateWrapper for time series model") + else: + self._collate_wrapper = HuggingFaceCollateWrapper(self.tokenizer) + logging.info("Using HuggingFaceCollateWrapper for text model") + self.training_args.gradient_accumulation_steps = ( self._gradient_accumulation_steps ) @@ -516,14 +656,16 @@ def __init__(self, model=None, callbacks=None): num_workers=0, pin_memory=True, ), - testing_strategy=HuggingFaceTestingStrategy(self._collate_wrapper), + testing_strategy=HuggingFaceTestingStrategy( + self._collate_wrapper, is_timeseries=self._is_timeseries + ), ) if hf_callbacks: self.add_callbacks(hf_callbacks) model_instance = self._require_model() - if hasattr(model_instance, "loss_type"): + if hasattr(model_instance, "loss_type") and not self._is_timeseries: setattr(model_instance, "loss_type", "ForCausalLM") # Ensure model checkpoints can be saved when model names include slashes. diff --git a/plato/trainers/lerobot.py b/plato/trainers/lerobot.py index e66d4ae9e..39fb11bab 100644 --- a/plato/trainers/lerobot.py +++ b/plato/trainers/lerobot.py @@ -344,9 +344,7 @@ def _resolve_runtime_device(device_value: Any, fallback_device: Any) -> torch.de try: gpu_index = int(normalized.split(":", 1)[1]) except (IndexError, ValueError) as exc: - raise ValueError( - f"Invalid CUDA device value: '{device_value}'." - ) from exc + raise ValueError(f"Invalid CUDA device value: '{device_value}'.") from exc if gpu_index < 0 or gpu_index >= torch.cuda.device_count(): raise RuntimeError( f"`parameters.policy.device` requested CUDA device {gpu_index}, " @@ -466,9 +464,7 @@ def training_step( ) if not torch.is_tensor(loss): - raise TypeError( - "LeRobot policy forward did not return a tensor loss." - ) + raise TypeError("LeRobot policy forward did not return a tensor loss.") loss.backward() optimizer.step() diff --git a/plato/trainers/lr_schedulers.py b/plato/trainers/lr_schedulers.py index edf87bd9b..fe2b77fef 100644 --- a/plato/trainers/lr_schedulers.py +++ b/plato/trainers/lr_schedulers.py @@ -104,8 +104,9 @@ def get(optimizer: optim.Optimizer, iterations_per_epoch: int, **kwargs: str | d for x in lr_params["milestone_steps"].split(",") ] lambdas.append( - lambda it, milestones=milestones: lr_params["gamma"] - ** bisect.bisect(milestones, it) + lambda it, milestones=milestones: ( + lr_params["gamma"] ** bisect.bisect(milestones, it) + ) ) # Add a linear learning rate warmup if specified diff --git a/tests/algorithms/test_fedavg_algorithm.py b/tests/algorithms/test_fedavg_algorithm.py index 1ccdd8399..69657036e 100644 --- a/tests/algorithms/test_fedavg_algorithm.py +++ b/tests/algorithms/test_fedavg_algorithm.py @@ -39,9 +39,7 @@ class BFloat16ToyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.weight = torch.nn.Parameter( - torch.ones((2, 2), dtype=torch.bfloat16) - ) + self.weight = torch.nn.Parameter(torch.ones((2, 2), dtype=torch.bfloat16)) def _algorithm_for(model: torch.nn.Module) -> FedAvgAlgorithm: @@ -114,9 +112,7 @@ def test_extract_weights_casts_bfloat16_payloads_for_transport(): payload = algorithm.extract_weights() assert payload["weight"].dtype == torch.float32 - inbound = OrderedDict( - {"weight": torch.full((2, 2), 3.5, dtype=torch.float32)} - ) + inbound = OrderedDict({"weight": torch.full((2, 2), 3.5, dtype=torch.float32)}) algorithm.load_weights(inbound) state = model.state_dict() diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 41da1e1d1..19e43cee4 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -212,8 +212,6 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): Config._cli_overrides = {} - - def test_config_loads_smolvla_lerobot_parameter_contract(tmp_path: Path, monkeypatch): """SmolVLA/LeRobot config keys should round-trip through Config().""" config_base = tmp_path / "runtime" diff --git a/tests/trainers/test_lerobot_trainer.py b/tests/trainers/test_lerobot_trainer.py index 1d1316aff..9b849bcaa 100644 --- a/tests/trainers/test_lerobot_trainer.py +++ b/tests/trainers/test_lerobot_trainer.py @@ -166,7 +166,7 @@ def test_lerobot_trainer_consumes_policy_precision_and_device( monkeypatch.setattr( lerobot_trainer, "_import_make_pre_post_processors", - lambda: (lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out)), + lambda: lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out), ) trainer = lerobot_trainer.Trainer(model=_TinyLeRobotPolicy()) From f8e1c3ad6f59dd5f90e6a3809656f7c717cd3340 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 22:20:42 -0400 Subject: [PATCH 04/39] Added ev_charging datasource and the data normalization for time series data. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Dataset: "EV Charging Reports" – Mendeley Data (dataset1_ev_charging_reports.csv) --- plato/datasources/ev_charging.py | 420 +++++++++++++++++++++++++++++++ 1 file changed, 420 insertions(+) create mode 100644 plato/datasources/ev_charging.py diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py new file mode 100644 index 000000000..905303eb2 --- /dev/null +++ b/plato/datasources/ev_charging.py @@ -0,0 +1,420 @@ +""" +EV Charging Datasource for Federated Time-Series Forecasting. + +Dataset: "EV Charging Reports" – Mendeley Data (dataset1_ev_charging_reports.csv) + https://data.mendeley.com/datasets/jbks2rcwyj/1 + +Raw CSV format: + session_ID ; Garage_ID ; User_ID ; User_type ; Shared_ID ; + Start_plugin ; Start_plugin_hour ; End_plugout ; End_plugout_hour ; + El_kWh ; Duration_hours ; month_plugin ; weekdays_plugin ; + Plugin_category ; Duration_category + - Datetimes use DD.MM.YYYY HH:MM format. + - El_kWh uses a comma as the decimal separator. + +Preprocessing pipeline +----------------------- +1. Filter to the requested garage (default "AdO1", which has 4 private users). +2. For each user, build a continuous hourly grid from the first to the last + session hour in the dataset. +3. For every hour, mark is_charging = 1 if the user had an active session, + else 0; accumulate energy_kwh proportionally over session hours. +4. Scale energy_kwh ∈ [0, 1] using the training-split maximum. +5. Add cyclic time encodings: + hour_sin = sin(2π · hour / 24) hour_cos = cos(2π · hour / 24) + dow_sin = sin(2π · dow / 7) dow_cos = cos(2π · dow / 7) +6. Split temporally: 70 % train, 15 % val, 15 % test. +7. Build sliding-window samples: + past_values : (context_length, 6) : all features + future_values : (prediction_length, 1) : is_charging only + +Federated split +--------------- +Each client sees only its own user's data. + +TOML configuration +------------------ +[data] +datasource = "EVCharging" +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" +garage = "AdO1" # optional +num_users = 4 # optional + +[trainer] +context_length = 168 # 7 * 24 h +prediction_length = 168 # 7 * 24 h +train_ratio = 0.70 +val_ratio = 0.15 +stride = 1 # slide 1 hour at a time +""" + +from __future__ import annotations + +import logging +import os + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from plato.config import Config + + +# Exact column names from the Mendeley CSV +_CSV_SEP = ";" +_GARAGE_COL = "Garage_ID" +_USER_COL = "User_ID" +_START_COL = "Start_plugin" +_END_COL = "End_plugout" +_ENERGY_COL = "El_kWh" +_DT_FORMAT = "%d.%m.%Y %H:%M" + + +# Preprocessing helpers +def _parse_european_float(series: pd.Series) -> pd.Series: + """Replace comma decimal separator and coerce to float.""" + return ( + series.astype(str) + .str.replace(",", ".", regex=False) + .str.strip() + .pipe(pd.to_numeric, errors="coerce") + .fillna(0.0) + ) + + +def _build_hourly_series( + df: pd.DataFrame, + garage: str, + num_users: int, + user_ids: list[str] | None = None, +) -> dict[str, pd.DataFrame]: + """ + Build per-user hourly DataFrames from raw session records. + + Parameters: + user_ids : explicit list of User_ID strings to include. + num_users : max number of users to take alphabetically when ``user_ids`` + is not given. + + Returns: + dict mapping user_id (str) -> pd.DataFrame with hourly index and columns: + is_charging (0/1 float), energy_kwh (float >= 0) + """ + # Filter to requested garage + mask = df[_GARAGE_COL].astype(str).str.strip() == garage + df = df[mask].copy() + if df.empty: + raise ValueError( + f"No records found for garage '{garage}'. " + f"Available: {sorted(df[_GARAGE_COL].unique())}" + ) + + # Parse datetimes + df[_START_COL] = pd.to_datetime(df[_START_COL].str.strip(), format=_DT_FORMAT) + df[_END_COL] = pd.to_datetime(df[_END_COL].str.strip(), format=_DT_FORMAT) + df[_ENERGY_COL] = _parse_european_float(df[_ENERGY_COL]) + + # Drop invalid rows + df = df.dropna(subset=[_START_COL, _END_COL]) + df = df[df[_END_COL] > df[_START_COL]] + + # Resolve user list + available = sorted(df[_USER_COL].dropna().unique()) + if user_ids is not None: + # Explicit list from config — validate each entry + missing = [u for u in user_ids if u not in available] + if missing: + raise ValueError( + f"Users not found in garage '{garage}': {missing}. " + f"Available: {available}" + ) + users = list(user_ids) # preserve config order + else: + users = available[:num_users] + logging.info("EVCharging: garage '%s' → users %s", garage, users) + + result: dict[str, pd.DataFrame] = {} + for user in users: + udf = df[df[_USER_COL] == user] + + # Per-user hourly index: only spans that user's own activity window. + # Using a global index would pad every user with the same number of + # zero-charging hours, giving all clients identical dataset sizes. + user_start = udf[_START_COL].min().floor("h") + user_end = udf[_END_COL].max().ceil("h") + hourly_index = pd.date_range(user_start, user_end, freq="h") + + is_charging = pd.Series(0.0, index=hourly_index) + energy_kwh = pd.Series(0.0, index=hourly_index) + + for _, row in udf.iterrows(): + # All hours touched by this session + session_hours = pd.date_range( + row[_START_COL].floor("h"), + row[_END_COL].floor("h"), + freq="h", + ) + valid_hours = session_hours[session_hours.isin(hourly_index)] + if valid_hours.empty: + continue + is_charging[valid_hours] = 1.0 + energy_per_hour = float(row[_ENERGY_COL]) / max(len(valid_hours), 1) + energy_kwh[valid_hours] += energy_per_hour + + user_df = pd.DataFrame( + {"is_charging": is_charging, "energy_kwh": energy_kwh}, + index=hourly_index, + ) + user_df.index.name = "timestamp" + result[user] = user_df + + return result + + +def _add_time_features(df: pd.DataFrame) -> pd.DataFrame: + """Append cyclic hour-of-day and day-of-week columns.""" + hour = df.index.hour.astype(float) + dow = df.index.dayofweek.astype(float) + df = df.copy() + df["hour_sin"] = np.sin(2 * np.pi * hour / 24) + df["hour_cos"] = np.cos(2 * np.pi * hour / 24) + df["dow_sin"] = np.sin(2 * np.pi * dow / 7) + df["dow_cos"] = np.cos(2 * np.pi * dow / 7) + return df + + +# Ordered feature columns fed into the model +_FEATURE_COLS = [ + "is_charging", + "energy_scaled", + "hour_sin", + "hour_cos", + "dow_sin", + "dow_cos", +] + + +# Torch Dataset +class _EVChargingDataset(Dataset): + """Sliding-window samples for one user / one split. + + Each sample: + past_values : FloatTensor (context_length, 6) + future_values : FloatTensor (prediction_length, 1) ← is_charging only + """ + + def __init__( + self, + data: np.ndarray, # shape (T, 6), already normalized + context_length: int, + prediction_length: int, + stride: int = 1, + starts: list[int] | None = None, # explicit window start indices + ): + super().__init__() + self.data = torch.FloatTensor(data) + self.context_length = context_length + self.prediction_length = prediction_length + if starts is not None: + # Caller already computed and partitioned the valid starts. + self.indices = starts + else: + total = context_length + prediction_length + max_start = len(data) - total + if max_start < 0: + logging.warning( + "EVCharging: data has only %d steps but needs %d " + "(context=%d + prediction=%d) — dataset will be empty.", + len(data), + total, + context_length, + prediction_length, + ) + self.indices = [] + else: + self.indices = list(range(0, max_start + 1, stride)) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, idx: int) -> dict: + s = self.indices[idx] + e_ctx = s + self.context_length + e_pred = e_ctx + self.prediction_length + return { + "past_values": self.data[s:e_ctx], # (ctx, 6) + "future_values": self.data[e_ctx:e_pred, :1], # (pred, 1) + } + + +# Plato DataSource +class DataSource: + """EV Charging DataSource for Plato federated learning. + + Each instance represents ONE user (client_id selects the user, 0-indexed + over the alphabetically sorted user list for the requested garage). + + Typical config (timesfm_ev_charging.toml): + + [data] + datasource = "EVCharging" + datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + garage = "AdO1" + num_users = 4 + + [trainer] + context_length = 168 + prediction_length = 168 + train_ratio = 0.70 + val_ratio = 0.15 + stride = 24 + """ + + def __init__(self, client_id: int = 0, **kwargs): + cfg = Config() + data_cfg = cfg.data + trainer_cfg = cfg.trainer + + # Locate CSV + csv_path = kwargs.get( + "datasource_path", + getattr(data_cfg, "datasource_path", None), + ) + if csv_path is None: + raise ValueError( + "EVCharging requires 'datasource_path' in [data] config, " + 'e.g. datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv"' + ) + if not os.path.isabs(csv_path): + csv_path = os.path.join(os.getcwd(), csv_path) + if not os.path.exists(csv_path): + raise FileNotFoundError( + f"EV charging CSV not found: {csv_path}\n" + "Download from https://data.mendeley.com/datasets/jbks2rcwyj/1" + ) + + garage = str(kwargs.get("garage", getattr(data_cfg, "garage", "AdO1"))) + + # Config: users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] + user_ids_cfg = kwargs.get("users", getattr(data_cfg, "users", None)) + if user_ids_cfg is not None: + user_ids: list[str] | None = [str(u) for u in user_ids_cfg] + num_users = len(user_ids) + else: + user_ids = None + num_users = int(kwargs.get("num_users", getattr(data_cfg, "num_users", 4))) + + # Window / split settings + self.context_length = int(getattr(trainer_cfg, "context_length", 168)) + self.prediction_length = int(getattr(trainer_cfg, "prediction_length", 168)) + train_ratio = float(getattr(trainer_cfg, "train_ratio", 0.70)) + val_ratio = float(getattr(trainer_cfg, "val_ratio", 0.15)) + stride = int(getattr(trainer_cfg, "stride", 1)) + + # Load and preprocess + logging.info("EVCharging: loading %s", csv_path) + raw_df = pd.read_csv(csv_path, sep=_CSV_SEP, low_memory=False) + + user_series = _build_hourly_series( + raw_df, garage=garage, num_users=num_users, user_ids=user_ids + ) + + # Preserve config-specified order when user_ids is given + if user_ids is not None: + users = [u for u in user_ids if u in user_series] + else: + users = sorted(user_series.keys()) + + user_index = max(0, client_id - 1) + + if user_index >= len(users): + raise ValueError( + f"client_id={client_id} out of range; " + f"found {len(users)} users in garage '{garage}': {users}" + ) + + user_key = users[user_index] + logging.info("EVCharging: client_id=%d → user '%s'", client_id, user_key) + + user_df = _add_time_features(user_series[user_key]) + raw_array = user_df[ + [ + "is_charging", + "energy_kwh", + "hour_sin", + "hour_cos", + "dow_sin", + "dow_cos", + ] + ].values.astype(np.float32) + + # Split window indices, not raw hours + window = self.context_length + self.prediction_length + all_starts = list(range(0, len(raw_array) - window + 1, stride)) + n_windows = len(all_starts) + + n_train_w = max(1, int(n_windows * train_ratio)) + n_val_w = max(0, int(n_windows * val_ratio)) + train_starts = all_starts[:n_train_w] + val_starts = all_starts[n_train_w : n_train_w + n_val_w] + test_starts = all_starts[n_train_w + n_val_w :] + + # Energy scaling + if train_starts: + train_end = min(len(user_df), train_starts[-1] + window) + else: + train_end = max(1, int(len(user_df) * train_ratio)) + energy_max = float(user_df["energy_kwh"].iloc[:train_end].max()) or 1.0 + + full_array = raw_array.copy() + full_array[:, 1] = full_array[:, 1] / energy_max # -> energy_scaled in [0, 1] + + # Keep the full normalized array for inference scripts + self.normalized_data = full_array + + self._train_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=train_starts, + ) + self._val_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=val_starts, + ) + self._test_set = _EVChargingDataset( + full_array, + self.context_length, + self.prediction_length, + stride=stride, + starts=test_starts, + ) + + logging.info( + "EVCharging user '%s': %d train / %d val / %d test windows", + user_key, + len(self._train_set), + len(self._val_set), + len(self._test_set), + ) + + # Plato DataSource interface + def get_train_set(self) -> _EVChargingDataset: + return self._train_set + + def get_val_set(self) -> _EVChargingDataset: + return self._val_set + + def get_test_set(self) -> _EVChargingDataset: + return self._test_set + + def num_train_examples(self) -> int: + return len(self._train_set) + + def num_test_examples(self) -> int: + return len(self._test_set) From 536f8464088bb2e03e3fa30f72855bb61eb05e48 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 22:22:35 -0400 Subject: [PATCH 05/39] Added ev availablitity datasource in the registry. --- plato/datasources/registry.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index 0a609d5e5..990b32ed5 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -8,6 +8,7 @@ from plato.config import Config from plato.datasources import ( cinic10, + ev_charging, feature, femnist, huggingface, @@ -32,7 +33,11 @@ "Nanochat": nanochat, } -registered_partitioned_datasources = {"FEMNIST": femnist, "LeRobot": lerobot} +registered_partitioned_datasources = { + "FEMNIST": femnist, + "LeRobot": lerobot, + "EVCharging": ev_charging, # per-user split; client_id selects the user +} _datasource_aliases = { "STL10": ("Torchvision", {"dataset_name": "STL10"}), From 907268ab97e063eb160d9fa32ffd64baea84e7f7 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 16 Mar 2026 22:25:37 -0400 Subject: [PATCH 06/39] Added TimesFM model and TimeSeriesUtils for time-series forecasting tasks. (Need to be refactored to be more modular and extensible in the future.) --- plato/models/huggingface.py | 314 +++++++++++++++++++++++++++++++- plato/models/registry.py | 2 + plato/utils/timeseries_utils.py | 43 +++++ 3 files changed, 351 insertions(+), 8 deletions(-) create mode 100644 plato/utils/timeseries_utils.py diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 4a8d3aafb..9090b1de9 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -7,9 +7,34 @@ import logging from typing import Any, Dict +import torch +import torch.nn as nn +import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM from plato.config import Config +from plato.utils.timeseries_utils import is_timeseries_model + +try: + from transformers import ( + PatchTSMixerConfig, + PatchTSMixerForPrediction, + PatchTSMixerForPretraining, + PatchTSMixerForRegression, + PatchTSMixerForTimeSeriesClassification, + ) +except ImportError: + PatchTSMixerConfig = None + PatchTSMixerForPrediction = None + PatchTSMixerForTimeSeriesClassification = None + PatchTSMixerForRegression = None + PatchTSMixerForPretraining = None + +try: + from transformers import TimesFmConfig, TimesFmModelForPrediction +except ImportError: + TimesFmConfig = None + TimesFmModelForPrediction = None try: from peft import LoraConfig, get_peft_model @@ -35,18 +60,263 @@ def _lora_config_dict(lora_config: Any) -> dict[str, Any]: raise TypeError("Unsupported LoRA configuration format.") +class _TimesFmOutput: + """Output container compatible with the Plato time-series training/testing pipeline.""" + + def __init__(self, loss=None, prediction_outputs=None): + self.loss = loss + self.prediction_outputs = prediction_outputs + + +class TimesFmMultivariateWrapper(nn.Module): + """Wraps TimesFmModelForPrediction for batched, multivariate time series. + + TimesFM is natively univariate (each call takes a list of 1-D tensors). + This wrapper accepts a standard batched tensor of shape + ``(batch, context_length)`` or ``(batch, context_length, channels)`` + and handles the reshaping transparently so the rest of the Plato pipeline + (collators, training/testing strategies) needs no changes. + + For multivariate input, every channel is processed independently through + the same TimesFM model (channel-independent forecasting). The outputs are + recombined into ``(batch, prediction_length, channels)``. + + If ``future_values`` is provided the wrapper computes MSE loss against it + and stores it in ``.loss``. ``prediction_outputs`` always holds the mean + predictions in ``(batch, prediction_length, out_channels)`` form. + + Args: + model: An instantiated ``TimesFmModelForPrediction``. + prediction_length: Number of future steps to keep. Predictions are + truncated to this length when the model's ``horizon_length`` + differs from the configured ``prediction_length``. + default_freq: Default frequency token (0 = high/hourly, + 1 = medium/daily-weekly, 2 = low/monthly-yearly). + """ + + def __init__( + self, + model: "TimesFmModelForPrediction", + prediction_length: int | None = None, + default_freq: int = 0, + ): + super().__init__() + self.model = model + self.prediction_length = prediction_length + self.default_freq = default_freq + + def forward( + self, + past_values: torch.Tensor, + future_values: torch.Tensor | None = None, + freq: int | list | torch.Tensor | None = None, + return_dict: bool = True, # accepted for API compat, ignored internally + **kwargs, + ) -> _TimesFmOutput: + if not isinstance(past_values, torch.Tensor): + raise TypeError("past_values must be a torch.Tensor") + + if past_values.dim() == 3: + # Multivariate path + batch, ctx, channels = past_values.shape + # (batch, ctx, ch) -> (batch*ch, ctx) + pv_2d = past_values.permute(0, 2, 1).reshape(batch * channels, ctx) + past_list = [pv_2d[i] for i in range(pv_2d.size(0))] + freq_list = self._build_freq_list(freq, batch, channels) + + outputs = self.model(past_values=past_list, freq=freq_list) + + # (batch*ch, horizon) -> (batch, horizon, ch) + raw = outputs.mean_predictions + horizon = raw.shape[-1] + mean_preds = raw.reshape(batch, channels, horizon).permute(0, 2, 1) + + else: + # Univariate path + batch = past_values.size(0) + past_list = [past_values[i] for i in range(batch)] + freq_list = self._build_freq_list(freq, batch, channels=1) + + outputs = self.model(past_values=past_list, freq=freq_list) + mean_preds = outputs.mean_predictions.unsqueeze(-1) # (batch, horizon, 1) + + # Truncate to configured prediction_length + if self.prediction_length is not None: + mean_preds = mean_preds[:, : self.prediction_length, :] + + # Compute MSE loss when targets are provided + loss = None + if future_values is not None: + fv = future_values + if fv.dim() == 2: + fv = fv.unsqueeze(-1) # (batch, pred) -> (batch, pred, 1) + min_len = min(mean_preds.shape[1], fv.shape[1]) + loss = F.mse_loss( + mean_preds[:, :min_len, : fv.shape[-1]], + fv[:, :min_len, :], + ) + + return _TimesFmOutput(loss=loss, prediction_outputs=mean_preds) + + def _build_freq_list( + self, + freq: int | list | torch.Tensor | None, + batch: int, + channels: int, + ) -> list[int]: + n = batch * channels + if freq is None: + return [self.default_freq] * n + if isinstance(freq, int): + return [freq] * n + if isinstance(freq, torch.Tensor): + freq = freq.tolist() + # freq is list of length batch; expand for each channel + return [int(f) for f in freq for _ in range(channels)] + + class Model: - """The CausalLM model loaded from HuggingFace.""" + """The HuggingFace model factory supporting various model types.""" @staticmethod - def get(model_name=None, **kwargs): # pylint: disable=unused-argument - """Returns a named model from HuggingFace.""" - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, + def _get_timeseries_task_type(model_task=None): + """Determine the task type for time series models from config or arguments.""" + trainer_config = Config().trainer + return ( + model_task + or getattr(trainer_config, "model_task", None) + or getattr(trainer_config, "task_type", "forecasting") + ) + + # PatchTSMixer + + @staticmethod + def _get_patchtsmixer_model(resolved_model_name, cache_dir, model_task=None): + """Load or create a PatchTSMixer model.""" + if PatchTSMixerForPrediction is None: + raise ImportError( + "PatchTSMixer models are not available. " + "Ensure you have transformers>=4.35.0 installed." + ) + + task_type = Model._get_timeseries_task_type(model_task) + + task_models = { + "classification": PatchTSMixerForTimeSeriesClassification, + "regression": PatchTSMixerForRegression, + "pretraining": PatchTSMixerForPretraining, + "forecasting": PatchTSMixerForPrediction, } + model_class = task_models.get(task_type, PatchTSMixerForPrediction) + + try: + logging.info( + "Attempting to load pretrained PatchTSMixer model: %s", + resolved_model_name, + ) + model = model_class.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained model") + except (OSError, ValueError, Exception): + logging.info( + "Model '%s' not found as pretrained, creating from config settings", + resolved_model_name, + ) + trainer_config = Config().trainer + + scaling_param = getattr(trainer_config, "scaling", "std") + if isinstance(scaling_param, str) and scaling_param.lower() == "none": + scaling_param = None + + config = PatchTSMixerConfig( + context_length=getattr(trainer_config, "context_length", 512), + prediction_length=getattr(trainer_config, "prediction_length", 96), + num_input_channels=getattr(trainer_config, "num_input_channels", 7), + patch_length=getattr(trainer_config, "patch_length", 8), + patch_stride=getattr(trainer_config, "patch_stride", 8), + d_model=getattr(trainer_config, "d_model", 64), + num_layers=getattr(trainer_config, "num_layers", 8), + expansion_factor=getattr(trainer_config, "expansion_factor", 2), + dropout=getattr(trainer_config, "dropout", 0.2), + head_dropout=getattr(trainer_config, "head_dropout", 0.2), + mode=getattr(trainer_config, "mode", "common_channel"), + gated_attn=getattr(trainer_config, "gated_attn", True), + scaling=scaling_param, + prediction_channel_indices=getattr( + trainer_config, "prediction_channel_indices", None + ), + ) + + if task_type == "classification": + config.num_labels = getattr(trainer_config, "num_classes", 2) + model = PatchTSMixerForTimeSeriesClassification(config) + elif task_type == "regression": + config.num_targets = getattr(trainer_config, "num_targets", 1) + model = PatchTSMixerForRegression(config) + elif task_type == "pretraining": + model = PatchTSMixerForPretraining(config) + else: + model = PatchTSMixerForPrediction(config) + + return model + + # TimesFM + + @staticmethod + def _get_timesfm_model(resolved_model_name, cache_dir): + """Load or create a TimesFM model wrapped for batched multivariate use.""" + if TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) + + trainer_config = Config().trainer + prediction_length = getattr(trainer_config, "prediction_length", 128) + default_freq = getattr(trainer_config, "freq", 0) + + try: + logging.info( + "Attempting to load pretrained TimesFM model: %s", + resolved_model_name, + ) + inner = TimesFmModelForPrediction.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained TimesFM model") + except (OSError, ValueError, Exception): + logging.info( + "TimesFM model '%s' not found as pretrained, creating from config", + resolved_model_name, + ) + context_length = getattr(trainer_config, "context_length", 512) + horizon_length = prediction_length + + config = TimesFmConfig( + context_length=context_length, + horizon_length=horizon_length, + patch_length=getattr(trainer_config, "patch_length", 32), + num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), + hidden_size=getattr(trainer_config, "hidden_size", 1280), + intermediate_size=getattr(trainer_config, "intermediate_size", 1280), + num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), + head_dim=getattr(trainer_config, "head_dim", 80), + attention_dropout=getattr(trainer_config, "dropout", 0.0), + ) + inner = TimesFmModelForPrediction(config) + + return TimesFmMultivariateWrapper( + model=inner, + prediction_length=prediction_length, + default_freq=default_freq, + ) + + # Main factory entry point + @staticmethod + def get(model_name=None, **kwargs): # pylint: disable=unused-argument + """Returns a named model from HuggingFace.""" resolved_model_name = ( model_name if isinstance(model_name, str) and model_name @@ -55,6 +325,35 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument if not isinstance(resolved_model_name, str) or not resolved_model_name: raise ValueError("A valid HuggingFace model name must be provided.") + cache_dir = Config().params["model_path"] + "/huggingface" + + model_type = kwargs.get("model_type") or getattr( + getattr(Config(), "trainer", None), "model_type", None + ) + + is_timeseries = is_timeseries_model( + model_name=resolved_model_name, model_type=model_type + ) + + if is_timeseries: + model_type_lower = (model_type or "").lower() + model_name_lower = resolved_model_name.lower() + + if model_type_lower == "timesfm" or "timesfm" in model_name_lower: + return Model._get_timesfm_model(resolved_model_name, cache_dir) + + # Default time-series path -> PatchTSMixer + model_task = kwargs.get("model_task") + return Model._get_patchtsmixer_model( + resolved_model_name, cache_dir, model_task + ) + + # NLP / CausalLM path + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } config = AutoConfig.from_pretrained(resolved_model_name, **config_kwargs) model = AutoModelForCausalLM.from_pretrained( @@ -70,7 +369,6 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument "The 'peft' package is required for LoRA fine-tuning. " "Install it by running `uv add peft`." ) - params_dict = _lora_config_dict(lora_params) logging.info("Configuring LoRA with parameters: %s", params_dict) lora_cfg = LoraConfig(**params_dict) diff --git a/plato/models/registry.py b/plato/models/registry.py index 1dbda56cf..b66c0edbb 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -44,6 +44,8 @@ "vit": vit.Model, "nanochat": nanochat.Model, "smolvla": smolvla.Model, + "timesfm": huggingface.Model, + "patchtsmixer": huggingface.Model, } registered_mlx_models = {} diff --git a/plato/utils/timeseries_utils.py b/plato/utils/timeseries_utils.py new file mode 100644 index 000000000..36449d713 --- /dev/null +++ b/plato/utils/timeseries_utils.py @@ -0,0 +1,43 @@ +""" +Utility functions for time series model detection and handling. +""" + +from typing import Optional + + +def is_timeseries_model( + model_name: Optional[str] = None, + model_type: Optional[str] = None, + dataset_type: Optional[str] = None, +) -> bool: + """ + Check if a model/dataset is for time series. + + Args: + model_name: Name of the model + model_type: Type of model from config + dataset_type: Type of dataset from config + + Returns: + True if this is a time series model, False otherwise + """ + model_name_lower = model_name.lower() if model_name else "" + model_type_lower = model_type.lower() if model_type else "" + + # Check for PatchTSMixer + if ( + model_type_lower == "patchtsmixer" + or "patchtsmixer" in model_name_lower + or "timeseries" in model_name_lower + ): + return True + + # Check for TimesFM + if model_type_lower == "timesfm" or "timesfm" in model_name_lower: + return True + + # Check dataset type + if dataset_type and dataset_type.lower() == "timeseries": + return True + + return False From 6fb4af7b16c3bb73f3627ed5c063efab6603e9ea Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 23 Mar 2026 18:19:51 -0400 Subject: [PATCH 07/39] Refactored timeseries models from HuggingFace. --- plato/datasources/ev_charging.py | 4 +- plato/models/huggingface.py | 313 +++++++++++++++++-------------- plato/utils/timeseries_utils.py | 27 +-- 3 files changed, 189 insertions(+), 155 deletions(-) diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py index 905303eb2..d5566c872 100644 --- a/plato/datasources/ev_charging.py +++ b/plato/datasources/ev_charging.py @@ -132,7 +132,7 @@ def _build_hourly_series( users = list(user_ids) # preserve config order else: users = available[:num_users] - logging.info("EVCharging: garage '%s' → users %s", garage, users) + logging.info("EVCharging: garage '%s' -> users %s", garage, users) result: dict[str, pd.DataFrame] = {} for user in users: @@ -335,7 +335,7 @@ def __init__(self, client_id: int = 0, **kwargs): ) user_key = users[user_index] - logging.info("EVCharging: client_id=%d → user '%s'", client_id, user_key) + logging.info("EVCharging: client_id=%d -> user '%s'", client_id, user_key) user_df = _add_time_features(user_series[user_key]) raw_array = user_df[ diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 9090b1de9..4efa032e7 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import Any, Dict +from typing import Any, Callable, Dict import torch import torch.nn as nn @@ -175,148 +175,187 @@ def _build_freq_list( return [int(f) for f in freq for _ in range(channels)] -class Model: - """The HuggingFace model factory supporting various model types.""" +# --------------------------------------------------------------------------- +# Time-series model loaders +# +# To add a new HuggingFace time series model: +# 1. Implement a loader function with signature: +# def _load_(resolved_model_name, cache_dir, **kwargs) -> nn.Module +# 2. Register it below in _TIMESERIES_LOADERS. +# 3. Add the model type string to TIMESERIES_MODEL_TYPES in +# plato/utils/timeseries_utils.py. +# --------------------------------------------------------------------------- + + +def _load_timesfm(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: + """Load or create a TimesFM model wrapped for batched multivariate use.""" + if TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) - @staticmethod - def _get_timeseries_task_type(model_task=None): - """Determine the task type for time series models from config or arguments.""" - trainer_config = Config().trainer - return ( - model_task - or getattr(trainer_config, "model_task", None) - or getattr(trainer_config, "task_type", "forecasting") + trainer_config = Config().trainer + prediction_length = getattr(trainer_config, "prediction_length", 128) + default_freq = getattr(trainer_config, "freq", 0) + + try: + logging.info( + "Attempting to load pretrained TimesFM model: %s", + resolved_model_name, ) + inner = TimesFmModelForPrediction.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained TimesFM model") + except (OSError, ValueError, Exception): + logging.info( + "TimesFM model '%s' not found as pretrained, creating from config", + resolved_model_name, + ) + context_length = getattr(trainer_config, "context_length", 512) + horizon_length = prediction_length + + config = TimesFmConfig( + context_length=context_length, + horizon_length=horizon_length, + patch_length=getattr(trainer_config, "patch_length", 32), + num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), + hidden_size=getattr(trainer_config, "hidden_size", 1280), + intermediate_size=getattr(trainer_config, "intermediate_size", 1280), + num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), + head_dim=getattr(trainer_config, "head_dim", 80), + attention_dropout=getattr(trainer_config, "dropout", 0.0), + ) + inner = TimesFmModelForPrediction(config) - # PatchTSMixer + return TimesFmMultivariateWrapper( + model=inner, + prediction_length=prediction_length, + default_freq=default_freq, + ) - @staticmethod - def _get_patchtsmixer_model(resolved_model_name, cache_dir, model_task=None): - """Load or create a PatchTSMixer model.""" - if PatchTSMixerForPrediction is None: - raise ImportError( - "PatchTSMixer models are not available. " - "Ensure you have transformers>=4.35.0 installed." - ) - task_type = Model._get_timeseries_task_type(model_task) +def _load_patchtsmixer(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: + """Load or create a PatchTSMixer model.""" + if PatchTSMixerForPrediction is None: + raise ImportError( + "PatchTSMixer models are not available. " + "Ensure you have transformers>=4.35.0 installed." + ) - task_models = { - "classification": PatchTSMixerForTimeSeriesClassification, - "regression": PatchTSMixerForRegression, - "pretraining": PatchTSMixerForPretraining, - "forecasting": PatchTSMixerForPrediction, - } - model_class = task_models.get(task_type, PatchTSMixerForPrediction) + trainer_config = Config().trainer + model_task = ( + kwargs.get("model_task") + or getattr(trainer_config, "model_task", None) + or getattr(trainer_config, "task_type", "forecasting") + ) - try: - logging.info( - "Attempting to load pretrained PatchTSMixer model: %s", - resolved_model_name, - ) - model = model_class.from_pretrained( - resolved_model_name, cache_dir=cache_dir - ) - logging.info("Successfully loaded pretrained model") - except (OSError, ValueError, Exception): - logging.info( - "Model '%s' not found as pretrained, creating from config settings", - resolved_model_name, - ) - trainer_config = Config().trainer - - scaling_param = getattr(trainer_config, "scaling", "std") - if isinstance(scaling_param, str) and scaling_param.lower() == "none": - scaling_param = None - - config = PatchTSMixerConfig( - context_length=getattr(trainer_config, "context_length", 512), - prediction_length=getattr(trainer_config, "prediction_length", 96), - num_input_channels=getattr(trainer_config, "num_input_channels", 7), - patch_length=getattr(trainer_config, "patch_length", 8), - patch_stride=getattr(trainer_config, "patch_stride", 8), - d_model=getattr(trainer_config, "d_model", 64), - num_layers=getattr(trainer_config, "num_layers", 8), - expansion_factor=getattr(trainer_config, "expansion_factor", 2), - dropout=getattr(trainer_config, "dropout", 0.2), - head_dropout=getattr(trainer_config, "head_dropout", 0.2), - mode=getattr(trainer_config, "mode", "common_channel"), - gated_attn=getattr(trainer_config, "gated_attn", True), - scaling=scaling_param, - prediction_channel_indices=getattr( - trainer_config, "prediction_channel_indices", None - ), - ) + task_models = { + "classification": PatchTSMixerForTimeSeriesClassification, + "regression": PatchTSMixerForRegression, + "pretraining": PatchTSMixerForPretraining, + "forecasting": PatchTSMixerForPrediction, + } + model_class = task_models.get(model_task, PatchTSMixerForPrediction) + + try: + logging.info( + "Attempting to load pretrained PatchTSMixer model: %s", + resolved_model_name, + ) + model = model_class.from_pretrained(resolved_model_name, cache_dir=cache_dir) + logging.info("Successfully loaded pretrained model") + except (OSError, ValueError, Exception): + logging.info( + "Model '%s' not found as pretrained, creating from config settings", + resolved_model_name, + ) + scaling_param = getattr(trainer_config, "scaling", "std") + if isinstance(scaling_param, str) and scaling_param.lower() == "none": + scaling_param = None + + config = PatchTSMixerConfig( + context_length=getattr(trainer_config, "context_length", 512), + prediction_length=getattr(trainer_config, "prediction_length", 96), + num_input_channels=getattr(trainer_config, "num_input_channels", 7), + patch_length=getattr(trainer_config, "patch_length", 8), + patch_stride=getattr(trainer_config, "patch_stride", 8), + d_model=getattr(trainer_config, "d_model", 64), + num_layers=getattr(trainer_config, "num_layers", 8), + expansion_factor=getattr(trainer_config, "expansion_factor", 2), + dropout=getattr(trainer_config, "dropout", 0.2), + head_dropout=getattr(trainer_config, "head_dropout", 0.2), + mode=getattr(trainer_config, "mode", "common_channel"), + gated_attn=getattr(trainer_config, "gated_attn", True), + scaling=scaling_param, + prediction_channel_indices=getattr( + trainer_config, "prediction_channel_indices", None + ), + ) - if task_type == "classification": - config.num_labels = getattr(trainer_config, "num_classes", 2) - model = PatchTSMixerForTimeSeriesClassification(config) - elif task_type == "regression": - config.num_targets = getattr(trainer_config, "num_targets", 1) - model = PatchTSMixerForRegression(config) - elif task_type == "pretraining": - model = PatchTSMixerForPretraining(config) - else: - model = PatchTSMixerForPrediction(config) + if model_task == "classification": + config.num_labels = getattr(trainer_config, "num_classes", 2) + model = PatchTSMixerForTimeSeriesClassification(config) + elif model_task == "regression": + config.num_targets = getattr(trainer_config, "num_targets", 1) + model = PatchTSMixerForRegression(config) + elif model_task == "pretraining": + model = PatchTSMixerForPretraining(config) + else: + model = PatchTSMixerForPrediction(config) - return model + return model - # TimesFM - @staticmethod - def _get_timesfm_model(resolved_model_name, cache_dir): - """Load or create a TimesFM model wrapped for batched multivariate use.""" - if TimesFmModelForPrediction is None: - raise ImportError( - "TimesFM models are not available. " - "Ensure you have transformers>=5.0.0 installed." - ) +# Registry mapping model_type (lowercase) -> loader function. +# This is the only place that needs updating when a new HF time series model +# is added (along with TIMESERIES_MODEL_TYPES in timeseries_utils.py). +_TIMESERIES_LOADERS: Dict[str, Callable[..., nn.Module]] = { + "timesfm": _load_timesfm, + "patchtsmixer": _load_patchtsmixer, +} - trainer_config = Config().trainer - prediction_length = getattr(trainer_config, "prediction_length", 128) - default_freq = getattr(trainer_config, "freq", 0) - try: - logging.info( - "Attempting to load pretrained TimesFM model: %s", - resolved_model_name, - ) - inner = TimesFmModelForPrediction.from_pretrained( - resolved_model_name, cache_dir=cache_dir - ) - logging.info("Successfully loaded pretrained TimesFM model") - except (OSError, ValueError, Exception): - logging.info( - "TimesFM model '%s' not found as pretrained, creating from config", - resolved_model_name, - ) - context_length = getattr(trainer_config, "context_length", 512) - horizon_length = prediction_length - - config = TimesFmConfig( - context_length=context_length, - horizon_length=horizon_length, - patch_length=getattr(trainer_config, "patch_length", 32), - num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), - hidden_size=getattr(trainer_config, "hidden_size", 1280), - intermediate_size=getattr(trainer_config, "intermediate_size", 1280), - num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), - head_dim=getattr(trainer_config, "head_dim", 80), - attention_dropout=getattr(trainer_config, "dropout", 0.0), - ) - inner = TimesFmModelForPrediction(config) +class Model: + """The HuggingFace model factory supporting various model types.""" - return TimesFmMultivariateWrapper( - model=inner, - prediction_length=prediction_length, - default_freq=default_freq, - ) + @staticmethod + def _get_timeseries_model( + resolved_model_name: str, cache_dir: str, model_type: str = "", **kwargs + ) -> nn.Module: + """Unified entry point for all HuggingFace time series models. + + Dispatches to the appropriate loader in ``_TIMESERIES_LOADERS`` based + on ``model_type`` or a substring match in ``resolved_model_name``. + """ + model_type_lower = model_type.lower() + model_name_lower = resolved_model_name.lower() + + loader = None + for ts_type, ts_loader in _TIMESERIES_LOADERS.items(): + if model_type_lower == ts_type or ts_type in model_name_lower: + loader = ts_loader + break + + if loader is None: + raise ValueError( + f"No time series loader found for model '{resolved_model_name}' " + f"(type='{model_type}'). " + "Register a loader in _TIMESERIES_LOADERS in plato/models/huggingface.py " + "and add the type to TIMESERIES_MODEL_TYPES in plato/utils/timeseries_utils.py." + ) - # Main factory entry point + return loader(resolved_model_name, cache_dir, **kwargs) @staticmethod def get(model_name=None, **kwargs): # pylint: disable=unused-argument - """Returns a named model from HuggingFace.""" + """Returns a named model from HuggingFace. + + Two paths: + - Time series models -> ``_get_timeseries_model()`` + - All other models -> ``AutoModelForCausalLM`` (with optional LoRA) + """ resolved_model_name = ( model_name if isinstance(model_name, str) and model_name @@ -327,25 +366,15 @@ def get(model_name=None, **kwargs): # pylint: disable=unused-argument cache_dir = Config().params["model_path"] + "/huggingface" - model_type = kwargs.get("model_type") or getattr( - getattr(Config(), "trainer", None), "model_type", None - ) - - is_timeseries = is_timeseries_model( - model_name=resolved_model_name, model_type=model_type + model_type = ( + kwargs.get("model_type") + or getattr(getattr(Config(), "trainer", None), "model_type", None) + or "" ) - if is_timeseries: - model_type_lower = (model_type or "").lower() - model_name_lower = resolved_model_name.lower() - - if model_type_lower == "timesfm" or "timesfm" in model_name_lower: - return Model._get_timesfm_model(resolved_model_name, cache_dir) - - # Default time-series path -> PatchTSMixer - model_task = kwargs.get("model_task") - return Model._get_patchtsmixer_model( - resolved_model_name, cache_dir, model_task + if is_timeseries_model(model_name=resolved_model_name, model_type=model_type): + return Model._get_timeseries_model( + resolved_model_name, cache_dir, model_type=model_type, **kwargs ) # NLP / CausalLM path diff --git a/plato/utils/timeseries_utils.py b/plato/utils/timeseries_utils.py index 36449d713..0593695ff 100644 --- a/plato/utils/timeseries_utils.py +++ b/plato/utils/timeseries_utils.py @@ -4,6 +4,11 @@ from typing import Optional +# Single source of truth: all known HuggingFace time series model types. +# When adding a new time series model, register it here AND add a loader to +# plato/models/huggingface.py (_TIMESERIES_LOADERS). +TIMESERIES_MODEL_TYPES: frozenset[str] = frozenset({"timesfm", "patchtsmixer"}) + def is_timeseries_model( model_name: Optional[str] = None, @@ -21,19 +26,19 @@ def is_timeseries_model( Returns: True if this is a time series model, False otherwise """ - model_name_lower = model_name.lower() if model_name else "" - model_type_lower = model_type.lower() if model_type else "" - - # Check for PatchTSMixer - if ( - model_type_lower == "patchtsmixer" - or "patchtsmixer" in model_name_lower - or "timeseries" in model_name_lower - ): + model_name_lower = (model_name or "").lower() + model_type_lower = (model_type or "").lower() + + # Check explicit model type + if model_type_lower in TIMESERIES_MODEL_TYPES: + return True + + # Check if any known time series type appears in the model name + if any(ts_type in model_name_lower for ts_type in TIMESERIES_MODEL_TYPES): return True - # Check for TimesFM - if model_type_lower == "timesfm" or "timesfm" in model_name_lower: + # Generic "timeseries" keyword in name + if "timeseries" in model_name_lower: return True # Check dataset type From ed2e0260defaa34007486dd65fffbc2c61fb2831 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 23 Mar 2026 18:21:30 -0400 Subject: [PATCH 08/39] Updated config files for time-series models. - PatchTSMixer, TimesFM and TimesFM2.5. --- .../TimeSeries/patchtsmixer_ev_charging.toml | 93 +++++++++++++++++++ configs/TimeSeries/timesfm25_ev_charging.toml | 89 ++++++++++++++++++ configs/TimeSeries/timesfm_ev_charging.toml | 3 +- 3 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 configs/TimeSeries/patchtsmixer_ev_charging.toml create mode 100644 configs/TimeSeries/timesfm25_ev_charging.toml diff --git a/configs/TimeSeries/patchtsmixer_ev_charging.toml b/configs/TimeSeries/patchtsmixer_ev_charging.toml new file mode 100644 index 000000000..fc3e5be7b --- /dev/null +++ b/configs/TimeSeries/patchtsmixer_ev_charging.toml @@ -0,0 +1,93 @@ +# Federated Learning with PatchTSMixer for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: PatchTSMixer (trained from scratch) +# - uses all 6 input features jointly via mix_channel mode +# - predicts only the is_charging channel +# +# Usage: +# uv run plato.py -c configs/TimeSeries/patchtsmixer_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/patchtsmixer_ev" +model_path = "models/timeseries/patchtsmixer_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "patchtsmixer_scratch" +model_type = "patchtsmixer" +model_task = "forecasting" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Predict and evaluate only the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +patch_stride = 8 +d_model = 64 +num_layers = 4 +expansion_factor = 2 +dropout = 0.1 +head_dropout = 0.1 + +# Mix all channels so the model can use time features jointly. +mode = "mix_channel" +gated_attn = true +scaling = "std" + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximize training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging.toml b/configs/TimeSeries/timesfm25_ev_charging.toml new file mode 100644 index 000000000..20ae2b57e --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging.toml @@ -0,0 +1,89 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm_ev_charging.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm_ev" +model_path = "models/timeseries/timesfm_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "AdO1" # garage id + +# Explicit user IDs to include — one client per user. +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 # 4 × 7 × 24 +prediction_length = 168 # 7 × 24 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximizes training windows + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm_ev_charging.toml b/configs/TimeSeries/timesfm_ev_charging.toml index 426e0d1fa..7ced5a721 100644 --- a/configs/TimeSeries/timesfm_ev_charging.toml +++ b/configs/TimeSeries/timesfm_ev_charging.toml @@ -44,7 +44,8 @@ random_seed = 42 type = "HuggingFace" rounds = 100 max_concurrency = 4 -model_name = "timesfm" +model_name = "google/timesfm-2.0-500m-pytorch" +model_type = "timesfm" context_length = 672 # 4 × 7 × 24 prediction_length = 168 # 7 × 24 From 2ed6e0f2812eccef5b6e9bc8e5b347ac5841bf15 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 29 Mar 2026 19:57:11 -0400 Subject: [PATCH 09/39] Added two new config files for testing data influence, may clean up later. --- .../timesfm25_ev_charging_bl2_5_only.toml | 89 ++++++++++++++++++ .../timesfm25_ev_charging_top4_mixed.toml | 90 +++++++++++++++++++ plato/datasources/ev_charging.py | 46 +++++++--- 3 files changed, 211 insertions(+), 14 deletions(-) create mode 100644 configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml create mode 100644 configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml diff --git a/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml new file mode 100644 index 000000000..1a76cb09b --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml @@ -0,0 +1,89 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – single-user run for Bl2-5 +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 1 client, using only user Bl2-5. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_bl2_5_only" +model_path = "models/timeseries/timesfm25_ev_bl2_5_only" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "Bl2" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 +prediction_length = 168 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml new file mode 100644 index 000000000..33b72ac62 --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml @@ -0,0 +1,90 @@ +# Federated Learning with TimesFM2.5 for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – mixed high-data users across garages +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_top4_mixed" +model_path = "models/timeseries/timesfm25_ev_top4_mixed" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +# Use explicit users across the whole dataset, not just a single garage. +garage = "all" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5", "AsO2-1", "Bl2-1", "AdO1-3"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 4 +model_name = "google/timesfm-2.5-200m-pytorch" +model_type = "timesfm" + +context_length = 672 +prediction_length = 168 + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +epochs = 10 +batch_size = 16 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py index d5566c872..376f92519 100644 --- a/plato/datasources/ev_charging.py +++ b/plato/datasources/ev_charging.py @@ -14,7 +14,8 @@ Preprocessing pipeline ----------------------- -1. Filter to the requested garage (default "AdO1", which has 4 private users). +1. Filter to the requested garage (default "AdO1", which has 4 private users), + or use all garages when ``garage = "all"``. 2. For each user, build a continuous hourly grid from the first to the last session hour in the dataset. 3. For every hour, mark is_charging = 1 if the user had an active session, @@ -37,7 +38,7 @@ [data] datasource = "EVCharging" datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" -garage = "AdO1" # optional +garage = "AdO1" # optional; use "all" for cross-garage user lists num_users = 4 # optional [trainer] @@ -85,7 +86,7 @@ def _parse_european_float(series: pd.Series) -> pd.Series: def _build_hourly_series( df: pd.DataFrame, - garage: str, + garage: str | None, num_users: int, user_ids: list[str] | None = None, ) -> dict[str, pd.DataFrame]: @@ -101,14 +102,22 @@ def _build_hourly_series( dict mapping user_id (str) -> pd.DataFrame with hourly index and columns: is_charging (0/1 float), energy_kwh (float >= 0) """ - # Filter to requested garage - mask = df[_GARAGE_COL].astype(str).str.strip() == garage - df = df[mask].copy() - if df.empty: - raise ValueError( - f"No records found for garage '{garage}'. " - f"Available: {sorted(df[_GARAGE_COL].unique())}" + garage_name = None if garage is None else str(garage).strip() + use_all_garages = not garage_name or garage_name.lower() in {"all", "*", "any"} + + if use_all_garages: + df = df.copy() + else: + available_garages = sorted( + df[_GARAGE_COL].astype(str).str.strip().dropna().unique() ) + mask = df[_GARAGE_COL].astype(str).str.strip() == garage_name + df = df[mask].copy() + if df.empty: + raise ValueError( + f"No records found for garage '{garage_name}'. " + f"Available: {available_garages}" + ) # Parse datetimes df[_START_COL] = pd.to_datetime(df[_START_COL].str.strip(), format=_DT_FORMAT) @@ -125,14 +134,16 @@ def _build_hourly_series( # Explicit list from config — validate each entry missing = [u for u in user_ids if u not in available] if missing: + scope = "all garages" if use_all_garages else f"garage '{garage_name}'" raise ValueError( - f"Users not found in garage '{garage}': {missing}. " + f"Users not found in {scope}: {missing}. " f"Available: {available}" ) users = list(user_ids) # preserve config order else: users = available[:num_users] - logging.info("EVCharging: garage '%s' -> users %s", garage, users) + scope = "all garages" if use_all_garages else f"garage '{garage_name}'" + logging.info("EVCharging: %s -> users %s", scope, users) result: dict[str, pd.DataFrame] = {} for user in users: @@ -294,7 +305,9 @@ def __init__(self, client_id: int = 0, **kwargs): "Download from https://data.mendeley.com/datasets/jbks2rcwyj/1" ) - garage = str(kwargs.get("garage", getattr(data_cfg, "garage", "AdO1"))) + garage_cfg = kwargs.get("garage", getattr(data_cfg, "garage", "AdO1")) + garage = None if garage_cfg is None else str(garage_cfg) + garage_name = None if garage is None else garage.strip() # Config: users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] user_ids_cfg = kwargs.get("users", getattr(data_cfg, "users", None)) @@ -329,9 +342,14 @@ def __init__(self, client_id: int = 0, **kwargs): user_index = max(0, client_id - 1) if user_index >= len(users): + scope = ( + "all garages" + if not garage_name or garage_name.lower() in {"all", "*", "any"} + else f"garage '{garage_name}'" + ) raise ValueError( f"client_id={client_id} out of range; " - f"found {len(users)} users in garage '{garage}': {users}" + f"found {len(users)} users in {scope}: {users}" ) user_key = users[user_index] From c757fd4c9857bee2112eb26a571fb00756ebf774 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 29 Mar 2026 21:06:38 -0400 Subject: [PATCH 10/39] Updated prediction length for pretrained model. --- configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml index 1a76cb09b..56895618e 100644 --- a/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml +++ b/configs/TimeSeries/timesfm25_ev_charging_bl2_5_only.toml @@ -1,7 +1,7 @@ # Federated Learning with TimesFM2.5 for EV Charging Prediction # # Task: Given the past 28 days (672 h) of a user's EV charging behaviour, -# predict whether they will be charging in each of the next 168 hours. +# predict whether they will be charging in each of the next 128 hours. # # Dataset: "EV Charging Reports" – single-user run for Bl2-5 # https://data.mendeley.com/datasets/jbks2rcwyj/1 @@ -48,7 +48,7 @@ model_name = "google/timesfm-2.5-200m-pytorch" model_type = "timesfm" context_length = 672 -prediction_length = 168 +prediction_length = 128 # Number of input channels: is_charging, energy_scaled, # hour_sin, hour_cos, dow_sin, dow_cos From c2a3c13c33aa7bed07f0bcf3d9a519452f42d96d Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 30 Mar 2026 15:43:24 -0400 Subject: [PATCH 11/39] Added TimesFM transformer model from HuggingFace. --- .../TimeSeries/timesfm_transformers_bl1.toml | 77 +++++++++++++ plato/models/huggingface.py | 105 ++++++++++++------ pyproject.toml | 4 +- 3 files changed, 150 insertions(+), 36 deletions(-) create mode 100644 configs/TimeSeries/timesfm_transformers_bl1.toml diff --git a/configs/TimeSeries/timesfm_transformers_bl1.toml b/configs/TimeSeries/timesfm_transformers_bl1.toml new file mode 100644 index 000000000..85d360dab --- /dev/null +++ b/configs/TimeSeries/timesfm_transformers_bl1.toml @@ -0,0 +1,77 @@ +# Federated Learning with TimesFM 2.5 (HuggingFace transformers) for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 168 hours. +# +# Dataset: "EV Charging Reports" – AdO1 garage, 4 users +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Model: google/timesfm-2.5-200m-transformers +# Uses Timesfm2P5ModelForPrediction from the transformers library. +# Channel-independent: each of the 6 input features is processed as a +# separate univariate series; only is_charging is evaluated. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_transformers_ev_charging.toml + +[clients] +type = "simple" +total_clients = 1 +per_round = 1 +do_test = true + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25t_ev" +model_path = "models/timeseries/timesfm25t_ev" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +garage = "Bl2" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-1"] +sampler = "all_inclusive" + +[trainer] +type = "HuggingFace" +rounds = 100 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" + +context_length = 672 # 4 × 7 × 24 (model supports up to 16384) +prediction_length = 128 # model horizon_length is fixed at 128 steps + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +# Sliding-window stride for dataset creation +stride = 1 # advance 1 hour at a time to maximise training windows + +epochs = 5 +batch_size = 32 +optimizer = "Adam" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 4efa032e7..514260314 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -36,6 +36,11 @@ TimesFmConfig = None TimesFmModelForPrediction = None +try: + from transformers import TimesFm2_5ModelForPrediction +except ImportError: + TimesFm2_5ModelForPrediction = None + try: from peft import LoraConfig, get_peft_model except ImportError: # pragma: no cover - handled at runtime with friendly message. @@ -99,11 +104,13 @@ def __init__( model: "TimesFmModelForPrediction", prediction_length: int | None = None, default_freq: int = 0, + use_transformers_api: bool = False, ): super().__init__() self.model = model self.prediction_length = prediction_length self.default_freq = default_freq + self.use_transformers_api = use_transformers_api def forward( self, @@ -122,9 +129,12 @@ def forward( # (batch, ctx, ch) -> (batch*ch, ctx) pv_2d = past_values.permute(0, 2, 1).reshape(batch * channels, ctx) past_list = [pv_2d[i] for i in range(pv_2d.size(0))] - freq_list = self._build_freq_list(freq, batch, channels) - outputs = self.model(past_values=past_list, freq=freq_list) + if self.use_transformers_api: + outputs = self.model(past_values=past_list, forecast_context_len=ctx) + else: + freq_list = self._build_freq_list(freq, batch, channels) + outputs = self.model(past_values=past_list, freq=freq_list) # (batch*ch, horizon) -> (batch, horizon, ch) raw = outputs.mean_predictions @@ -134,10 +144,15 @@ def forward( else: # Univariate path batch = past_values.size(0) + ctx = past_values.size(1) past_list = [past_values[i] for i in range(batch)] - freq_list = self._build_freq_list(freq, batch, channels=1) - outputs = self.model(past_values=past_list, freq=freq_list) + if self.use_transformers_api: + outputs = self.model(past_values=past_list, forecast_context_len=ctx) + else: + freq_list = self._build_freq_list(freq, batch, channels=1) + outputs = self.model(past_values=past_list, freq=freq_list) + mean_preds = outputs.mean_predictions.unsqueeze(-1) # (batch, horizon, 1) # Truncate to configured prediction_length @@ -188,51 +203,73 @@ def _build_freq_list( def _load_timesfm(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: - """Load or create a TimesFM model wrapped for batched multivariate use.""" - if TimesFmModelForPrediction is None: - raise ImportError( - "TimesFM models are not available. " - "Ensure you have transformers>=5.0.0 installed." - ) + """Load or create a TimesFM model wrapped for batched multivariate use. + + Supports two HuggingFace variants: + - ``*-transformers``: Uses ``TimesFm2_5ModelForPrediction`` from the + ``transformers`` library. Forward call uses ``forecast_context_len``. + - ``*-pytorch`` (default): Uses ``TimesFmModelForPrediction``. + Forward call uses ``freq``. + """ + use_transformers_api = "transformers" in resolved_model_name.lower() trainer_config = Config().trainer prediction_length = getattr(trainer_config, "prediction_length", 128) default_freq = getattr(trainer_config, "freq", 0) - try: + if use_transformers_api: + if TimesFm2_5ModelForPrediction is None: + raise ImportError( + "TimesFm2_5ModelForPrediction is not available. " + "Ensure you have a recent transformers version installed." + ) logging.info( - "Attempting to load pretrained TimesFM model: %s", + "Attempting to load pretrained TimesFM 2.5 (transformers) model: %s", resolved_model_name, ) - inner = TimesFmModelForPrediction.from_pretrained( + inner = TimesFm2_5ModelForPrediction.from_pretrained( resolved_model_name, cache_dir=cache_dir ) - logging.info("Successfully loaded pretrained TimesFM model") - except (OSError, ValueError, Exception): - logging.info( - "TimesFM model '%s' not found as pretrained, creating from config", - resolved_model_name, - ) - context_length = getattr(trainer_config, "context_length", 512) - horizon_length = prediction_length - - config = TimesFmConfig( - context_length=context_length, - horizon_length=horizon_length, - patch_length=getattr(trainer_config, "patch_length", 32), - num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), - hidden_size=getattr(trainer_config, "hidden_size", 1280), - intermediate_size=getattr(trainer_config, "intermediate_size", 1280), - num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), - head_dim=getattr(trainer_config, "head_dim", 80), - attention_dropout=getattr(trainer_config, "dropout", 0.0), - ) - inner = TimesFmModelForPrediction(config) + logging.info("Successfully loaded pretrained TimesFM 2.5 (transformers) model") + else: + if TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) + try: + logging.info( + "Attempting to load pretrained TimesFM model: %s", + resolved_model_name, + ) + inner = TimesFmModelForPrediction.from_pretrained( + resolved_model_name, cache_dir=cache_dir + ) + logging.info("Successfully loaded pretrained TimesFM model") + except (OSError, ValueError, Exception): + logging.info( + "TimesFM model '%s' not found as pretrained, creating from config", + resolved_model_name, + ) + context_length = getattr(trainer_config, "context_length", 512) + config = TimesFmConfig( + context_length=context_length, + horizon_length=prediction_length, + patch_length=getattr(trainer_config, "patch_length", 32), + num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), + hidden_size=getattr(trainer_config, "hidden_size", 1280), + intermediate_size=getattr(trainer_config, "intermediate_size", 1280), + num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), + head_dim=getattr(trainer_config, "head_dim", 80), + attention_dropout=getattr(trainer_config, "dropout", 0.0), + ) + inner = TimesFmModelForPrediction(config) return TimesFmMultivariateWrapper( model=inner, prediction_length=prediction_length, default_freq=default_freq, + use_transformers_api=use_transformers_api, ) diff --git a/pyproject.toml b/pyproject.toml index 2ca72df43..463ab6d53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "torch", "torch-optimizer", "torchvision", - "transformers", + "transformers>=5.4.0", "zstd", ] @@ -60,7 +60,7 @@ nanochat = [ "PyYAML", ] robotics = [ - "lerobot[smolvla]>=0.4.3,<0.5.0", + "lerobot[smolvla]>=0.5.0", ] [project.urls] From d45f4b23837d2733fc11d4f551190bdf11464cf1 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:37:38 -0400 Subject: [PATCH 12/39] Document the DiLoCo implementation contract. Adds the faithful DiLoCo design contract for Plato, including the server-side outer optimizer sign convention, exact local-step H semantics, small-H sampling requirements, client-local optimizer and scheduler state ownership, and the implementation dependency graph.\n\nCovers Linear issue DT-408. --- docs/docs/development/diloco.md | 220 ++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 4 +- 2 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 docs/docs/development/diloco.md diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md new file mode 100644 index 000000000..94dfd49fc --- /dev/null +++ b/docs/docs/development/diloco.md @@ -0,0 +1,220 @@ +# DiLoCo Design Contract + +This note defines what Plato will call faithful DiLoCo for the initial +implementation. It is a contract for the implementation issues that follow; it +does not describe runtime behavior that already exists in Plato. + +Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo +training loop inside Plato's federated runtime. It does not mean reproducing +the paper's exact C4 dataset, model scale, tokenizer, hardware topology, +pretraining duration, or final benchmark numbers. + +## Algorithm Contract + +DiLoCo has two optimizer levels: + +- The client-local inner optimizer trains each selected logical client for + exactly `H` local optimizer steps between synchronizations. +- The server-side outer optimizer updates the global model from the averaged + outer gradient. + +Plato's FedAvg-style model delta is: + +```text +plato_delta = client_after - global_before +``` + +DiLoCo's outer gradient is: + +```text +outer_gradient = global_before - client_after = -plato_delta +``` + +The DiLoCo server must still return a Plato-compatible model delta because +`algorithm.update_weights()` adds the returned delta to the current global +model. For example, outer SGD with learning rate `1.0` returns the averaged +Plato delta and is equivalent to FedAvg only when the same averaging rule is +used. + +The outer optimizer runs on the server. Clients run only the inner optimizer +and send model weights or weight-equivalent updates. Client-local optimizer and +scheduler state persists per logical client and is never sent to the server. + +## Local Work `H` + +`H` means client-local optimizer steps between synchronizations. It is not: + +- epochs, +- raw dataloader batches, or +- gradient-accumulation micro-batches. + +When gradient accumulation is enabled, `H` counts completed optimizer steps. +Raw batches that do not trigger `optimizer.step()` do not increment `H`. + +`H` may be smaller than one epoch. Faithful DiLoCo must therefore stop local +training mid-epoch after exactly `H` optimizer steps. This early stop must +still run normal trainer cleanup, state persistence, callback completion, and +reporting paths. It must not perform an extra final optimizer step. + +Small-`H` training must not repeatedly replay the same first `H` batches only +because the train loader is recreated each round. The implementation must use +round-aware resampling or an equivalent persistent sampling stream so each +logical client's local data stream advances across rounds in a reproducible +way. + +## State Ownership + +Server-owned state: + +- the global model, +- outer optimizer momentum or other outer optimizer state, +- aggregation metadata needed to update the global model. + +Client-owned state: + +- inner optimizer state, such as AdamW first and second moments, +- scheduler state and global/local optimizer-step counters, +- sampler or dataloader stream position needed for small-`H` continuity. + +Client-owned optimizer and scheduler state must not appear in client-server +payloads. It must remain local to the logical client, including when training +uses subprocesses. + +## Parameter And Buffer Policy + +By default, the outer optimizer applies only to trainable floating parameters. +This matches the algorithm definition, which optimizes model parameters. + +Floating buffers, such as batch normalization running statistics, are +synchronized without outer momentum by default. They use the selected averaging +rule but do not receive server-side momentum or Nesterov treatment. + +Non-floating buffers use conservative FedAvg-style behavior, including casting +or rounding as needed to preserve the buffer's dtype-compatible semantics. + +The implementation may offer `apply_outer_optimizer_to = "all_floating"` for +experiments, but the default must remain `parameters`. + +## Configuration Contract + +The faithful initial mode uses these configuration names and defaults: + +```toml +[server] +type = "diloco" + +[algorithm] +type = "fedavg" + +[trainer] +local_steps_per_round = H +preserve_optimizer_state = true +optimizer = "AdamW" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" # or "num_samples" +apply_outer_optimizer_to = "parameters" # or "all_floating" +``` + +`algorithm.type = "fedavg"` is intentional. Plato should reuse the existing +FedAvg weight extraction, delta computation, and global model loading path, +while `server.type = "diloco"` selects the server-side DiLoCo aggregation and +outer optimizer behavior. + +`aggregation_weighting = "uniform"` matches the balanced worker setting most +closely. `aggregation_weighting = "num_samples"` matches Plato's traditional +sample-weighted FedAvg behavior. FedAvg equivalence for outer SGD with learning +rate `1.0` is valid only when both runs use the same weighting rule. + +Unsupported modes must fail clearly. They must not silently fall back to an +approximate DiLoCo variant. Examples include trainer backends that cannot count +local optimizer steps exactly, execution paths that cannot preserve +client-local optimizer and scheduler state, or payload paths that would send +optimizer state to the server. + +## Implementation Sequence + +Dependency graph: + +```text +D1 +|-- D2 --> D3 +|-- D4 --> D5 +|-- D6 --> D7 +|-- D8 --> D9 +`-- D10 --> D11 + +D3, D5, D7, D9, D11 --> D12 --> D13 +``` + +Tasks: + +```yaml +- id: D1 + depends_on: [] + task: Document the exact DiLoCo contract and unsupported modes. + +- id: D2 + depends_on: [D1] + task: Add red tests for server-side outer gradient sign, weighting, and + FedAvg equivalence under matching weighting. + +- id: D3 + depends_on: [D2] + task: Implement DiLoCo server aggregation and outer optimizer state for SGD, + momentum SGD, and Nesterov. + +- id: D4 + depends_on: [D1] + task: Add red tests for exact local optimizer-step counting and `H` smaller + than one epoch. + +- id: D5 + depends_on: [D4] + task: Implement `trainer.local_steps_per_round` with mid-epoch termination + after exactly `H` optimizer steps. + +- id: D6 + depends_on: [D1] + task: Add red tests for per-client optimizer and scheduler state + persistence. + +- id: D7 + depends_on: [D6] + task: Persist client-local optimizer and scheduler state without sending it + to the server. + +- id: D8 + depends_on: [D1] + task: Add red tests for round-aware small-`H` sampling. + +- id: D9 + depends_on: [D8] + task: Implement round-aware resampling or an equivalent persistent sampling + stream for each logical client. + +- id: D10 + depends_on: [D1] + task: Add red tests for parameter and buffer eligibility. + +- id: D11 + depends_on: [D10] + task: Implement the default trainable-parameter-only outer optimizer policy + and conservative buffer synchronization. + +- id: D12 + depends_on: [D3, D5, D7, D9, D11] + task: Wire exact DiLoCo configuration, examples, and user-facing + documentation. + +- id: D13 + depends_on: [D12] + task: Add end-to-end faithful-mode validation coverage. +``` + +Every implementation task should use red/green test-driven development. Add +the failing tests that describe the contract first, then implement the smallest +runtime change that makes those tests pass. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index fa5bc9908..f4177c8b4 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -88,7 +88,9 @@ nav: - Servers: references/servers.md - Trainers: references/trainers.md - Evaluators: references/evaluators.md - - Developer's Guide: development.md + - Developer's Guide: + - Overview: development.md + - DiLoCo Design Contract: development/diloco.md - Deployment Guide: deployment.md - Digital Research Alliance of Canada: ccdb.md - Miscellaneous Notes: misc.md From 0378a578fe59f8ff3805b90abd69ad6631be2b30 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:46:34 -0400 Subject: [PATCH 13/39] Implemented DiLoCo outer aggregation. Adds the DiLoCo aggregation strategy with server-side SGD, momentum SGD, and Nesterov outer optimizer behavior over Plato-style client deltas. Covers uniform and sample-weighted aggregation, validates configuration values, and adds focused tests for sign handling, FedAvg equivalence under matching weighting, momentum state persistence, reset, and stale-key cleanup.\n\nValidation reported by worker:\n- uv run pytest tests/servers/test_diloco_strategy.py\n- uv run pytest tests/servers/test_fedavg_strategy.py\n- uv run ruff check . --select I\n\nCovers Linear issue DT-410. --- .../strategies/aggregation/__init__.py | 2 + .../servers/strategies/aggregation/diloco.py | 315 ++++++++++++++++ tests/servers/test_diloco_strategy.py | 336 ++++++++++++++++++ 3 files changed, 653 insertions(+) create mode 100644 plato/servers/strategies/aggregation/diloco.py create mode 100644 tests/servers/test_diloco_strategy.py diff --git a/plato/servers/strategies/aggregation/__init__.py b/plato/servers/strategies/aggregation/__init__.py index f44c420a4..fe4174402 100644 --- a/plato/servers/strategies/aggregation/__init__.py +++ b/plato/servers/strategies/aggregation/__init__.py @@ -4,6 +4,7 @@ Each strategy is defined in its own module for clarity. """ +from plato.servers.strategies.aggregation.diloco import DiLoCoAggregationStrategy from plato.servers.strategies.aggregation.fedasync import FedAsyncAggregationStrategy from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy from plato.servers.strategies.aggregation.fedbuff import FedBuffAggregationStrategy @@ -16,6 +17,7 @@ __all__ = [ "FedAvgAggregationStrategy", + "DiLoCoAggregationStrategy", "FedBuffAggregationStrategy", "FedNovaAggregationStrategy", "FedAsyncAggregationStrategy", diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py new file mode 100644 index 000000000..2b2df57e3 --- /dev/null +++ b/plato/servers/strategies/aggregation/diloco.py @@ -0,0 +1,315 @@ +""" +DiLoCo aggregation strategy. + +The strategy consumes Plato-style client deltas (`client_after - global_before`), +converts them to DiLoCo outer gradients, and returns Plato-compatible server +deltas for `algorithm.update_weights()` to add to the global model. +""" + +from __future__ import annotations + +import asyncio +import copy +import numbers +from collections.abc import Callable, Mapping +from types import SimpleNamespace +from typing import Any, cast + +import numpy as np + +from plato.servers.strategies.aggregation.fedavg import FedAvgAggregationStrategy +from plato.servers.strategies.base import ServerContext + +try: # pragma: no cover - optional dependency + import torch +except ImportError: # pragma: no cover + torch = cast(Any, None) + + +class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): + """Aggregate client deltas with a server-side DiLoCo outer optimizer.""" + + _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} + _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + + def __init__( + self, + outer_optimizer: str = "nesterov", + outer_learning_rate: float = 0.7, + outer_momentum: float = 0.9, + aggregation_weighting: str = "uniform", + ): + super().__init__() + self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) + self.outer_learning_rate = self._validate_learning_rate( + outer_learning_rate + ) + self.outer_momentum = self._validate_momentum(outer_momentum) + self.aggregation_weighting = self._validate_weighting_mode( + aggregation_weighting + ) + self.momentum_state: dict[str, Any] = {} + + async def aggregate_deltas( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + context: ServerContext, + ) -> dict: + """Aggregate deltas and apply the configured DiLoCo outer optimizer.""" + eligible = self._eligible_updates(updates, deltas_received) + if not eligible: + self._remove_stale_momentum(set()) + return self._empty_delta(context, self._first_delta(deltas_received)) + + weights = self._aggregation_weights(eligible) + if not weights: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta: Any = None + for (_, delta, _), weight in zip(eligible, weights): + avg_delta = self._accumulate_weighted(avg_delta, delta, weight, context) + await asyncio.sleep(0) + + if avg_delta is None: + self._remove_stale_momentum(set()) + return self._empty_delta(context, eligible[0][1]) + + avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) + outer_gradient = self._scale_tree(avg_delta, -1.0) + server_delta, active_paths = self._apply_outer_optimizer(outer_gradient) + self._remove_stale_momentum(active_paths) + + return self._match_reference_structure(server_delta, eligible[0][1]) + + @classmethod + def _validate_outer_optimizer(cls, value: str) -> str: + optimizer = str(value).lower() + if optimizer not in cls._SUPPORTED_OPTIMIZERS: + supported = ", ".join(sorted(cls._SUPPORTED_OPTIMIZERS)) + raise ValueError( + f"Invalid outer_optimizer '{value}'. Supported values: {supported}." + ) + return optimizer + + @staticmethod + def _validate_learning_rate(value: float) -> float: + learning_rate = float(value) + if learning_rate < 0: + raise ValueError("outer_learning_rate must be nonnegative.") + return learning_rate + + @staticmethod + def _validate_momentum(value: float) -> float: + momentum = float(value) + if not 0 <= momentum < 1: + raise ValueError("outer_momentum must be in the range [0, 1).") + return momentum + + @classmethod + def _validate_weighting_mode(cls, value: str) -> str: + weighting = str(value).lower() + if weighting not in cls._SUPPORTED_WEIGHTING_MODES: + supported = ", ".join(sorted(cls._SUPPORTED_WEIGHTING_MODES)) + raise ValueError( + "Invalid aggregation_weighting " + f"'{value}'. Supported values: {supported}." + ) + return weighting + + def _eligible_updates( + self, + updates: list[SimpleNamespace], + deltas_received: list[dict], + ) -> list[tuple[SimpleNamespace, dict, float]]: + eligible: list[tuple[SimpleNamespace, dict, float]] = [] + for update, delta in zip(updates, deltas_received): + if getattr(update.report, "type", "weights") == "features": + continue + + num_samples = self._num_samples(update) + if num_samples <= 0: + continue + + eligible.append((update, delta, num_samples)) + + return eligible + + @staticmethod + def _num_samples(update: SimpleNamespace) -> float: + try: + return float(update.report.num_samples) + except (AttributeError, TypeError, ValueError): + return 0.0 + + def _aggregation_weights( + self, eligible: list[tuple[SimpleNamespace, dict, float]] + ) -> list[float]: + if not eligible: + return [] + + if self.aggregation_weighting == "uniform": + return [1.0 / len(eligible)] * len(eligible) + + total_samples = sum(num_samples for _, _, num_samples in eligible) + if total_samples <= 0: + return [] + + return [num_samples / total_samples for _, _, num_samples in eligible] + + def _apply_outer_optimizer(self, outer_gradient: Any) -> tuple[Any, set[str]]: + active_paths: set[str] = set() + + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate), set() + + server_delta = self._map_tree( + outer_gradient, + lambda value, path: self._apply_momentum_leaf( + value, path, active_paths + ), + ) + return server_delta, active_paths + + def _apply_momentum_leaf( + self, outer_gradient: Any, path: str, active_paths: set[str] + ) -> Any: + active_paths.add(path) + previous = self.momentum_state.get(path) + if previous is not None and not self._is_compatible(previous, outer_gradient): + previous = None + + if previous is None: + momentum = self._clone_tree(outer_gradient) + else: + momentum = self._add_values( + self._scale_tree(previous, self.outer_momentum), + outer_gradient, + ) + + self.momentum_state[path] = self._clone_tree(momentum) + + if self.outer_optimizer == "nesterov": + direction = self._add_values( + outer_gradient, + self._scale_tree(momentum, self.outer_momentum), + ) + else: + direction = momentum + + return self._scale_tree(direction, -self.outer_learning_rate) + + def _remove_stale_momentum(self, active_paths: set[str]) -> None: + if self.outer_optimizer == "sgd": + self.momentum_state.clear() + return + + for path in list(self.momentum_state): + if path not in active_paths: + del self.momentum_state[path] + + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: + zero_delta = self._zero_delta(context, reference_delta) + if zero_delta is not None: + return zero_delta + + if reference_delta is None: + return {} + + return self._scale_tree(reference_delta, 0.0) + + @staticmethod + def _first_delta(deltas_received: list[dict]) -> dict | None: + return deltas_received[0] if deltas_received else None + + def _map_tree(self, value: Any, leaf_fn: Callable[[Any, str], Any], path="") -> Any: + if isinstance(value, Mapping): + return { + key: self._map_tree(item, leaf_fn, self._join_path(path, key)) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ] + + if isinstance(value, tuple): + return tuple( + self._map_tree(item, leaf_fn, self._join_path(path, index)) + for index, item in enumerate(value) + ) + + return leaf_fn(value, path) + + def _scale_tree(self, value: Any, scalar: float) -> Any: + if isinstance(value, Mapping): + return { + key: self._scale_tree(item, scalar) for key, item in value.items() + } + + if isinstance(value, list): + return [self._scale_tree(item, scalar) for item in value] + + if isinstance(value, tuple): + return tuple(self._scale_tree(item, scalar) for item in value) + + return value * scalar + + @staticmethod + def _add_values(left: Any, right: Any) -> Any: + return left + right + + def _clone_tree(self, value: Any) -> Any: + if isinstance(value, Mapping): + return {key: self._clone_tree(item) for key, item in value.items()} + + if isinstance(value, list): + return [self._clone_tree(item) for item in value] + + if isinstance(value, tuple): + return tuple(self._clone_tree(item) for item in value) + + if torch is not None and isinstance(value, torch.Tensor): + return value.detach().clone() + + if isinstance(value, np.ndarray): + return value.copy() + + try: + return copy.deepcopy(value) + except TypeError: + return value + + @staticmethod + def _is_compatible(left: Any, right: Any) -> bool: + if torch is not None and isinstance(left, torch.Tensor): + return ( + isinstance(right, torch.Tensor) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + if isinstance(left, np.ndarray): + return ( + isinstance(right, np.ndarray) + and left.shape == right.shape + and left.dtype == right.dtype + ) + + left_shape = getattr(left, "shape", None) + right_shape = getattr(right, "shape", None) + if left_shape is not None or right_shape is not None: + return ( + left_shape == right_shape + and getattr(left, "dtype", None) == getattr(right, "dtype", None) + ) + + return isinstance(left, numbers.Number) and isinstance(right, numbers.Number) + + @staticmethod + def _join_path(prefix: str, key: Any) -> str: + key_text = str(key) + return key_text if not prefix else f"{prefix}.{key_text}" diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py new file mode 100644 index 000000000..336e5e1aa --- /dev/null +++ b/tests/servers/test_diloco_strategy.py @@ -0,0 +1,336 @@ +"""Tests for DiLoCo server-side outer aggregation.""" + +import asyncio +from types import SimpleNamespace + +import pytest +import torch + +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy +from plato.servers.strategies.base import ServerContext + + +class DummyAlgorithm: + """Minimal algorithm stub for zero-delta construction.""" + + def __init__(self, baseline): + self.baseline = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.baseline.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + return [ + { + name: weights[name] - baseline_weights[name] + for name in baseline_weights.keys() + } + for weights in weights_list + ] + + +def _context(baseline=None): + context = ServerContext() + if baseline is not None: + context.algorithm = DummyAlgorithm(baseline) + return context + + +def _update(num_samples, report_type="weights"): + return SimpleNamespace( + report=SimpleNamespace(num_samples=num_samples, type=report_type) + ) + + +def _aggregate(strategy, updates, deltas, baseline=None): + return asyncio.run( + strategy.aggregate_deltas(updates, deltas, _context(baseline)) + ) + + +def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): + """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(99)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([15.0])) + + +def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): + """Outer SGD with lr=1 should match sample-weighted FedAvg.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(1), _update(3)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([6.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([16.5])) + + +def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): + """A lower outer SGD lr should partially move toward the averaged model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.5, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([10.0])} + updates = [_update(5), _update(5)] + deltas = [{"w": torch.tensor([2.0])}, {"w": torch.tensor([8.0])}] + + server_delta = _aggregate(strategy, updates, deltas, baseline) + + assert torch.allclose(server_delta["w"], torch.tensor([2.5])) + assert torch.allclose(baseline["w"] + server_delta["w"], torch.tensor([12.5])) + + +def test_sgd_uses_diloco_outer_gradient_sign(temp_config): + """The strategy should negate Plato deltas before applying outer SGD.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=0.25, + aggregation_weighting="uniform", + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([1.0])) + + +def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): + """Uniform mode should weight eligible clients equally.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + server_delta = _aggregate( + strategy, + [_update(1), _update(1000)], + [{"w": torch.tensor([0.0])}, {"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([5.0])) + + +def test_nonpositive_sample_reports_are_ineligible(temp_config): + """Reports with zero or negative sample counts should not affect averages.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="num_samples", + ) + + server_delta = _aggregate( + strategy, + [_update(0), _update(-5), _update(10)], + [ + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([100.0])}, + {"w": torch.tensor([4.0])}, + ], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0])) + + +def test_empty_eligible_updates_return_zero_delta(temp_config): + """An empty eligible set should produce a zero delta matching the baseline.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + baseline = {"w": torch.tensor([3.0, 4.0])} + server_delta = _aggregate( + strategy, + [_update(0), _update(5, report_type="features")], + [{"w": torch.tensor([10.0, 10.0])}, {"w": torch.tensor([10.0, 10.0])}], + baseline, + ) + + assert torch.allclose(server_delta["w"], torch.zeros_like(baseline["w"])) + + +def test_empty_eligible_updates_remove_stale_momentum(temp_config): + """A round with no eligible keys should clear stale momentum buffers.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + server_delta = _aggregate( + strategy, + [_update(0)], + [{"w": torch.tensor([10.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([0.0])) + assert strategy.momentum_state == {} + + +def test_sgdm_persists_momentum_across_rounds(temp_config): + """Momentum SGD should reuse server-side outer momentum across rounds.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([2.0])) + assert torch.allclose(second_delta["w"], torch.tensor([5.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): + """Nesterov should use g + beta * m after updating the momentum buffer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + first_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0])}], + {"w": torch.tensor([0.0])}, + ) + second_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0])}], + {"w": torch.tensor([0.0])}, + ) + + assert torch.allclose(first_delta["w"], torch.tensor([3.0])) + assert torch.allclose(second_delta["w"], torch.tensor([6.5])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-5.0])) + + +def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( + temp_config, +): + """Momentum state should reset incompatible keys and prune missing keys.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0]), "b": torch.tensor([1.0])}], + {"w": torch.tensor([0.0]), "b": torch.tensor([0.0])}, + ) + + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0, 6.0])}], + {"w": torch.tensor([0.0, 0.0])}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0, 6.0])) + assert torch.allclose(strategy.momentum_state["w"], torch.tensor([-4.0, -6.0])) + assert "b" not in strategy.momentum_state + + +def test_momentum_state_resets_on_dtype_mismatch(temp_config): + """Momentum state should reset when the tensor dtype changes.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=1.0, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + + _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([2.0], dtype=torch.float32)}], + {"w": torch.tensor([0.0], dtype=torch.float32)}, + ) + server_delta = _aggregate( + strategy, + [_update(1)], + [{"w": torch.tensor([4.0], dtype=torch.float64)}], + {"w": torch.tensor([0.0], dtype=torch.float64)}, + ) + + assert torch.allclose(server_delta["w"], torch.tensor([4.0], dtype=torch.float64)) + assert strategy.momentum_state["w"].dtype == torch.float64 + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"outer_optimizer": "adam"}, "outer_optimizer"), + ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"outer_learning_rate": -0.1}, "outer_learning_rate"), + ({"outer_momentum": -0.1}, "outer_momentum"), + ({"outer_momentum": 1.0}, "outer_momentum"), + ], +) +def test_invalid_config_values_fail_clearly(temp_config, kwargs, match): + """Invalid DiLoCo aggregation configuration should raise clear errors.""" + with pytest.raises(ValueError, match=match): + DiLoCoAggregationStrategy(**kwargs) From 6287b0f94287dd48c18aec525af9aea0dfec3c4c Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:47:22 -0400 Subject: [PATCH 14/39] Added exact local step limits for trainers. Adds trainer.local_steps_per_round support to ComposableTrainer so local work can stop after an exact number of completed optimizer steps, including mid-epoch DiLoCo-style runs. The trainer counts optimizer steps rather than raw batches, avoids finalization after the limit is reached, preserves existing epoch behavior when unset, and adds focused tests for delayed optimizer stepping, cleanup, and invalid values.\n\nValidation reported by worker:\n- uv run pytest tests/trainers/test_composable_trainer.py -k local_steps\n- uv run pytest tests/trainers/test_composable_trainer.py\n- uv run ruff check . --select I\n\nThe broader ============================= test session starts ============================== platform darwin -- Python 3.13.12, pytest-8.4.2, pluggy-1.6.0 rootdir: /Users/bli/Playground/plato configfile: pyproject.toml plugins: anyio-4.13.0 collected 106 items / 1 error / 75 deselected / 31 selected ==================================== ERRORS ==================================== _______ ERROR collecting tests/trainers/test_dp_data_loader_strategy.py ________ ImportError while importing test module '/Users/bli/Playground/plato/tests/trainers/test_dp_data_loader_strategy.py'. Hint: make sure your test modules/packages have valid Python names. Traceback: ../../.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/importlib/__init__.py:88: in import_module return _bootstrap._gcd_import(name[level:], package, level) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ tests/trainers/test_dp_data_loader_strategy.py:4: in from plato.trainers.diff_privacy import DPDataLoaderStrategy plato/trainers/diff_privacy.py:15: in from opacus import GradSampleModule E ModuleNotFoundError: No module named 'opacus' =========================== short test summary info ============================ ERROR tests/trainers/test_dp_data_loader_strategy.py !!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!! ======================= 75 deselected, 1 error in 0.11s ======================== collection is blocked by missing optional dependency opacus in unrelated DP tests.\n\nCovers Linear issue DT-416. --- plato/trainers/composable.py | 47 ++++- tests/trainers/test_composable_trainer.py | 206 ++++++++++++++++++++++ 2 files changed, 251 insertions(+), 2 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 1e98d1128..f2bc2f0ca 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -177,6 +177,29 @@ def _require_model(self) -> nn.Module: ) return cast(nn.Module, self.model) + @staticmethod + def _local_steps_per_round(config: dict[str, Any]) -> int | None: + """Return the optional local optimizer-step limit for one train run.""" + value = config.get("local_steps_per_round") + if value is None: + return None + + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + raise ValueError( + "trainer.local_steps_per_round must be a positive integer." + ) + + return value + + def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> bool: + """Record one completed optimizer step and report whether H was reached.""" + if local_steps_per_round is None: + return False + + completed_steps = int(self.context.state.get("local_optimizer_steps", 0)) + 1 + self.context.state["local_optimizer_steps"] = completed_steps + return completed_steps >= local_steps_per_round + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -397,6 +420,12 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + local_steps_per_round = self._local_steps_per_round(config) + self.context.state["local_optimizer_steps"] = 0 + if local_steps_per_round is None: + self.context.state.pop("local_steps_per_round", None) + else: + self.context.state["local_steps_per_round"] = local_steps_per_round # Ensure training step strategy respects higher-order gradient settings if self.training_step_strategy is not None: @@ -494,6 +523,7 @@ def train_model(self, config, trainset, sampler, **kwargs): total_epochs = config["epochs"] tic = time.perf_counter() training_stop_requested = False + local_step_limit_reached = False try: total_batches = len(self.train_loader) except (TypeError, AttributeError): @@ -564,6 +594,9 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) # Strategy hook: after_step self.model_update_strategy.after_step(self.context) @@ -591,7 +624,7 @@ def compute_loss(outputs, labels_inner): ): self._handle_control_log() - if control_actions.get("stop_training"): + if control_actions.get("stop_training") or local_step_limit_reached: training_stop_requested = True break @@ -601,7 +634,11 @@ def compute_loss(outputs, labels_inner): finalize_loss = None finalize_step_done = False finalize_callable = getattr(self.training_step_strategy, "finalize", None) - if batches_seen and callable(finalize_callable): + if ( + batches_seen + and callable(finalize_callable) + and not local_step_limit_reached + ): finalize_loss = finalize_callable( model=model, optimizer=self.optimizer, @@ -613,6 +650,9 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + local_step_limit_reached = self._record_local_optimizer_step( + local_steps_per_round + ) self.model_update_strategy.after_step(self.context) self.callback_handler.call_event( "on_train_step_end", @@ -652,6 +692,9 @@ def compute_loss(outputs, labels_inner): # No batches remain, but respect control flag. pass + if local_step_limit_reached: + training_stop_requested = True + self.context.state.pop("is_last_batch", None) self.context.state.pop("hf_optimizer_step_index", None) diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ff35f76ec..ec83e20ea 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.utils.data import TensorDataset +from plato.callbacks.trainer import TrainerCallback from plato.config import Config from plato.evaluators.runner import EVALUATION_PRIMARY_KEY, EVALUATION_RESULTS_KEY from plato.trainers.composable import ComposableTrainer @@ -25,6 +26,7 @@ LossCriterionStrategy, ModelUpdateStrategy, TrainingContext, + TrainingStepStrategy, ) @@ -178,6 +180,210 @@ def test_multiple_epochs(self, simple_model, simple_dataset): assert len(trainer.run_history.get_metric_values("train_loss")) == 5 +class TestComposableTrainerLocalSteps: + """Test local optimizer-step limits for DiLoCo-style local work.""" + + class CountingCallback(TrainerCallback): + def __init__(self): + self.train_run_end_called = False + self.train_step_end_count = 0 + + def on_train_step_end(self, trainer, config, batch, loss, **kwargs): + self.train_step_end_count += 1 + + def on_train_run_end(self, trainer, config, **kwargs): + self.train_run_end_called = True + + class CountingUpdateStrategy(ModelUpdateStrategy): + def __init__(self): + self.after_step_count = 0 + self.on_train_end_called = False + + def after_step(self, context): + self.after_step_count += 1 + + def on_train_end(self, context): + self.on_train_end_called = True + + class CountingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + class DelayedOptimizerStepStrategy(TrainingStepStrategy): + def __init__(self, accumulation_steps=2, finalize_steps=True): + self.accumulation_steps = accumulation_steps + self.finalize_steps = finalize_steps + self.raw_batch_count = 0 + self.optimizer_step_count = 0 + self.finalize_calls = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + outputs = model(examples) + loss = loss_criterion(outputs, labels) + (loss / self.accumulation_steps).backward() + + self.raw_batch_count += 1 + if self.raw_batch_count % self.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False + + return loss + + def finalize(self, model, optimizer, context): + self.finalize_calls += 1 + if not self.finalize_steps: + return None + + optimizer.step() + optimizer.zero_grad() + self.optimizer_step_count += 1 + context.state["optimizer_step_completed"] = True + return torch.tensor(0.0) + + def test_local_steps_stop_mid_epoch_and_run_cleanup( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + callback = self.CountingCallback() + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + callbacks=[callback], + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 3 + assert update_strategy.after_step_count == 3 + assert callback.train_step_end_count == 3 + assert trainer.current_epoch == 1 + assert trainer.context.state["local_optimizer_steps"] == 3 + assert update_strategy.on_train_end_called + assert callback.train_run_end_called + + def test_local_steps_count_optimizer_steps_not_raw_batches( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=3) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert step_strategy.optimizer_step_count == 2 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + + def test_local_steps_skip_finalize_after_limit_is_reached( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 2, + "local_steps_per_round": 1, + } + step_strategy = self.DelayedOptimizerStepStrategy(accumulation_steps=2) + update_strategy = self.CountingUpdateStrategy() + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 2 + assert step_strategy.optimizer_step_count == 1 + assert step_strategy.finalize_calls == 0 + assert update_strategy.after_step_count == 1 + assert trainer.context.state["local_optimizer_steps"] == 1 + + def test_epoch_behavior_is_unchanged_when_local_steps_unset( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 10, + "epochs": 2, + } + step_strategy = self.CountingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.batch_count == 20 + assert trainer.current_epoch == 2 + assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) + def test_invalid_local_steps_fail_clearly( + self, simple_model, simple_dataset, simple_config, local_steps_per_round + ): + config = { + **simple_config, + "local_steps_per_round": local_steps_per_round, + } + trainer = ComposableTrainer(model=simple_model) + + with pytest.raises(ValueError, match="local_steps_per_round"): + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + class TestComposableTrainerStrategies: """Test strategy integration.""" From b91d97d67cf4958de9a72e02b45f7f351ecd92cf Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:52:00 -0400 Subject: [PATCH 15/39] Fixed local step counting for accumulation. DT-417 review found that the built-in GradientAccumulationStepStrategy did not publish optimizer_step_completed, so local_steps_per_round counted raw batches instead of optimizer steps when accumulation_steps > 1. Set optimizer_step_completed only when the accumulation strategy actually performs optimizer.step(), and add a regression test that uses the real built-in accumulation strategy with H=2 and accumulation_steps=3. Validation: uv run pytest tests/trainers/test_composable_trainer.py -k local_steps; uv run pytest tests/trainers/test_composable_trainer.py; uv run ruff check . --select I. --- plato/trainers/strategies/training_step.py | 3 ++ tests/trainers/test_composable_trainer.py | 50 ++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/plato/trainers/strategies/training_step.py b/plato/trainers/strategies/training_step.py index b4aba6d9d..5afa9702c 100644 --- a/plato/trainers/strategies/training_step.py +++ b/plato/trainers/strategies/training_step.py @@ -128,6 +128,9 @@ def training_step( if self.current_step % self.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() + context.state["optimizer_step_completed"] = True + else: + context.state["optimizer_step_completed"] = False # Return unscaled loss for logging return loss diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index ec83e20ea..d5daeb4b4 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -19,6 +19,7 @@ CrossEntropyLossStrategy, DefaultDataLoaderStrategy, DefaultTrainingStepStrategy, + GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, ) @@ -272,6 +273,30 @@ def finalize(self, model, optimizer, context): context.state["optimizer_step_completed"] = True return torch.tensor(0.0) + class CountingGradientAccumulationStepStrategy(GradientAccumulationStepStrategy): + def __init__(self, accumulation_steps): + super().__init__(accumulation_steps=accumulation_steps) + self.raw_batch_count = 0 + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + self.raw_batch_count += 1 + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + def test_local_steps_stop_mid_epoch_and_run_cleanup( self, simple_model, simple_dataset, simple_config ): @@ -325,6 +350,31 @@ def test_local_steps_count_optimizer_steps_not_raw_batches( assert update_strategy.after_step_count == 2 assert trainer.context.state["local_optimizer_steps"] == 2 + def test_local_steps_respect_builtin_gradient_accumulation( + self, simple_model, simple_dataset, simple_config + ): + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 2, + } + update_strategy = self.CountingUpdateStrategy() + step_strategy = self.CountingGradientAccumulationStepStrategy( + accumulation_steps=3 + ) + trainer = ComposableTrainer( + model=simple_model, + model_update_strategy=update_strategy, + training_step_strategy=step_strategy, + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert step_strategy.raw_batch_count == 6 + assert update_strategy.after_step_count == 2 + assert trainer.context.state["local_optimizer_steps"] == 2 + def test_local_steps_skip_finalize_after_limit_is_reached( self, simple_model, simple_dataset, simple_config ): From 6846f5097dffa8c01825189fee51163dc7036017 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 12:56:05 -0400 Subject: [PATCH 16/39] Added DiLoCo parameter eligibility policy. DT-412 makes the DiLoCo outer optimizer apply to trainable floating parameters by default while preserving full state_dict safety for frozen parameters and buffers. The new apply_outer_optimizer_to option supports parameters and all_floating modes, validates unsupported values clearly, resolves trainable parameter names from context.trainer.model for the default mode, and keeps momentum state only for tensors that receive outer optimization. Tests cover trainable parameters, frozen parameters, floating buffers, integer and boolean buffers, all_floating behavior, missing model context, invalid config values, and the existing DiLoCo aggregation math. Validation: uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 142 ++++++++++++- tests/servers/test_diloco_strategy.py | 191 +++++++++++++++++- 2 files changed, 321 insertions(+), 12 deletions(-) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 2b2df57e3..a00316a9a 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -31,6 +31,7 @@ class DiLoCoAggregationStrategy(FedAvgAggregationStrategy): _SUPPORTED_OPTIMIZERS = {"sgd", "sgdm", "nesterov"} _SUPPORTED_WEIGHTING_MODES = {"uniform", "num_samples"} + _SUPPORTED_APPLY_POLICIES = {"parameters", "all_floating"} def __init__( self, @@ -38,6 +39,7 @@ def __init__( outer_learning_rate: float = 0.7, outer_momentum: float = 0.9, aggregation_weighting: str = "uniform", + apply_outer_optimizer_to: str = "parameters", ): super().__init__() self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) @@ -48,6 +50,9 @@ def __init__( self.aggregation_weighting = self._validate_weighting_mode( aggregation_weighting ) + self.apply_outer_optimizer_to = self._validate_apply_policy( + apply_outer_optimizer_to + ) self.momentum_state: dict[str, Any] = {} async def aggregate_deltas( @@ -77,8 +82,10 @@ async def aggregate_deltas( return self._empty_delta(context, eligible[0][1]) avg_delta = self._match_reference_structure(avg_delta, eligible[0][1]) - outer_gradient = self._scale_tree(avg_delta, -1.0) - server_delta, active_paths = self._apply_outer_optimizer(outer_gradient) + optimizer_paths = self._outer_optimizer_paths(avg_delta, context) + server_delta, active_paths = self._apply_outer_optimizer( + avg_delta, optimizer_paths + ) self._remove_stale_momentum(active_paths) return self._match_reference_structure(server_delta, eligible[0][1]) @@ -118,6 +125,17 @@ def _validate_weighting_mode(cls, value: str) -> str: ) return weighting + @classmethod + def _validate_apply_policy(cls, value: str) -> str: + policy = str(value).lower() + if policy not in cls._SUPPORTED_APPLY_POLICIES: + supported = ", ".join(sorted(cls._SUPPORTED_APPLY_POLICIES)) + raise ValueError( + "Invalid apply_outer_optimizer_to " + f"'{value}'. Supported values: {supported}." + ) + return policy + def _eligible_updates( self, updates: list[SimpleNamespace], @@ -158,20 +176,48 @@ def _aggregation_weights( return [num_samples / total_samples for _, _, num_samples in eligible] - def _apply_outer_optimizer(self, outer_gradient: Any) -> tuple[Any, set[str]]: - active_paths: set[str] = set() + def _outer_optimizer_paths( + self, avg_delta: Any, context: ServerContext + ) -> set[str]: + if self.apply_outer_optimizer_to == "all_floating": + return self._floating_leaf_paths(avg_delta) + + trainable_parameter_names = self._trainable_parameter_names(context) + return self._collect_leaf_paths( + avg_delta, + lambda value, path: path in trainable_parameter_names + and self._is_floating_value(value), + ) - if self.outer_optimizer == "sgd": - return self._scale_tree(outer_gradient, -self.outer_learning_rate), set() + def _apply_outer_optimizer( + self, avg_delta: Any, optimizer_paths: set[str] + ) -> tuple[Any, set[str]]: + active_paths: set[str] = set() server_delta = self._map_tree( - outer_gradient, - lambda value, path: self._apply_momentum_leaf( - value, path, active_paths + avg_delta, + lambda value, path: self._apply_outer_optimizer_leaf( + value, path, optimizer_paths, active_paths ), ) return server_delta, active_paths + def _apply_outer_optimizer_leaf( + self, + avg_delta: Any, + path: str, + optimizer_paths: set[str], + active_paths: set[str], + ) -> Any: + if path not in optimizer_paths: + return avg_delta + + outer_gradient = self._scale_tree(avg_delta, -1.0) + if self.outer_optimizer == "sgd": + return self._scale_tree(outer_gradient, -self.outer_learning_rate) + + return self._apply_momentum_leaf(outer_gradient, path, active_paths) + def _apply_momentum_leaf( self, outer_gradient: Any, path: str, active_paths: set[str] ) -> Any: @@ -209,6 +255,84 @@ def _remove_stale_momentum(self, active_paths: set[str]) -> None: if path not in active_paths: del self.momentum_state[path] + def _trainable_parameter_names(self, context: ServerContext) -> set[str]: + model = self._model_from_context(context) + trainable_names: set[str] = set() + + for name, parameter in model.named_parameters(): + if getattr(parameter, "requires_grad", False) and self._is_floating_value( + parameter + ): + trainable_names.add(name) + + return trainable_names + + @staticmethod + def _model_from_context(context: ServerContext) -> Any: + trainer = getattr(context, "trainer", None) + model = getattr(trainer, "model", None) if trainer is not None else None + if model is None or not hasattr(model, "named_parameters"): + raise AttributeError( + "DiLoCo apply_outer_optimizer_to='parameters' requires " + "context.trainer.model with named_parameters()." + ) + return model + + def _floating_leaf_paths(self, value: Any) -> set[str]: + return self._collect_leaf_paths( + value, lambda leaf, _: self._is_floating_value(leaf) + ) + + def _collect_leaf_paths( + self, + value: Any, + predicate: Callable[[Any, str], bool], + path: str = "", + ) -> set[str]: + if isinstance(value, Mapping): + paths: set[str] = set() + for key, item in value.items(): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, key) + ) + ) + return paths + + if isinstance(value, list): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + if isinstance(value, tuple): + paths = set() + for index, item in enumerate(value): + paths.update( + self._collect_leaf_paths( + item, predicate, self._join_path(path, index) + ) + ) + return paths + + return {path} if predicate(value, path) else set() + + @staticmethod + def _is_floating_value(value: Any) -> bool: + if torch is not None and isinstance(value, torch.Tensor): + return torch.is_floating_point(value) + + if isinstance(value, np.ndarray): + return np.issubdtype(value.dtype, np.floating) + + return isinstance(value, numbers.Real) and not isinstance( + value, (numbers.Integral, bool) + ) + def _empty_delta(self, context: ServerContext, reference_delta: Any | None) -> dict: zero_delta = self._zero_delta(context, reference_delta) if zero_delta is not None: diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 336e5e1aa..a7b60b78d 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -35,10 +35,24 @@ def compute_weight_deltas(self, baseline_weights, weights_list): ] -def _context(baseline=None): +class MixedStateModel(torch.nn.Module): + """Model exposing trainable, frozen, floating-buffer, and integer state.""" + + def __init__(self): + super().__init__() + self.trainable = torch.nn.Parameter(torch.tensor([1.0])) + self.frozen = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=False) + self.register_buffer("floating_buffer", torch.tensor([1.0])) + self.register_buffer("integer_buffer", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) + + +def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: context.algorithm = DummyAlgorithm(baseline) + if model is not None: + context.trainer = SimpleNamespace(model=model) return context @@ -48,9 +62,9 @@ def _update(num_samples, report_type="weights"): ) -def _aggregate(strategy, updates, deltas, baseline=None): +def _aggregate(strategy, updates, deltas, baseline=None, model=None): return asyncio.run( - strategy.aggregate_deltas(updates, deltas, _context(baseline)) + strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) ) @@ -60,6 +74,7 @@ def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -77,6 +92,7 @@ def test_sgd_lr_one_num_samples_matches_weighted_fedavg(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -95,6 +111,7 @@ def test_sgd_lr_half_moves_halfway_to_averaged_model(temp_config): outer_optimizer="sgd", outer_learning_rate=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([10.0])} @@ -113,6 +130,7 @@ def test_sgd_uses_diloco_outer_gradient_sign(temp_config): outer_optimizer="sgd", outer_learning_rate=0.25, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -131,6 +149,7 @@ def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -149,6 +168,7 @@ def test_nonpositive_sample_reports_are_ineligible(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", ) server_delta = _aggregate( @@ -171,6 +191,7 @@ def test_empty_eligible_updates_return_zero_delta(temp_config): outer_optimizer="sgd", outer_learning_rate=1.0, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) baseline = {"w": torch.tensor([3.0, 4.0])} @@ -191,6 +212,7 @@ def test_empty_eligible_updates_remove_stale_momentum(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -217,6 +239,7 @@ def test_sgdm_persists_momentum_across_rounds(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) first_delta = _aggregate( @@ -244,6 +267,7 @@ def test_nesterov_uses_pytorch_style_two_round_recurrence(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) first_delta = _aggregate( @@ -273,6 +297,7 @@ def test_momentum_state_resets_on_shape_mismatch_and_removes_stale_keys( outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -301,6 +326,7 @@ def test_momentum_state_resets_on_dtype_mismatch(temp_config): outer_learning_rate=1.0, outer_momentum=0.5, aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", ) _aggregate( @@ -320,11 +346,170 @@ def test_momentum_state_resets_on_dtype_mismatch(temp_config): assert strategy.momentum_state["w"].dtype == torch.float64 +def test_parameters_policy_optimizes_only_trainable_floating_parameters( + temp_config, +): + """Default policy should leave frozen parameters and buffers on FedAvg deltas.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + first_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + model, + ) + + assert torch.allclose(first_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(first_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(first_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(first_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(first_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-4.0])) + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + model, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([6.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([6.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == {"trainable"} + assert torch.allclose(strategy.momentum_state["trainable"], torch.tensor([-8.0])) + + +def test_all_floating_policy_optimizes_every_floating_state_tensor(temp_config): + """All-floating mode should not require model context for eligibility.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + model = MixedStateModel() + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + server_delta = _aggregate( + strategy, + [_update(1), _update(1)], + [ + { + "trainable": torch.tensor([2.0]), + "frozen": torch.tensor([2.0]), + "floating_buffer": torch.tensor([2.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + }, + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([2], dtype=torch.int64), + "bool_buffer": torch.tensor([True]), + }, + ], + baseline, + ) + + assert torch.allclose(server_delta["trainable"], torch.tensor([2.0])) + assert torch.allclose(server_delta["frozen"], torch.tensor([2.0])) + assert torch.allclose(server_delta["floating_buffer"], torch.tensor([2.0])) + assert torch.equal(server_delta["integer_buffer"], torch.tensor([2])) + assert torch.equal(server_delta["bool_buffer"], torch.tensor([True])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + second_delta = _aggregate( + strategy, + [_update(1)], + [ + { + "trainable": torch.tensor([6.0]), + "frozen": torch.tensor([6.0]), + "floating_buffer": torch.tensor([6.0]), + "integer_buffer": torch.tensor([1], dtype=torch.int64), + "bool_buffer": torch.tensor([False]), + } + ], + baseline, + ) + + assert torch.allclose(second_delta["trainable"], torch.tensor([4.0])) + assert torch.allclose(second_delta["frozen"], torch.tensor([4.0])) + assert torch.allclose(second_delta["floating_buffer"], torch.tensor([4.0])) + assert torch.equal(second_delta["integer_buffer"], torch.tensor([1])) + assert torch.equal(second_delta["bool_buffer"], torch.tensor([False])) + assert set(strategy.momentum_state) == { + "trainable", + "frozen", + "floating_buffer", + } + + +def test_parameters_policy_requires_trainer_model_context(temp_config): + """Default parameter eligibility should fail clearly without a model.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + ) + + with pytest.raises(AttributeError, match="context.trainer.model"): + _aggregate( + strategy, + [_update(1)], + [{"trainable": torch.tensor([2.0])}], + {"trainable": torch.tensor([0.0])}, + ) + + @pytest.mark.parametrize( ("kwargs", "match"), [ ({"outer_optimizer": "adam"}, "outer_optimizer"), ({"aggregation_weighting": "weighted"}, "aggregation_weighting"), + ({"apply_outer_optimizer_to": "buffers"}, "apply_outer_optimizer_to"), ({"outer_learning_rate": -0.1}, "outer_learning_rate"), ({"outer_momentum": -0.1}, "outer_momentum"), ({"outer_momentum": 1.0}, "outer_momentum"), From d122831ecb0471b6acbfddfb5975f70309a6586d Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:02:37 -0400 Subject: [PATCH 17/39] Handled adapter payload names in DiLoCo eligibility. DT-413 review found that default parameter eligibility missed trainable adapter parameters when adapter payload keys omit PEFT adapter-name segments, such as lora_A.weight versus lora_A.default.weight. Resolve trainable payload aliases from model adapter metadata and intersect them with the actual floating payload leaves, so exact state_dict keys still work while PEFT-style adapter payloads receive outer optimization and momentum. Added a PEFT-like regression test that fails without the alias mapping and verifies the payload key receives SGDM scaling and momentum state. Validation: uv run pytest tests/servers/test_diloco_strategy.py -k peft_adapter -q; uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 60 ++++++++++++++++--- tests/servers/test_diloco_strategy.py | 41 +++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index a00316a9a..18162fd14 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -182,12 +182,11 @@ def _outer_optimizer_paths( if self.apply_outer_optimizer_to == "all_floating": return self._floating_leaf_paths(avg_delta) - trainable_parameter_names = self._trainable_parameter_names(context) - return self._collect_leaf_paths( - avg_delta, - lambda value, path: path in trainable_parameter_names - and self._is_floating_value(value), + floating_paths = self._floating_leaf_paths(avg_delta) + trainable_parameter_names = self._trainable_parameter_names( + context, floating_paths ) + return floating_paths.intersection(trainable_parameter_names) def _apply_outer_optimizer( self, avg_delta: Any, optimizer_paths: set[str] @@ -255,18 +254,65 @@ def _remove_stale_momentum(self, active_paths: set[str]) -> None: if path not in active_paths: del self.momentum_state[path] - def _trainable_parameter_names(self, context: ServerContext) -> set[str]: + def _trainable_parameter_names( + self, context: ServerContext, payload_paths: set[str] | None = None + ) -> set[str]: model = self._model_from_context(context) + adapter_names = self._adapter_names(model) trainable_names: set[str] = set() for name, parameter in model.named_parameters(): if getattr(parameter, "requires_grad", False) and self._is_floating_value( parameter ): - trainable_names.add(name) + trainable_names.update( + self._payload_name_candidates(name, adapter_names, payload_paths) + ) return trainable_names + @staticmethod + def _adapter_names(model: Any) -> set[str]: + adapter_names = {"default"} + + peft_config = getattr(model, "peft_config", None) + if isinstance(peft_config, Mapping): + adapter_names.update(str(name) for name in peft_config) + + active_adapter = getattr(model, "active_adapter", None) + if isinstance(active_adapter, str): + adapter_names.add(active_adapter) + + active_adapters = getattr(model, "active_adapters", None) + if callable(active_adapters): + try: + adapter_names.update(str(name) for name in active_adapters()) + except TypeError: + pass + elif isinstance(active_adapters, (list, tuple, set)): + adapter_names.update(str(name) for name in active_adapters) + + return adapter_names + + @classmethod + def _payload_name_candidates( + cls, + parameter_name: str, + adapter_names: set[str], + payload_paths: set[str] | None, + ) -> set[str]: + candidates = {parameter_name} + parts = parameter_name.split(".") + for index, part in enumerate(parts): + if part not in adapter_names: + continue + + candidate = ".".join(parts[:index] + parts[index + 1 :]) + if payload_paths is None or candidate in payload_paths: + candidates.add(candidate) + + return candidates + @staticmethod def _model_from_context(context: ServerContext) -> Any: trainer = getattr(context, "trainer", None) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index a7b60b78d..772ed5b0a 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -47,6 +47,20 @@ def __init__(self): self.register_buffer("bool_buffer", torch.tensor([True], dtype=torch.bool)) +class PeftLikeAdapterModel(torch.nn.Module): + """Model whose adapter payload keys omit PEFT's default adapter segment.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.base_model = torch.nn.Module() + self.base_model.model = torch.nn.Module() + self.base_model.model.linear = torch.nn.Module() + self.base_model.model.linear.lora_A = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: @@ -504,6 +518,33 @@ def test_parameters_policy_requires_trainer_model_context(temp_config): ) +def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): + """PEFT payloads can omit adapter-name segments from trainable param names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = PeftLikeAdapterModel() + payload_name = "base_model.model.linear.lora_A.weight" + baseline = {payload_name: torch.zeros((1, 1))} + + server_delta = _aggregate( + strategy, + [_update(1)], + [{payload_name: torch.full((1, 1), 4.0)}], + baseline, + model, + ) + + assert torch.allclose(server_delta[payload_name], torch.full((1, 1), 2.0)) + assert set(strategy.momentum_state) == {payload_name} + assert torch.allclose( + strategy.momentum_state[payload_name], torch.full((1, 1), -4.0) + ) + + @pytest.mark.parametrize( ("kwargs", "match"), [ From a9d8f3b9bb518e82fe22efd30cebabb3f5ce0ee3 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:11:11 -0400 Subject: [PATCH 18/39] Avoided DiLoCo adapter alias overmatching. DT-413 re-review found that adapter-name aliasing could include a separate floating payload key when the exact trainable parameter key was also present. Make exact payload key matches take precedence over adapter-name removal, so alias candidates are only considered when the original trainable parameter name is absent from the payload. Added a negative collision regression to keep unrelated payload keys on the plain averaged-delta path. Validation: uv run pytest tests/servers/test_diloco_strategy.py -k "adapter_payload_names or alias_collisions" -q; uv run pytest tests/servers/test_diloco_strategy.py; uv run pytest tests/servers/test_fedavg_strategy.py; uv run ruff check . --select I. --- .../servers/strategies/aggregation/diloco.py | 3 ++ tests/servers/test_diloco_strategy.py | 45 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 18162fd14..6504990cf 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -302,6 +302,9 @@ def _payload_name_candidates( payload_paths: set[str] | None, ) -> set[str]: candidates = {parameter_name} + if payload_paths is not None and parameter_name in payload_paths: + return candidates + parts = parameter_name.split(".") for index, part in enumerate(parts): if part not in adapter_names: diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 772ed5b0a..196928af8 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -61,6 +61,17 @@ def __init__(self): ) +class AdapterAliasCollisionModel(torch.nn.Module): + """Model with a trainable parameter and separate payload key collision.""" + + def __init__(self): + super().__init__() + self.peft_config = {"default": object()} + self.foo = torch.nn.ModuleDict( + {"default": torch.nn.Linear(1, 1, bias=False)} + ) + + def _context(baseline=None, model=None): context = ServerContext() if baseline is not None: @@ -545,6 +556,40 @@ def test_parameters_policy_maps_peft_adapter_payload_names(temp_config): ) +def test_parameters_policy_does_not_overmatch_adapter_alias_collisions(temp_config): + """Alias support should not optimize unrelated colliding payload names.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="sgdm", + outer_learning_rate=0.5, + outer_momentum=0.5, + aggregation_weighting="uniform", + ) + model = AdapterAliasCollisionModel() + trainable_name = "foo.default.weight" + colliding_name = "foo.weight" + baseline = { + trainable_name: torch.zeros((1, 1)), + colliding_name: torch.zeros((1, 1)), + } + + server_delta = _aggregate( + strategy, + [_update(1)], + [ + { + trainable_name: torch.full((1, 1), 4.0), + colliding_name: torch.full((1, 1), 4.0), + } + ], + baseline, + model, + ) + + assert torch.allclose(server_delta[trainable_name], torch.full((1, 1), 2.0)) + assert torch.allclose(server_delta[colliding_name], torch.full((1, 1), 4.0)) + assert set(strategy.momentum_state) == {trainable_name} + + @pytest.mark.parametrize( ("kwargs", "match"), [ From 679c9f6ba7ccac84af03feea0a0deebebbc5b38b Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:13:31 -0400 Subject: [PATCH 19/39] Persisted in-process optimizer state for DiLoCo. DT-418 adds trainer.preserve_optimizer_state for the PyTorch ComposableTrainer in-process path so client-local AdamW and scheduler state survive communication rounds without entering client-server payloads. The trainer caches optimizer and scheduler state per logical client, restores it after creating the next round optimizer/scheduler, and discards cached state when optimizer type, scheduler type, parameter names, shapes, dtypes, or optimizer parameter ordering no longer match. Focused tests cover AdamW moment persistence, scheduler LR progress, logical-client isolation, payload locality, disabled behavior, optimizer changes, parameter-order changes, and shape/dtype/scheduler compatibility rejection. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k optimizer_state-or-scheduler_state-or-composable -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I. The unignored trainer selector still hits the repo optional opacus collection dependency in tests/trainers/test_dp_data_loader_strategy.py. --- plato/trainers/composable.py | 147 +++++++++ .../test_composable_optimizer_state.py | 304 ++++++++++++++++++ 2 files changed, 451 insertions(+) create mode 100644 tests/trainers/test_composable_optimizer_state.py diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index f2bc2f0ca..8f16f0656 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -168,6 +168,7 @@ def __init__( self.current_epoch = 0 self.training_start_time = time.time() self.model_state_dict = None + self._preserved_optimizer_states: dict[int, dict[str, Any]] = {} def _require_model(self) -> nn.Module: """Return the underlying model, ensuring it is available.""" @@ -200,6 +201,143 @@ def _record_local_optimizer_step(self, local_steps_per_round: int | None) -> boo self.context.state["local_optimizer_steps"] = completed_steps return completed_steps >= local_steps_per_round + @staticmethod + def _preserve_optimizer_state(config: dict[str, Any]) -> bool: + """Return whether optimizer state should survive local train runs.""" + return bool(config.get("preserve_optimizer_state", False)) + + @staticmethod + def _parameter_signature(name: str | None, parameter: torch.Tensor): + """Build a compatibility signature for one model parameter.""" + return (name, tuple(parameter.shape), str(parameter.dtype)) + + @classmethod + def _model_parameter_signature(cls, model: nn.Module): + """Return parameter names, shapes, dtypes, and order for a model.""" + return tuple( + cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + ) + + @classmethod + def _optimizer_parameter_signature( + cls, model: nn.Module, optimizer: torch.optim.Optimizer + ): + """Return optimizer parameter group ordering with model metadata.""" + named_parameters = { + id(parameter): cls._parameter_signature(name, parameter) + for name, parameter in model.named_parameters() + } + + group_signatures = [] + for group in optimizer.param_groups: + group_signatures.append( + tuple( + named_parameters.get( + id(parameter), + cls._parameter_signature(None, parameter), + ) + for parameter in group.get("params", []) + ) + ) + + return tuple(group_signatures) + + @staticmethod + def _scheduler_type(scheduler: Any | None) -> type | None: + """Return the scheduler type used for compatibility checks.""" + if scheduler is None: + return None + return type(scheduler) + + def _preserved_state_is_compatible( + self, + payload: dict[str, Any], + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Any | None, + ) -> bool: + """Return whether a cached optimizer bundle matches this train run.""" + if payload.get("optimizer_type") is not type(optimizer): + return False + + if payload.get("scheduler_type") is not self._scheduler_type(scheduler): + return False + + if payload.get("model_parameters") != self._model_parameter_signature(model): + return False + + if payload.get("optimizer_parameters") != self._optimizer_parameter_signature( + model, optimizer + ): + return False + + if not callable(getattr(optimizer, "load_state_dict", None)): + return False + + if payload.get("scheduler_state") is not None and not callable( + getattr(scheduler, "load_state_dict", None) + ): + return False + + return True + + def _restore_preserved_optimizer_state(self) -> None: + """Restore compatible optimizer and scheduler state for this client.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None or self.optimizer is None: + return + + model = self._require_model() + if not self._preserved_state_is_compatible( + payload, model, self.optimizer, self.lr_scheduler + ): + self._preserved_optimizer_states.pop(self.client_id, None) + return + + try: + scheduler_state = payload.get("scheduler_state") + if scheduler_state is not None: + self.lr_scheduler.load_state_dict(copy.deepcopy(scheduler_state)) + + self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) + except Exception as error: + logging.debug( + "[Client #%d] Discarding incompatible optimizer state: %s", + self.client_id, + error, + ) + self._preserved_optimizer_states.pop(self.client_id, None) + self.optimizer = self.optimizer_strategy.create_optimizer( + model, self.context + ) + self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( + self.optimizer, self.context + ) + + def _save_preserved_optimizer_state(self) -> None: + """Save optimizer and scheduler state locally for this logical client.""" + if self.optimizer is None: + return + + model = self._require_model() + scheduler_state = None + if self.lr_scheduler is not None: + state_dict_fn = getattr(self.lr_scheduler, "state_dict", None) + if callable(state_dict_fn): + scheduler_state = copy.deepcopy(state_dict_fn()) + + self._preserved_optimizer_states[self.client_id] = { + "optimizer_type": type(self.optimizer), + "optimizer_state": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_type": self._scheduler_type(self.lr_scheduler), + "scheduler_state": scheduler_state, + "model_parameters": self._model_parameter_signature(model), + "optimizer_parameters": self._optimizer_parameter_signature( + model, self.optimizer + ), + } + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -420,6 +558,10 @@ def train_model(self, config, trainset, sampler, **kwargs): self.sampler = sampler self.context.config = config self.context.current_round = self.current_round + preserve_optimizer_state = self._preserve_optimizer_state(config) + if not preserve_optimizer_state: + self._preserved_optimizer_states.pop(self.client_id, None) + local_steps_per_round = self._local_steps_per_round(config) self.context.state["local_optimizer_steps"] = 0 if local_steps_per_round is None: @@ -513,6 +655,8 @@ def train_model(self, config, trainset, sampler, **kwargs): self.lr_scheduler = self.lr_scheduler_strategy.create_scheduler( self.optimizer, self.context ) + if preserve_optimizer_state: + self._restore_preserved_optimizer_state() # Move model to device model = self._require_model() @@ -744,6 +888,9 @@ def compute_loss(outputs, labels_inner): # Callbacks: train run end self.callback_handler.call_event("on_train_run_end", self, config) + if preserve_optimizer_state: + self._save_preserved_optimizer_state() + def train(self, trainset, sampler, **kwargs) -> float: """ The main training loop in a federated learning workload. diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py new file mode 100644 index 000000000..f37289625 --- /dev/null +++ b/tests/trainers/test_composable_optimizer_state.py @@ -0,0 +1,304 @@ +"""Tests for in-process optimizer state preservation in ComposableTrainer.""" + +import copy +from collections import OrderedDict + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import ( + AdamWOptimizerStrategy, + CrossEntropyLossStrategy, + DefaultTrainingStepStrategy, + SGDOptimizerStrategy, + StepLRSchedulerStrategy, +) + + +@pytest.fixture +def tiny_dataset(): + features = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + [-1.0, 0.5], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 1, 0, 1], dtype=torch.long) + return TensorDataset(features, labels) + + +@pytest.fixture +def one_step_config(): + return { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "optimizer-state-test", + } + + +class CapturingTrainingStep(DefaultTrainingStepStrategy): + """Record optimizer state before each local optimizer step.""" + + def __init__(self): + super().__init__() + self.pre_step_states = [] + self.pre_step_lrs = [] + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + optimizer_state = optimizer.state_dict() + self.pre_step_states.append(copy.deepcopy(optimizer_state["state"])) + self.pre_step_lrs.append( + [group["lr"] for group in optimizer_state["param_groups"]] + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + + +def _linear_model(): + return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) + + +def _two_layer_model(first_name="first", second_name="second"): + return nn.Sequential( + OrderedDict( + [ + (first_name, nn.Linear(2, 2, bias=False)), + (second_name, nn.Linear(2, 2, bias=False)), + ] + ) + ) + + +def _first_param_state(optimizer_state): + return next(iter(optimizer_state.values())) + + +def _state_step(param_state): + step = param_state["step"] + if isinstance(step, torch.Tensor): + return int(step.item()) + return int(step) + + +def test_adamw_moment_buffers_persist_between_rounds_for_same_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + trainer.set_client_id(7) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + round1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[1]) + saved_state = _first_param_state(round1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + assert torch.allclose(restored_state["exp_avg_sq"], saved_state["exp_avg_sq"]) + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 2 + + +def test_scheduler_state_and_lr_progress_persist_between_rounds( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + training_step_strategy=step_strategy, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_lrs == [[0.2], [0.1]] + assert trainer.lr_scheduler.last_epoch == 2 + assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) + + +def test_preserved_optimizer_state_is_local_to_logical_client( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + client1_state = copy.deepcopy(trainer.optimizer.state_dict()["state"]) + + trainer.set_client_id(2) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + trainer.set_client_id(1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[0] == {} + assert step_strategy.pre_step_states[1] == {} + restored_state = _first_param_state(step_strategy.pre_step_states[2]) + saved_state = _first_param_state(client1_state) + assert torch.allclose(restored_state["exp_avg"], saved_state["exp_avg"]) + + +def test_preserved_state_stays_out_of_model_update_payload( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + update = trainer.obtain_model_update( + config, tiny_dataset, list(range(len(tiny_dataset))) + ) + + assert "optimizer_state" not in update + assert "scheduler_state" not in update + assert all(torch.is_tensor(value) for value in update.values()) + + +def test_preserved_state_invalidates_when_parameter_order_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=lambda: _two_layer_model("first", "second"), + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.model = _two_layer_model("second", "first") + trainer.context.model = trainer.model + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + + +def test_preserved_state_invalidates_when_optimizer_type_changes( + temp_config, tiny_dataset, one_step_config +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.optimizer_strategy = SGDOptimizerStrategy(lr=0.1) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states[1] == {} + assert isinstance(trainer.optimizer, torch.optim.SGD) + + +def test_preserved_state_compatibility_rejects_shape_dtype_and_scheduler_changes( + temp_config, tiny_dataset, one_step_config +): + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(4) + config = {**one_step_config, "preserve_optimizer_state": True} + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + payload = copy.deepcopy(trainer._preserved_optimizer_states[4]) + + current_model = trainer.model + current_optimizer = trainer.optimizer_strategy.create_optimizer( + current_model, trainer.context + ) + changed_scheduler = StepLRSchedulerStrategy( + step_size=1, gamma=0.5 + ).create_scheduler(current_optimizer, trainer.context) + assert not trainer._preserved_state_is_compatible( + payload, current_model, current_optimizer, changed_scheduler + ) + + changed_shape_model = nn.Sequential( + OrderedDict([("linear", nn.Linear(2, 3))]) + ) + changed_shape_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_shape_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_shape_model, changed_shape_optimizer, None + ) + + changed_dtype_model = _linear_model().to(torch.float64) + changed_dtype_optimizer = trainer.optimizer_strategy.create_optimizer( + changed_dtype_model, trainer.context + ) + assert not trainer._preserved_state_is_compatible( + payload, changed_dtype_model, changed_dtype_optimizer, None + ) + + +@pytest.mark.parametrize("preserve_value", [None, False]) +def test_optimizer_state_is_not_restored_when_disabled_or_unset( + temp_config, tiny_dataset, one_step_config, preserve_value +): + step_strategy = CapturingTrainingStep() + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + training_step_strategy=step_strategy, + ) + config = dict(one_step_config) + if preserve_value is not None: + config["preserve_optimizer_state"] = preserve_value + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert step_strategy.pre_step_states == [{}, {}] + final_param_state = _first_param_state(trainer.optimizer.state_dict()["state"]) + assert _state_step(final_param_state) == 1 From 114f2a7f05461ec914d89d4fab360ecd9c672e55 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:21:42 -0400 Subject: [PATCH 20/39] Wired DiLoCo server selection. DT-414 adds server.type=diloco as a FedAvg-compatible server that injects DiLoCoAggregationStrategy while keeping algorithm.type=fedavg. The FedAvg delta path now filters non-weight reports before compute_weight_deltas(), so feature or metrics payloads cannot crash delta-only strategies before strategy eligibility handling. DiLoCo remains on aggregate_deltas and does not use inherited direct weight aggregation. Server-level tests cover registry/config selection, delta-path dispatch, inherited aggregate_weights avoidance, non-weight payload filtering, and existing FedAvg delta-strategy behavior. Validation: uv run pytest tests/servers/test_diloco_strategy.py -q; uv run pytest tests/servers/test_fedavg_strategy.py -q; uv run ruff check . --select I. --- plato/servers/diloco.py | 50 +++++++ plato/servers/fedavg.py | 27 +++- plato/servers/registry.py | 2 + tests/servers/test_diloco_strategy.py | 179 ++++++++++++++++++++++++++ 4 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 plato/servers/diloco.py diff --git a/plato/servers/diloco.py b/plato/servers/diloco.py new file mode 100644 index 000000000..bedbb5f05 --- /dev/null +++ b/plato/servers/diloco.py @@ -0,0 +1,50 @@ +"""FedAvg-compatible server using DiLoCo aggregation.""" + +from plato.config import Config +from plato.servers import fedavg +from plato.servers.strategies.aggregation import DiLoCoAggregationStrategy + + +class Server(fedavg.Server): + """Federated learning server with server-side DiLoCo outer aggregation.""" + + def __init__( + self, + model=None, + datasource=None, + algorithm=None, + trainer=None, + callbacks=None, + aggregation_strategy=None, + client_selection_strategy=None, + ): + if aggregation_strategy is None: + aggregation_strategy = DiLoCoAggregationStrategy( + **self._aggregation_config() + ) + + super().__init__( + model=model, + datasource=datasource, + algorithm=algorithm, + trainer=trainer, + callbacks=callbacks, + aggregation_strategy=aggregation_strategy, + client_selection_strategy=client_selection_strategy, + ) + + @staticmethod + def _aggregation_config() -> dict: + """Read optional DiLoCo aggregation settings from [server.diloco].""" + config = getattr(Config().server, "diloco", None) + if config is None: + return {} + + keys = ( + "outer_optimizer", + "outer_learning_rate", + "outer_momentum", + "aggregation_weighting", + "apply_outer_optimizer_to", + ) + return {key: getattr(config, key) for key in keys if hasattr(config, key)} diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 5f862fc35..e357adc31 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -222,13 +222,20 @@ async def _process_reports(self): # Use delta aggregation (default path) # Computes the weight deltas by comparing the weights received with # the current global model weights - deltas_received = algorithm.compute_weight_deltas( - baseline_weights, weights_received + delta_updates, delta_weights_received = ( + self._weight_updates_and_payloads(self.updates, weights_received) + ) + deltas_received = ( + algorithm.compute_weight_deltas( + baseline_weights, delta_weights_received + ) + if delta_weights_received + else [] ) # Runs a framework-agnostic server aggregation algorithm, such as # the federated averaging algorithm logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid()) - deltas = await self.aggregate_deltas(self.updates, deltas_received) + deltas = await self.aggregate_deltas(delta_updates, deltas_received) # Updates the existing model weights from the provided deltas updated_weights = algorithm.update_weights(deltas) # Loads the new model weights @@ -299,6 +306,20 @@ def _should_prefer_weight_aggregation(self) -> bool: and aggregate_deltas_impl is not FedAvgAggregationStrategy.aggregate_deltas ) + @staticmethod + def _weight_updates_and_payloads(updates, weights_received): + """Return update/payload pairs whose reports contain model weights.""" + delta_updates = [] + delta_weights_received = [] + + for update, weights in zip(updates, weights_received): + if getattr(update.report, "type", "weights") != "weights": + continue + delta_updates.append(update) + delta_weights_received.append(weights) + + return delta_updates, delta_weights_received + def clients_processed(self) -> None: """Additional work to be performed after client reports have been processed.""" diff --git a/plato/servers/registry.py b/plato/servers/registry.py index 2211b8972..b26465a8c 100644 --- a/plato/servers/registry.py +++ b/plato/servers/registry.py @@ -10,6 +10,7 @@ from plato.config import Config from plato.servers import ( + diloco, fedavg, fedavg_cs, fedavg_gan, @@ -30,6 +31,7 @@ registered_servers = { "fedavg": fedavg.Server, "fedavg_lora": fedavg.Server, + "diloco": diloco.Server, "fedavg_cross_silo": fedavg_cs.Server, "fedavg_gan": fedavg_gan.Server, "fedavg_personalized": fedavg_personalized.Server, diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 196928af8..f35467344 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -35,6 +35,60 @@ def compute_weight_deltas(self, baseline_weights, weights_list): ] +class ServerAlgorithm(DummyAlgorithm): + """Algorithm stub for exercising FedAvg-compatible server dispatch.""" + + def __init__(self, baseline): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in baseline.items() + } + self.delta_payloads = None + + def extract_weights(self): + return { + name: value.clone() if hasattr(value, "clone") else value + for name, value in self.current.items() + } + + def compute_weight_deltas(self, baseline_weights, weights_list): + self.delta_payloads = weights_list + return super().compute_weight_deltas(baseline_weights, weights_list) + + def update_weights(self, deltas): + self.current = { + name: self.current[name] + deltas[name] for name in self.current + } + return self.extract_weights() + + def load_weights(self, weights): + self.current = { + name: value.clone() if hasattr(value, "clone") else value + for name, value in weights.items() + } + + +class RecordingDiLoCoStrategy(DiLoCoAggregationStrategy): + """DiLoCo strategy recording server dispatch calls.""" + + def __init__(self): + super().__init__( + outer_optimizer="sgd", + outer_learning_rate=1.0, + aggregation_weighting="uniform", + apply_outer_optimizer_to="all_floating", + ) + self.delta_calls = 0 + self.last_updates = None + self.last_deltas = None + + async def aggregate_deltas(self, updates, deltas_received, context): + self.delta_calls += 1 + self.last_updates = updates + self.last_deltas = deltas_received + return await super().aggregate_deltas(updates, deltas_received, context) + + class MixedStateModel(torch.nn.Module): """Model exposing trainable, frozen, floating-buffer, and integer state.""" @@ -87,12 +141,137 @@ def _update(num_samples, report_type="weights"): ) +def _server_update(payload, num_samples=1, report_type="weights"): + update = _update(num_samples, report_type) + update.client_id = len(str(payload)) + update.report.accuracy = 0.5 + update.report.processing_time = 0.1 + update.report.comm_time = 0.1 + update.report.training_time = 0.1 + update.payload = payload + return update + + def _aggregate(strategy, updates, deltas, baseline=None, model=None): return asyncio.run( strategy.aggregate_deltas(updates, deltas, _context(baseline, model)) ) +def test_diloco_server_type_uses_fedavg_algorithm_and_strategy(temp_config): + """server.type=diloco should select a FedAvg-compatible DiLoCo server.""" + from plato.algorithms import registry as algorithms_registry + from plato.config import Config + from plato.servers import diloco as diloco_server + from plato.servers import fedavg + from plato.servers import registry as servers_registry + + Config().server.type = "diloco" + Config().algorithm.type = "fedavg" + Config().server.diloco = SimpleNamespace( + outer_optimizer="sgd", + outer_learning_rate=0.25, + outer_momentum=0.1, + aggregation_weighting="num_samples", + apply_outer_optimizer_to="all_floating", + ) + + server = servers_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server, fedavg.Server) + assert isinstance(server.aggregation_strategy, DiLoCoAggregationStrategy) + assert server.aggregation_strategy.outer_optimizer == "sgd" + assert server.aggregation_strategy.outer_learning_rate == 0.25 + assert server.aggregation_strategy.outer_momentum == 0.1 + assert server.aggregation_strategy.aggregation_weighting == "num_samples" + assert server.aggregation_strategy.apply_outer_optimizer_to == "all_floating" + assert Config().algorithm.type == "fedavg" + assert "diloco" not in algorithms_registry.registered_algorithms + + +def test_diloco_server_process_reports_uses_delta_aggregation(temp_config): + """DiLoCo server processing should reach the delta aggregation path.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [ + _server_update({"w": torch.tensor([2.0])}), + _server_update({"w": torch.tensor([4.0])}), + ] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert strategy.last_updates == server.updates + assert len(strategy.last_deltas) == 2 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([3.0])) + + +def test_diloco_server_does_not_use_inherited_weight_aggregation(temp_config): + """DiLoCo must not bypass delta aggregation via inherited FedAvg weights.""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + + async def fail_if_called(*_args, **_kwargs): + raise AssertionError("Inherited aggregate_weights() must not be called.") + + strategy.aggregate_weights = fail_if_called + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + server.updates = [_server_update({"w": torch.tensor([2.0])})] + + asyncio.run(server._process_reports()) + + assert strategy.delta_calls == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + +def test_diloco_server_filters_non_weight_reports_before_delta_computation( + temp_config, +): + """Non-weight payloads should not reach compute_weight_deltas().""" + from plato.config import Config + from plato.servers import diloco + + Config().server.do_test = False + strategy = RecordingDiLoCoStrategy() + server = diloco.Server(aggregation_strategy=strategy) + baseline = {"w": torch.zeros(1)} + server.algorithm = ServerAlgorithm(baseline) + server.context.algorithm = server.algorithm + server.context.server = server + server.context.state["prng_state"] = None + weight_payload = {"w": torch.tensor([2.0])} + server.updates = [ + _server_update("feature payload", report_type="features"), + _server_update({"metrics": 1.0}, report_type="metrics"), + _server_update(weight_payload), + ] + + asyncio.run(server._process_reports()) + + assert server.algorithm.delta_payloads == [weight_payload] + assert strategy.last_updates == [server.updates[2]] + assert len(strategy.last_deltas) == 1 + assert torch.allclose(server.algorithm.current["w"], torch.tensor([2.0])) + + def test_sgd_lr_one_uniform_matches_uniform_model_averaging(temp_config): """Outer SGD with lr=1 should match uniform averaging under uniform mode.""" strategy = DiLoCoAggregationStrategy( From da0f3413af89af9c31716feb98f7bbaa14010585 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:29:27 -0400 Subject: [PATCH 21/39] Persisted optimizer state across train subprocesses. DT-420 extends trainer.preserve_optimizer_state to ComposableTrainer subprocess training by using a local optimizer-state sidecar under the configured model path. Child training loads any preserved sidecar before train_model(), saves updated optimizer and scheduler state after training, and the parent reloads the sidecar after the trained model is loaded. Missing, unreadable, invalid, or incompatible state falls back to fresh optimizer/scheduler state with explicit logging. Tests cover parent reload, optimizer state persistence across two subprocess rounds, scheduler progress, invalid sidecar reset, disabled behavior, and payload non-leakage. State remains local and is not added to network payloads. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "subprocess and (optimizer_state or scheduler_state)" -q; uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k "subprocess and (optimizer_state or scheduler_state)" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 105 ++++++++++- .../test_composable_optimizer_state.py | 163 ++++++++++++++++++ 2 files changed, 266 insertions(+), 2 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 8f16f0656..fc293302d 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -292,6 +292,11 @@ def _restore_preserved_optimizer_state(self) -> None: if not self._preserved_state_is_compatible( payload, model, self.optimizer, self.lr_scheduler ): + logging.info( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + ) self._preserved_optimizer_states.pop(self.client_id, None) return @@ -302,8 +307,9 @@ def _restore_preserved_optimizer_state(self) -> None: self.optimizer.load_state_dict(copy.deepcopy(payload["optimizer_state"])) except Exception as error: - logging.debug( - "[Client #%d] Discarding incompatible optimizer state: %s", + logging.warning( + "[Client #%d] Discarding incompatible optimizer state; " + "starting with fresh optimizer and scheduler state: %s", self.client_id, error, ) @@ -338,6 +344,82 @@ def _save_preserved_optimizer_state(self) -> None: ), } + def _optimizer_state_filename(self, run_id: str) -> str: + """Return the local optimizer-state handoff filename.""" + model_name = Config().trainer.model_name + return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + + def _optimizer_state_path(self, filename: str) -> str: + """Return the local optimizer-state handoff path.""" + return os.path.join(Config().params["model_path"], filename) + + def _save_preserved_optimizer_state_file(self, filename: str) -> None: + """Persist preserved optimizer state for subprocess handoff.""" + payload = self._preserved_optimizer_states.get(self.client_id) + if payload is None: + return + + model_path = Config().params["model_path"] + os.makedirs(model_path, exist_ok=True) + state_path = self._optimizer_state_path(filename) + tmp_path = f"{state_path}.{os.getpid()}.tmp" + + try: + with open(tmp_path, "wb") as state_file: + pickle.dump(copy.deepcopy(payload), state_file) + os.replace(tmp_path, state_path) + except Exception as error: + if os.path.exists(tmp_path): + os.remove(tmp_path) + logging.warning( + "[Client #%d] Failed to persist optimizer state to %s: %s", + self.client_id, + state_path, + error, + ) + + def _load_preserved_optimizer_state_file( + self, filename: str, *, clear_on_missing: bool = False + ) -> None: + """Load preserved optimizer state from a subprocess handoff file.""" + state_path = self._optimizer_state_path(filename) + if not os.path.exists(state_path): + if clear_on_missing: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.info( + "[Client #%d] No persisted optimizer state found at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return + + try: + with open(state_path, "rb") as state_file: + payload = pickle.load(state_file) + except Exception as error: + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding unreadable optimizer state at %s; " + "starting with fresh optimizer and scheduler state: %s", + self.client_id, + state_path, + error, + ) + return + + if not isinstance(payload, dict): + self._preserved_optimizer_states.pop(self.client_id, None) + logging.warning( + "[Client #%d] Discarding invalid optimizer state at %s; " + "starting with fresh optimizer and scheduler state.", + self.client_id, + state_path, + ) + return + + self._preserved_optimizer_states[self.client_id] = payload + @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: """State keys that must survive spawned test subprocesses.""" @@ -545,8 +627,16 @@ def simulate_sleep_time(self): def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" + preserve_optimizer_state = self._preserve_optimizer_state(config) + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename(config["run_id"]) + self._load_preserved_optimizer_state_file(optimizer_state_filename) + self.train_model(config, trainset, sampler, **kwargs) + if preserve_optimizer_state: + self._save_preserved_optimizer_state_file(optimizer_state_filename) + model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" self.save_model(filename) @@ -911,6 +1001,12 @@ def train(self, trainset, sampler, **kwargs) -> float: if "max_concurrency" in config: tic = time.perf_counter() + preserve_optimizer_state = self._preserve_optimizer_state(config) + optimizer_state_filename = None + if preserve_optimizer_state: + optimizer_state_filename = self._optimizer_state_filename( + config["run_id"] + ) if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -963,6 +1059,11 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error + if optimizer_state_filename is not None: + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) + toc = time.perf_counter() self.pause_training() else: diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index f37289625..59f57fd2e 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -1,13 +1,18 @@ """Tests for in-process optimizer state preservation in ComposableTrainer.""" import copy +import os +import pickle +import sys from collections import OrderedDict +from pathlib import Path import pytest import torch import torch.nn as nn from torch.utils.data import TensorDataset +from plato.config import Config from plato.trainers.composable import ComposableTrainer from plato.trainers.strategies import ( AdamWOptimizerStrategy, @@ -101,6 +106,38 @@ def _state_step(param_state): return int(step) +def _configure_subprocess_training( + monkeypatch, + tmp_path, + *, + preserve_optimizer_state, +): + """Configure parent and spawned child processes to share local artifacts.""" + model_path = Path(tmp_path) / "models" / "pretrained" + model_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(Path(tmp_path) / "checkpoints") + Config.params["run_id"] = "subprocess-optimizer-state" + os.makedirs(Config.params["checkpoint_path"], exist_ok=True) + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + max_concurrency=1, + preserve_optimizer_state=preserve_optimizer_state, + batch_size=4, + epochs=1, + ) + + +def _cached_optimizer_step(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return _state_step(_first_param_state(payload["optimizer_state"]["state"])) + + +def _cached_scheduler_last_epoch(trainer): + payload = trainer._preserved_optimizer_states[trainer.client_id] + return payload["scheduler_state"]["last_epoch"] + + def test_adamw_moment_buffers_persist_between_rounds_for_same_client( temp_config, tiny_dataset, one_step_config ): @@ -149,6 +186,132 @@ def test_scheduler_state_and_lr_progress_persist_between_rounds( assert trainer.optimizer.param_groups[0]["lr"] == pytest.approx(0.05) +def test_subprocess_optimizer_state_parent_reloads_after_child( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer.client_id in trainer._preserved_optimizer_states + assert _cached_optimizer_step(trainer) == 1 + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert state_path.exists() + assert "optimizer_state" not in trainer.obtain_model_update( + { + "batch_size": 4, + "epochs": 1, + "lr": 0.01, + "run_id": "payload-check", + "preserve_optimizer_state": True, + }, + tiny_dataset, + list(range(len(tiny_dataset))), + ) + + +def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 2 + + +def test_subprocess_scheduler_state_persists_across_rounds( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=SGDOptimizerStrategy(lr=0.2), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + trainer.set_client_id(3) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert _cached_scheduler_last_epoch(trainer) == 2 + assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) + + +def test_subprocess_invalid_optimizer_state_resets_safely( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with open(state_path, "wb") as state_file: + pickle.dump({"optimizer_type": torch.optim.SGD}, state_file) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + payload = trainer._preserved_optimizer_states[trainer.client_id] + assert payload["optimizer_type"] is torch.optim.AdamW + assert _cached_optimizer_step(trainer) == 1 + + +def test_subprocess_optimizer_state_is_not_persisted_when_disabled( + temp_config, monkeypatch, tmp_path, tiny_dataset +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=False + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert trainer._preserved_optimizer_states == {} + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + assert not state_path.exists() + + def test_preserved_optimizer_state_is_local_to_logical_client( temp_config, tiny_dataset, one_step_config ): From c7307320096a3f6881a0f185a7410303931aa9aa Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:39:42 -0400 Subject: [PATCH 22/39] Hardened subprocess optimizer state handoff. DT-421 review found that a missing optimizer sidecar could leave inherited parent cache active in the child, and that parent reload could confuse stale input with current child output. The child now clears inherited cache when the input sidecar is missing. Subprocess training writes to a unique child output sidecar, the parent loads that output, promotes it to the stable input sidecar for the next round, and removes stale stable state if child output is missing or invalid. Added regressions for missing input sidecars clearing inherited cache and missing child output removing stale input sidecars. Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "subprocess or sidecar" -q; uv run pytest tests/trainers/test_composable_optimizer_state.py -q; uv run pytest tests/trainers/test_composable_trainer.py -q; uv run pytest tests/trainers -k "subprocess and (optimizer_state or scheduler_state)" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 82 ++++++++++++++++--- .../test_composable_optimizer_state.py | 76 +++++++++++++++++ 2 files changed, 146 insertions(+), 12 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index fc293302d..78ceb07bf 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -349,15 +349,21 @@ def _optimizer_state_filename(self, run_id: str) -> str: model_name = Config().trainer.model_name return f"{model_name}_{self.client_id}_{run_id}.optim.pkl" + def _optimizer_state_output_filename(self, run_id: str) -> str: + """Return a unique subprocess optimizer-state output filename.""" + model_name = Config().trainer.model_name + token = time.time_ns() + return f"{model_name}_{self.client_id}_{run_id}_{os.getpid()}_{token}.optim.pkl" + def _optimizer_state_path(self, filename: str) -> str: """Return the local optimizer-state handoff path.""" return os.path.join(Config().params["model_path"], filename) - def _save_preserved_optimizer_state_file(self, filename: str) -> None: + def _save_preserved_optimizer_state_file(self, filename: str) -> bool: """Persist preserved optimizer state for subprocess handoff.""" payload = self._preserved_optimizer_states.get(self.client_id) if payload is None: - return + return False model_path = Config().params["model_path"] os.makedirs(model_path, exist_ok=True) @@ -368,6 +374,7 @@ def _save_preserved_optimizer_state_file(self, filename: str) -> None: with open(tmp_path, "wb") as state_file: pickle.dump(copy.deepcopy(payload), state_file) os.replace(tmp_path, state_path) + return True except Exception as error: if os.path.exists(tmp_path): os.remove(tmp_path) @@ -377,10 +384,11 @@ def _save_preserved_optimizer_state_file(self, filename: str) -> None: state_path, error, ) + return False def _load_preserved_optimizer_state_file( self, filename: str, *, clear_on_missing: bool = False - ) -> None: + ) -> bool: """Load preserved optimizer state from a subprocess handoff file.""" state_path = self._optimizer_state_path(filename) if not os.path.exists(state_path): @@ -392,7 +400,7 @@ def _load_preserved_optimizer_state_file( self.client_id, state_path, ) - return + return False try: with open(state_path, "rb") as state_file: @@ -406,7 +414,7 @@ def _load_preserved_optimizer_state_file( state_path, error, ) - return + return False if not isinstance(payload, dict): self._preserved_optimizer_states.pop(self.client_id, None) @@ -416,9 +424,38 @@ def _load_preserved_optimizer_state_file( self.client_id, state_path, ) - return + return False self._preserved_optimizer_states[self.client_id] = payload + return True + + def _remove_preserved_optimizer_state_file(self, filename: str) -> None: + """Remove a local optimizer-state sidecar if it exists.""" + state_path = self._optimizer_state_path(filename) + try: + os.remove(state_path) + except FileNotFoundError: + return + except OSError as error: + logging.warning( + "[Client #%d] Failed to remove optimizer state at %s: %s", + self.client_id, + state_path, + error, + ) + + def _finish_subprocess_optimizer_state( + self, input_filename: str, output_filename: str + ) -> None: + """Load the child output sidecar and promote it for the next round.""" + loaded = self._load_preserved_optimizer_state_file( + output_filename, clear_on_missing=True + ) + if loaded: + self._save_preserved_optimizer_state_file(input_filename) + self._remove_preserved_optimizer_state_file(output_filename) + else: + self._remove_preserved_optimizer_state_file(input_filename) @staticmethod def _persisted_test_state_keys() -> tuple[str, ...]: @@ -629,13 +666,22 @@ def train_process(self, config, trainset, sampler, **kwargs): """The training process in a federated learning workload.""" preserve_optimizer_state = self._preserve_optimizer_state(config) if preserve_optimizer_state: - optimizer_state_filename = self._optimizer_state_filename(config["run_id"]) - self._load_preserved_optimizer_state_file(optimizer_state_filename) + optimizer_state_filename = config.get( + "_optimizer_state_input_filename", + self._optimizer_state_filename(config["run_id"]), + ) + optimizer_state_output_filename = config.get( + "_optimizer_state_output_filename", + optimizer_state_filename, + ) + self._load_preserved_optimizer_state_file( + optimizer_state_filename, clear_on_missing=True + ) self.train_model(config, trainset, sampler, **kwargs) if preserve_optimizer_state: - self._save_preserved_optimizer_state_file(optimizer_state_filename) + self._save_preserved_optimizer_state_file(optimizer_state_output_filename) model_name = Config().trainer.model_name filename = f"{model_name}_{self.client_id}_{config['run_id']}.safetensors" @@ -1003,10 +1049,19 @@ def train(self, trainset, sampler, **kwargs) -> float: tic = time.perf_counter() preserve_optimizer_state = self._preserve_optimizer_state(config) optimizer_state_filename = None + optimizer_state_output_filename = None if preserve_optimizer_state: optimizer_state_filename = self._optimizer_state_filename( config["run_id"] ) + optimizer_state_output_filename = self._optimizer_state_output_filename( + config["run_id"] + ) + config = { + **config, + "_optimizer_state_input_filename": optimizer_state_filename, + "_optimizer_state_output_filename": optimizer_state_output_filename, + } if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) @@ -1059,9 +1114,12 @@ def train(self, trainset, sampler, **kwargs) -> float: f"Training on client {self.client_id} failed." ) from error - if optimizer_state_filename is not None: - self._load_preserved_optimizer_state_file( - optimizer_state_filename, clear_on_missing=True + if ( + optimizer_state_filename is not None + and optimizer_state_output_filename is not None + ): + self._finish_subprocess_optimizer_state( + optimizer_state_filename, optimizer_state_output_filename ) toc = time.perf_counter() diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 59f57fd2e..9a8cab29e 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -262,6 +262,82 @@ def test_subprocess_scheduler_state_persists_across_rounds( assert payload["optimizer_state"]["param_groups"][0]["lr"] == pytest.approx(0.05) +def test_subprocess_missing_sidecar_clears_inherited_parent_cache( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + source_trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + source_trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + assert _cached_optimizer_step(source_trainer) == 1 + + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + trainer._preserved_optimizer_states[7] = copy.deepcopy( + source_trainer._preserved_optimizer_states[7] + ) + + state_path = ( + Path(Config.params["model_path"]) + / trainer._optimizer_state_filename(Config.params["run_id"]) + ) + state_path.unlink(missing_ok=True) + + trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) + + assert _cached_optimizer_step(trainer) == 1 + + +def test_missing_subprocess_output_removes_stale_input_sidecar( + temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config +): + _configure_subprocess_training( + monkeypatch, tmp_path, preserve_optimizer_state=True + ) + trainer = ComposableTrainer( + model=_linear_model, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + ) + trainer.set_client_id(7) + config = { + **one_step_config, + "run_id": Config.params["run_id"], + "preserve_optimizer_state": True, + } + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + input_filename = trainer._optimizer_state_filename(Config.params["run_id"]) + missing_output_filename = trainer._optimizer_state_output_filename( + Config.params["run_id"] + ) + assert trainer._save_preserved_optimizer_state_file(input_filename) + input_path = Path(Config.params["model_path"]) / input_filename + assert input_path.exists() + + trainer._finish_subprocess_optimizer_state( + input_filename, missing_output_filename + ) + + assert trainer.client_id not in trainer._preserved_optimizer_states + assert not input_path.exists() + + def test_subprocess_invalid_optimizer_state_resets_safely( temp_config, monkeypatch, tmp_path, tiny_dataset ): From 33129d083443f932bbf9a4b6fc4e9c54a890d375 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 13:56:47 -0400 Subject: [PATCH 23/39] Added DiLoCo payload safety coverage. DT-422 adds regression tests proving client-local optimizer and scheduler state remains local when trainer.preserve_optimizer_state is enabled. Client tests now cover the FedAvg/DiLoCo-compatible in-process path and subprocess sidecar path, asserting outbound payloads contain exactly model state tensors and reject optimizer_state, scheduler_state, global_step, local metadata, and sidecar filename keys. Trainer tests also verify model-update payloads stay model-only while optimizer and scheduler state are persisted locally. Validation: uv run pytest tests/clients -k "payload or simple" -q; uv run pytest tests/trainers -k "optimizer_state or scheduler_state" -q --ignore=tests/trainers/test_dp_data_loader_strategy.py; uv run ruff check . --select I; git diff --check. --- tests/clients/test_simple_client.py | 155 ++++++++++++++++-- .../test_composable_optimizer_state.py | 29 +++- 2 files changed, 171 insertions(+), 13 deletions(-) diff --git a/tests/clients/test_simple_client.py b/tests/clients/test_simple_client.py index 43ffbb907..11b1bd8e0 100644 --- a/tests/clients/test_simple_client.py +++ b/tests/clients/test_simple_client.py @@ -1,7 +1,10 @@ """End-to-end smoke tests for the strategy-based client runtime.""" import asyncio +import pickle +import sys from dataclasses import dataclass +from pathlib import Path import torch from torch.utils.data import Dataset @@ -10,6 +13,20 @@ from plato.clients import simple from plato.config import Config from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import AdamWOptimizerStrategy, StepLRSchedulerStrategy +from tests.test_utils.fakes import NoOpCommunicationStrategy + +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} class ToyDataset(Dataset): @@ -48,37 +65,155 @@ def get_test_set(self): return self._test -def _build_client(): +def _build_client(trainer=ComposableTrainer): """Instantiate a client wired with custom model, datasource, and trainer.""" return simple.Client( model=torch.nn.Linear(4, 2), datasource=ToyDatasource, - trainer=ComposableTrainer, + trainer=trainer, algorithm=lambda trainer: fedavg.Algorithm(trainer), ) -def test_simple_client_trains_with_default_strategies(temp_config): - """A simple client should complete one training round using the strategy stack.""" - Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) +def _build_stateful_trainer(model=None, callbacks=None): + """Build a trainer whose local optimizer and scheduler state is non-empty.""" + return ComposableTrainer( + model=model, + callbacks=callbacks, + optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) - client = _build_client() - # Assign identifiers expected by the client runtime. +def _configure_one_round_client(client): + """Prepare a client for a deterministic single training round.""" client.client_id = 1 client._context.client_id = 1 client.current_round = 1 client._context.current_round = 1 - # Prepare data and runtime components. client._load_data() client.configure() client._allocate_data() + +def _disable_payload_processors(client): + """Keep the test focused on decoded client-server model payload contents.""" + client.inbound_processor = None + client.outbound_processor = None + client._context.inbound_processor = None + client._context.outbound_processor = None + + +def _assert_model_weight_payload(payload, model): + """Assert that an outbound payload contains exactly model state tensors.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + for name, expected in model_state.items(): + assert torch.equal(payload[name], expected) + + +def _assert_preserved_state_is_local(trainer, client_id): + """Assert optimizer and scheduler persistence exists only in trainer state.""" + state = trainer._preserved_optimizer_states[client_id] + + assert state["optimizer_state"]["state"] + assert state["scheduler_state"] is not None + assert state["scheduler_state"]["last_epoch"] >= 1 + assert state["scheduler_state"]["_step_count"] >= 2 + + +def test_simple_client_trains_with_default_strategies(temp_config): + """A simple client should complete one training round using the strategy stack.""" + Config().trainer = Config().trainer._replace(epochs=1, batch_size=2) + + client = _build_client() + + _configure_one_round_client(client) + report, payload = asyncio.run(client._train()) assert report.client_id == 1 # With partition_size=4 each client receives four samples. assert report.num_samples == 4 - assert isinstance(payload, dict) - assert all(isinstance(value, torch.Tensor) for value in payload.values()) + _assert_model_weight_payload(payload, client.trainer.model) + + +def test_simple_client_payload_excludes_local_state_when_persistence_enabled( + temp_config, +): + """FedAvg/DiLoCo client payloads stay model-only with local persistence.""" + Config.params["run_id"] = "client-payload-in-process" + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + _assert_preserved_state_is_local(client.trainer, client.client_id) + _assert_model_weight_payload(sent_payload, client.trainer.model) + + +def test_simple_client_subprocess_payload_excludes_local_state_sidecar( + temp_config, monkeypatch, tmp_path +): + """Subprocess persistence uses a sidecar without changing server payloads.""" + model_path = Path(tmp_path) / "models" / "pretrained" + checkpoint_path = Path(tmp_path) / "checkpoints" + model_path.mkdir(parents=True, exist_ok=True) + checkpoint_path.mkdir(parents=True, exist_ok=True) + Config.params["model_path"] = str(model_path) + Config.params["checkpoint_path"] = str(checkpoint_path) + Config.params["run_id"] = "client-payload-subprocess" + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-b", str(tmp_path)]) + Config().trainer = Config().trainer._replace( + epochs=1, + batch_size=2, + max_concurrency=1, + preserve_optimizer_state=True, + ) + client = _build_client(trainer=_build_stateful_trainer) + client._configure_composable( + lifecycle_strategy=client.lifecycle_strategy, + payload_strategy=client.payload_strategy, + training_strategy=client.training_strategy, + reporting_strategy=client.reporting_strategy, + communication_strategy=NoOpCommunicationStrategy(), + ) + _configure_one_round_client(client) + _disable_payload_processors(client) + + server_payload = client.algorithm.extract_weights() + asyncio.run(client._handle_payload(server_payload)) + + sent_payload = client._context.state["sent_payloads"][-1] + state_path = ( + Path(Config.params["model_path"]) + / client.trainer._optimizer_state_filename(Config.params["run_id"]) + ) + with state_path.open("rb") as state_file: + sidecar_state = pickle.load(state_file) + + _assert_preserved_state_is_local(client.trainer, client.client_id) + assert sidecar_state["optimizer_state"]["state"] + assert sidecar_state["scheduler_state"] is not None + _assert_model_weight_payload(sent_payload, client.trainer.model) diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 9a8cab29e..3620c6248 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -22,6 +22,18 @@ StepLRSchedulerStrategy, ) +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + @pytest.fixture def tiny_dataset(): @@ -138,6 +150,14 @@ def _cached_scheduler_last_epoch(trainer): return payload["scheduler_state"]["last_epoch"] +def _assert_model_update_contains_only_model_weights(update, model): + model_state = model.state_dict() + + assert set(update) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(update) + assert all(torch.is_tensor(value) for value in update.values()) + + def test_adamw_moment_buffers_persist_between_rounds_for_same_client( temp_config, tiny_dataset, one_step_config ): @@ -424,16 +444,19 @@ def test_preserved_state_stays_out_of_model_update_payload( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), ) + trainer.set_client_id(5) config = {**one_step_config, "preserve_optimizer_state": True} update = trainer.obtain_model_update( config, tiny_dataset, list(range(len(tiny_dataset))) ) + preserved_state = trainer._preserved_optimizer_states[trainer.client_id] - assert "optimizer_state" not in update - assert "scheduler_state" not in update - assert all(torch.is_tensor(value) for value in update.values()) + assert preserved_state["optimizer_state"]["state"] + assert preserved_state["scheduler_state"]["last_epoch"] >= 1 + _assert_model_update_contains_only_model_weights(update, trainer.model) def test_preserved_state_invalidates_when_parameter_order_changes( From 21bf980c0c86c5bc732f3aa60613d5dd572c98a5 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:06:55 -0400 Subject: [PATCH 24/39] Added round-aware local-step sampling. DT-428 prevents exact local-step training from replaying the same deterministic sampler prefix when H is smaller than one epoch and the train loader is recreated each round. The data-loader strategies now materialize supported sampler streams only when trainer.local_steps_per_round is active, rotate the stream by the deterministic round offset, and leave epoch-based training unchanged when local-step limits are unset. Unsupported non-materializable sampler objects log a clear warning and fall back unchanged. Added focused red/green coverage showing two short local-step rounds for the same client consume different prefixes while repeated runs with the same round sequence remain deterministic. Validation: uv run pytest tests/trainers -k "local_steps or data_loader or sampler" -q; uv run pytest tests/samplers -q; uv run ruff check . --select I; git diff --check. --- plato/trainers/strategies/data_loader.py | 81 +++++++++++++++++++++ tests/trainers/test_composable_trainer.py | 88 ++++++++++++++++++++++- 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 91e5a0482..f65bbf3f3 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -20,6 +20,19 @@ AdjustFn = Callable[[TrainingContext], int] +class _FixedOrderSampler(torch.utils.data.Sampler): + """Sampler that yields precomputed dataset indices in order.""" + + def __init__(self, indices: list[int]): + self._indices = indices + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + + def _context_uses_cuda(context: TrainingContext) -> bool: """Return True if the training context targets a CUDA device.""" device = getattr(context, "device", None) @@ -40,6 +53,54 @@ def _resolve_pin_memory(setting: bool | None, context: TrainingContext) -> bool: return _context_uses_cuda(context) +def _local_step_stream_start( + context: TrainingContext, samples_per_round: int, stream_length: int +) -> int: + """Return the deterministic stream offset for this local-step round.""" + current_round = int(getattr(context, "current_round", 0) or 0) + if current_round > 0: + return ((current_round - 1) * samples_per_round) % stream_length + + offset = int(context.state.get("_local_step_sampler_stream_offset", 0)) + context.state["_local_step_sampler_stream_offset"] = offset + samples_per_round + return offset % stream_length + + +def _apply_local_step_sampling_stream( + sampler_obj, batch_size: int, context: TrainingContext +): + """Advance deterministic samplers across short local-step rounds.""" + local_steps_per_round = context.state.get("local_steps_per_round") + if local_steps_per_round is None or sampler_obj is None: + return sampler_obj + + samples_per_round = int(local_steps_per_round) * int(batch_size) + if samples_per_round <= 0: + return sampler_obj + + try: + indices = list(iter(sampler_obj)) + except TypeError: + logging.warning( + "Sampler %s cannot be materialized for round-aware local-step " + "sampling; using it unchanged. Consecutive short local rounds may " + "replay the same sampler prefix.", + type(sampler_obj), + ) + return sampler_obj + + if len(indices) == 0: + return sampler_obj + + start = _local_step_stream_start(context, samples_per_round, len(indices)) + if start == 0: + ordered_indices = indices + else: + ordered_indices = indices[start:] + indices[:start] + + return _FixedOrderSampler(ordered_indices) + + class DefaultDataLoaderStrategy(DataLoaderStrategy): """ Default data loader strategy. @@ -100,6 +161,10 @@ def create_train_loader( sampler_obj = None shuffle = self.shuffle + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + if sampler is None and not shuffle: logging.warning( "Data loader strategy received no sampler; falling back to SequentialSampler." @@ -174,6 +239,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -239,6 +308,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, @@ -320,6 +393,10 @@ def create_train_loader( sampler_obj = None shuffle = False + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, actual_batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=actual_batch_size, @@ -383,6 +460,10 @@ def create_train_loader( sampler_obj = None shuffle = True + sampler_obj = _apply_local_step_sampling_stream( + sampler_obj, batch_size, context + ) + return torch.utils.data.DataLoader( dataset=trainset, batch_size=batch_size, diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index d5daeb4b4..441ab7895 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -8,7 +8,7 @@ import pytest import torch import torch.nn as nn -from torch.utils.data import TensorDataset +from torch.utils.data import SubsetRandomSampler, TensorDataset from plato.callbacks.trainer import TrainerCallback from plato.config import Config @@ -184,6 +184,19 @@ def test_multiple_epochs(self, simple_model, simple_dataset): class TestComposableTrainerLocalSteps: """Test local optimizer-step limits for DiLoCo-style local work.""" + class DeterministicPlatoSampler: + def __init__(self, indices, seed=47): + self.indices = list(indices) + self.seed = seed + + def get(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + return SubsetRandomSampler(self.indices, generator=generator) + + def num_samples(self): + return len(self.indices) + class CountingCallback(TrainerCallback): def __init__(self): self.train_run_end_called = False @@ -230,6 +243,33 @@ def training_step( context=context, ) + class RecordingStepStrategy(DefaultTrainingStepStrategy): + def __init__(self): + super().__init__() + self.samples_by_round = {} + + def training_step( + self, + model, + optimizer, + examples, + labels, + loss_criterion, + context, + ): + sample_ids = examples[:, 0].detach().cpu().int().tolist() + self.samples_by_round.setdefault(context.current_round, []).extend( + sample_ids + ) + return super().training_step( + model=model, + optimizer=optimizer, + examples=examples, + labels=labels, + loss_criterion=loss_criterion, + context=context, + ) + class DelayedOptimizerStepStrategy(TrainingStepStrategy): def __init__(self, accumulation_steps=2, finalize_steps=True): self.accumulation_steps = accumulation_steps @@ -420,6 +460,52 @@ def test_epoch_behavior_is_unchanged_when_local_steps_unset( assert trainer.current_epoch == 2 assert len(trainer.run_history.get_metric_values("train_loss")) == 2 + def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( + self, simple_model, simple_config + ): + dataset_size = 10 + features = torch.arange(dataset_size, dtype=torch.float32).view(-1, 1) + features = features.repeat(1, 10) + labels = torch.arange(dataset_size) % 2 + dataset = TensorDataset(features, labels) + config = { + **simple_config, + "batch_size": 1, + "epochs": 1, + "local_steps_per_round": 3, + } + sampler = self.DeterministicPlatoSampler(range(dataset_size)) + step_strategy = self.RecordingStepStrategy() + trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=step_strategy, + ) + trainer.set_client_id(2) + + for round_number in (1, 2): + trainer.current_round = round_number + trainer.train_model(config, dataset, sampler) + + round_one_samples = step_strategy.samples_by_round[1] + round_two_samples = step_strategy.samples_by_round[2] + + assert len(round_one_samples) == config["local_steps_per_round"] + assert len(round_two_samples) == config["local_steps_per_round"] + assert round_one_samples != round_two_samples + + repeat_step_strategy = self.RecordingStepStrategy() + repeat_trainer = ComposableTrainer( + model=simple_model, + training_step_strategy=repeat_step_strategy, + ) + repeat_trainer.set_client_id(2) + + for round_number in (1, 2): + repeat_trainer.current_round = round_number + repeat_trainer.train_model(config, dataset, sampler) + + assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From f6e81965a1065b296a243322cf19b2773e688a33 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:11:30 -0400 Subject: [PATCH 25/39] Handled non-materializable local-step samplers. DT-429 review found that the round-aware local-step sampler wrapper only treated TypeError as an unsupported materialization path. Samplers that raise NotImplementedError during iteration should also warn and fall back unchanged instead of failing while setting up the data loader. This patch catches NotImplementedError in the same warning/fallback path and adds regression coverage with a non-materializable sampler to verify the warning and unchanged sampler handoff. Validation: uv run pytest tests/trainers/test_composable_trainer.py -k "non_materializable or local_steps" -q; uv run pytest tests/trainers -k "local_steps or data_loader or sampler" -q; uv run pytest tests/samplers -q; uv run ruff check plato/trainers/strategies/data_loader.py tests/trainers/test_composable_trainer.py --select I; git diff --check. --- plato/trainers/strategies/data_loader.py | 2 +- tests/trainers/test_composable_trainer.py | 30 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index f65bbf3f3..9d9c5dc0e 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -80,7 +80,7 @@ def _apply_local_step_sampling_stream( try: indices = list(iter(sampler_obj)) - except TypeError: + except (TypeError, NotImplementedError): logging.warning( "Sampler %s cannot be materialized for round-aware local-step " "sampling; using it unchanged. Consecutive short local rounds may " diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index 441ab7895..a2063c933 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -5,6 +5,8 @@ it works correctly in end-to-end training scenarios. """ +import logging + import pytest import torch import torch.nn as nn @@ -197,6 +199,13 @@ def get(self): def num_samples(self): return len(self.indices) + class NonMaterializableSampler(torch.utils.data.Sampler): + def __iter__(self): + raise NotImplementedError("This sampler cannot be materialized.") + + def __len__(self): + return 10 + class CountingCallback(TrainerCallback): def __init__(self): self.train_run_end_called = False @@ -506,6 +515,27 @@ def test_local_steps_do_not_replay_same_deterministic_sampler_prefix( assert repeat_step_strategy.samples_by_round == step_strategy.samples_by_round + def test_local_step_sampling_warns_for_non_materializable_sampler( + self, simple_dataset, caplog + ): + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + sampler = self.NonMaterializableSampler() + + with caplog.at_level(logging.WARNING): + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + sampler, + batch_size=1, + context=context, + ) + + assert loader.sampler is sampler + assert ( + "cannot be materialized for round-aware local-step sampling" + in caplog.text + ) + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From 711fdb11691780d1a5e65394c99c4bd7457fca10 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:18:33 -0400 Subject: [PATCH 26/39] Added exact DiLoCo smoke configuration. DT-424 adds a small MNIST/LeNet DiLoCo smoke config that uses the faithful configuration contract: server.type=diloco, algorithm.type=fedavg, local_steps_per_round=2, preserve_optimizer_state=true, AdamW inner optimizer, Nesterov outer optimizer, uniform weighting, and parameter-only outer updates. The docs now explain how to run the smoke config, distinguish algorithm mechanics from reproducing the paper C4/model/pretraining setup, and document H semantics, mid-epoch stopping, round-aware small-H sampling, local-only optimizer and scheduler state, FedAvg equivalence conditions, and the parameter/buffer policy. The integration smoke test loads the real config, verifies the contract values, and checks that the server registry selects the DiLoCo server and DiLoCo aggregation strategy. Validation: uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run ruff check . --select I; git diff --check. --- configs/MNIST/diloco_lenet5_smoke.toml | 68 +++++++++++++++++++++++++ docs/docs/configurations/server.md | 31 +++++++++++ docs/docs/configurations/trainer.md | 14 +++++ docs/docs/development/diloco.md | 25 +++++++-- tests/integration/test_smoke_configs.py | 36 ++++++++++++- tests/integration/utils.py | 30 +++++++++++ 6 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 configs/MNIST/diloco_lenet5_smoke.toml diff --git a/configs/MNIST/diloco_lenet5_smoke.toml b/configs/MNIST/diloco_lenet5_smoke.toml new file mode 100644 index 000000000..0c17c2ebc --- /dev/null +++ b/configs/MNIST/diloco_lenet5_smoke.toml @@ -0,0 +1,68 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8000 +random_seed = 1 +simulate_wall_time = true +do_test = false + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +include = "mnist_iid.toml" +partition_size = 16 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 2 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_name = "lenet5" + +# DiLoCo local work H, counted in optimizer steps. +local_steps_per_round = 2 +preserve_optimizer_state = true + +epochs = 1 +batch_size = 4 +optimizer = "AdamW" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.001 +weight_decay = 0.0 diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 2bb800237..42d0bad99 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -8,6 +8,7 @@ - `fedavg_personalized` a Federated Averaging server that supports all-purpose personalized federated learning by controlling when and which group of clients are to perform local personalization. - `fedavg_mpc_additive` a Federated Averaging server that reconstructs additive MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_additive` processor. - `fedavg_mpc_shamir` a Federated Averaging server that reconstructs Shamir MPC shares before aggregation. Requires clients of type `mpc` with the `mpc_model_encrypt_shamir` processor. + - `diloco` a FedAvg-compatible server that applies DiLoCo outer aggregation. Use it with `algorithm.type = "fedavg"` and configure the outer optimizer under `[server.diloco]`. - `split_learning` a Split Learning server that supports training different kinds of models in split learning framework. When this server is used, the `clients.per_round` in the configuration should be set to 1. Users should define the rules for updating models weights before cut from the clients to the server in the callback function `on_update_weights_before_cut`, depending on the specific model they use. - `fedavg_personalized` a personalized federated learning server that starts from a number of regular rounds of federated learning. In these regular rounds, only a subset of the total clients can be selected to perform the local update (the ratio of which is a configuration setting). After all regular rounds are completed, it starts a final round of personalization, where a selected subset of clients perform local training using their local dataset. - `pfedgraph` a personalized federated learning server that aggregates models using an inferred collaboration graph and sends per-client aggregated weights. @@ -124,6 +125,36 @@ Default value: `100` +!!! example "diloco" + Settings for `server.type = "diloco"`. DiLoCo reuses `algorithm.type = "fedavg"` for client weight extraction and global model loading, while the DiLoCo server turns client deltas into an outer-gradient update. + + ```toml + [server] + type = "diloco" + + [algorithm] + type = "fedavg" + + [server.diloco] + outer_optimizer = "nesterov" + outer_learning_rate = 0.7 + outer_momentum = 0.9 + aggregation_weighting = "uniform" + apply_outer_optimizer_to = "parameters" + ``` + + `aggregation_weighting = "uniform"` matches balanced IID worker smoke runs. `aggregation_weighting = "num_samples"` matches Plato's traditional sample-weighted FedAvg behavior. With outer SGD and `outer_learning_rate = 1.0`, uniform weighting is equivalent to uniform model averaging; with `num_samples`, it is equivalent to Plato-style sample-weighted FedAvg. + + `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. + + A runnable smoke configuration is available at `configs/MNIST/diloco_lenet5_smoke.toml`: + + ```bash + uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml + ``` + + The smoke configuration validates DiLoCo mechanics in Plato; it is not a C4/model/pretraining reproduction of the DiLoCo paper. + !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/configurations/trainer.md b/docs/docs/configurations/trainer.md index de9eb715e..05fef2f2d 100644 --- a/docs/docs/configurations/trainer.md +++ b/docs/docs/configurations/trainer.md @@ -56,6 +56,20 @@ !!! example "epochs" The total number of epochs in local training in each communication round. +!!! example "local_steps_per_round" + The DiLoCo local work value `H`, counted as completed client-local optimizer steps between synchronizations. + + `H` is not an epoch count, raw dataloader batch count, or gradient-accumulation micro-batch count. When gradient accumulation is enabled, only batches that trigger `optimizer.step()` increment `H`. + + `H` may be smaller than one epoch. In that case, local training stops mid-epoch after exactly `H` optimizer steps while still running normal trainer cleanup, callback completion, state persistence, and reporting. + + Small-`H` DiLoCo runs use round-aware sampling where supported so a logical client does not replay the same first `H` batches every round. Trainers or samplers that cannot count optimizer steps or advance the local stream faithfully must fail or warn clearly instead of silently approximating DiLoCo. + +!!! example "preserve_optimizer_state" + Whether client-local optimizer and scheduler state should persist across a logical client's local train runs. + + DiLoCo should set this to `true` with a stateful inner optimizer such as `AdamW`. Optimizer and scheduler state remains client-local and is not transmitted in client-server payloads. + !!! example "batch_size" The size of the mini-batch of data in each step (iteration) of the training loop. diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md index 94dfd49fc..b66eae9ad 100644 --- a/docs/docs/development/diloco.md +++ b/docs/docs/development/diloco.md @@ -1,14 +1,27 @@ # DiLoCo Design Contract -This note defines what Plato will call faithful DiLoCo for the initial -implementation. It is a contract for the implementation issues that follow; it -does not describe runtime behavior that already exists in Plato. +This note defines what Plato calls faithful DiLoCo in the current +implementation. Faithful DiLoCo in Plato means algorithm-faithful execution of the DiLoCo training loop inside Plato's federated runtime. It does not mean reproducing the paper's exact C4 dataset, model scale, tokenizer, hardware topology, pretraining duration, or final benchmark numbers. +## Smoke Configuration + +Plato includes a small MNIST/LeNet smoke configuration for checking the DiLoCo +mechanics: + +```bash +uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml +``` + +This smoke run validates configuration loading, DiLoCo server selection, local +optimizer-step work, client-local optimizer-state persistence, and server-side +outer aggregation. It is intentionally tiny and does not reproduce the C4 +language-model pretraining setup or the paper's reported metrics. + ## Algorithm Contract DiLoCo has two optimizer levels: @@ -132,8 +145,10 @@ rate `1.0` is valid only when both runs use the same weighting rule. Unsupported modes must fail clearly. They must not silently fall back to an approximate DiLoCo variant. Examples include trainer backends that cannot count local optimizer steps exactly, execution paths that cannot preserve -client-local optimizer and scheduler state, or payload paths that would send -optimizer state to the server. +client-local optimizer and scheduler state, samplers that cannot advance the +small-`H` local data stream across rounds, or payload paths that would send +optimizer state to the server. Experimental combinations that are allowed but +not faithful must warn clearly. ## Implementation Sequence diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 6dbc1fa08..938ec0c66 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -4,7 +4,8 @@ from __future__ import annotations -from importlib import import_module +from importlib import import_module, reload +from pathlib import Path from types import SimpleNamespace from typing import cast @@ -17,8 +18,11 @@ async_run, build_minimal_config, configure_environment, + configure_environment_from_path, ) +REPO_ROOT = Path(__file__).resolve().parents[2] + class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" @@ -97,6 +101,36 @@ def test_fedavg_lenet5_smoke(monkeypatch): assert server.accuracy >= 0 +@pytest.mark.integration +def test_diloco_lenet5_smoke_config_contract_loads(): + """Smoke config should load the faithful DiLoCo contract.""" + config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" + + with configure_environment_from_path(config_path) as config: + assert config.server.type == "diloco" + assert config.algorithm.type == "fedavg" + assert config.trainer.local_steps_per_round == 2 + assert config.trainer.preserve_optimizer_state is True + assert config.trainer.optimizer == "AdamW" + assert config.server.diloco.outer_optimizer == "nesterov" + assert config.server.diloco.outer_learning_rate == 0.7 + assert config.server.diloco.outer_momentum == 0.9 + assert config.server.diloco.aggregation_weighting == "uniform" + assert config.server.diloco.apply_outer_optimizer_to == "parameters" + + server_registry = reload(import_module("plato.servers.registry")) + diloco_server = import_module("plato.servers.diloco") + diloco_aggregation = import_module("plato.servers.strategies.aggregation") + + server = server_registry.get() + + assert isinstance(server, diloco_server.Server) + assert isinstance( + server.aggregation_strategy, + diloco_aggregation.DiLoCoAggregationStrategy, + ) + + @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 3cb4a907b..4ff610756 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -107,6 +107,36 @@ def configure_environment(config_dict: dict): Config._instance = None +@contextlib.contextmanager +def configure_environment_from_path(config_path: Path): + """ + Context manager that initialises Config singleton from an existing config. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + Config._instance = None # reset singleton + Config.params = {} + + previous_env = os.environ.get("config_file") + previous_argv = sys.argv[:] + os.environ["config_file"] = str(config_path) + sys.argv = [ + previous_argv[0] if previous_argv else "pytest", + "--base", + tmp_dir, + ] + + try: + config = Config() + yield config + finally: + if previous_env is None: + os.environ.pop("config_file", None) + else: + os.environ["config_file"] = previous_env + sys.argv = previous_argv + Config._instance = None + + def async_run(coro): """Utility to execute the coroutine using asyncio.run (Python 3.7+).""" return asyncio.run(coro) From 082aaf1aa8980ce96ce0e5df0bf9498735a5e62b Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 14:28:21 -0400 Subject: [PATCH 27/39] Added end-to-end DiLoCo validation coverage. DT-426 adds final integration coverage for the faithful DiLoCo path using the exact MNIST smoke config. The test builds the configured DiLoCo server and simple client, runs local training with local_steps_per_round=2 and preserved optimizer state, verifies the outbound client payload remains model weights only, and processes deterministic server updates through the DiLoCo delta aggregation path. The validation would fail if the config selected ordinary FedAvg server aggregation, if local step control were ignored, or if the server bypassed aggregate_deltas. It directly checks the Nesterov outer update differs from ordinary FedAvg averaging, while relying on reviewed lower-level tests for small-H mid-epoch stopping, round-aware sampler non-replay, scheduler sidecar persistence, and broader payload leak coverage. Validation: uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run pytest tests/servers/test_diloco_strategy.py -q; uv run pytest tests/trainers -k "local_steps or optimizer_state or scheduler_state or data_loader" -q; uv run pytest tests/clients -k "payload or simple" -q; uv run ruff check . --select I; git diff --check. --- tests/integration/test_smoke_configs.py | 164 +++++++++++++++++++++++- 1 file changed, 160 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 938ec0c66..8dd1271a4 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -28,13 +28,14 @@ class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" def __init__(self, train_size: int = 4, test_size: int = 2): + generator = torch.Generator().manual_seed(13) self._train = TensorDataset( - torch.randn(train_size, 1, 28, 28), - torch.randint(0, 10, (train_size,)), + torch.randn(train_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (train_size,), generator=generator), ) self._test = TensorDataset( - torch.randn(test_size, 1, 28, 28), - torch.randint(0, 10, (test_size,)), + torch.randn(test_size, 1, 28, 28, generator=generator), + torch.randint(0, 10, (test_size,), generator=generator), ) def num_train_examples(self): @@ -47,6 +48,56 @@ def get_test_set(self): return self._test +LOCAL_STATE_PAYLOAD_KEYS = { + "optimizer_state", + "scheduler_state", + "trainer_state", + "local_metadata", + "metadata", + "global_step", + "local_optimizer_steps", + "_optimizer_state_input_filename", + "_optimizer_state_output_filename", +} + + +def _model_weight_payload(payload, model): + """Assert that a client payload contains model weights only.""" + model_state = model.state_dict() + + assert isinstance(payload, dict) + assert set(payload) == set(model_state) + assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) + assert all(torch.is_tensor(value) for value in payload.values()) + + +def _shifted_payload(weights, amount): + """Build a fake client model payload shifted from the server baseline.""" + shifted = {} + for name, value in weights.items(): + shifted[name] = value.clone() + if torch.is_floating_point(shifted[name]): + shifted[name] = shifted[name] + amount + return shifted + + +def _server_update(client_id, payload): + """Build a minimal server update carrying model weights.""" + return SimpleNamespace( + client_id=client_id, + report=SimpleNamespace( + client_id=client_id, + num_samples=1, + accuracy=0.5, + processing_time=0.1, + comm_time=0.1, + training_time=0.1, + type="weights", + ), + payload=payload, + ) + + @pytest.mark.integration def test_fedavg_lenet5_smoke(monkeypatch): """End-to-end smoke test for a minimal FedAvg run.""" @@ -131,6 +182,111 @@ def test_diloco_lenet5_smoke_config_contract_loads(): ) +@pytest.mark.integration +def test_diloco_lenet5_smoke_config_runs_faithful_path(monkeypatch): + """Exact DiLoCo smoke config exercises local work and outer aggregation.""" + config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" + + with configure_environment_from_path(config_path) as config: + datasources_registry = import_module("plato.datasources.registry") + processor_registry = import_module("plato.processors.registry") + server_registry = reload(import_module("plato.servers.registry")) + client_mod = import_module("plato.clients.simple") + config_mod = import_module("plato.config") + diloco_server = import_module("plato.servers.diloco") + fedavg_algorithm = import_module("plato.algorithms.fedavg") + diloco_aggregation = import_module("plato.servers.strategies.aggregation") + + fake_datasource = MNISTSmokeDatasource(train_size=32, test_size=4) + monkeypatch.setattr( + datasources_registry, + "get", + lambda *args, **kwargs: fake_datasource, + ) + monkeypatch.setattr( + processor_registry, + "get", + lambda *args, **kwargs: (None, None), + ) + + server = server_registry.get() + server.configure() + + assert isinstance(server, diloco_server.Server) + assert isinstance(server.algorithm, fedavg_algorithm.Algorithm) + assert isinstance( + server.aggregation_strategy, + diloco_aggregation.DiLoCoAggregationStrategy, + ) + assert config.server.type == "diloco" + assert config.algorithm.type == "fedavg" + assert config.trainer.local_steps_per_round == 2 + assert config.trainer.preserve_optimizer_state is True + assert config.data.sampler == "iid" + + client = client_mod.Client() + client.client_id = 1 + client._context.client_id = 1 + client.current_round = 1 + client._context.current_round = 1 + client._load_data() + client.configure() + client._allocate_data() + client._load_payload(server.algorithm.extract_weights()) + + train_config = config.trainer._asdict() + train_config["run_id"] = config_mod.Config.params["run_id"] + client.trainer.current_round = client.current_round + client.trainer.train_model(train_config, client.trainset, client.sampler) + payload = client.algorithm.extract_weights() + + assert client.sampler.num_samples() == config.data.partition_size + assert client.trainer.context.state["local_steps_per_round"] == 2 + assert client.trainer.context.state["local_optimizer_steps"] == 2 + assert client.trainer.current_epoch == 1 + assert client.client_id in client.trainer._preserved_optimizer_states + assert client.trainer._preserved_optimizer_states[client.client_id][ + "optimizer_state" + ]["state"] + _model_weight_payload(payload, client.trainer.model) + + # Small-H mid-epoch stopping and round-aware sampler streaming are covered + # in TestComposableTrainerLocalSteps; this integration path verifies the + # exact smoke config enables those runtime flags with the supported sampler. + baseline = server.algorithm.extract_weights() + trainable_name = next(iter(dict(server.trainer.model.named_parameters()))) + server.updates = [ + _server_update(1, _shifted_payload(baseline, 1.0)), + _server_update(2, _shifted_payload(baseline, 3.0)), + ] + server.current_round = 1 + server.context.current_round = 1 + + delta_calls = [] + aggregate_deltas = server.aggregation_strategy.aggregate_deltas + + async def record_delta_aggregation(updates, deltas_received, context): + delta_calls.append((updates, deltas_received)) + return await aggregate_deltas(updates, deltas_received, context) + + monkeypatch.setattr( + server.aggregation_strategy, + "aggregate_deltas", + record_delta_aggregation, + ) + + async_run(server._process_reports()) + + updated = server.algorithm.extract_weights() + ordinary_fedavg_value = baseline[trainable_name] + 2.0 + faithful_diloco_value = baseline[trainable_name] + 2.0 * 1.9 * 0.7 + + assert len(delta_calls) == 1 + assert len(delta_calls[0][1]) == 2 + assert not torch.allclose(updated[trainable_name], ordinary_fedavg_value) + assert torch.allclose(updated[trainable_name], faithful_diloco_value) + + @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" From 1313d30897ee93f0189322367610b65590518a53 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Wed, 29 Apr 2026 15:09:36 -0400 Subject: [PATCH 28/39] Restored optimizer state after moving models to device. Preserved AdamW state was loaded before ComposableTrainer moved the model to the trainer device. On GPU, PyTorch therefore mapped restored optimizer tensors to CPU and later optimizer.step() saw CUDA parameters with CPU Adam state, producing mixed-device runtime errors in later rounds. Move the model to the trainer device before optimizer construction and preserved-state restore so optimizer.load_state_dict() maps state tensors onto the same device as the optimizer parameters. Add a regression test that fails if preserved optimizer state is restored before model.to(). Validation: uv run pytest tests/trainers/test_composable_optimizer_state.py -k "restores_after_model_moves_to_device or optimizer_state or scheduler_state" -q; uv run pytest tests/trainers -k "local_steps or optimizer_state or scheduler_state or data_loader" -q; uv run pytest tests/integration/test_smoke_configs.py -k diloco -q; uv run pytest tests/clients -k "payload or simple" -q; uv run ruff check . --select I; git diff --check. --- plato/trainers/composable.py | 12 +-- .../test_composable_optimizer_state.py | 78 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 78ceb07bf..cbfb4dd95 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -783,8 +783,13 @@ def train_model(self, config, trainset, sampler, **kwargs): self.context.state["grad_accum_loss_total"] = 0.0 self.context.state["grad_accum_loss_count"] = 0 - # Create optimizer using strategy + # Move the model before optimizer state restore so PyTorch maps restored + # state tensors onto the same device as the optimizer parameters. model = self._require_model() + model.to(self.device) + model.train() + + # Create optimizer using strategy self.optimizer = self.optimizer_strategy.create_optimizer(model, self.context) # Create LR scheduler using strategy @@ -794,11 +799,6 @@ def train_model(self, config, trainset, sampler, **kwargs): if preserve_optimizer_state: self._restore_preserved_optimizer_state() - # Move model to device - model = self._require_model() - model.to(self.device) - model.train() - # Training epochs total_epochs = config["epochs"] tic = time.perf_counter() diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index 3620c6248..b005981f0 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -21,6 +21,7 @@ SGDOptimizerStrategy, StepLRSchedulerStrategy, ) +from plato.trainers.strategies.base import OptimizerStrategy, TrainingContext LOCAL_STATE_PAYLOAD_KEYS = { "optimizer_state", @@ -96,6 +97,51 @@ def _linear_model(): return nn.Sequential(OrderedDict([("linear", nn.Linear(2, 2))])) +class DeviceTrackingModel(nn.Module): + """Model that records whether it has been moved to a trainer device.""" + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.moved_to_trainer_device = False + + def forward(self, features): + return self.linear(features) + + def to(self, *args, **kwargs): + self.moved_to_trainer_device = True + return super().to(*args, **kwargs) + + +class RestoreOrderOptimizer(torch.optim.SGD): + """Optimizer that records whether state restore happens after model.to().""" + + def __init__(self, params, model: DeviceTrackingModel): + self.model = model + self.loaded_after_model_to = None + super().__init__(params, lr=0.01, momentum=0.9) + + def load_state_dict(self, state_dict): + self.loaded_after_model_to = self.model.moved_to_trainer_device + if not self.loaded_after_model_to: + raise AssertionError("optimizer state restored before model.to()") + return super().load_state_dict(state_dict) + + +class RestoreOrderOptimizerStrategy(OptimizerStrategy): + """Create restore-order-aware optimizers for regression tests.""" + + def __init__(self): + self.optimizers = [] + + def create_optimizer( + self, model: DeviceTrackingModel, context: TrainingContext + ) -> torch.optim.Optimizer: + optimizer = RestoreOrderOptimizer(model.parameters(), model) + self.optimizers.append(optimizer) + return optimizer + + def _two_layer_model(first_name="first", second_name="second"): return nn.Sequential( OrderedDict( @@ -184,6 +230,38 @@ def test_adamw_moment_buffers_persist_between_rounds_for_same_client( assert _state_step(final_param_state) == 2 +def test_preserved_optimizer_state_restores_after_model_moves_to_device( + temp_config, tiny_dataset, one_step_config +): + config = {**one_step_config, "preserve_optimizer_state": True} + source_trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=RestoreOrderOptimizerStrategy(), + ) + source_trainer.set_client_id(11) + source_trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + restore_strategy = RestoreOrderOptimizerStrategy() + trainer = ComposableTrainer( + model=DeviceTrackingModel, + loss_strategy=CrossEntropyLossStrategy(), + optimizer_strategy=restore_strategy, + ) + trainer.set_client_id(11) + trainer._preserved_optimizer_states[11] = copy.deepcopy( + source_trainer._preserved_optimizer_states[11] + ) + + trainer.train_model(config, tiny_dataset, list(range(len(tiny_dataset)))) + + assert restore_strategy.optimizers[0].loaded_after_model_to is True + restored_state = _first_param_state( + trainer._preserved_optimizer_states[11]["optimizer_state"]["state"] + ) + assert "momentum_buffer" in restored_state + + def test_scheduler_state_and_lr_progress_persist_between_rounds( temp_config, tiny_dataset, one_step_config ): From f359d789633c8520e9d10c9d387e180925ce96ca Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 02:07:40 -0400 Subject: [PATCH 29/39] Logged DiLoCo outer optimizer application. Emit a server-side info log each time DiLoCo applies the configured outer optimizer to averaged client deltas. The log includes optimizer settings, aggregation weighting, apply policy, eligible update count, and optimized tensor count so runs show where the server update occurs. Validation: - uv run pytest tests/servers/test_diloco_strategy.py -q - uv run pytest tests/integration/test_smoke_configs.py -k diloco -q - uv run ruff check plato/servers/strategies/aggregation/diloco.py tests/servers/test_diloco_strategy.py --select I - git diff --check Co-authored-by: Codex --- .../servers/strategies/aggregation/diloco.py | 13 +++++++ tests/servers/test_diloco_strategy.py | 36 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index 6504990cf..ddc2428fc 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -10,6 +10,7 @@ import asyncio import copy +import logging import numbers from collections.abc import Callable, Mapping from types import SimpleNamespace @@ -86,6 +87,18 @@ async def aggregate_deltas( server_delta, active_paths = self._apply_outer_optimizer( avg_delta, optimizer_paths ) + logging.info( + "[Server] DiLoCo outer optimizer applied: optimizer=%s " + "outer_lr=%g outer_momentum=%g weighting=%s apply_to=%s " + "eligible_updates=%d optimized_tensors=%d.", + self.outer_optimizer, + self.outer_learning_rate, + self.outer_momentum, + self.aggregation_weighting, + self.apply_outer_optimizer_to, + len(eligible), + len(optimizer_paths), + ) self._remove_stale_momentum(active_paths) return self._match_reference_structure(server_delta, eligible[0][1]) diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index f35467344..4c739051d 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -1,6 +1,7 @@ """Tests for DiLoCo server-side outer aggregation.""" import asyncio +import logging from types import SimpleNamespace import pytest @@ -347,6 +348,41 @@ def test_sgd_uses_diloco_outer_gradient_sign(temp_config): assert torch.allclose(server_delta["w"], torch.tensor([1.0])) +def test_outer_optimizer_application_is_logged(temp_config, caplog): + """A DiLoCo aggregation should report the server-side outer optimizer.""" + strategy = DiLoCoAggregationStrategy( + outer_optimizer="nesterov", + outer_learning_rate=0.7, + outer_momentum=0.9, + aggregation_weighting="uniform", + apply_outer_optimizer_to="parameters", + ) + model = torch.nn.Linear(1, 1, bias=False) + baseline = {name: tensor.clone() for name, tensor in model.state_dict().items()} + + with caplog.at_level(logging.INFO): + _aggregate( + strategy, + [_update(1), _update(1)], + [ + {"weight": torch.tensor([[2.0]])}, + {"weight": torch.tensor([[4.0]])}, + ], + baseline, + model, + ) + + message = caplog.text + assert "DiLoCo outer optimizer applied" in message + assert "optimizer=nesterov" in message + assert "outer_lr=0.7" in message + assert "outer_momentum=0.9" in message + assert "weighting=uniform" in message + assert "apply_to=parameters" in message + assert "eligible_updates=2" in message + assert "optimized_tensors=1" in message + + def test_uniform_weighting_ignores_positive_sample_count_magnitude(temp_config): """Uniform mode should weight eligible clients equally.""" strategy = DiLoCoAggregationStrategy( From c365751e1bd160e0815f4b8bf5abd4e6ce6de28a Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 12:11:20 -0400 Subject: [PATCH 30/39] Added DiLoCo comparison configs and step-based scheduling. --- configs/CIFAR10/diloco_resnet18.toml | 79 ++++++++++++++++++ .../fedavg_resnet18_diloco_comparison.toml | 67 ++++++++++++++++ plato/trainers/composable.py | 25 +++++- plato/trainers/strategies/data_loader.py | 26 +++++- tests/trainers/test_composable_trainer.py | 80 +++++++++++++++++++ 5 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 configs/CIFAR10/diloco_resnet18.toml create mode 100644 configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml new file mode 100644 index 000000000..70193dee5 --- /dev/null +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -0,0 +1,79 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8021 + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.8 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 +preserve_optimizer_state = true + +# DiLoCo paper inner-optimizer settings. +epochs = 250 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Weight extraction and model update path reused by DiLoCo. +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml new file mode 100644 index 000000000..329572e8d --- /dev/null +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -0,0 +1,67 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8022 + +[data] + +# The training and testing dataset +datasource = "Torchvision" +dataset_name = "CIFAR10" +download = true + +# Number of samples in each partition +partition_size = 1000 + +# IID or non-IID? +sampler = "iid" + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.8 + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +epochs = 5 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +# The machine learning model +model_name = "resnet_18" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index cbfb4dd95..fc6feb886 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -206,6 +206,21 @@ def _preserve_optimizer_state(config: dict[str, Any]) -> bool: """Return whether optimizer state should survive local train runs.""" return bool(config.get("preserve_optimizer_state", False)) + @staticmethod + def _step_lr_scheduler_per_optimizer_step(config: dict[str, Any]) -> bool: + """Return whether LR scheduling should follow optimizer steps.""" + if config.get("local_steps_per_round") is None: + return False + + return getattr(Config().server, "type", None) == "diloco" + + def _step_lr_scheduler_after_optimizer_step( + self, step_lr_per_optimizer_step: bool + ) -> None: + """Advance step-based LR schedules after one completed optimizer step.""" + if step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + @staticmethod def _parameter_signature(name: str | None, parameter: torch.Tensor): """Build a compatibility signature for one model parameter.""" @@ -801,6 +816,7 @@ def train_model(self, config, trainset, sampler, **kwargs): # Training epochs total_epochs = config["epochs"] + step_lr_per_optimizer_step = self._step_lr_scheduler_per_optimizer_step(config) tic = time.perf_counter() training_stop_requested = False local_step_limit_reached = False @@ -874,6 +890,9 @@ def compute_loss(outputs, labels_inner): self.optimizer_strategy.on_optimizer_step( self.optimizer, self.context ) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) local_step_limit_reached = self._record_local_optimizer_step( local_steps_per_round ) @@ -930,6 +949,9 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) + self._step_lr_scheduler_after_optimizer_step( + step_lr_per_optimizer_step + ) local_step_limit_reached = self._record_local_optimizer_step( local_steps_per_round ) @@ -979,7 +1001,8 @@ def compute_loss(outputs, labels_inner): self.context.state.pop("hf_optimizer_step_index", None) # LR scheduler step - self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) + if not step_lr_per_optimizer_step: + self.lr_scheduler_strategy.step(self.lr_scheduler, self.context) # Handle optimizer params state update if needed if hasattr(self.optimizer, "params_state_update"): diff --git a/plato/trainers/strategies/data_loader.py b/plato/trainers/strategies/data_loader.py index 9d9c5dc0e..c934e48b4 100644 --- a/plato/trainers/strategies/data_loader.py +++ b/plato/trainers/strategies/data_loader.py @@ -14,6 +14,7 @@ import torch import torch.utils.data +from plato.config import Config from plato.trainers.strategies.base import DataLoaderStrategy, TrainingContext CollateFn = Callable[[list[Any]], Any] @@ -66,12 +67,35 @@ def _local_step_stream_start( return offset % stream_length +def _enforce_diloco_full_participation_for_local_steps() -> None: + """Require DiLoCo workers to train once per outer synchronization.""" + server_type = getattr(Config().server, "type", None) + if server_type != "diloco": + return + + total_clients = int(Config().clients.total_clients) + clients_per_round = int(Config().clients.per_round) + if clients_per_round == total_clients: + return + + raise ValueError( + "DiLoCo local-step data loading requires clients.per_round to equal " + "clients.total_clients so every worker advances its local data stream " + "once per outer round." + ) + + def _apply_local_step_sampling_stream( sampler_obj, batch_size: int, context: TrainingContext ): """Advance deterministic samplers across short local-step rounds.""" local_steps_per_round = context.state.get("local_steps_per_round") - if local_steps_per_round is None or sampler_obj is None: + if local_steps_per_round is None: + return sampler_obj + + _enforce_diloco_full_participation_for_local_steps() + + if sampler_obj is None: return sampler_obj samples_per_round = int(local_steps_per_round) * int(batch_size) diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index a2063c933..fd6b4a668 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -24,6 +24,7 @@ GradientAccumulationStepStrategy, NoOpUpdateStrategy, NoSchedulerStrategy, + StepLRSchedulerStrategy, ) from plato.trainers.strategies.base import ( LossCriterionStrategy, @@ -536,6 +537,85 @@ def test_local_step_sampling_warns_for_non_materializable_sampler( in caplog.text ) + def test_diloco_local_steps_require_full_client_participation( + self, simple_dataset, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + with pytest.raises( + ValueError, match="clients\\.per_round.*clients\\.total_clients" + ): + DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + def test_partial_participation_still_allowed_without_diloco_local_steps( + self, simple_dataset, temp_config + ): + Config().server.type = "fedavg" + Config().clients.total_clients = 4 + Config().clients.per_round = 2 + context = TrainingContext() + context.state["local_steps_per_round"] = 2 + + loader = DefaultDataLoaderStrategy().create_train_loader( + simple_dataset, + list(range(len(simple_dataset))), + batch_size=1, + context=context, + ) + + assert len(loader.sampler) == len(simple_dataset) + + def test_diloco_local_steps_advance_lr_scheduler_per_optimizer_step( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "diloco" + Config().clients.total_clients = 4 + Config().clients.per_round = 4 + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 3 + + def test_non_diloco_local_steps_keep_epoch_based_lr_scheduler( + self, simple_model, simple_dataset, simple_config, temp_config + ): + Config().server.type = "fedavg" + config = { + **simple_config, + "batch_size": 1, + "epochs": 3, + "local_steps_per_round": 3, + } + trainer = ComposableTrainer( + model=simple_model, + lr_scheduler_strategy=StepLRSchedulerStrategy(step_size=1, gamma=0.5), + ) + + trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + + assert trainer.context.state["local_optimizer_steps"] == 3 + assert trainer.lr_scheduler.last_epoch == 1 + @pytest.mark.parametrize("local_steps_per_round", [0, -1, 1.5, "2", True]) def test_invalid_local_steps_fail_clearly( self, simple_model, simple_dataset, simple_config, local_steps_per_round From bcc8073728b5ab0e03ed4f28438c1d6616f43505 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 13:28:09 -0400 Subject: [PATCH 31/39] Added MNIST DiLoCo comparison configs. --- ...o_lenet5_smoke.toml => diloco_lenet5.toml} | 33 +-- .../fedavg_lenet5_diloco_comparison.toml | 63 ++++++ docs/docs/configurations/server.md | 7 +- docs/docs/development/diloco.md | 18 +- tests/integration/test_smoke_configs.py | 192 +----------------- 5 files changed, 98 insertions(+), 215 deletions(-) rename configs/MNIST/{diloco_lenet5_smoke.toml => diloco_lenet5.toml} (68%) create mode 100644 configs/MNIST/fedavg_lenet5_diloco_comparison.toml diff --git a/configs/MNIST/diloco_lenet5_smoke.toml b/configs/MNIST/diloco_lenet5.toml similarity index 68% rename from configs/MNIST/diloco_lenet5_smoke.toml rename to configs/MNIST/diloco_lenet5.toml index 0c17c2ebc..1a6e68ee0 100644 --- a/configs/MNIST/diloco_lenet5_smoke.toml +++ b/configs/MNIST/diloco_lenet5.toml @@ -4,10 +4,10 @@ type = "simple" # The total number of clients -total_clients = 2 +total_clients = 50 # The number of clients selected in each round -per_round = 2 +per_round = 50 # Should the clients compute test accuracy locally? do_test = false @@ -15,10 +15,9 @@ do_test = false [server] type = "diloco" address = "127.0.0.1" -port = 8000 +port = 8001 random_seed = 1 simulate_wall_time = true -do_test = false [server.diloco] outer_optimizer = "nesterov" @@ -29,7 +28,7 @@ apply_outer_optimizer_to = "parameters" [data] include = "mnist_iid.toml" -partition_size = 16 +partition_size = 1000 [trainer] @@ -37,21 +36,26 @@ partition_size = 16 type = "basic" # The maximum number of training rounds -rounds = 2 +rounds = 20 # The maximum number of clients running concurrently -max_concurrency = 1 +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.97 # The machine learning model model_name = "lenet5" -# DiLoCo local work H, counted in optimizer steps. -local_steps_per_round = 2 +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 500 preserve_optimizer_state = true -epochs = 1 -batch_size = 4 +# DiLoCo paper inner-optimizer settings. +epochs = 250 +batch_size = 512 optimizer = "AdamW" +lr_scheduler = "LambdaLR" [algorithm] @@ -64,5 +68,8 @@ type = "fedavg" num_classes = 10 [parameters.optimizer] -lr = 0.001 -weight_decay = 0.0 +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml new file mode 100644 index 000000000..ecbd1c544 --- /dev/null +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -0,0 +1,63 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 50 + +# The number of clients selected in each round +per_round = 50 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8002 +random_seed = 1 +simulate_wall_time = true + +[data] +include = "mnist_iid.toml" +partition_size = 1000 + +[trainer] + +# The type of the trainer +type = "basic" + +# The maximum number of training rounds +rounds = 20 + +# The maximum number of clients running concurrently +max_concurrency = 7 + +# The target accuracy +target_accuracy = 0.97 + +# The machine learning model +model_name = "lenet5" + +# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +epochs = 5 +batch_size = 512 +optimizer = "AdamW" +lr_scheduler = "LambdaLR" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.model] +num_classes = 10 + +[parameters.optimizer] +lr = 0.0004 +weight_decay = 0.1 + +[parameters.learning_rate] +warmup_steps = "1000it" diff --git a/docs/docs/configurations/server.md b/docs/docs/configurations/server.md index 42d0bad99..cef578eb9 100644 --- a/docs/docs/configurations/server.md +++ b/docs/docs/configurations/server.md @@ -147,13 +147,14 @@ `apply_outer_optimizer_to = "parameters"` applies the outer optimizer only to trainable floating parameters. Floating buffers are synchronized with the selected averaging rule but do not receive outer momentum. `apply_outer_optimizer_to = "all_floating"` is available for experiments that also apply the outer optimizer to floating buffers. - A runnable smoke configuration is available at `configs/MNIST/diloco_lenet5_smoke.toml`: + Runnable comparison configurations are available for MNIST/LeNet and CIFAR-10/ResNet-18: ```bash - uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml + uv run python plato.py --config configs/MNIST/diloco_lenet5.toml + uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml ``` - The smoke configuration validates DiLoCo mechanics in Plato; it is not a C4/model/pretraining reproduction of the DiLoCo paper. + These configurations validate DiLoCo mechanics in Plato; they are not C4/model/pretraining reproductions of the DiLoCo paper. !!! example "edge_downlink_bandwidth" The edge server's estimated downlink capacity (an edge server to its clients) in Mbps, used for computing the transmission time (see `compute_comm_time` in the `clients` section). diff --git a/docs/docs/development/diloco.md b/docs/docs/development/diloco.md index b66eae9ad..f4b8ce405 100644 --- a/docs/docs/development/diloco.md +++ b/docs/docs/development/diloco.md @@ -8,19 +8,21 @@ training loop inside Plato's federated runtime. It does not mean reproducing the paper's exact C4 dataset, model scale, tokenizer, hardware topology, pretraining duration, or final benchmark numbers. -## Smoke Configuration +## Example Configurations -Plato includes a small MNIST/LeNet smoke configuration for checking the DiLoCo -mechanics: +Plato includes MNIST/LeNet and CIFAR-10/ResNet-18 comparison configurations +for checking DiLoCo against matched FedAvg runs: ```bash -uv run python plato.py --config configs/MNIST/diloco_lenet5_smoke.toml +uv run python plato.py --config configs/MNIST/diloco_lenet5.toml +uv run python plato.py --config configs/MNIST/fedavg_lenet5_diloco_comparison.toml +uv run python plato.py --config configs/CIFAR10/diloco_resnet18.toml +uv run python plato.py --config configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml ``` -This smoke run validates configuration loading, DiLoCo server selection, local -optimizer-step work, client-local optimizer-state persistence, and server-side -outer aggregation. It is intentionally tiny and does not reproduce the C4 -language-model pretraining setup or the paper's reported metrics. +These examples validate Plato's DiLoCo mechanics without reproducing the C4 +dataset, tokenizer, language-model scale, hardware topology, pretraining +duration, or final benchmark numbers from the paper. ## Algorithm Contract diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index 8dd1271a4..e499655be 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -4,8 +4,7 @@ from __future__ import annotations -from importlib import import_module, reload -from pathlib import Path +from importlib import import_module from types import SimpleNamespace from typing import cast @@ -18,11 +17,8 @@ async_run, build_minimal_config, configure_environment, - configure_environment_from_path, ) -REPO_ROOT = Path(__file__).resolve().parents[2] - class MNISTSmokeDatasource: """Datasource returning image-shaped tensors for LeNet smoke tests.""" @@ -48,56 +44,6 @@ def get_test_set(self): return self._test -LOCAL_STATE_PAYLOAD_KEYS = { - "optimizer_state", - "scheduler_state", - "trainer_state", - "local_metadata", - "metadata", - "global_step", - "local_optimizer_steps", - "_optimizer_state_input_filename", - "_optimizer_state_output_filename", -} - - -def _model_weight_payload(payload, model): - """Assert that a client payload contains model weights only.""" - model_state = model.state_dict() - - assert isinstance(payload, dict) - assert set(payload) == set(model_state) - assert LOCAL_STATE_PAYLOAD_KEYS.isdisjoint(payload) - assert all(torch.is_tensor(value) for value in payload.values()) - - -def _shifted_payload(weights, amount): - """Build a fake client model payload shifted from the server baseline.""" - shifted = {} - for name, value in weights.items(): - shifted[name] = value.clone() - if torch.is_floating_point(shifted[name]): - shifted[name] = shifted[name] + amount - return shifted - - -def _server_update(client_id, payload): - """Build a minimal server update carrying model weights.""" - return SimpleNamespace( - client_id=client_id, - report=SimpleNamespace( - client_id=client_id, - num_samples=1, - accuracy=0.5, - processing_time=0.1, - comm_time=0.1, - training_time=0.1, - type="weights", - ), - payload=payload, - ) - - @pytest.mark.integration def test_fedavg_lenet5_smoke(monkeypatch): """End-to-end smoke test for a minimal FedAvg run.""" @@ -151,142 +97,6 @@ def test_fedavg_lenet5_smoke(monkeypatch): async_run(server._process_reports()) assert server.accuracy >= 0 - -@pytest.mark.integration -def test_diloco_lenet5_smoke_config_contract_loads(): - """Smoke config should load the faithful DiLoCo contract.""" - config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" - - with configure_environment_from_path(config_path) as config: - assert config.server.type == "diloco" - assert config.algorithm.type == "fedavg" - assert config.trainer.local_steps_per_round == 2 - assert config.trainer.preserve_optimizer_state is True - assert config.trainer.optimizer == "AdamW" - assert config.server.diloco.outer_optimizer == "nesterov" - assert config.server.diloco.outer_learning_rate == 0.7 - assert config.server.diloco.outer_momentum == 0.9 - assert config.server.diloco.aggregation_weighting == "uniform" - assert config.server.diloco.apply_outer_optimizer_to == "parameters" - - server_registry = reload(import_module("plato.servers.registry")) - diloco_server = import_module("plato.servers.diloco") - diloco_aggregation = import_module("plato.servers.strategies.aggregation") - - server = server_registry.get() - - assert isinstance(server, diloco_server.Server) - assert isinstance( - server.aggregation_strategy, - diloco_aggregation.DiLoCoAggregationStrategy, - ) - - -@pytest.mark.integration -def test_diloco_lenet5_smoke_config_runs_faithful_path(monkeypatch): - """Exact DiLoCo smoke config exercises local work and outer aggregation.""" - config_path = REPO_ROOT / "configs" / "MNIST" / "diloco_lenet5_smoke.toml" - - with configure_environment_from_path(config_path) as config: - datasources_registry = import_module("plato.datasources.registry") - processor_registry = import_module("plato.processors.registry") - server_registry = reload(import_module("plato.servers.registry")) - client_mod = import_module("plato.clients.simple") - config_mod = import_module("plato.config") - diloco_server = import_module("plato.servers.diloco") - fedavg_algorithm = import_module("plato.algorithms.fedavg") - diloco_aggregation = import_module("plato.servers.strategies.aggregation") - - fake_datasource = MNISTSmokeDatasource(train_size=32, test_size=4) - monkeypatch.setattr( - datasources_registry, - "get", - lambda *args, **kwargs: fake_datasource, - ) - monkeypatch.setattr( - processor_registry, - "get", - lambda *args, **kwargs: (None, None), - ) - - server = server_registry.get() - server.configure() - - assert isinstance(server, diloco_server.Server) - assert isinstance(server.algorithm, fedavg_algorithm.Algorithm) - assert isinstance( - server.aggregation_strategy, - diloco_aggregation.DiLoCoAggregationStrategy, - ) - assert config.server.type == "diloco" - assert config.algorithm.type == "fedavg" - assert config.trainer.local_steps_per_round == 2 - assert config.trainer.preserve_optimizer_state is True - assert config.data.sampler == "iid" - - client = client_mod.Client() - client.client_id = 1 - client._context.client_id = 1 - client.current_round = 1 - client._context.current_round = 1 - client._load_data() - client.configure() - client._allocate_data() - client._load_payload(server.algorithm.extract_weights()) - - train_config = config.trainer._asdict() - train_config["run_id"] = config_mod.Config.params["run_id"] - client.trainer.current_round = client.current_round - client.trainer.train_model(train_config, client.trainset, client.sampler) - payload = client.algorithm.extract_weights() - - assert client.sampler.num_samples() == config.data.partition_size - assert client.trainer.context.state["local_steps_per_round"] == 2 - assert client.trainer.context.state["local_optimizer_steps"] == 2 - assert client.trainer.current_epoch == 1 - assert client.client_id in client.trainer._preserved_optimizer_states - assert client.trainer._preserved_optimizer_states[client.client_id][ - "optimizer_state" - ]["state"] - _model_weight_payload(payload, client.trainer.model) - - # Small-H mid-epoch stopping and round-aware sampler streaming are covered - # in TestComposableTrainerLocalSteps; this integration path verifies the - # exact smoke config enables those runtime flags with the supported sampler. - baseline = server.algorithm.extract_weights() - trainable_name = next(iter(dict(server.trainer.model.named_parameters()))) - server.updates = [ - _server_update(1, _shifted_payload(baseline, 1.0)), - _server_update(2, _shifted_payload(baseline, 3.0)), - ] - server.current_round = 1 - server.context.current_round = 1 - - delta_calls = [] - aggregate_deltas = server.aggregation_strategy.aggregate_deltas - - async def record_delta_aggregation(updates, deltas_received, context): - delta_calls.append((updates, deltas_received)) - return await aggregate_deltas(updates, deltas_received, context) - - monkeypatch.setattr( - server.aggregation_strategy, - "aggregate_deltas", - record_delta_aggregation, - ) - - async_run(server._process_reports()) - - updated = server.algorithm.extract_weights() - ordinary_fedavg_value = baseline[trainable_name] + 2.0 - faithful_diloco_value = baseline[trainable_name] + 2.0 * 1.9 * 0.7 - - assert len(delta_calls) == 1 - assert len(delta_calls[0][1]) == 2 - assert not torch.allclose(updated[trainable_name], ordinary_fedavg_value) - assert torch.allclose(updated[trainable_name], faithful_diloco_value) - - @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" From 6ffb475f423891cfde2c9c252d46d0e1752960b6 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 30 Apr 2026 13:53:44 -0400 Subject: [PATCH 32/39] Aligned DiLoCo comparison budgets. --- configs/CIFAR10/diloco_resnet18.toml | 6 +++--- configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml | 7 ++++--- configs/MNIST/diloco_lenet5.toml | 6 +++--- configs/MNIST/fedavg_lenet5_diloco_comparison.toml | 9 ++++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/configs/CIFAR10/diloco_resnet18.toml b/configs/CIFAR10/diloco_resnet18.toml index 70193dee5..ed407000c 100644 --- a/configs/CIFAR10/diloco_resnet18.toml +++ b/configs/CIFAR10/diloco_resnet18.toml @@ -49,15 +49,15 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.8 +target_accuracy = 0.9 # Number of local optimizer steps per DiLoCo synchronization. local_steps_per_round = 500 preserve_optimizer_state = true # DiLoCo paper inner-optimizer settings. -epochs = 250 -batch_size = 512 +epochs = 5 +batch_size = 10 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml index 329572e8d..26f32d0ce 100644 --- a/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml +++ b/configs/CIFAR10/fedavg_resnet18_diloco_comparison.toml @@ -41,11 +41,12 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.8 +target_accuracy = 0.9 -# Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# Match the original FedAvg local training shape while keeping 500 optimizer +# steps per round, equal to DiLoCo's H. epochs = 5 -batch_size = 512 +batch_size = 10 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/MNIST/diloco_lenet5.toml b/configs/MNIST/diloco_lenet5.toml index 1a6e68ee0..53eff9305 100644 --- a/configs/MNIST/diloco_lenet5.toml +++ b/configs/MNIST/diloco_lenet5.toml @@ -42,7 +42,7 @@ rounds = 20 max_concurrency = 7 # The target accuracy -target_accuracy = 0.97 +target_accuracy = 0.99 # The machine learning model model_name = "lenet5" @@ -52,8 +52,8 @@ local_steps_per_round = 500 preserve_optimizer_state = true # DiLoCo paper inner-optimizer settings. -epochs = 250 -batch_size = 512 +epochs = 5 +batch_size = 32 optimizer = "AdamW" lr_scheduler = "LambdaLR" diff --git a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml index ecbd1c544..e223915bb 100644 --- a/configs/MNIST/fedavg_lenet5_diloco_comparison.toml +++ b/configs/MNIST/fedavg_lenet5_diloco_comparison.toml @@ -28,20 +28,23 @@ partition_size = 1000 type = "basic" # The maximum number of training rounds -rounds = 20 +rounds = 63 # The maximum number of clients running concurrently max_concurrency = 7 # The target accuracy -target_accuracy = 0.97 +target_accuracy = 0.99 # The machine learning model model_name = "lenet5" # Match the DiLoCo paper-style inner optimizer settings used by the DiLoCo run. +# 5 epochs over 1000 samples at batch size 32 gives 160 optimizer steps per +# round. With 63 rounds, FedAvg gets 10,080 local steps, closely matching +# DiLoCo's 20 * H=500 = 10,000-step total budget. epochs = 5 -batch_size = 512 +batch_size = 32 optimizer = "AdamW" lr_scheduler = "LambdaLR" From bd817d37405fa08de322e5a3ae91b799baa19bf9 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 10 May 2026 16:41:38 -0400 Subject: [PATCH 33/39] Added MSE in metrics for time series forecasting. --- plato/datasources/ev_charging.py | 9 +++- plato/servers/fedavg.py | 72 +++++++++++++++++++++++++------- plato/servers/fedavg_cs.py | 36 ++-------------- plato/servers/split_learning.py | 14 ++----- plato/trainers/composable.py | 7 +++- 5 files changed, 79 insertions(+), 59 deletions(-) diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py index 376f92519..7fa68dd28 100644 --- a/plato/datasources/ev_charging.py +++ b/plato/datasources/ev_charging.py @@ -61,7 +61,6 @@ from plato.config import Config - # Exact column names from the Mendeley CSV _CSV_SEP = ";" _GARAGE_COL = "Garage_ID" @@ -390,6 +389,14 @@ def __init__(self, client_id: int = 0, **kwargs): # Keep the full normalized array for inference scripts self.normalized_data = full_array + self.timestamps = user_df.index + self.user_id = user_key + self.feature_columns = list(_FEATURE_COLS) + self.split_window_starts = { + "train": list(train_starts), + "val": list(val_starts), + "test": list(test_starts), + } self._train_set = _EVChargingDataset( full_array, diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 5f862fc35..c3f4c5e43 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -63,6 +63,54 @@ def __init__( self.clients_per_round, ) + def _primary_metric_name(self) -> str: + """Return the name of the primary testing metric.""" + trainer = getattr(self, "trainer", None) + testing_strategy = getattr(trainer, "testing_strategy", None) + metric_name = getattr(testing_strategy, "metric_name", None) + + if isinstance(metric_name, str) and metric_name: + metric_name = metric_name.lower() + if metric_name != "accuracy": + return metric_name + + if hasattr(Config().trainer, "target_perplexity"): + return "perplexity" + + return "accuracy" + + def _log_average_client_metric(self, metric_value: float) -> None: + """Log the client-aggregated testing metric with the appropriate label.""" + metric_name = self._primary_metric_name() + + if metric_name == "mse": + logging.info("[%s] Average client MSE: %.6f.", self, metric_value) + elif metric_name == "perplexity": + logging.info("[%s] Average client perplexity: %.2f.", self, metric_value) + else: + logging.info("[%s] Average client accuracy: %.2f%%.", self, 100 * metric_value) + + def _log_global_metric(self, metric_value: float) -> None: + """Log the server-tested metric with the appropriate label.""" + metric_name = self._primary_metric_name() + + if metric_name == "mse": + logging.info( + fonts.colourize(f"[{self}] Global model MSE: {metric_value:.6f}\n") + ) + elif metric_name == "perplexity": + logging.info( + fonts.colourize( + f"[{self}] Global model perplexity: {metric_value:.2f}\n" + ) + ) + else: + logging.info( + fonts.colourize( + f"[{self}] Global model accuracy: {100 * metric_value:.2f}%\n" + ) + ) + def configure(self) -> None: """ Booting the federated learning server by setting up the data, model, and @@ -133,7 +181,8 @@ def configure(self) -> None: accuracy_csv_file = ( f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv" ) - accuracy_headers = ["round", "client_id", "accuracy"] + metric_name = self._primary_metric_name() + accuracy_headers = ["round", "client_id", metric_name] csv_processor.initialize_csv( accuracy_csv_file, accuracy_headers, Config().params["result_path"] ) @@ -243,9 +292,7 @@ async def _process_reports(self): if hasattr(Config().server, "do_test") and not Config().server.do_test: # Compute the average accuracy from client reports self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy - ) + self._log_average_client_metric(self.accuracy) else: # Testing the updated model directly at the server logging.info("[%s] Started model testing.", self) @@ -270,17 +317,9 @@ async def _process_reports(self): ) ) elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) + self._log_global_metric(self.accuracy) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) self.clients_processed() self.callback_handler.call_event("on_clients_processed", self) @@ -345,6 +384,11 @@ def get_logged_items(self) -> dict: if hasattr(self, "_core_metric"): logged["core_metric"] = self._core_metric + metric_name = self._primary_metric_name() + if metric_name != "accuracy": + logged[metric_name] = self.accuracy + logged[f"{metric_name}_std"] = self.accuracy_std + logged.update(evaluation_logging.extract_logged_items(self.trainer)) return logged diff --git a/plato/servers/fedavg_cs.py b/plato/servers/fedavg_cs.py index eaba3caf6..fe865dd90 100644 --- a/plato/servers/fedavg_cs.py +++ b/plato/servers/fedavg_cs.py @@ -222,11 +222,7 @@ async def _process_reports(self): self.average_accuracy, self.std_accuracy, ) = self.get_accuracy_mean_std(self.updates) - logging.info( - "[%s] Average client accuracy: %.2f%%.", - self, - 100 * self.average_accuracy, - ) + self._log_average_client_metric(self.average_accuracy) elif Config().is_central_server() and Config().clients.do_test: # Compute the average accuracy from client reports total_samples = sum(update.report.num_samples for update in self.updates) @@ -238,11 +234,7 @@ async def _process_reports(self): / total_samples ) - logging.info( - "[%s] Average client accuracy: %.2f%%.", - self, - 100 * self.average_accuracy, - ) + self._log_average_client_metric(self.average_accuracy) if ( Config().is_central_server() @@ -268,18 +260,8 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) - elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) elif ( Config().is_edge_server() and hasattr(Config().server, "edge_do_test") @@ -304,18 +286,8 @@ async def _process_reports(self): f"[{self}] Average Centered CORE benchmark metric: {100 * core_metric:.2f}%\n" ) ) - elif hasattr(Config().trainer, "target_perplexity"): - logging.info( - fonts.colourize( - f"[{self}] Global model perplexity: {self.accuracy:.2f}\n" - ) - ) else: - logging.info( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.accuracy) else: self.accuracy = self.average_accuracy self.accuracy_std = self.std_accuracy diff --git a/plato/servers/split_learning.py b/plato/servers/split_learning.py index 49fdc2d4c..72ffe688a 100644 --- a/plato/servers/split_learning.py +++ b/plato/servers/split_learning.py @@ -91,7 +91,7 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): self.phase = "gradient" elif report.type == "weights": - logging.warning("[%s] Weights received, start testing accuracy.", self) + logging.warning("[%s] Weights received, start testing.", self) weights = update.payload # The weights after cut layer are not trained by clients @@ -112,17 +112,9 @@ async def aggregate_weights(self, updates, baseline_weights, weights_received): ) ) else: - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.test_accuracy) else: - logging.warning( - fonts.colourize( - f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n" - ) - ) + self._log_global_metric(self.test_accuracy) self.phase = "prompt" # Change client in next round self.next_client = True diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 1e98d1128..47feb17fc 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -551,7 +551,12 @@ def compute_loss(outputs, labels_inner): ) # Track loss - self._loss_tracker.update(loss, labels.size(0)) + if labels is not None: + batch_size = labels.size(0) + else: + first_val = next(iter(examples.values())) if hasattr(examples, "values") else examples + batch_size = first_val.size(0) if hasattr(first_val, "size") else 1 + self._loss_tracker.update(loss, batch_size) # Store last loss in context self.context.state["last_loss"] = loss.item() From ef6c6067608d374633aaa0eea81cc7a55ce6b8ae Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 10 May 2026 16:43:03 -0400 Subject: [PATCH 34/39] Added functions for customizing the time series models from configurations. --- plato/models/huggingface.py | 61 ++++++---- plato/trainers/huggingface.py | 210 ++++++++++++++++++++++++---------- 2 files changed, 191 insertions(+), 80 deletions(-) diff --git a/plato/models/huggingface.py b/plato/models/huggingface.py index 514260314..929b3ae4f 100644 --- a/plato/models/huggingface.py +++ b/plato/models/huggingface.py @@ -202,35 +202,64 @@ def _build_freq_list( # --------------------------------------------------------------------------- +def _create_timesfm_from_config(trainer_config, prediction_length: int) -> nn.Module: + """Instantiate a fresh TimesFmModelForPrediction from TOML trainer settings.""" + if TimesFmConfig is None or TimesFmModelForPrediction is None: + raise ImportError( + "TimesFM models are not available. " + "Ensure you have transformers>=5.0.0 installed." + ) + context_length = getattr(trainer_config, "context_length", 512) + config = TimesFmConfig( + context_length=context_length, + horizon_length=prediction_length, + patch_length=getattr(trainer_config, "patch_length", 32), + num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), + hidden_size=getattr(trainer_config, "hidden_size", 1280), + intermediate_size=getattr(trainer_config, "intermediate_size", 1280), + num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), + head_dim=getattr(trainer_config, "head_dim", 80), + attention_dropout=getattr(trainer_config, "dropout", 0.0), + ) + return TimesFmModelForPrediction(config) + + def _load_timesfm(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Module: """Load or create a TimesFM model wrapped for batched multivariate use. - Supports two HuggingFace variants: - - ``*-transformers``: Uses ``TimesFm2_5ModelForPrediction`` from the - ``transformers`` library. Forward call uses ``forecast_context_len``. - - ``*-pytorch`` (default): Uses ``TimesFmModelForPrediction``. - Forward call uses ``freq``. + Model class selection is based on version, not checkpoint format: + + - TimesFM 2.5 (``2.5`` in the name): ``TimesFm2_5ModelForPrediction``. + Both ``*-pytorch`` and ``*-transformers`` checkpoints are supported by + this class. The forward API differs by suffix: + - ``*-transformers``: uses ``forecast_context_len`` + - ``*-pytorch`` (and others): uses ``freq`` + - TimesFM 1.0 / unspecified: ``TimesFmModelForPrediction`` with ``freq``. + Falls back to config-based creation if the checkpoint is not found. """ - use_transformers_api = "transformers" in resolved_model_name.lower() + name_lower = resolved_model_name.lower() + is_v25 = "2.5" in name_lower + # Controls which forward kwarg to use, not which class to load. + use_transformers_api = "transformers" in name_lower trainer_config = Config().trainer prediction_length = getattr(trainer_config, "prediction_length", 128) default_freq = getattr(trainer_config, "freq", 0) - if use_transformers_api: + if is_v25: if TimesFm2_5ModelForPrediction is None: raise ImportError( "TimesFm2_5ModelForPrediction is not available. " "Ensure you have a recent transformers version installed." ) logging.info( - "Attempting to load pretrained TimesFM 2.5 (transformers) model: %s", + "Loading pretrained TimesFM 2.5 model: %s", resolved_model_name, ) inner = TimesFm2_5ModelForPrediction.from_pretrained( resolved_model_name, cache_dir=cache_dir ) - logging.info("Successfully loaded pretrained TimesFM 2.5 (transformers) model") + logging.info("Successfully loaded pretrained TimesFM 2.5 model") else: if TimesFmModelForPrediction is None: raise ImportError( @@ -251,19 +280,7 @@ def _load_timesfm(resolved_model_name: str, cache_dir: str, **kwargs) -> nn.Modu "TimesFM model '%s' not found as pretrained, creating from config", resolved_model_name, ) - context_length = getattr(trainer_config, "context_length", 512) - config = TimesFmConfig( - context_length=context_length, - horizon_length=prediction_length, - patch_length=getattr(trainer_config, "patch_length", 32), - num_hidden_layers=getattr(trainer_config, "num_hidden_layers", 20), - hidden_size=getattr(trainer_config, "hidden_size", 1280), - intermediate_size=getattr(trainer_config, "intermediate_size", 1280), - num_attention_heads=getattr(trainer_config, "num_attention_heads", 16), - head_dim=getattr(trainer_config, "head_dim", 80), - attention_dropout=getattr(trainer_config, "dropout", 0.0), - ) - inner = TimesFmModelForPrediction(config) + inner = _create_timesfm_from_config(trainer_config, prediction_length) return TimesFmMultivariateWrapper( model=inner, diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index 734b857b2..f8b54dd46 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -38,6 +38,7 @@ TrainingContext, TrainingStepStrategy, ) +from plato.utils.timeseries_utils import is_timeseries_model class HuggingFaceBatch(dict): @@ -161,6 +162,84 @@ def __call__( return batch, labels +class TimeSeriesCollateWrapper: + """Collate function for time-series datasets that return tensor dicts. + + Stacks per-sample dicts (e.g. ``{"past_values": ..., "future_values": ...}``) + into a batched ``HuggingFaceBatch``. Labels are always ``None`` because the + model computes its own loss from ``future_values``. + """ + + def __call__( + self, examples: Iterable[dict] + ) -> tuple[HuggingFaceBatch, None]: + example_list = list(examples) + if not example_list: + raise ValueError("TimeSeriesCollateWrapper received an empty batch.") + + keys = example_list[0].keys() + batch = HuggingFaceBatch( + { + key: torch.stack([torch.as_tensor(ex[key]) for ex in example_list]) + for key in keys + } + ) + return batch, None + + +class TimeSeriesTestingStrategy(TestingStrategy): + """Evaluates time-series models and reports mean MSE loss.""" + + metric_name = "mse" + + def __init__(self, collate_fn: TimeSeriesCollateWrapper): + self.collate_fn = collate_fn + + def test_model(self, model, config, testset, sampler, context: TrainingContext): + batch_size = config.get("batch_size", 1) + + if sampler is not None: + if isinstance(sampler, torch.utils.data.Sampler): + sampler_obj = sampler + elif isinstance(sampler, (list, range)): + sampler_obj = torch.utils.data.SubsetRandomSampler(sampler) + elif hasattr(sampler, "get"): + sampler_obj = sampler.get() + else: + sampler_obj = sampler + else: + sampler_obj = None + + data_loader = torch.utils.data.DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + sampler=sampler_obj, + collate_fn=self.collate_fn, + ) + + model.to(context.device) + model.eval() + + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch_inputs, _ in data_loader: + batch_inputs = batch_inputs.to(context.device) + batch_inputs.setdefault("return_dict", True) + outputs = model(**batch_inputs) + loss = _resolve_hf_loss(outputs, labels=None) + total_loss += loss.item() + num_batches += 1 + + model.train() + + if num_batches == 0: + return float("inf") + return total_loss / num_batches + + def _resolve_hf_loss(outputs, labels, *, allow_fallback: bool = True): """ Resolve a loss tensor from HuggingFace model outputs. @@ -520,45 +599,7 @@ def __init__(self, model=None, callbacks=None): self.training_args = cast(TrainingArguments, training_args) model_name = Config().trainer.model_name - tokenizer_name = getattr(Config().trainer, "tokenizer_name", model_name) - if not isinstance(tokenizer_name, str) or not tokenizer_name: - tokenizer_name = model_name - - config_kwargs = { - "cache_dir": None, - "revision": "main", - "use_auth_token": None, - } - self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) - - cache_dir = Config().params["data_path"] - use_fast_tokenizer = True - revision = "main" - auth_token = getattr( - getattr(Config(), "parameters", None), "huggingface_token", None - ) - - tokenizer_loader: Any = ( - LlamaTokenizer if "llama" in tokenizer_name else AutoTokenizer - ) - tokenizer_kwargs: dict[str, Any] = { - "config": self.config, - "cache_dir": cache_dir, - "use_fast": use_fast_tokenizer, - "revision": revision, - } - if isinstance(auth_token, str) and auth_token: - tokenizer_kwargs["use_auth_token"] = auth_token - self.tokenizer: Any = tokenizer_loader.from_pretrained( - tokenizer_name, - **tokenizer_kwargs, - ) - - tokenizer = cast(Any, self.tokenizer) - if getattr(tokenizer, "pad_token_id", None) is None: - eos_token = getattr(tokenizer, "eos_token", None) - if eos_token is not None: - tokenizer.pad_token = eos_token + model_type = getattr(Config().trainer, "model_type", "") grad_accum_steps = getattr(Config().trainer, "gradient_accumulation_steps", 1) try: @@ -566,7 +607,59 @@ def __init__(self, model=None, callbacks=None): except (TypeError, ValueError): grad_accum_steps = 1 self._gradient_accumulation_steps = max(grad_accum_steps, 1) - self._collate_wrapper = HuggingFaceCollateWrapper(tokenizer) + + if is_timeseries_model(model_name=model_name, model_type=model_type): + # Time-series models have no tokenizer. Use a simple tensor-stacking + # collator and return raw MSE from the testing strategy. + self.tokenizer = None + self.config = None + ts_collate = TimeSeriesCollateWrapper() + self._collate_wrapper = ts_collate + testing_strategy: TestingStrategy = TimeSeriesTestingStrategy(ts_collate) + else: + tokenizer_name = getattr(Config().trainer, "tokenizer_name", model_name) + if not isinstance(tokenizer_name, str) or not tokenizer_name: + tokenizer_name = model_name + + config_kwargs = { + "cache_dir": None, + "revision": "main", + "use_auth_token": None, + } + self.config = AutoConfig.from_pretrained(model_name, **config_kwargs) + + cache_dir = Config().params["data_path"] + use_fast_tokenizer = True + revision = "main" + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) + + tokenizer_loader: Any = ( + LlamaTokenizer if "llama" in tokenizer_name else AutoTokenizer + ) + tokenizer_kwargs: dict[str, Any] = { + "config": self.config, + "cache_dir": cache_dir, + "use_fast": use_fast_tokenizer, + "revision": revision, + } + if isinstance(auth_token, str) and auth_token: + tokenizer_kwargs["use_auth_token"] = auth_token + self.tokenizer: Any = tokenizer_loader.from_pretrained( + tokenizer_name, + **tokenizer_kwargs, + ) + + tokenizer = cast(Any, self.tokenizer) + if getattr(tokenizer, "pad_token_id", None) is None: + eos_token = getattr(tokenizer, "eos_token", None) + if eos_token is not None: + tokenizer.pad_token = eos_token + + self._collate_wrapper = HuggingFaceCollateWrapper(tokenizer) + testing_strategy = HuggingFaceTestingStrategy(self._collate_wrapper) + self.training_args.gradient_accumulation_steps = ( self._gradient_accumulation_steps ) @@ -593,7 +686,7 @@ def __init__(self, model=None, callbacks=None): num_workers=0, pin_memory=True, ), - testing_strategy=HuggingFaceTestingStrategy(self._collate_wrapper), + testing_strategy=testing_strategy, ) if hf_callbacks: @@ -603,23 +696,24 @@ def __init__(self, model=None, callbacks=None): if hasattr(model_instance, "loss_type"): setattr(model_instance, "loss_type", "ForCausalLM") - tokenizer_vocab_size = None - if hasattr(self.tokenizer, "__len__"): - try: - tokenizer_vocab_size = len(self.tokenizer) - except TypeError: - tokenizer_vocab_size = None - embedding_getter = getattr(model_instance, "get_input_embeddings", None) - embedding_resizer = getattr(model_instance, "resize_token_embeddings", None) - if ( - tokenizer_vocab_size is not None - and callable(embedding_getter) - and callable(embedding_resizer) - ): - embeddings = embedding_getter() - embedding_size = getattr(embeddings, "num_embeddings", None) - if embedding_size is not None and embedding_size != tokenizer_vocab_size: - embedding_resizer(tokenizer_vocab_size) + if self.tokenizer is not None: + tokenizer_vocab_size = None + if hasattr(self.tokenizer, "__len__"): + try: + tokenizer_vocab_size = len(self.tokenizer) + except TypeError: + tokenizer_vocab_size = None + embedding_getter = getattr(model_instance, "get_input_embeddings", None) + embedding_resizer = getattr(model_instance, "resize_token_embeddings", None) + if ( + tokenizer_vocab_size is not None + and callable(embedding_getter) + and callable(embedding_resizer) + ): + embeddings = embedding_getter() + embedding_size = getattr(embeddings, "num_embeddings", None) + if embedding_size is not None and embedding_size != tokenizer_vocab_size: + embedding_resizer(tokenizer_vocab_size) if self.training_args.gradient_checkpointing: model_config = getattr(model_instance, "config", None) From 8e30a8781fc0120add366cf8eae04a402a26abf7 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 10 May 2026 16:44:38 -0400 Subject: [PATCH 35/39] Added functions to save the personalized models. --- plato/servers/pfedgraph.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/plato/servers/pfedgraph.py b/plato/servers/pfedgraph.py index f6d694bd4..73aa57d84 100644 --- a/plato/servers/pfedgraph.py +++ b/plato/servers/pfedgraph.py @@ -4,9 +4,12 @@ from __future__ import annotations +import logging +from pathlib import Path from typing import Any, Sequence from plato.config import Config +from plato.serialization.safetensor import serialize_tree from plato.servers import fedavg from plato.servers.strategies.aggregation.pfedgraph import ( PFedGraphAggregationStrategy, @@ -79,3 +82,40 @@ def customize_server_payload(self, payload: Any) -> Any: if client_id in self.client_models: return self.client_models[client_id] return payload + + def _client_model_path(self, client_id: int) -> Path: + """Return the output path for a saved client-specific pFedGraph model.""" + + model_name = ( + Config().trainer.model_name + if hasattr(Config().trainer, "model_name") + else "custom" + ) + return ( + Path(Config().params["model_path"]) + / f"{model_name}_client_{client_id}.safetensors" + ) + + def save_client_models(self) -> None: + """Persist the latest pFedGraph client-specific models.""" + + if not self.client_models: + return + + for client_id, client_model in sorted(self.client_models.items()): + model_path = self._client_model_path(client_id) + model_path.parent.mkdir(parents=True, exist_ok=True) + with model_path.open("wb") as model_file: + model_file.write(serialize_tree(client_model)) + logging.info( + "[%s] Saved pFedGraph client #%d model to %s.", + self, + client_id, + model_path, + ) + + def server_will_close(self) -> None: + """Save pFedGraph client-specific models before server shutdown.""" + + self.save_client_models() + super().server_will_close() From 2c3f2c3b7db22b128a6f51228064cc1c764b1858 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Sun, 10 May 2026 18:28:17 -0400 Subject: [PATCH 36/39] Added config file for TimesFM25 with diloco. --- ...mesfm25_ev_charging_top4_mixed_diloco.toml | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml diff --git a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml new file mode 100644 index 000000000..845238c74 --- /dev/null +++ b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml @@ -0,0 +1,107 @@ +# Federated Learning with TimesFM2.5 + DiLoCo for EV Charging Prediction +# +# Task: Given the past 28 days (672 h) of a user's EV charging behaviour, +# predict whether they will be charging in each of the next 128 hours. +# +# Dataset: "EV Charging Reports" – mixed high-data users across garages +# https://data.mendeley.com/datasets/jbks2rcwyj/1 +# +# Federated setup: 4 clients, one user each. All clients participate every round. +# +# Model: TimesFM (custom, trained from scratch for the small data regime) +# – channel-independent: each of the 6 input features is processed +# as a separate univariate series; only is_charging is evaluated. +# +# DiLoCo uses server.type = "diloco" with algorithm.type = "fedavg" so the +# standard FedAvg weight extraction/update path is reused for local training. +# local_steps_per_round counts optimizer steps, not epochs; optimizer state is +# preserved locally under model_path between DiLoCo synchronizations. +# +# Usage: +# uv run plato.py -c configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml + +[clients] +type = "simple" +total_clients = 4 +per_round = 4 +do_test = true + +[server] +type = "diloco" +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/timeseries/timesfm25_ev_top4_mixed_diloco" +model_path = "models/timeseries/timesfm25_ev_top4_mixed_diloco" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[data] +datasource = "EVCharging" + +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + +# Use explicit users across the whole dataset, not just a single garage. +garage = "all" + +# Explicit user IDs to include — one client per user. +users = ["Bl2-5", "AsO2-1", "Bl2-1", "AdO1-3"] +sampler = "all_inclusive" +random_seed = 42 + +[trainer] +type = "HuggingFace" +rounds = 10 +max_concurrency = 1 +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" + +context_length = 672 +prediction_length = 128 # TimesFM2.5 transformers horizon_length is fixed at 128 steps. + +# Number of input channels: is_charging, energy_scaled, +# hour_sin, hour_cos, dow_sin, dow_cos +num_input_channels = 6 + +# Only evaluate the is_charging channel (index 0) +prediction_channel_indices = [0] + +patch_length = 8 +num_hidden_layers = 4 +hidden_size = 256 +intermediate_size = 256 +num_attention_heads = 4 +head_dim = 64 +dropout = 0.1 + +freq = 0 + +# Sliding-window stride for dataset creation +stride = 1 + +# Number of local optimizer steps per DiLoCo synchronization. +local_steps_per_round = 10 +preserve_optimizer_state = true + +epochs = 10 +batch_size = 16 +optimizer = "AdamW" + +train_ratio = 0.70 +val_ratio = 0.15 + +[algorithm] +type = "fedavg" + +[parameters] +[parameters.optimizer] +lr = 0.0005 +weight_decay = 1e-4 + +[results] +types = "round, elapsed_time, mse" From dda867e754ffd23a793e78741730da973570162f Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 11 May 2026 11:40:55 -0400 Subject: [PATCH 37/39] Updated DiLoCo steps to match the FedAvg. --- configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml index 845238c74..36e450d78 100644 --- a/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml +++ b/configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml @@ -85,7 +85,7 @@ freq = 0 stride = 1 # Number of local optimizer steps per DiLoCo synchronization. -local_steps_per_round = 10 +local_steps_per_round = 1500 preserve_optimizer_state = true epochs = 10 From fb5d231aae84cfcd785f38dadcd058b18b8ec9e8 Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 11 May 2026 11:46:36 -0400 Subject: [PATCH 38/39] Ruff format . --- .../feddf/feddf_algorithm.py | 8 ++- plato/datasources/ev_charging.py | 3 +- plato/datasources/huggingface.py | 14 +++-- plato/datasources/lerobot.py | 25 +++++---- plato/evaluators/lighteval.py | 4 +- plato/evaluators/lighteval_tasks.py | 5 +- plato/evaluators/nanochat_core.py | 20 +++++-- plato/evaluators/runner.py | 4 +- plato/models/smolvla.py | 2 +- plato/servers/fedavg.py | 8 +-- .../servers/strategies/aggregation/diloco.py | 15 ++---- plato/trainers/composable.py | 10 ++-- plato/trainers/huggingface.py | 14 +++-- plato/utils/tree.py | 5 +- tests/clients/test_feddf_strategy.py | 27 +++++----- tests/clients/test_simple_client.py | 7 ++- .../test_huggingface_datasource.py | 32 +++++++----- tests/evaluators/test_lighteval.py | 28 +++++++--- tests/evaluators/test_registry_runner.py | 12 +++-- .../test_huggingface_smollm_smoke.py | 4 +- tests/integration/test_smoke_configs.py | 1 + tests/servers/test_diloco_strategy.py | 4 +- tests/servers/test_fedavg_strategy.py | 8 +-- tests/servers/test_feddf_server_strategy.py | 8 +-- .../strategies/test_loss_criterion.py | 1 - .../test_composable_optimizer_state.py | 52 ++++++------------- tests/trainers/test_composable_trainer.py | 11 ++-- tests/trainers/test_huggingface_trainer.py | 4 +- 28 files changed, 187 insertions(+), 149 deletions(-) diff --git a/examples/server_aggregation/feddf/feddf_algorithm.py b/examples/server_aggregation/feddf/feddf_algorithm.py index 57164529e..ecab8cfe6 100644 --- a/examples/server_aggregation/feddf/feddf_algorithm.py +++ b/examples/server_aggregation/feddf/feddf_algorithm.py @@ -37,7 +37,9 @@ def aggregate_teacher_logits( "FedDF teacher weighting must be either 'uniform' or 'samples'." ) - total_samples = sum(getattr(update.report, "num_samples", 0) for update in updates) + total_samples = sum( + getattr(update.report, "num_samples", 0) for update in updates + ) use_uniform_average = weighting_name == "uniform" or total_samples <= 0 aggregated = torch.zeros_like(first_logits, dtype=torch.float32) @@ -91,7 +93,9 @@ def distill_weights( inputs.append(extract_batch_inputs(example)) proxy_inputs = torch.stack(inputs) - distillation_dataset = TensorDataset(proxy_inputs, teacher_logits.detach().cpu()) + distillation_dataset = TensorDataset( + proxy_inputs, teacher_logits.detach().cpu() + ) dataloader = DataLoader( distillation_dataset, batch_size=distillation_batch_size, diff --git a/plato/datasources/ev_charging.py b/plato/datasources/ev_charging.py index 7fa68dd28..d992814b4 100644 --- a/plato/datasources/ev_charging.py +++ b/plato/datasources/ev_charging.py @@ -135,8 +135,7 @@ def _build_hourly_series( if missing: scope = "all garages" if use_all_garages else f"garage '{garage_name}'" raise ValueError( - f"Users not found in {scope}: {missing}. " - f"Available: {available}" + f"Users not found in {scope}: {missing}. Available: {available}" ) users = list(user_ids) # preserve config order else: diff --git a/plato/datasources/huggingface.py b/plato/datasources/huggingface.py index 49d9582c3..7240561b8 100644 --- a/plato/datasources/huggingface.py +++ b/plato/datasources/huggingface.py @@ -207,7 +207,9 @@ def __init__(self, **kwargs): if isinstance(tokenizer_name, str) and tokenizer_name else Config().trainer.model_name ) - auth_token = getattr(getattr(Config(), "parameters", None), "huggingface_token", None) + auth_token = getattr( + getattr(Config(), "parameters", None), "huggingface_token", None + ) config_kwargs = { "cache_dir": Config().params["model_path"], "revision": "main", @@ -325,7 +327,11 @@ def preprocess_corpus_lm(self, dataset_split): ) configured_block_size = getattr(Config().data, "block_size", None) - block_size = configured_block_size if configured_block_size is not None else self.block_size + block_size = ( + configured_block_size + if configured_block_size is not None + else self.block_size + ) block_size = int(block_size) if block_size > 1024: logging.warning( @@ -364,9 +370,7 @@ def _build_chat_labels( if self.label_strategy == "full_sequence": return list(input_ids) if self.label_strategy != "assistant_only": - raise ValueError( - f"Unsupported chat label strategy: {self.label_strategy}" - ) + raise ValueError(f"Unsupported chat label strategy: {self.label_strategy}") if not hasattr(self.tokenizer, "apply_chat_template"): raise AttributeError( diff --git a/plato/datasources/lerobot.py b/plato/datasources/lerobot.py index ef2cd6654..679cb3e02 100644 --- a/plato/datasources/lerobot.py +++ b/plato/datasources/lerobot.py @@ -97,7 +97,7 @@ def _import_lerobot() -> tuple[Any, Any]: raise ImportError( "LeRobot datasource requires optional LeRobot / SmolVLA robotics dependencies. " "Install the robotics stack in the active environment before using " - '"data.datasource = \"LeRobot\"". ' + '"data.datasource = "LeRobot"". ' ) from exc return LeRobotDataset, LeRobotDatasetMetadata @@ -362,7 +362,9 @@ def _resolve_task_name(row: Mapping[str, Any], tasks_lookup: Any) -> str | None: return None -def _resolve_episode_tasks(metadata: Any, episodes: Sequence[int]) -> dict[int, str | None]: +def _resolve_episode_tasks( + metadata: Any, episodes: Sequence[int] +) -> dict[int, str | None]: episode_tasks = {episode: None for episode in episodes} episode_rows = _episode_rows(getattr(metadata, "episodes", None)) tasks_lookup = _to_plain(getattr(metadata, "tasks", None)) @@ -473,10 +475,14 @@ def _resolve_episode_split( episode_set = set(int(episode) for episode in all_episodes) if explicit_train is None and explicit_test is None: - return _split_episodes(all_episodes, episode_tasks, train_ratio, seed, task_aware) + return _split_episodes( + all_episodes, episode_tasks, train_ratio, seed, task_aware + ) train_episodes = [ - int(episode) for episode in (explicit_train or []) if int(episode) in episode_set + int(episode) + for episode in (explicit_train or []) + if int(episode) in episode_set ] test_episodes = [ int(episode) @@ -569,7 +575,9 @@ def _resolve_total_clients(config: Any) -> int: return total_clients -def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: +def _filter_constructor_kwargs( + dataset_cls: Any, kwargs: Mapping[str, Any] +) -> dict[str, Any]: try: signature = inspect.signature(dataset_cls.__init__) except (TypeError, ValueError): @@ -582,9 +590,7 @@ def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> d if accepts_var_kwargs: return dict(kwargs) - valid_parameters = { - name for name in signature.parameters.keys() if name != "self" - } + valid_parameters = {name for name in signature.parameters.keys() if name != "self"} filtered = {key: value for key, value in kwargs.items() if key in valid_parameters} dropped = sorted(set(kwargs.keys()) - set(filtered.keys())) @@ -646,8 +652,7 @@ def __init__(self, client_id: int = 0, **kwargs): repo_id = str(dataset_cfg.pop("repo_id", "")).strip() if not repo_id: raise ValueError( - "LeRobot datasource requires " - '"parameters.dataset.repo_id" to be set.' + 'LeRobot datasource requires "parameters.dataset.repo_id" to be set.' ) train_split_raw = dataset_cfg.pop("train_split", _DEFAULT_TRAIN_SPLIT) diff --git a/plato/evaluators/lighteval.py b/plato/evaluators/lighteval.py index 2c623b4ba..bd51e1986 100644 --- a/plato/evaluators/lighteval.py +++ b/plato/evaluators/lighteval.py @@ -98,9 +98,7 @@ def __exit__(self, exc_type, exc, exc_tb) -> None: def advance(self, message: str) -> None: self._current += 1 - logging.info( - "[Lighteval] %s (%d/%d).", message, self._current, self.total - ) + logging.info("[Lighteval] %s (%d/%d).", message, self._current, self.total) if self._bar is not None: self._bar.set_postfix_str(message) self._bar.update(1) diff --git a/plato/evaluators/lighteval_tasks.py b/plato/evaluators/lighteval_tasks.py index dc9731852..e4fcf33a1 100644 --- a/plato/evaluators/lighteval_tasks.py +++ b/plato/evaluators/lighteval_tasks.py @@ -15,7 +15,10 @@ def piqa_hf_prompt(line, task_name: str | None = None): query = "The following are multiple choice questions (with answers) about common sense.\n" query += f"Question: {line['goal']}\n" query += "".join( - [f"{key}. {choice}\n" for key, choice in zip(letters, [line["sol1"], line["sol2"]])] + [ + f"{key}. {choice}\n" + for key, choice in zip(letters, [line["sol1"], line["sol2"]]) + ] ) query += "Answer: " diff --git a/plato/evaluators/nanochat_core.py b/plato/evaluators/nanochat_core.py index d22292315..961a20374 100644 --- a/plato/evaluators/nanochat_core.py +++ b/plato/evaluators/nanochat_core.py @@ -319,8 +319,14 @@ def _resolve_tokenizer(model) -> Any: def _safe_evaluate_task( - model, tokenizer, data, device, task_meta, label, - evaluate_task_fn, evaluate_example_fn, + model, + tokenizer, + data, + device, + task_meta, + label, + evaluate_task_fn, + evaluate_example_fn, ): """Wrap upstream ``evaluate_task`` so that examples whose tokenized prompts exceed the model's ``max_seq_len`` are gracefully skipped @@ -446,8 +452,14 @@ def run_core_evaluation( data = data[:max_per_task] accuracy = _safe_evaluate_task( - model, eval_tokenizer, data, model_device, task_meta, label, - evaluate_task, evaluate_example, + model, + eval_tokenizer, + data, + model_device, + task_meta, + label, + evaluate_task, + evaluate_example, ) if accuracy is None: # All examples were skipped (too long for model's max_seq_len). diff --git a/plato/evaluators/runner.py b/plato/evaluators/runner.py index e9b284f7d..7f2657f93 100644 --- a/plato/evaluators/runner.py +++ b/plato/evaluators/runner.py @@ -26,7 +26,9 @@ def _configured_evaluator_type() -> str | None: evaluator_type = evaluation_cfg.get("type") else: evaluator_type = getattr(evaluation_cfg, "type", None) - return evaluator_type if isinstance(evaluator_type, str) and evaluator_type else None + return ( + evaluator_type if isinstance(evaluator_type, str) and evaluator_type else None + ) def _evaluation_fail_on_error() -> bool: diff --git a/plato/models/smolvla.py b/plato/models/smolvla.py index 9dd59ca2e..6760c29ba 100644 --- a/plato/models/smolvla.py +++ b/plato/models/smolvla.py @@ -43,7 +43,7 @@ def _import_smolvla_policy() -> type[Any]: except ImportError as exc: # pragma: no cover - environment dependent raise ImportError( "SmolVLA requires optional LeRobot robotics dependencies. " - "Install the robotics stack in the active environment before using `model_type = \"smolvla\"`." + 'Install the robotics stack in the active environment before using `model_type = "smolvla"`.' ) from exc return SmolVLAPolicy diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index 29299c7f2..2e6ebc733 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -88,7 +88,9 @@ def _log_average_client_metric(self, metric_value: float) -> None: elif metric_name == "perplexity": logging.info("[%s] Average client perplexity: %.2f.", self, metric_value) else: - logging.info("[%s] Average client accuracy: %.2f%%.", self, 100 * metric_value) + logging.info( + "[%s] Average client accuracy: %.2f%%.", self, 100 * metric_value + ) def _log_global_metric(self, metric_value: float) -> None: """Log the server-tested metric with the appropriate label.""" @@ -271,8 +273,8 @@ async def _process_reports(self): # Use delta aggregation (default path) # Computes the weight deltas by comparing the weights received with # the current global model weights - delta_updates, delta_weights_received = ( - self._weight_updates_and_payloads(self.updates, weights_received) + delta_updates, delta_weights_received = self._weight_updates_and_payloads( + self.updates, weights_received ) deltas_received = ( algorithm.compute_weight_deltas( diff --git a/plato/servers/strategies/aggregation/diloco.py b/plato/servers/strategies/aggregation/diloco.py index ddc2428fc..e1d35aa5f 100644 --- a/plato/servers/strategies/aggregation/diloco.py +++ b/plato/servers/strategies/aggregation/diloco.py @@ -44,9 +44,7 @@ def __init__( ): super().__init__() self.outer_optimizer = self._validate_outer_optimizer(outer_optimizer) - self.outer_learning_rate = self._validate_learning_rate( - outer_learning_rate - ) + self.outer_learning_rate = self._validate_learning_rate(outer_learning_rate) self.outer_momentum = self._validate_momentum(outer_momentum) self.aggregation_weighting = self._validate_weighting_mode( aggregation_weighting @@ -432,9 +430,7 @@ def _map_tree(self, value: Any, leaf_fn: Callable[[Any, str], Any], path="") -> def _scale_tree(self, value: Any, scalar: float) -> Any: if isinstance(value, Mapping): - return { - key: self._scale_tree(item, scalar) for key, item in value.items() - } + return {key: self._scale_tree(item, scalar) for key, item in value.items()} if isinstance(value, list): return [self._scale_tree(item, scalar) for item in value] @@ -488,10 +484,9 @@ def _is_compatible(left: Any, right: Any) -> bool: left_shape = getattr(left, "shape", None) right_shape = getattr(right, "shape", None) if left_shape is not None or right_shape is not None: - return ( - left_shape == right_shape - and getattr(left, "dtype", None) == getattr(right, "dtype", None) - ) + return left_shape == right_shape and getattr( + left, "dtype", None + ) == getattr(right, "dtype", None) return isinstance(left, numbers.Number) and isinstance(right, numbers.Number) diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index 76adaa421..2645b7904 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -880,7 +880,11 @@ def compute_loss(outputs, labels_inner): if labels is not None: batch_size = labels.size(0) else: - first_val = next(iter(examples.values())) if hasattr(examples, "values") else examples + first_val = ( + next(iter(examples.values())) + if hasattr(examples, "values") + else examples + ) batch_size = first_val.size(0) if hasattr(first_val, "size") else 1 self._loss_tracker.update(loss, batch_size) @@ -954,9 +958,7 @@ def compute_loss(outputs, labels_inner): ) if finalize_step_done: self.optimizer_strategy.on_optimizer_step(self.optimizer, self.context) - self._step_lr_scheduler_after_optimizer_step( - step_lr_per_optimizer_step - ) + self._step_lr_scheduler_after_optimizer_step(step_lr_per_optimizer_step) local_step_limit_reached = self._record_local_optimizer_step( local_steps_per_round ) diff --git a/plato/trainers/huggingface.py b/plato/trainers/huggingface.py index f8b54dd46..8d4c7e592 100644 --- a/plato/trainers/huggingface.py +++ b/plato/trainers/huggingface.py @@ -125,7 +125,10 @@ def __call__( if not example_list: raise ValueError("HuggingFace collator received an empty batch.") - feature_rows = [{k: v for k, v in example.items() if k != "labels"} for example in example_list] + feature_rows = [ + {k: v for k, v in example.items() if k != "labels"} + for example in example_list + ] padding_side = getattr(self.tokenizer, "padding_side", "right") batch = self.tokenizer.pad( @@ -170,9 +173,7 @@ class TimeSeriesCollateWrapper: model computes its own loss from ``future_values``. """ - def __call__( - self, examples: Iterable[dict] - ) -> tuple[HuggingFaceBatch, None]: + def __call__(self, examples: Iterable[dict]) -> tuple[HuggingFaceBatch, None]: example_list = list(examples) if not example_list: raise ValueError("TimeSeriesCollateWrapper received an empty batch.") @@ -712,7 +713,10 @@ def __init__(self, model=None, callbacks=None): ): embeddings = embedding_getter() embedding_size = getattr(embeddings, "num_embeddings", None) - if embedding_size is not None and embedding_size != tokenizer_vocab_size: + if ( + embedding_size is not None + and embedding_size != tokenizer_vocab_size + ): embedding_resizer(tokenizer_vocab_size) if self.training_args.gradient_checkpointing: diff --git a/plato/utils/tree.py b/plato/utils/tree.py index ed7c2b647..0f1a13c54 100644 --- a/plato/utils/tree.py +++ b/plato/utils/tree.py @@ -66,7 +66,10 @@ def _ensure_numpy(value: Any) -> np.ndarray: if callable(cpu_fn): tensor = cpu_fn() torch_bfloat16 = getattr(torch, "bfloat16", None) if torch is not None else None - if torch_bfloat16 is not None and getattr(tensor, "dtype", None) == torch_bfloat16: + if ( + torch_bfloat16 is not None + and getattr(tensor, "dtype", None) == torch_bfloat16 + ): tensor = tensor.to(torch.float32) numpy_fn = getattr(tensor, "numpy", None) if callable(numpy_fn): diff --git a/tests/clients/test_feddf_strategy.py b/tests/clients/test_feddf_strategy.py index 6676eda37..518735d6e 100644 --- a/tests/clients/test_feddf_strategy.py +++ b/tests/clients/test_feddf_strategy.py @@ -15,9 +15,7 @@ from tests.test_utils.fakes import FakeModel _TESTS_ROOT = Path(__file__).resolve().parent -_FEDDF_DIR = ( - _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" -) +_FEDDF_DIR = _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" if str(_FEDDF_DIR) not in sys.path: sys.path.insert(0, str(_FEDDF_DIR)) @@ -44,7 +42,9 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): context = SimpleNamespace( client_id=1, current_round=1, - algorithm=SimpleNamespace(load_weights=lambda weights: loaded_weights.append(weights)), + algorithm=SimpleNamespace( + load_weights=lambda weights: loaded_weights.append(weights) + ), trainer=SimpleNamespace(model=FakeModel(), device="cpu"), state={}, ) @@ -58,14 +58,17 @@ def test_feddf_training_strategy_returns_teacher_logits(temp_config): mock_report = SimpleNamespace(num_samples=8) async_mock = AsyncMock(return_value=(mock_report, {"weights": torch.ones(1)})) - with patch.object( - feddf_client.DefaultTrainingStrategy, - "train", - new=async_mock, - ) as mock_train, patch.object( - feddf_client.time, - "perf_counter", - side_effect=[10.0, 10.25], + with ( + patch.object( + feddf_client.DefaultTrainingStrategy, + "train", + new=async_mock, + ) as mock_train, + patch.object( + feddf_client.time, + "perf_counter", + side_effect=[10.0, 10.25], + ), ): report, payload = asyncio.run(strategy.train(context)) diff --git a/tests/clients/test_simple_client.py b/tests/clients/test_simple_client.py index 11b1bd8e0..db3d7c838 100644 --- a/tests/clients/test_simple_client.py +++ b/tests/clients/test_simple_client.py @@ -206,10 +206,9 @@ def test_simple_client_subprocess_payload_excludes_local_state_sidecar( asyncio.run(client._handle_payload(server_payload)) sent_payload = client._context.state["sent_payloads"][-1] - state_path = ( - Path(Config.params["model_path"]) - / client.trainer._optimizer_state_filename(Config.params["run_id"]) - ) + state_path = Path( + Config.params["model_path"] + ) / client.trainer._optimizer_state_filename(Config.params["run_id"]) with state_path.open("rb") as state_file: sidecar_state = pickle.load(state_file) diff --git a/tests/datasources/test_huggingface_datasource.py b/tests/datasources/test_huggingface_datasource.py index 1937c34b6..532ee126e 100644 --- a/tests/datasources/test_huggingface_datasource.py +++ b/tests/datasources/test_huggingface_datasource.py @@ -30,8 +30,7 @@ def apply_chat_template( ): if not tokenize: return "".join( - f"<{message['role']}>{message['content']}|" - for message in messages + f"<{message['role']}>{message['content']}|" for message in messages ) tokens = [] @@ -125,8 +124,12 @@ def test_huggingface_datasource_keeps_validation_split_for_corpus_mode( } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, @@ -172,8 +175,12 @@ def test_huggingface_datasource_falls_back_to_test_split(temp_config, monkeypatc } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, @@ -218,9 +225,7 @@ def test_huggingface_datasource_loads_legacy_cache_path_when_present( } ) - legacy_path = ( - f"{Config().params['data_path']}/{cfg.data.dataset_name}_{cfg.data.dataset_config}" - ) + legacy_path = f"{Config().params['data_path']}/{cfg.data.dataset_name}_{cfg.data.dataset_config}" loaded_paths: list[str] = [] monkeypatch.setattr( @@ -260,7 +265,6 @@ class LargeContextDummyTokenizer(DummyTokenizer): model_max_length = 4096 - def test_huggingface_corpus_mode_keeps_legacy_default_block_size( temp_config, monkeypatch ): @@ -283,8 +287,12 @@ def test_huggingface_corpus_mode_keeps_legacy_default_block_size( } ) - monkeypatch.setattr(huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset) - monkeypatch.setattr(huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset) + monkeypatch.setattr( + huggingface_datasource, "load_dataset", lambda *args, **kwargs: dataset + ) + monkeypatch.setattr( + huggingface_datasource, "load_from_disk", lambda *args, **kwargs: dataset + ) monkeypatch.setattr(huggingface_datasource.os.path, "exists", lambda *args: False) monkeypatch.setattr( huggingface_datasource.AutoConfig, diff --git a/tests/evaluators/test_lighteval.py b/tests/evaluators/test_lighteval.py index 3045d711f..702600520 100644 --- a/tests/evaluators/test_lighteval.py +++ b/tests/evaluators/test_lighteval.py @@ -42,7 +42,13 @@ def test_lighteval_fast_preset_contains_expected_tasks(temp_config): preset = _resolve_preset("smollm_round_fast") - assert preset["tasks"] == ["ifeval", "hellaswag", "arc_easy", "arc_challenge", "piqa"] + assert preset["tasks"] == [ + "ifeval", + "hellaswag", + "arc_easy", + "arc_challenge", + "piqa", + ] assert preset["primary_metric"] == "ifeval_avg" @@ -72,7 +78,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): calls["launcher_type"] = launcher_type calls["custom_tasks_directory"] = custom_tasks_directory calls["max_samples"] = max_samples @@ -228,7 +236,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory, max_samples class FakeEvaluationTracker: @@ -335,7 +345,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory, max_samples class FakeEvaluationTracker: @@ -425,7 +437,9 @@ class FakeParallelismManager(Enum): ACCELERATE = auto() class FakePipelineParameters: - def __init__(self, launcher_type, custom_tasks_directory=None, max_samples=None): + def __init__( + self, launcher_type, custom_tasks_directory=None, max_samples=None + ): del launcher_type, custom_tasks_directory captured["max_samples"] = max_samples @@ -605,9 +619,7 @@ def _mock_pipeline(**kwargs): result = LightevalEvaluator( {"type": "lighteval", "preset": "smollm_round_fast"} - ).evaluate( - EvaluationInput(model=SaveableArtifact(), tokenizer=SaveableArtifact()) - ) + ).evaluate(EvaluationInput(model=SaveableArtifact(), tokenizer=SaveableArtifact())) assert result.metrics["ifeval_avg"] == pytest.approx(0.40) assert captured["model_name"] diff --git a/tests/evaluators/test_registry_runner.py b/tests/evaluators/test_registry_runner.py index 193c0ee6d..b4be4e9f9 100644 --- a/tests/evaluators/test_registry_runner.py +++ b/tests/evaluators/test_registry_runner.py @@ -109,7 +109,9 @@ def test_composable_trainer_runs_registered_evaluator_and_stores_results(temp_co testing_strategy=ConstantTestingStrategy(0.5), ) - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert trainer.accuracy == 0.5 @@ -257,7 +259,9 @@ def test_composable_trainer_tolerates_evaluator_runtime_failure_by_default( testing_strategy=ConstantTestingStrategy(0.5), ) - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert trainer.accuracy == 0.5 @@ -306,7 +310,9 @@ def test_composable_trainer_restores_grad_mode_after_evaluator_side_effect( ) assert torch.is_grad_enabled() is True - accuracy = trainer.test_model(config={"batch_size": 1}, testset=[], sampler=None) + accuracy = trainer.test_model( + config={"batch_size": 1}, testset=[], sampler=None + ) assert accuracy == 0.5 assert torch.is_grad_enabled() is True diff --git a/tests/integration/test_huggingface_smollm_smoke.py b/tests/integration/test_huggingface_smollm_smoke.py index 4ac9320c3..b1de91f89 100644 --- a/tests/integration/test_huggingface_smollm_smoke.py +++ b/tests/integration/test_huggingface_smollm_smoke.py @@ -93,7 +93,9 @@ def gradient_checkpointing_enable(self): def test_smollm_smoltalk_config_smoke(monkeypatch, tmp_path): """Smoke test the SmolLM2 + smol-smoltalk config with mocked HF/Lighteval hooks.""" repo_root = Path(__file__).resolve().parents[2] - config_path = repo_root / "configs/HuggingFace/fedavg_smol_smoltalk_smollm2_135m.toml" + config_path = ( + repo_root / "configs/HuggingFace/fedavg_smol_smoltalk_smollm2_135m.toml" + ) assert config_path.exists() dataset = DatasetDict( diff --git a/tests/integration/test_smoke_configs.py b/tests/integration/test_smoke_configs.py index e499655be..ea99bb401 100644 --- a/tests/integration/test_smoke_configs.py +++ b/tests/integration/test_smoke_configs.py @@ -97,6 +97,7 @@ def test_fedavg_lenet5_smoke(monkeypatch): async_run(server._process_reports()) assert server.accuracy >= 0 + @pytest.mark.integration def test_split_learning_smoke(monkeypatch): """Smoke test for split-learning trainer orchestrating gradients.""" diff --git a/tests/servers/test_diloco_strategy.py b/tests/servers/test_diloco_strategy.py index 4c739051d..35b685b63 100644 --- a/tests/servers/test_diloco_strategy.py +++ b/tests/servers/test_diloco_strategy.py @@ -122,9 +122,7 @@ class AdapterAliasCollisionModel(torch.nn.Module): def __init__(self): super().__init__() self.peft_config = {"default": object()} - self.foo = torch.nn.ModuleDict( - {"default": torch.nn.Linear(1, 1, bias=False)} - ) + self.foo = torch.nn.ModuleDict({"default": torch.nn.Linear(1, 1, bias=False)}) def _context(baseline=None, model=None): diff --git a/tests/servers/test_fedavg_strategy.py b/tests/servers/test_fedavg_strategy.py index 2f788a9f1..41f0c89a5 100644 --- a/tests/servers/test_fedavg_strategy.py +++ b/tests/servers/test_fedavg_strategy.py @@ -242,9 +242,7 @@ def test_fedavg_server_prefers_custom_delta_strategy_over_inherited_weights( assert torch.allclose(server.algorithm.current["bias"], torch.ones(1)) -def test_fedavg_server_logged_items_flatten_evaluator_metrics( - temp_config, tmp_path -): +def test_fedavg_server_logged_items_flatten_evaluator_metrics(temp_config, tmp_path): """FedAvg should keep accuracy while surfacing evaluator summary metrics.""" from plato.config import Config from plato.servers import fedavg @@ -309,9 +307,7 @@ def test_fedavg_server_logged_items_include_detailed_lighteval_metrics( assert logged_items["evaluation_arc_challenge_acc_stderr"] == 0.0701 -def test_fedavg_server_does_not_persist_evaluator_jsonl_sidecar( - temp_config, tmp_path -): +def test_fedavg_server_does_not_persist_evaluator_jsonl_sidecar(temp_config, tmp_path): """FedAvg should rely on CSV logging instead of a JSONL sidecar.""" from plato.config import Config from plato.servers import fedavg diff --git a/tests/servers/test_feddf_server_strategy.py b/tests/servers/test_feddf_server_strategy.py index edf382b1d..39b584372 100644 --- a/tests/servers/test_feddf_server_strategy.py +++ b/tests/servers/test_feddf_server_strategy.py @@ -16,9 +16,7 @@ from plato.config import Config _TESTS_ROOT = Path(__file__).resolve().parent -_FEDDF_DIR = ( - _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" -) +_FEDDF_DIR = _TESTS_ROOT.parent.parent / "examples" / "server_aggregation" / "feddf" if str(_FEDDF_DIR) not in sys.path: sys.path.insert(0, str(_FEDDF_DIR)) @@ -71,7 +69,9 @@ def __init__( self._unlabeled = TensorDataset(proxy_inputs, torch.zeros(len(proxy_inputs))) self._test = TensorDataset( test_inputs if test_inputs is not None else proxy_inputs, - torch.zeros(len(test_inputs) if test_inputs is not None else len(proxy_inputs)), + torch.zeros( + len(test_inputs) if test_inputs is not None else len(proxy_inputs) + ), ) def get_unlabeled_set(self): diff --git a/tests/trainers/strategies/test_loss_criterion.py b/tests/trainers/strategies/test_loss_criterion.py index f414cb031..7c798d813 100644 --- a/tests/trainers/strategies/test_loss_criterion.py +++ b/tests/trainers/strategies/test_loss_criterion.py @@ -44,7 +44,6 @@ def import_without_lightly(name, package=None): assert isinstance(criterion, nn.CrossEntropyLoss) - def test_loss_registry_ssl_loss_requires_optional_lightly(temp_config, monkeypatch): from plato.trainers import loss_criterion as loss_criterion_registry diff --git a/tests/trainers/test_composable_optimizer_state.py b/tests/trainers/test_composable_optimizer_state.py index b005981f0..2ddb53f86 100644 --- a/tests/trainers/test_composable_optimizer_state.py +++ b/tests/trainers/test_composable_optimizer_state.py @@ -287,9 +287,7 @@ def test_scheduler_state_and_lr_progress_persist_between_rounds( def test_subprocess_optimizer_state_parent_reloads_after_child( temp_config, monkeypatch, tmp_path, tiny_dataset ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), @@ -301,9 +299,8 @@ def test_subprocess_optimizer_state_parent_reloads_after_child( assert trainer.client_id in trainer._preserved_optimizer_states assert _cached_optimizer_step(trainer) == 1 - state_path = ( - Path(Config.params["model_path"]) - / trainer._optimizer_state_filename(Config.params["run_id"]) + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] ) assert state_path.exists() assert "optimizer_state" not in trainer.obtain_model_update( @@ -322,9 +319,7 @@ def test_subprocess_optimizer_state_parent_reloads_after_child( def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( temp_config, monkeypatch, tmp_path, tiny_dataset ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), @@ -341,9 +336,7 @@ def test_subprocess_optimizer_state_persists_across_rounds_for_same_client( def test_subprocess_scheduler_state_persists_across_rounds( temp_config, monkeypatch, tmp_path, tiny_dataset ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), @@ -363,9 +356,7 @@ def test_subprocess_scheduler_state_persists_across_rounds( def test_subprocess_missing_sidecar_clears_inherited_parent_cache( temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) source_trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), @@ -390,9 +381,8 @@ def test_subprocess_missing_sidecar_clears_inherited_parent_cache( source_trainer._preserved_optimizer_states[7] ) - state_path = ( - Path(Config.params["model_path"]) - / trainer._optimizer_state_filename(Config.params["run_id"]) + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] ) state_path.unlink(missing_ok=True) @@ -404,9 +394,7 @@ def test_subprocess_missing_sidecar_clears_inherited_parent_cache( def test_missing_subprocess_output_removes_stale_input_sidecar( temp_config, monkeypatch, tmp_path, tiny_dataset, one_step_config ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), @@ -428,9 +416,7 @@ def test_missing_subprocess_output_removes_stale_input_sidecar( input_path = Path(Config.params["model_path"]) / input_filename assert input_path.exists() - trainer._finish_subprocess_optimizer_state( - input_filename, missing_output_filename - ) + trainer._finish_subprocess_optimizer_state(input_filename, missing_output_filename) assert trainer.client_id not in trainer._preserved_optimizer_states assert not input_path.exists() @@ -439,18 +425,15 @@ def test_missing_subprocess_output_removes_stale_input_sidecar( def test_subprocess_invalid_optimizer_state_resets_safely( temp_config, monkeypatch, tmp_path, tiny_dataset ): - _configure_subprocess_training( - monkeypatch, tmp_path, preserve_optimizer_state=True - ) + _configure_subprocess_training(monkeypatch, tmp_path, preserve_optimizer_state=True) trainer = ComposableTrainer( model=_linear_model, loss_strategy=CrossEntropyLossStrategy(), optimizer_strategy=AdamWOptimizerStrategy(lr=0.01), ) trainer.set_client_id(7) - state_path = ( - Path(Config.params["model_path"]) - / trainer._optimizer_state_filename(Config.params["run_id"]) + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] ) with open(state_path, "wb") as state_file: pickle.dump({"optimizer_type": torch.optim.SGD}, state_file) @@ -479,9 +462,8 @@ def test_subprocess_optimizer_state_is_not_persisted_when_disabled( trainer.train(tiny_dataset, list(range(len(tiny_dataset)))) assert trainer._preserved_optimizer_states == {} - state_path = ( - Path(Config.params["model_path"]) - / trainer._optimizer_state_filename(Config.params["run_id"]) + state_path = Path(Config.params["model_path"]) / trainer._optimizer_state_filename( + Config.params["run_id"] ) assert not state_path.exists() @@ -602,9 +584,7 @@ def test_preserved_state_compatibility_rejects_shape_dtype_and_scheduler_changes payload, current_model, current_optimizer, changed_scheduler ) - changed_shape_model = nn.Sequential( - OrderedDict([("linear", nn.Linear(2, 3))]) - ) + changed_shape_model = nn.Sequential(OrderedDict([("linear", nn.Linear(2, 3))])) changed_shape_optimizer = trainer.optimizer_strategy.create_optimizer( changed_shape_model, trainer.context ) diff --git a/tests/trainers/test_composable_trainer.py b/tests/trainers/test_composable_trainer.py index fd6b4a668..e9ca21090 100644 --- a/tests/trainers/test_composable_trainer.py +++ b/tests/trainers/test_composable_trainer.py @@ -533,8 +533,7 @@ def test_local_step_sampling_warns_for_non_materializable_sampler( assert loader.sampler is sampler assert ( - "cannot be materialized for round-aware local-step sampling" - in caplog.text + "cannot be materialized for round-aware local-step sampling" in caplog.text ) def test_diloco_local_steps_require_full_client_participation( @@ -627,7 +626,9 @@ def test_invalid_local_steps_fail_clearly( trainer = ComposableTrainer(model=simple_model) with pytest.raises(ValueError, match="local_steps_per_round"): - trainer.train_model(config, simple_dataset, list(range(len(simple_dataset)))) + trainer.train_model( + config, simple_dataset, list(range(len(simple_dataset))) + ) class TestComposableTrainerStrategies: @@ -920,9 +921,7 @@ def test_test_state_roundtrip_persists_evaluation_metadata(self, temp_config): "metric": "ifeval_avg", "value": 0.31, } - assert trainer.context.state["nanochat_core_results"] == { - "core_metric": 0.9 - } + assert trainer.context.state["nanochat_core_results"] == {"core_metric": 0.9} def test_test_state_restore_clears_stale_evaluation_metadata(self, temp_config): trainer = ComposableTrainer(model=nn.Linear(2, 1)) diff --git a/tests/trainers/test_huggingface_trainer.py b/tests/trainers/test_huggingface_trainer.py index 597c8084c..144c37ea8 100644 --- a/tests/trainers/test_huggingface_trainer.py +++ b/tests/trainers/test_huggingface_trainer.py @@ -30,7 +30,9 @@ def pad(self, features, padding=True, return_tensors=None): for feature in features: pad_width = max_len - len(feature["input_ids"]) - batch["input_ids"].append(feature["input_ids"] + [self.pad_token_id] * pad_width) + batch["input_ids"].append( + feature["input_ids"] + [self.pad_token_id] * pad_width + ) batch["attention_mask"].append( feature.get("attention_mask", [1] * len(feature["input_ids"])) + [0] * pad_width From 34cf705b955e2ab02a00b526a1838f5980d0edda Mon Sep 17 00:00:00 2001 From: Jasmine-Yuting-Zhang Date: Mon, 11 May 2026 14:40:59 -0400 Subject: [PATCH 39/39] Updated documents for time series models, including the models and a case study. --- docs/docs/configurations/data.md | 15 ++ docs/docs/configurations/results.md | 3 + docs/docs/configurations/trainer.md | 21 +- docs/docs/examples/Getting Started.md | 4 + ...6. Time-Series Forecasting with TimesFM.md | 210 ++++++++++++++++++ docs/docs/index.md | 1 + docs/mkdocs.yml | 1 + 7 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md diff --git a/docs/docs/configurations/data.md b/docs/docs/configurations/data.md index 2173be10e..ff2b1c481 100644 --- a/docs/docs/configurations/data.md +++ b/docs/docs/configurations/data.md @@ -5,6 +5,7 @@ - `Torchvision`: including torchvision datasets such as MNIST, FashionMNIST, EMNIST, CIFAR10, CIFAR100, CelebA, or STL10 (requires `dataset_name`) - `CINIC10` - `FEMNIST`: Federated EMNIST + - `EVCharging`: per-user EV charging time-series forecasting windows - `TinyImageNet` - `Purchase` - `Texas` @@ -57,6 +58,20 @@ !!! example "test_path" Where the test dataset is located. +!!! tip "EVCharging time-series datasource" + `EVCharging` builds per-user hourly time series from the [_Residential electric vehicle charging datasets from apartment buildings_](https://data.mendeley.com/datasets/jbks2rcwyj/1/files/2e3b8ced-9887-4a91-b721-8e510e18a127) [doi: 10.17632/jbks2rcwyj.1]. Each client receives one configured user, so use `sampler = "all_inclusive"` rather than class-label partitioning. + + ```toml + [data] + datasource = "EVCharging" + datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" + garage = "AdO1" # use "all" for users across garages + users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] + sampler = "all_inclusive" + ``` + + The datasource creates `past_values` / `future_values` sliding-window samples. The input feature order is `is_charging`, `energy_scaled`, `hour_sin`, `hour_cos`, `dow_sin`, and `dow_cos`; the reference configs forecast only `is_charging`. + !!! example "sampler" How to divide the entire dataset to the clients. The following options are available: diff --git a/docs/docs/configurations/results.md b/docs/docs/configurations/results.md index 56537231c..dea8d2d7b 100644 --- a/docs/docs/configurations/results.md +++ b/docs/docs/configurations/results.md @@ -6,6 +6,8 @@ - `round` - `accuracy` - `accuracy_std` + - `mse` + - `mse_std` - `elapsed_time` - `comm_time` - `processing_time` @@ -19,6 +21,7 @@ !!! note "Note" Use commas to separate them. The default is `round, accuracy, elapsed_time`. + Time-series configs commonly use `round, elapsed_time, mse` instead. !!! note "Structured evaluators" When `[evaluation]` is configured, Plato automatically appends any new `evaluation_*` columns that appear at runtime. You do **not** need to predeclare every Lighteval task metric in `results.types`, although predeclaring the summary columns can keep the CSV order stable. diff --git a/docs/docs/configurations/trainer.md b/docs/docs/configurations/trainer.md index 05fef2f2d..9c8852971 100644 --- a/docs/docs/configurations/trainer.md +++ b/docs/docs/configurations/trainer.md @@ -5,7 +5,7 @@ - `composable` the strategy-based trainer that exposes loss, optimiser, scheduler, data-loader, model-update, and testing strategies directly. - `timm_basic` a basic trainer with the [timm](https://timm.fast.ai/) learning rate scheduler. - `diff_privacy` a trainer that supports local differential privacy in its training loop by adding noise to the gradients during each step of training. - - `HuggingFace` a trainer for Hugging Face causal language models and tokenizers. + - `HuggingFace` a trainer for Hugging Face causal language models, tokenizers, and time-series models. - `nanochat` a trainer for Nanochat language-model workloads. - `lerobot` a trainer for LeRobot / SmolVLA workloads. - `split_learning` a trainer that supports the split learning framework. @@ -179,6 +179,8 @@ - `cnn_encoder` (for generating various encoders by extracting from CNN models such as ResNet models) - `general_multilayer` (for generating a multi-layer perceptron using a provided configuration) - `huggingface` (for [HuggingFace](https://huggingface.co/models) causal language models) + - `timesfm` (for Hugging Face TimesFM time-series forecasting models) + - `patchtsmixer` (for Hugging Face PatchTSMixer time-series models) - `torch_hub` (for models from [PyTorch Hub](https://pytorch.org/hub/)) - `vit` (for Vision Transformer models from [HuggingFace](https://huggingface.co/models), [Tokens-to-Token ViT](https://github.com/yitu-opensource/T2T-ViT), and [Deep Vision Transformer](https://github.com/zhoudaquan/dvit_repo)) - `smolvla` (for LeRobot / SmolVLA robotics policies) @@ -204,6 +206,8 @@ - `multilayer` - `nanochat` - `smolvla` + - `timesfm` + - `patchtsmixer` !!! note "Note" If the `model_type` above specified a model repository, supply the name of the model, such as `gpt2`, `HuggingFaceTB/SmolLM2-135M`, or `smolvla`, here. @@ -214,3 +218,18 @@ An optional tokenizer identifier to use instead of `trainer.model_name`. This is mainly useful for Hugging Face language-model workloads where the tokenizer/chat template comes from a separate repository. + +!!! tip "HuggingFace time-series models" + Set `trainer.type = "HuggingFace"` with `model_type = "timesfm"` or `model_type = "patchtsmixer"` to use Plato's time-series collator and MSE testing strategy instead of the tokenizer-based language-model path. + + Common fields include: + + - `context_length`: number of historical time steps in each input window. + - `prediction_length`: number of future time steps to forecast. + - `num_input_channels`: number of features in `past_values`. + - `prediction_channel_indices`: channels to keep/evaluate from model output. + - `stride`: sliding-window stride used by the datasource. + - `train_ratio` and `val_ratio`: temporal split ratios for per-user windows. + - `freq`: TimesFM frequency token (`0` for high-frequency/hourly data). + + TimesFM models are wrapped channel-independently for multivariate tensors; PatchTSMixer configs can use model options such as `mode = "mix_channel"` to mix features jointly. See the [TimesFM case study](../examples/case-studies/6. Time-Series Forecasting with TimesFM.md) for complete configs. diff --git a/docs/docs/examples/Getting Started.md b/docs/docs/examples/Getting Started.md index c8e5a157a..7109746e2 100644 --- a/docs/docs/examples/Getting Started.md +++ b/docs/docs/examples/Getting Started.md @@ -54,3 +54,7 @@ Plato supports both Linux with NVIDIA GPUs and macOS with M1/M2/M4/M4 GPUs. It w - [Server-side Lighteval for SmolLM2](case-studies/4. Server-side Lighteval for SmolLM2.md) - [SmolVLA Trainer with LeRobot](case-studies/3. SmolVLA Trainer with LeRobot.md) + +- [Nanochat in Plato](case-studies/5. Nanochat in Plato.md) + +- [Time-Series Forecasting with TimesFM](case-studies/6. Time-Series Forecasting with TimesFM.md) diff --git a/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md b/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md new file mode 100644 index 000000000..05d62a0d8 --- /dev/null +++ b/docs/docs/examples/case-studies/6. Time-Series Forecasting with TimesFM.md @@ -0,0 +1,210 @@ +# Time-Series Forecasting with TimesFM + +Plato includes a reference workflow for **federated time-series forecasting** with Hugging Face time-series models. The initial case study predicts EV charging availability: each client owns one user's charging history and trains on sliding windows from that user's hourly sequence. + +Reference files: + +- `configs/TimeSeries/timesfm25_ev_charging.toml` +- `configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml` +- `configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml` +- `configs/TimeSeries/patchtsmixer_ev_charging.toml` +- `plato/datasources/ev_charging.py` + +## Dataset preparation + +The configs use the [_Residential electric vehicle charging datasets from apartment buildings_](https://data.mendeley.com/datasets/jbks2rcwyj/1/files/2e3b8ced-9887-4a91-b721-8e510e18a127) [doi: 10.17632/jbks2rcwyj.1]. + +Download `dataset1_ev_charging_reports.csv` and place it at the path used by the configs, for example: + +```text +runtime/data/ado1/dataset1_ev_charging_reports.csv +``` + +The dataset is not bundled with Plato. The datasource expects the raw semicolon-separated CSV and performs the preprocessing at runtime. + +## EVCharging datasource behavior + +Use the datasource with: + +```toml +[data] +datasource = "EVCharging" +datasource_path = "runtime/data/ado1/dataset1_ev_charging_reports.csv" +garage = "AdO1" +users = ["AdO1-1", "AdO1-2", "AdO1-3", "AdO1-4"] +sampler = "all_inclusive" +``` + +The datasource: + +- filters to one garage, or uses the whole CSV when `garage = "all"`; +- preserves the configured `users` order; +- maps client IDs to users in that order (`client_id = 1` selects the first user); +- builds a continuous hourly grid for each user's active date range; +- marks `is_charging = 1` when a charging session overlaps an hour; +- accumulates `energy_kwh` over active charging hours; +- scales energy with the training-window maximum; +- adds cyclic time features for hour-of-day and day-of-week; +- splits valid sliding-window starts into train / validation / test windows. + +The model input feature order is: + +```text +is_charging, energy_scaled, hour_sin, hour_cos, dow_sin, dow_cos +``` + +The reference configs forecast only the first channel, `is_charging`. + +## Model choices + +### TimesFM + +Select TimesFM through the HuggingFace trainer: + +```toml +[trainer] +type = "HuggingFace" +model_name = "google/timesfm-2.5-200m-transformers" +model_type = "timesfm" +context_length = 672 +prediction_length = 128 +num_input_channels = 6 +prediction_channel_indices = [0] +freq = 0 +``` + +Supported reference variants include: + +- `google/timesfm-2.0-500m-pytorch` +- `google/timesfm-2.5-200m-pytorch` +- `google/timesfm-2.5-200m-transformers` + +The TimesFM reference configs use `prediction_length = 128` because the selected TimesFM checkpoints expose a fixed 128-step native horizon. + +### PatchTSMixer + +PatchTSMixer is useful as a smaller scratch baseline: + +```toml +[trainer] +type = "HuggingFace" +model_name = "patchtsmixer_scratch" +model_type = "patchtsmixer" +model_task = "forecasting" +context_length = 672 +prediction_length = 168 +num_input_channels = 6 +prediction_channel_indices = [0] +mode = "mix_channel" +``` + +Unlike the TimesFM wrapper's channel-independent path, the reference PatchTSMixer config uses `mode = "mix_channel"` so the model can use the time features jointly. + +## Run the reference configs + +Install the normal Plato environment first: + +```bash +uv sync +``` + +Then run one of the configs from the repository root. + +TimesFM 2.5 on the four AdO1 users: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging.toml +``` + +TimesFM 2.5 on selected high-data users across garages: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging_top4_mixed.toml +``` + +PatchTSMixer scratch baseline: + +```bash +uv run plato.py --config configs/TimeSeries/patchtsmixer_ev_charging.toml +``` + +Single-client TimesFM 2.5 transformers smoke run: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm_transformers_bl1.toml +``` + +## DiLoCo variant + +The branch also includes a TimesFM 2.5 + DiLoCo config: + +```bash +uv run plato.py --config configs/TimeSeries/timesfm25_ev_charging_top4_mixed_diloco.toml +``` + +The config uses: + +```toml +[server] +type = "diloco" + +[server.diloco] +outer_optimizer = "nesterov" +outer_learning_rate = 0.7 +outer_momentum = 0.9 +aggregation_weighting = "uniform" +apply_outer_optimizer_to = "parameters" + +[trainer] +local_steps_per_round = 1500 +preserve_optimizer_state = true +``` + +`local_steps_per_round` is counted in completed optimizer steps, not epochs. See the [DiLoCo design contract](../../development/diloco.md) for the mechanics behind this server type. + +## Result logging + +The time-series configs use MSE as the scalar test metric: + +```toml +[results] +types = "round, elapsed_time, mse" +``` + +A lower MSE is better. + +## Troubleshooting + +### Dataset file not found + +Make sure `data.datasource_path` points to the downloaded `dataset1_ev_charging_reports.csv`. Relative paths are resolved from the Plato repository root when using the reference commands above. + +### User not found + +If a configured user is missing, check the `garage` setting. Users from multiple garages require: + +```toml +garage = "all" +``` + +### TimesFM class not available + +TimesFM 2.5 requires a recent `transformers` version that exposes `TimesFm2_5ModelForPrediction`. If model import fails, update the environment and verify the class can be imported before launching a long run. + +### Metric looks like accuracy in old scripts + +For time-series runs, use configs that include: + +```toml +[results] +types = "round, elapsed_time, mse" +``` + +The server and client logs label the primary metric as MSE when the active trainer testing strategy reports `metric_name = "mse"`. + +## Related documentation + +- [Data configuration](../../configurations/data.md) +- [Trainer configuration](../../configurations/trainer.md) +- [Results logging](../../configurations/results.md) +- [DiLoCo design contract](../../development/diloco.md) diff --git a/docs/docs/index.md b/docs/docs/index.md index 1a105dbcf..4c44fd94f 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -39,6 +39,7 @@ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, a - **[SmolVLA Trainer with LeRobot](examples/case-studies/3. SmolVLA Trainer with LeRobot.md)** - **[Server-side Lighteval for SmolLM2](examples/case-studies/4. Server-side Lighteval for SmolLM2.md)** - **[Nanochat in Plato](examples/case-studies/5. Nanochat in Plato.md)** + - **[Time-Series Forecasting with TimesFM](examples/case-studies/6. Time-Series Forecasting with TimesFM.md)** ## Configuration Settings diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index f4177c8b4..4da2e01b2 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -72,6 +72,7 @@ nav: - SmolVLA Trainer with LeRobot: examples/case-studies/3. SmolVLA Trainer with LeRobot.md - Server-side Lighteval for SmolLM2: examples/case-studies/4. Server-side Lighteval for SmolLM2.md - Nanochat in Plato: examples/case-studies/5. Nanochat in Plato.md + - Time-Series Forecasting with TimesFM: examples/case-studies/6. Time-Series Forecasting with TimesFM.md - Configuration Settings: - Overview: configurations/overview.md - General: configurations/general.md