diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index abc2d20..2b59938 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -42,7 +42,7 @@ jobs: - name: Test with pytest run: | - pytest tests/ --ignore-glob='tests/test_ml_*.py' --cov=coco_pipe/ --cov-report=xml --verbose -s + pytest tests/ --cov=coco_pipe/ --cov-report=xml --verbose -s - name: Upload coverage reports to Codecov if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5490491..d64121e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: '^coco_pipe/decoding/fm_hub/cbramod_src/' repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 diff --git a/README.md b/README.md index 3330f3d..f95d73e 100644 --- a/README.md +++ b/README.md @@ -42,127 +42,77 @@ Whether you're conducting clinical research, developing ML models for brain-comp For detailed development instructions, please see [CONTRIBUTING.md](CONTRIBUTING.md). -## Using the ML Module +## Using the Decoding Module -CoCo Pipe provides two main ways to use the ML module: +The supported modeling API is `coco_pipe.decoding.Experiment`. It is array-first: +prepare `X` and `y` explicitly, then pass optional sample IDs, groups, feature +names, and time labels when they matter for the analysis. -### 1. Direct Python API Usage +```python +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding.configs import ( + CVConfig, + FeatureSelectionConfig, + LogisticRegressionConfig, + TuningConfig, +) -You can use the ML module directly in your Python scripts by importing from `coco_pipe.io` for data loading and `coco_pipe.ml` for machine learning pipelines: +config = ExperimentConfig( + task="classification", + models={"logreg": LogisticRegressionConfig(max_iter=500)}, + metrics=["accuracy", "roc_auc"], + cv=CVConfig(strategy="stratified", n_splits=5, shuffle=True, random_state=42), + feature_selection=FeatureSelectionConfig( + enabled=True, + method="k_best", + n_features=20, + scoring="f_classif", + ), + tuning=TuningConfig( + enabled=True, + param_grid={"model__C": [0.1, 1.0, 10.0]}, + scoring="roc_auc", + cv=CVConfig(strategy="stratified", n_splits=3, shuffle=True, random_state=42), + ), + n_jobs=1, +) -```python -from coco_pipe.io import load_data -from coco_pipe.ml import MLPipeline - -# Load your data into the canonical package container -container = load_data( - "data/your_dataset.csv", - mode="tabular", - target_col="target_class", - sep=",", +result = Experiment(config).run( + X, + y, + groups=subject_ids, + sample_ids=trial_ids, + feature_names=feature_names, ) -# Select a subset explicitly from the container when needed -container = container.select(feature=["feat1", "feat2"], y=["case", "control"]) -X = container.X -y = container.y - -# Configure and run ML pipeline -config = { - "task": "classification", # or 'regression' - "analysis_type": "baseline", # Options: 'baseline', 'feature_selection', 'hp_search', 'hp_search_fs' - "models": "all", # or list of specific models - "metrics": ["accuracy", "f1-score"], - "cv_strategy": "stratified", - "n_splits": 5, - "n_features": 10, # For feature selection - "direction": "forward", # For feature selection - "search_type": "grid", # For hyperparameter search - "n_iter": 100, # For random search - "scoring": "accuracy", - "n_jobs": -1 -} - -pipeline = MLPipeline(X=X, y=y, config=config) -results = pipeline.run() +summary = result.summary() +predictions = result.get_predictions() +splits = result.get_splits() +selected = result.get_selected_features() ``` -### 2. Using the CLI Tool - -For batch processing or experiment management, use the CLI tool with a YAML configuration file: - -```yaml -# ----------------------------------------------------------------------------- -# Toy config for MLPipeline -# ----------------------------------------------------------------------------- - -# Global parameters shared across analyses -global_experiment_id: "toy_ml_config" -data_path: "../datasets/toy_dataset.csv" -results_dir: "../results" -results_file: "toy_ml_config" - -# Default analysis parameters (can be overridden per analysis) -defaults: - random_state: 42 - n_jobs: -1 - cv_kwargs: - strategy: "stratified" - n_splits: 5 - shuffle: true - random_state: 42 - covariates: ["age"] - spatial_units: ["regionX", "regionY"] - feature_names: ["feat1", "feat2", "feat3"] - -# List of analyses to run -analyses: - - id: "classification_baseline" - task: "classification" - analysis_type: "baseline" - target_columns: ["target_class"] - row_filter: - - column: "age" - values: 13 - operator: ">" - - column: "sex" - values: ["male"] - models: - - "Logistic Regression" - - "Random Forest" - metrics: - - "accuracy" - - "roc_auc" - - - id: "regression_hp_search" - task: "regression" - analysis_type: "hp_search" - target_columns: ["target_reg"] - feature_names: ["feat1"] - spatial_units: ["regionX"] - models: "all" - metrics: - - "r2" - - "neg_mse" - cv_kwargs: - strategy: "kfold" - n_splits: 3 - search_type: "grid" - n_iter: 20 - scoring: "r2" -``` +For grouped EEG studies, make the outer and inner CV decisions explicit: -Run the analysis using: +```python +config = ExperimentConfig( + task="classification", + models={"logreg": LogisticRegressionConfig(max_iter=500)}, + metrics=["accuracy"], + cv=CVConfig(strategy="group_kfold", n_splits=5), + tuning=TuningConfig( + enabled=True, + param_grid={"model__C": [0.1, 1.0, 10.0]}, + scoring="accuracy", + cv=CVConfig(strategy="group_kfold", n_splits=3), + ), +) -```bash -python scripts/run_ml.py --config configs/your_config.yml +result = Experiment(config).run(X, y, groups=subject_ids) ``` -The pipeline will: -- Load and preprocess your data -- Run all specified analyses -- Save results for each model/analysis -- Generate a combined results file +See the decoding documentation for feature selection, temporal decoding, result +tables, plotting helpers, and report integration. Batch decoding CLIs are not +part of the public surface yet; use the Python API for now. ## Documentation @@ -179,12 +129,6 @@ Contributions are welcome! If you have suggestions or find any bugs, please open - Implement CSV loading and M/EEG data loading functionalities. - Develop comprehensive unit tests. -#### ML Module -- Restructure to mirror the design of the dim_reduction module. -- Consolidate scripts within the main pipeline. -- Add regression support and enhance cross-validation methods. -- Update and expand unit tests. - #### DL Module - Define and implement deep learning functionalities. - Create corresponding unit tests. diff --git a/coco_pipe/decoding/__init__.py b/coco_pipe/decoding/__init__.py index 3215082..52d3f2e 100644 --- a/coco_pipe/decoding/__init__.py +++ b/coco_pipe/decoding/__init__.py @@ -1,12 +1,67 @@ -from .configs import ExperimentConfig -from .core import Experiment -from .registry import get_estimator_cls, register_estimator -from .utils import cross_validate_score +""" +Decoding Module +=============== + +Core module for scientific decoding and machine learning experiments on +electrophysiological and behavioral data. +""" + +from .configs import ( + CheckpointConfig, + ClassicalModelConfig, + DeviceConfig, + ExperimentConfig, + FoundationEmbeddingModelConfig, + FrozenBackboneDecoderConfig, + LoRAConfig, + NeuralFineTuneConfig, + QuantizationConfig, + StatisticalAssessmentConfig, + TemporalDecoderConfig, + TrainerConfig, + TrainStageConfig, +) +from .experiment import Experiment +from .registry import ( + EstimatorCapabilities, + get_capabilities, + list_capabilities, + register_estimator, + register_estimator_spec, +) +from .result import ExperimentResult +from .stats import ( + aggregate_predictions_for_inference, + binomial_accuracy_test, + run_statistical_assessment, +) __all__ = [ + # Configs "ExperimentConfig", - "register_estimator", - "get_estimator_cls", + "ClassicalModelConfig", + "FoundationEmbeddingModelConfig", + "FrozenBackboneDecoderConfig", + "NeuralFineTuneConfig", + "TemporalDecoderConfig", + "LoRAConfig", + "QuantizationConfig", + "DeviceConfig", + "CheckpointConfig", + "TrainerConfig", + "TrainStageConfig", + "StatisticalAssessmentConfig", + # Execution "Experiment", - "cross_validate_score", + "ExperimentResult", + # Model Discovery & Metadata + "register_estimator", + "register_estimator_spec", + "get_capabilities", + "list_capabilities", + "EstimatorCapabilities", + # Stats Utilities + "run_statistical_assessment", + "binomial_accuracy_test", + "aggregate_predictions_for_inference", ] diff --git a/coco_pipe/decoding/_cache.py b/coco_pipe/decoding/_cache.py new file mode 100644 index 0000000..f828c54 --- /dev/null +++ b/coco_pipe/decoding/_cache.py @@ -0,0 +1,82 @@ +""" +Cache-key helpers for decoding feature extraction. +================================================== + +The decoding module uses these helpers to generate stable, split-safe keys for +caching intermediate artifacts like embeddings or fitted preprocessing steps. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Sequence + + +def make_feature_cache_key( + train_sample_ids: Sequence[Any], + test_sample_ids: Sequence[Any], + preprocessing_fingerprint: str, + backbone_fingerprint: str, + extra_metadata: dict[str, Any] | None = None, + sort_ids: bool = True, +) -> str: + """ + Build a stable cache key for split-specific feature extraction artifacts. + + This generates a SHA256 hex digest of a JSON-serialized payload containing + the identities of the train/test samples and the configuration of the + preprocessing and backbone modules. This ensures that fitted transforms + or extracted embeddings cannot be reused for incompatible splits or + different model configurations, preventing data leakage and silent errors. + + The sample IDs are converted to strings to ensure stability across + different ID types. By default, IDs are sorted to ensure the cache key + is order-insensitive. If order-dependent preprocessing is used, + set `sort_ids=False`. + + Parameters + ---------- + train_sample_ids : Sequence[Any] + Sample IDs identifying the training fold. + test_sample_ids : Sequence[Any] + Sample IDs identifying the test/validation fold. + preprocessing_fingerprint : str + A unique hash or string representing the preprocessing configuration. + backbone_fingerprint : str + A unique hash or string representing the model/extractor configuration. + extra_metadata : dict[str, Any], optional + Additional dimensions that affect the output (e.g., time indices, + target labels, or stage names). Default is None. + sort_ids : bool, default=True + Whether to sort the sample IDs before hashing. Sorting makes the + key order-insensitive, which is usually desired for reproducibility. + + Returns + ------- + key : str + The SHA256 hex digest of the normalized JSON payload. + """ + # 1. Normalize identifiers + train_ids = [str(value) for value in train_sample_ids] + test_ids = [str(value) for value in test_sample_ids] + + if sort_ids: + train_ids.sort() + test_ids.sort() + + payload = { + "train_sample_ids": train_ids, + "test_sample_ids": test_ids, + "preprocessing_fingerprint": preprocessing_fingerprint, + "backbone_fingerprint": backbone_fingerprint, + } + + # 2. Handle metadata path (explicit for coverage) + if extra_metadata is not None: + payload["extra_metadata"] = extra_metadata + else: + payload["extra_metadata"] = {} + + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode() + return hashlib.sha256(encoded).hexdigest() diff --git a/coco_pipe/decoding/_constants.py b/coco_pipe/decoding/_constants.py new file mode 100644 index 0000000..d2461c6 --- /dev/null +++ b/coco_pipe/decoding/_constants.py @@ -0,0 +1,119 @@ +""" +Decoding Constants +================== + +Shared literal types and constants for the decoding module. This module +centralizes the taxonomy used for models, metrics, and experimental +configurations to ensure scientific consistency and runtime safety. +""" + +from typing import Literal + +# --- Cross-Module Constants --- + +GROUP_CV_STRATEGIES = { + "group_kfold", + "stratified_group_kfold", + "leave_p_out", + "leave_one_group_out", + "group_shuffle_split", +} +"""Set of scikit-learn CV strategies that require a 'groups' array.""" + +CLASSICAL_FAMILIES = { + "linear", + "tree", + "ensemble", + "svm", + "neighbors", + "bayes", + "dummy", +} +"""Architectural families that follow standard scikit-learn API patterns. +(e.g., scaling, feature selection). +""" + +RESULT_SCHEMA_VERSION = "decoding_result_v1" +"""Version identifier for the serialized ExperimentResult payload.""" + +# --- Common Literal Types --- + +MetricTask = Literal["classification", "regression"] +"""Type of predictive task being performed.""" + +ResponseMethod = Literal["predict", "proba", "score", "proba_or_score"] +"""The required estimator method to produce predictions for a metric.""" + +PredictionInterface = Literal["predict", "predict_proba", "decision_function"] +"""The actual scikit-learn API method names used for prediction.""" + +InputRank = Literal["2d", "3d_temporal", "tokens"] +"""The dimensionality/rank of the input data X.""" + +InputKind = Literal[ + "tabular", + "temporal", + "epoched", + "embeddings", + "tokens", + "tabular_2d", + "embedding_2d", + "temporal_3d", +] +"""Semantic classification of the input data structure.""" + +EstimatorFamily = Literal[ + "linear", + "tree", + "ensemble", + "svm", + "neighbors", + "neural", + "bayes", + "dummy", + "temporal", + "foundation", +] +"""High-level architectural family of the estimator.""" + +GroupedMetadata = Literal["none", "search_cv", "sfs_metadata_routing"] +"""Types of metadata generated by meta-estimators (e.g., GridSearch).""" + +FeatureSelectionSupport = Literal["univariate", "sfs", "disabled"] +"""Level of feature selection supported or enabled.""" + +CalibrationSupport = Literal["eligible", "already_probabilistic", "unsupported"] +"""Whether an estimator supports post-hoc probability calibration.""" + +ImportanceSupport = Literal[ + "coefficients", + "feature_importances", + "permutation", + "saliency", + "unavailable", +] +"""The mechanism used to extract feature importance or weights.""" + +TemporalSupport = Literal["none", "sliding", "generalizing", "native"] +"""The level of temporal decoding logic supported by the model.""" + +DependencyGroup = Literal[ + "core", + "mne", + "torch", + "braindecode", + "transformers", + "peft", + "quant", +] +"""Optional dependency groups for lazy-loading and capability checking.""" + +MetricFamily = Literal[ + "label", + "score_probability", + "threshold_sweep", + "calibration", + "confusion", + "regression", +] +"""Categorization of metrics for reporting and diagnostic grouping.""" diff --git a/coco_pipe/decoding/_diagnostics.py b/coco_pipe/decoding/_diagnostics.py new file mode 100644 index 0000000..bdf31b7 --- /dev/null +++ b/coco_pipe/decoding/_diagnostics.py @@ -0,0 +1,673 @@ +""" +Decoding Diagnostics & Tidy Data Helpers +======================================== +Functions for expanding and tidying raw decoding results into DataFrames. +""" + +from typing import Any, Dict, Iterator, Optional, Sequence + +import numpy as np +import pandas as pd + +from ._metrics import get_metric_spec + + +def time_value(index: int, time_axis: Optional[Sequence[Any]]) -> Any: + """ + Map a raw integer index to a meaningful scientific time value. + + This helper ensures that temporal decoding results (from sliding or + generalizing estimators) are human-readable by aligning array indices + with the actual experiment time points (e.g., mapping index 0 to -0.2s). + + Parameters + ---------- + index : int + The raw integer index from the results array. + time_axis : Sequence[Any], optional + A sequence of time values (e.g., a numpy array of seconds) + corresponding to the temporal dimension of the data. + + Returns + ------- + Any + The scientific time value if the axis is provided and index is in range; + otherwise, returns the raw index. + + Examples + -------- + >>> time_value(0, [-0.2, -0.1, 0.0]) + -0.2 + >>> time_value(5, None) + 5 + """ + if time_axis is None or index >= len(time_axis): + return index + return time_axis[index] + + +def score_rows( + model: str, + fold_idx: int, + metric: str, + score: Any, + time_axis: Optional[Sequence[Any]] = None, +) -> list[Dict[str, Any]]: + """ + Expand scalar or temporal fold scores into tidy data rows. + + This function unrolls raw result arrays into a flat list of dictionaries, + automatically handling three distinct scientific patterns: + 1. Scalar: Standard decoding (1 row per fold/metric). + 2. 1D Array: Sliding Estimator (N rows per fold, mapping 'Time'). + 3. 2D Array: Generalizing Estimator (N*M rows, mapping 'TrainTime' and 'TestTime'). + + Parameters + ---------- + model : str + Name of the estimator. + fold_idx : int + The cross-validation fold index. + metric : str + The name of the scoring metric. + score : Any + The raw score (float, 1D array, or 2D array). + time_axis : Sequence[Any], optional + The scientific time points for temporal mapping. Default is None. + + Returns + ------- + list[Dict[str, Any]] + A list of "tidy" rows ready for DataFrame conversion. + + Examples + -------- + >>> score_rows("svc", 0, "accuracy", 0.8) + [{'Model': 'svc', 'Fold': 0, 'Metric': 'accuracy', 'Value': 0.8}] + """ + score = np.asarray(score) + rows = [] + + if score.ndim == 0: + return [ + { + "Model": model, + "Fold": fold_idx, + "Metric": metric, + "Value": float(score), + } + ] + + if score.ndim == 1: + for t_idx, val in enumerate(score): + rows.append( + { + "Model": model, + "Fold": fold_idx, + "Metric": metric, + "Time": time_value(t_idx, time_axis), + "Value": val, + } + ) + return rows + + if score.ndim == 2: + for t_tr in range(score.shape[0]): + for t_te in range(score.shape[1]): + rows.append( + { + "Model": model, + "Fold": fold_idx, + "Metric": metric, + "TrainTime": time_value(t_tr, time_axis), + "TestTime": time_value(t_te, time_axis), + "Value": score[t_tr, t_te], + } + ) + return rows + + return [{"Model": model, "Fold": fold_idx, "Metric": metric, "Value": score}] + + +def prediction_rows( + model: str, + fold_idx: int, + preds: Dict[str, Any], + time_axis: Optional[Sequence[Any]] = None, +) -> list[Dict[str, Any]]: + """ + Expand raw predictions from a results dictionary into tidy data rows. + + This function is the primary engine for converting raw estimator outputs + into analyzable DataFrames. It automatically handles the expansion of + standard, sliding (temporal), and generalizing (train-time x test-time) + predictions into a flat, tabular format while aligning metadata. + + Parameters + ---------- + model : str + Name of the estimator. + fold_idx : int + The cross-validation fold index. + preds : Dict[str, Any] + Raw predictions dictionary containing 'y_true', 'y_pred', + and optionally 'y_proba', 'sample_id', and 'sample_metadata'. + time_axis : Sequence[Any], optional + Scientific time points for coordinate mapping. Default is None. + + Returns + ------- + list[Dict[str, Any]] + A list of "tidy" records (list of dictionaries). + + Examples + -------- + >>> preds = {"y_true": [0], "y_pred": [0], "sample_id": ["s1"]} + >>> # result = prediction_rows("svc", 0, preds) + """ + y_true = np.asarray(preds["y_true"]) + y_pred = np.asarray(preds["y_pred"]) + y_proba = np.asarray(preds["y_proba"]) if "y_proba" in preds else None + y_score = np.asarray(preds["y_score"]) if "y_score" in preds else None + + n_samples = len(y_true) + sample_index = np.asarray(preds.get("sample_index", np.arange(n_samples))) + sample_id = np.asarray(preds.get("sample_id", sample_index)) + groups = optional_values(preds.get("group"), n_samples) + metadata = preds.get("sample_metadata") or {} + + # 1. Determine Expansion Factor (n_exp) and Temporal Coordinates + pattern = "standard" + n_exp = 1 + coords = {} + + if y_pred.ndim == 2: # Sliding + pattern = "sliding" + n_times = y_pred.shape[1] + n_exp = n_times + coords["Time"] = np.tile( + [time_value(t, time_axis) for t in range(n_times)], n_samples + ) + elif y_pred.ndim == 3: # Generalizing + pattern = "generalizing" + n_tr, n_te = y_pred.shape[1], y_pred.shape[2] + n_exp = n_tr * n_te + tr_vals = [time_value(t, time_axis) for t in range(n_tr)] + te_vals = [time_value(t, time_axis) for t in range(n_te)] + coords["TrainTime"] = np.tile(np.repeat(tr_vals, n_te), n_samples) + coords["TestTime"] = np.tile(np.tile(te_vals, n_tr), n_samples) + + # 2. Vectorized Backbone Construction + data = { + "Model": [model] * (n_samples * n_exp), + "Fold": [fold_idx] * (n_samples * n_exp), + "SampleIndex": np.repeat(sample_index, n_exp), + "SampleID": np.repeat(sample_id, n_exp), + "Group": np.repeat(groups, n_exp), + "y_true": np.repeat([row_value(y_true, i) for i in range(n_samples)], n_exp), + "y_pred": y_pred.ravel(), + } + + # Add temporal coordinates + data.update(coords) + + # Inject metadata (broadcasted) + for key, values in metadata.items(): + v_arr = np.asarray(values, dtype=object) + data[key] = np.repeat(v_arr[:n_samples], n_exp) + + df = pd.DataFrame(data) + + # 3. Add Probabilities (Standard / Sliding / Generalizing) + if y_proba is not None: + if pattern == "standard": + if y_proba.ndim == 1: + df["y_proba_0"] = 1.0 - y_proba + df["y_proba_1"] = y_proba + elif y_proba.ndim == 2: + for c in range(y_proba.shape[1]): + df[f"y_proba_{c}"] = y_proba[:, c] + elif pattern == "sliding": + # 2D: (samples, times) -> binary, return proba for class 1 + if y_proba.ndim == 2: + df["y_proba_0"] = 1.0 - y_proba.ravel() + df["y_proba_1"] = y_proba.ravel() + # 3D: (samples, times, classes) -> multiclass + elif y_proba.ndim == 3: + # MNE SlidingEstimator returns (samples, times, classes) + for c in range(y_proba.shape[2]): + df[f"y_proba_{c}"] = y_proba[:, :, c].ravel() + elif pattern == "generalizing": + # 3D: (samples, tr, te) -> binary + if y_proba.ndim == 3: + df["y_proba_0"] = 1.0 - y_proba.ravel() + df["y_proba_1"] = y_proba.ravel() + # 4D: (samples, tr, te, classes) -> multiclass + elif y_proba.ndim == 4: + # MNE GeneralizingEstimator returns (samples, tr, te, classes) + for c in range(y_proba.shape[3]): + df[f"y_proba_{c}"] = y_proba[:, :, :, c].ravel() + + # 4. Add Decision Scores + if y_score is not None: + if pattern == "standard": + if y_score.ndim == 1: + df["y_score"] = y_score + elif y_score.ndim == 2: + for c in range(y_score.shape[1]): + df[f"y_score_{c}"] = y_score[:, c] + else: # Temporal scores are less common but supported + df["y_score"] = y_score.ravel() + + return df.to_dict(orient="records") + + +def row_value(values: np.ndarray, row_idx: int) -> Any: + """ + Extract a value from an array while ensuring JSON serialization safety. + + This helper extracts a single row/item and converts any nested NumPy + arrays into standard Python lists. This is critical for ensuring that + the final tidy records can be serialized to JSON without errors. + + Parameters + ---------- + values : np.ndarray + The source array (e.g., y_true or metadata). + row_idx : int + The index to extract. + + Returns + ------- + Any + The extracted value, converted to a list if it was an array. + """ + val = values[row_idx] + if isinstance(val, np.ndarray): + return val.tolist() + return val + + +def optional_values(values: Optional[Any], length: int) -> np.ndarray: + """ + Ensure a sequence exists and has the correct length for broadcasting. + + This helper is used during tidy data expansion to handle optional + fields (like 'group' or custom metadata). If the source is None, + it generates a 'ghost' array of Nones, ensuring that downstream + vectorized operations (like np.repeat) do not crash. + + Parameters + ---------- + values : Any, optional + The source sequence or None. + length : int + The required length of the resulting array. + + Returns + ------- + np.ndarray + An array of the specified length. + """ + if values is None: + return np.full(length, None, dtype=object) + return np.asarray(values) + + +def proba_matrix(group: pd.DataFrame, n_classes: int) -> Optional[np.ndarray]: + """ + Re-assemble a probability matrix from tidy prediction columns. + + This is an "inverse tidy" operation used by statistical assessment + routines. It finds columns like 'y_proba_0', 'y_proba_1', etc., + and packs them back into a single 2D NumPy array. + + Parameters + ---------- + group : pd.DataFrame + A DataFrame containing prediction rows. + n_classes : int + The expected number of classes (determines how many columns to look for). + + Returns + ------- + np.ndarray, optional + A (n_samples, n_classes) matrix if all columns are present + and valid; otherwise None. + + Examples + -------- + >>> df = pd.DataFrame({"y_proba_0": [0.8, 0.2], "y_proba_1": [0.2, 0.8]}) + >>> proba_matrix(df, 2) + array([[0.8, 0.2], + [0.2, 0.8]]) + """ + cols = [f"y_proba_{idx}" for idx in range(n_classes)] + if not all(c in group for c in cols): + return None + pb = group[cols] + if pb.isna().any().any(): + return None + return pb.to_numpy(dtype=float) + + +def unit_indices(group: pd.DataFrame, unit: str) -> list[np.ndarray]: + """ + Identify row-index blocks for unit-based resampling. + + In scientific assessment, samples are often non-independent (e.g., + multiple trials from the same subject). This helper identifies the + blocks of indices belonging to each independent unit, enabling + 'Block Bootstrapping' or unit-level permutations. + + Parameters + ---------- + group : pd.DataFrame + The tidy prediction DataFrame. + unit : str + The level of independence (e.g., 'sample', 'subject', 'session'). + + Returns + ------- + list[np.ndarray] + A list of index arrays, one per independent unit. + + Raises + ------ + ValueError + If the unit is unknown or if required columns are missing/empty. + """ + return _get_unit_blocks(group, unit) + + +def paired_unit_indices(merged: pd.DataFrame, unit: str) -> list[np.ndarray]: + """ + Identify row-index blocks for paired unit-based resampling. + + Used specifically for model comparisons. Ensures that when + performing paired permutation tests, the indices for both models + are retrieved from a merged DataFrame while maintaining pairing + consistency (e.g., Subject 1 in Model A paired with Subject 1 in Model B). + + Parameters + ---------- + merged : pd.DataFrame + The merged tidy DataFrame containing results for two models. + unit : str + The level of independence. + + Returns + ------- + list[np.ndarray] + A list of index arrays for the units in the merged Frame. + + Raises + ------ + ValueError + If the unit is unknown or if required columns are missing/empty. + """ + return _get_unit_blocks(merged, unit, suffix="_A") + + +def _get_unit_blocks(df: pd.DataFrame, unit: str, suffix: str = "") -> list[np.ndarray]: + """Identify blocks of row indices belonging to the same unit.""" + values = _resolve_unit_values(df, unit, suffix) + return [np.flatnonzero(values == v) for v in pd.unique(values)] + + +def _resolve_unit_values(df: pd.DataFrame, unit: str, suffix: str = "") -> np.ndarray: + """Extract the column identifying individual units with fallback logic.""" + valid_units = {"sample", "epoch", "group", "subject", "session", "site"} + if unit not in valid_units: + raise ValueError(f"unit must be one of {valid_units}, got '{unit}'.") + + if unit in {"sample", "epoch"}: + col = "SampleID" + elif unit == "group": + col = f"Group{suffix}" + else: + col = f"{unit.capitalize()}{suffix}" + + if col not in df or df[col].isna().all(): + raise ValueError(f"unit='{unit}' requires a non-empty '{col}' column.") + + return df[col].to_numpy() + + +def score_frame(frame: pd.DataFrame, metric: str) -> float: + """ + Score a tidy prediction frame using the specified metric. + + This dispatcher automatically routes the correct columns from the tidy + DataFrame to the underlying Scikit-Learn scorer based on the metric's + required response method (labels, probabilities, or decision scores). + + Parameters + ---------- + frame : pd.DataFrame + The tidy prediction DataFrame containing 'y_true', 'y_pred', + and optional 'y_proba_X' or 'y_score' columns. + metric : str + The name of the metric to compute (e.g., 'roc_auc', 'accuracy'). + + Returns + ------- + float + The calculated scientific score. + + Raises + ------ + ValueError + If the required columns for the metric are missing (e.g., scoring + ROC-AUC without probabilities) or if binary-only metrics are + applied to multiclass data. + """ + metric_spec = get_metric_spec(metric) + y_true = frame["y_true"].to_numpy() + + # 1. Label-based metrics + if metric_spec.response_method == "predict": + return float(metric_spec.scorer(y_true, frame["y_pred"].to_numpy())) + + # 2. Probability-based metrics + proba_cols = sorted( + [col for col in frame.columns if col.startswith("y_proba_")], + key=lambda value: int(value.rsplit("_", 1)[-1]), + ) + if metric_spec.response_method in {"proba", "proba_or_score"} and proba_cols: + proba = frame[proba_cols].to_numpy(dtype=float) + labels = sorted(pd.unique(y_true).tolist()) + if metric == "brier_score": + if proba.shape[1] != 2: + raise ValueError("brier_score supports binary classification only.") + return float(metric_spec.scorer(y_true, proba[:, 1])) + + if metric == "roc_auc" and proba.shape[1] == 2: + return float(metric_spec.scorer(y_true, proba[:, 1])) + + if metric in {"average_precision", "pr_auc"}: + if proba.shape[1] == 2: + return float(metric_spec.scorer(y_true, proba[:, 1])) + # AP multiclass: requires binarized y_true + from sklearn.preprocessing import label_binarize + + y_bin = label_binarize(y_true, classes=labels) + return float(metric_spec.scorer(y_bin, proba)) + + if metric == "log_loss": + return float(metric_spec.scorer(y_true, proba, labels=labels)) + + # Fallback for custom multiclass metrics that support multi_class='ovr'. + # Standard metrics like average_precision are handled explicitly above + # as they require custom binarization logic. + return float(metric_spec.scorer(y_true, proba, multi_class="ovr")) + + # 3. Decision-function based metrics + if ( + metric_spec.response_method in {"score", "proba_or_score"} + and "y_score" in frame + ): + return float(metric_spec.scorer(y_true, frame["y_score"].to_numpy(dtype=float))) + + raise ValueError(f"Metric '{metric}' cannot be scored from available predictions.") + + +def scalar_prediction_frame(preds: pd.DataFrame) -> pd.DataFrame: + """ + Filter a prediction DataFrame to include only scalar results. + + Excludes rows with temporal coordinates (Time, TrainTime, TestTime), + which is a common requirement for standard non-temporal diagnostics. + + Parameters + ---------- + preds : pd.DataFrame + The input prediction DataFrame. + + Returns + ------- + pd.DataFrame + A filtered DataFrame containing only scalar (standard) results. + """ + if preds.empty: + return preds + + mask = pd.Series(True, index=preds.index) + for col in ["Time", "TrainTime", "TestTime"]: + if col in preds: + mask &= preds[col].isna() + + return preds[mask] + + +def confusion_matrix_frame( + preds: pd.DataFrame, + labels: Sequence[Any], + normalize: Optional[str] = None, + group_cols: Optional[list[str]] = None, +) -> pd.DataFrame: + """ + Compute and tidy confusion matrices from a prediction frame. + + Parameters + ---------- + preds : pd.DataFrame + Prediction frame containing 'y_true' and 'y_pred'. + labels : Sequence[Any] + The set of class labels to include in the matrix. + normalize : {'true', 'pred', 'all'}, optional + Normalization strategy for the matrix. + group_cols : list[str], optional + Columns to group by (e.g., ['Model', 'Fold']). + + Returns + ------- + pd.DataFrame + A tidy DataFrame with grouping columns, TrueLabel, PredictedLabel, and Value. + """ + from sklearn.metrics import confusion_matrix + + group_cols = group_cols or ["Model", "Fold"] + frames = [] + + for names, group in preds.groupby(group_cols): + matrix = confusion_matrix( + group["y_true"], group["y_pred"], labels=labels, normalize=normalize + ) + df_m = pd.DataFrame(matrix, index=labels, columns=labels) + df_m.index.name = "TrueLabel" + df_m.columns.name = "PredictedLabel" + df_m = df_m.stack().reset_index(name="Value") + + # Map group column names to their values + if isinstance(names, (list, tuple)): + for col, val in zip(group_cols, names): + df_m[col] = val + else: + df_m[group_cols[0]] = names + frames.append(df_m) + + if not frames: + return pd.DataFrame( + columns=group_cols + ["TrueLabel", "PredictedLabel", "Value"] + ) + + return pd.concat(frames, ignore_index=True)[ + group_cols + ["TrueLabel", "PredictedLabel", "Value"] + ] + + +def curve_score_groups( + preds: pd.DataFrame, + model: Optional[str] = None, + require_probability: bool = False, + pos_label: Optional[Any] = None, +) -> Iterator[tuple[str, int, Any, np.ndarray, np.ndarray]]: + """ + Yield binary or one-vs-rest score groups for curve plotting. + + This helper handles the complexity of resolving positive labels, + identifying probability columns, and falling back to decision scores + across multiple models and folds. + + Yields + ------ + model_name : str + fold_idx : int + class_label : Any + y_binary : np.ndarray + y_score : np.ndarray + """ + if model is not None: + preds = preds[preds["Model"] == model] + + if preds.empty: + return + + for (m_name, f_idx), group in preds.groupby(["Model", "Fold"]): + y_true = group["y_true"].to_numpy() + unique_labels = sorted(pd.unique(y_true).tolist()) + if len(unique_labels) < 2: + continue + + # Binary Case + if len(unique_labels) == 2: + if pos_label is not None: + if pos_label not in unique_labels: + continue + target_label = pos_label + else: + target_label = unique_labels[1] + + l_idx = unique_labels.index(target_label) + p_col = f"y_proba_{l_idx}" + + if p_col in group and group[p_col].notna().all(): + y_score = group[p_col].to_numpy(dtype=float) + elif ( + not require_probability + and "y_score" in group + and group["y_score"].notna().all() + ): + y_score = group["y_score"].to_numpy(dtype=float) + if l_idx == 0: + y_score = -y_score + else: + continue + yield m_name, f_idx, target_label, (y_true == target_label), y_score + continue + + # Multiclass Case (One-vs-Rest) + for c_idx, label in enumerate(unique_labels): + p_col, s_col = f"y_proba_{c_idx}", f"y_score_{c_idx}" + if p_col in group and group[p_col].notna().all(): + y_score = group[p_col].to_numpy(dtype=float) + elif ( + not require_probability + and s_col in group + and group[s_col].notna().all() + ): + y_score = group[s_col].to_numpy(dtype=float) + else: + continue + yield m_name, f_idx, label, (y_true == label), y_score diff --git a/coco_pipe/decoding/_engine.py b/coco_pipe/decoding/_engine.py new file mode 100644 index 0000000..9c7c9e4 --- /dev/null +++ b/coco_pipe/decoding/_engine.py @@ -0,0 +1,719 @@ +""" +Decoding Engine +=============== +Functions for fitting, scoring, and metadata extraction. + +This module provides the core execution logic for cross-validation folds. +It is designed for high-performance, parallel execution. +""" + +import logging +import time +import warnings +from contextlib import nullcontext +from typing import Any, Callable, Dict, Optional, Sequence, Union + +import joblib +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator +from sklearn.feature_selection import SequentialFeatureSelector +from sklearn.pipeline import Pipeline +from sklearn.utils.multiclass import type_of_target + +from ._constants import GROUP_CV_STRATEGIES +from ._metrics import get_metric_spec +from ._splitters import _CVWithGroups, cv_uses_groups, get_cv_splitter +from .interfaces import NeuralTrainable + +logger = logging.getLogger(__name__) + + +class GroupedSequentialFeatureSelector(SequentialFeatureSelector): + """ + SequentialFeatureSelector that accepts groups via Pipeline fit parameters. + + In sklearn 1.8, SFS can route metadata to its internal CV splitter when + called directly, but Pipeline calls intermediate transformers through + ``fit_transform`` and does not forward top-level ``groups`` to SFS. This + adapter keeps the public SFS behavior but accepts a sliced ``groups`` fit + parameter, binds it to the SFS CV for this fit call, and then restores the + original ``cv`` object. + """ + + def fit(self, X: Any, y: Any = None, groups: Any = None, **params: Any): + """Fit SFS, using ``groups`` for its internal grouped CV if supplied.""" + if groups is None: + return super().fit(X, y, **params) + + groups_arr = ( + groups + if isinstance(groups, (np.ndarray, pd.Series)) + else np.asarray(groups) + ) + if len(groups_arr) != len(X): + raise ValueError( + "SequentialFeatureSelector groups length does not match X: " + f"{len(groups_arr)} != {len(X)}." + ) + + original_cv = self.cv + self.cv = _CVWithGroups(original_cv, groups_arr) + try: + return super().fit(X, y, **params) + finally: + self.cv = original_cv + + def fit_transform( + self, + X: Any, + y: Any = None, + groups: Any = None, + **params: Any, + ): + """Fit to data, then transform it.""" + return self.fit(X, y, groups=groups, **params).transform(X) + + +def fit_and_score_fold( + estimator: BaseEstimator, + X: np.ndarray, + y: np.ndarray, + groups: Optional[np.ndarray], + sample_ids: np.ndarray, + sample_metadata: Optional[Dict[str, np.ndarray]], + train_idx: np.ndarray, + test_idx: np.ndarray, + metrics: Sequence[str], + feature_selection_config: Any, + calibration_config: Any, + spec: Any, + tuning_config: Any = None, + feature_names: Optional[list[str]] = None, + search_enabled: bool = False, + force_serial: bool = False, +) -> Dict[str, Any]: + """ + Execute a single Cross-Validation fold: Fit, Predict, and Score. + + This function is designed to be pure and standalone, making it safe for + parallel execution via joblib. It handles the entire lifecycle of a fold, + including feature selection, calibration, and metadata extraction. + + Parameters + ---------- + estimator : BaseEstimator + The un-fitted estimator instance (or pipeline) for this fold. + X : np.ndarray + The full feature matrix of shape (n_samples, n_features). + y : np.ndarray + The full target vector of shape (n_samples,). + groups : np.ndarray, optional + Group labels (e.g., Subject IDs) for group-aware splitting. + sample_ids : np.ndarray + Unique IDs for each sample, used for tracking predictions. + sample_metadata : dict, optional + Pre-converted metadata dictionary (column: numpy array) for the split. + train_idx : np.ndarray + The indices of X/y to use for training. + test_idx : np.ndarray + The indices of X/y to use for testing. + metrics : Sequence[str] + List of metric names to compute (e.g., ['accuracy', 'roc_auc']). + feature_selection_config : Any + Configuration for the feature selection step (from CVConfig). + calibration_config : Any + Configuration for probability calibration (from CVConfig). + spec : EstimatorSpec + Hardened registry specification for the model. + tuning_config : Any, optional + Hyperparameter tuning settings. + feature_names : list of str, optional + Original names of the features, used for importance labeling. + force_serial : bool, default=False + If True, forces the internal estimator fit to be serial. + + Returns + ------- + Dict[str, Any] + A dictionary containing test indices, predictions, scores, + feature importances, and diagnostic timing information. + """ + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + + groups_train = groups[train_idx] if groups is not None else None + captured_warnings = [] + fit_time = np.nan + predict_time = np.nan + score_time = np.nan + + # 1. Fit + fit_start = time.perf_counter() + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + backend = ( + joblib.parallel_backend("sequential") if force_serial else nullcontext() + ) + with backend: + fit_estimator( + estimator, + X_train, + y_train, + groups_train, + feature_selection_config=feature_selection_config, + calibration_config=calibration_config, + tuning_config=tuning_config, + ) + fit_time = time.perf_counter() - fit_start + captured_warnings.extend(warning_records_to_dict("fit", warning_records)) + + # 2. Predict (Standard or Temporal) + predict_start = time.perf_counter() + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + y_pred = estimator.predict(X_test) + predict_time = time.perf_counter() - predict_start + captured_warnings.extend(warning_records_to_dict("predict", warning_records)) + test_groups = groups[test_idx] if groups is not None else None + + fold_data = { + "sample_index": test_idx, + "sample_id": sample_ids[test_idx], + "group": test_groups, + "sample_metadata": metadata_slice(sample_metadata, test_idx), + "y_true": y_test, + "y_pred": y_pred, + } + + # 3. Predict probabilities + if spec.supports_proba: + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + fold_data["y_proba"] = estimator.predict_proba(X_test) + captured_warnings.extend( + warning_records_to_dict("predict_proba", warning_records) + ) + + if "y_proba" not in fold_data and spec.supports_decision_function: + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + fold_data["y_score"] = estimator.decision_function(X_test) + captured_warnings.extend( + warning_records_to_dict("decision_function", warning_records) + ) + + # 4. Extract Feature Importances (Zero Guesswork) + imp = ( + extract_feature_importances( + estimator, + spec, + fs_enabled=feature_selection_config.enabled, + search_enabled=search_enabled, + calibration_enabled=calibration_config.enabled, + ) + if spec.importance[0] != "unavailable" + else None + ) + + # 5. Compute Metrics (Pre-fetched Specs) + scores = {} + is_multiclass = type_of_target(y_test) == "multiclass" + metric_specs = {m: get_metric_spec(m) for m in metrics} + + score_start = time.perf_counter() + with warnings.catch_warnings(record=True) as warning_records: + warnings.simplefilter("always") + for name, m_spec in metric_specs.items(): + if m_spec.response_method == "predict": + y_est, is_p = y_pred, False + elif m_spec.response_method == "proba": + y_est, is_p = fold_data.get("y_proba"), True + else: # proba_or_score + y_est = fold_data.get("y_proba") + is_p = True + if y_est is None: + y_est, is_p = fold_data.get("y_score"), False + + scores[name] = compute_metric_safe( + m_spec.scorer, + y_test, + y_est, + is_multiclass, + is_proba=is_p, + name=name, + ) + captured_warnings.extend(warning_records_to_dict("score", warning_records)) + score_time = time.perf_counter() - score_start + + # 6. Extract Metadata + meta = extract_metadata( + estimator, + spec, + feature_selection_config=feature_selection_config, + search_enabled=search_enabled, + feature_names=feature_names, + ) + + return { + "test_idx": test_idx, + "preds": fold_data, + "scores": scores, + "importance": imp, + "metadata": meta, + "split": { + "train_idx": train_idx, + "test_idx": test_idx, + "train_sample_id": sample_ids[train_idx], + "test_sample_id": sample_ids[test_idx], + "train_group": groups[train_idx] if groups is not None else None, + "test_group": groups[test_idx] if groups is not None else None, + "train_metadata": metadata_slice(sample_metadata, train_idx), + "test_metadata": metadata_slice(sample_metadata, test_idx), + }, + "diagnostics": { + "fit_time": fit_time, + "predict_time": predict_time, + "score_time": score_time, + "total_time": fit_time + predict_time + score_time, + "warnings": captured_warnings, + }, + } + + +def fit_estimator( + estimator: BaseEstimator, + X_train: np.ndarray, + y_train: np.ndarray, + groups_train: Optional[np.ndarray], + feature_selection_config: Any, + calibration_config: Any, + tuning_config: Any = None, +) -> None: + """ + Fit an estimator with intelligent metadata and group routing. + + Handles specialized logic for group-aware internal CV using standard + scikit-learn fit-parameter slicing. SearchCV receives ``groups`` for its + splitter, Pipeline receives ``fs__groups`` for SFS, and calibration CV gets + a fold-local group binding because CalibratedClassifierCV does not pass + ``groups`` to its splitter unless global metadata routing is enabled. + + Parameters + ---------- + estimator : BaseEstimator + The estimator or pipeline to fit. + X_train : np.ndarray + Training feature matrix. + y_train : np.ndarray + Training target vector. + groups_train : np.ndarray, optional + Training group labels. + feature_selection_config : Any + Feature selection settings. + calibration_config : Any + Probability calibration settings. + tuning_config : Any + Hyperparameter tuning settings. + """ + from sklearn.calibration import CalibratedClassifierCV + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + + calibrated = isinstance(estimator, CalibratedClassifierCV) + fitted_estimator = estimator.estimator if calibrated else estimator + search_cv = isinstance(fitted_estimator, (GridSearchCV, RandomizedSearchCV)) + + pipeline = fitted_estimator.estimator if search_cv else fitted_estimator + sfs = None + if isinstance(pipeline, Pipeline) and "fs" in pipeline.named_steps: + sfs = pipeline.named_steps["fs"] + if ( + getattr(feature_selection_config, "enabled", False) + and getattr(feature_selection_config, "method", None) == "sfs" + and _config_uses_group_cv(getattr(feature_selection_config, "cv", None)) + and hasattr(sfs, "cv") + and not cv_uses_groups(sfs.cv) + ): + sfs.cv = get_cv_splitter(feature_selection_config.cv, require_groups=False) + + fit_params: Dict[str, Any] = {} + if groups_train is not None: + if calibrated and _config_uses_group_cv( + getattr(calibration_config, "cv", None) + ): + cal_cv = get_cv_splitter(calibration_config.cv, require_groups=False) + estimator.cv = _CVWithGroups(cal_cv, groups_train) + + if search_cv and _config_uses_group_cv(getattr(tuning_config, "cv", None)): + fit_params["groups"] = groups_train + + if sfs is not None and _config_uses_group_cv( + getattr(feature_selection_config, "cv", None) + ): + fit_params["fs__groups"] = groups_train + + estimator.fit(X_train, y_train, **fit_params) + + +def _config_uses_group_cv(cv_config: Any) -> bool: + """Return True when a decoding CV config needs subject groups.""" + strategy = getattr(cv_config, "strategy", None) + return isinstance(strategy, str) and strategy.lower() in GROUP_CV_STRATEGIES + + +def extract_feature_importances( + estimator: BaseEstimator, + spec: Any, + fs_enabled: bool = False, + search_enabled: bool = False, + calibration_enabled: bool = False, +) -> Optional[np.ndarray]: + """ + Extract and aggregate feature importances or coefficients from a fitted model. + + This function drills down through potential pipeline wrappers (calibration, + tuning, and feature selection) to reach the base estimator. It then + delegates to `_get_raw_importance` and handles masking if feature + selection was applied. + + Parameters + ---------- + estimator : BaseEstimator + The fitted estimator, potentially wrapped in several layers. + spec : EstimatorSpec + The model capability registry entry defining how to extract weights. + fs_enabled : bool, optional + Whether feature selection (and thus a projection mask) was used. + search_enabled : bool, optional + Whether hyperparameter search (and thus `best_estimator_`) was used. + calibration_enabled : bool, optional + Whether calibration (and thus `.estimator`) was used. + + Returns + ------- + np.ndarray, optional + A vector of importances aligned with the input features, or None + if the model does not support weight extraction (e.g., k-NN). + """ + # 1. Drill down through wrappers + if calibration_enabled: + calibrated = getattr(estimator, "calibrated_classifiers_", None) + if calibrated: + fold_imps = [] + for calibrated_classifier in calibrated: + base = getattr(calibrated_classifier, "estimator", None) + imp = extract_feature_importances( + base, + spec, + fs_enabled=fs_enabled, + search_enabled=search_enabled, + calibration_enabled=False, + ) + if imp is not None: + fold_imps.append(imp) + if fold_imps and all(imp.shape == fold_imps[0].shape for imp in fold_imps): + return np.mean(np.vstack(fold_imps), axis=0) + return None + + estimator = getattr(estimator, "estimator", estimator) + + if search_enabled: + estimator = getattr(estimator, "best_estimator_", estimator) + + # 2. Handle Pipeline (Scaler + [FS] + Clf) + if hasattr(estimator, "named_steps"): + clf = estimator.named_steps["clf"] + if fs_enabled: + fs = estimator.named_steps["fs"] + raw_imp = _get_raw_importance(clf, spec) + if raw_imp is not None: + mask = fs.get_support() + full_imp = np.zeros_like(mask, dtype=float) + full_imp[mask] = raw_imp + return full_imp + return None + estimator = clf + + # 3. Direct extraction from base estimator + return _get_raw_importance(estimator, spec) + + +def _get_raw_importance(estimator: Any, spec: Any) -> Optional[np.ndarray]: + """ + Internal helper to extract importance from the base estimator. + + Handles sparse coefficients and multiclass magnitude aggregation. + """ + imp_type = spec.importance[0] + + if imp_type == "coefficients": + if not hasattr(estimator, "coef_"): + return None + vals = estimator.coef_ + # Zero-guesswork sparse handling + if spec.is_sparse_capable and hasattr(vals, "toarray"): + vals = vals.toarray() + + if vals.ndim > 1: + # Aggregate across classes (multiclass). Note: binary LDA or + # LogisticRegression often return shape (1, n_features). + return np.mean(np.abs(vals), axis=0) + return np.abs(vals) + + if imp_type == "feature_importances": + if not hasattr(estimator, "feature_importances_"): + return None + return estimator.feature_importances_ + + return None + + +def compute_metric_safe( + scorer: Callable, + y_true: np.ndarray, + y_est: np.ndarray, + is_multiclass: bool, + is_proba: bool = False, + name: Optional[str] = None, +) -> Union[float, np.ndarray]: + """ + Handles Standard, Sliding, and Generalizing decoding results efficiently. + + Parameters + ---------- + scorer : Callable + Scikit-learn compatible scoring function. + y_true : np.ndarray + Ground truth labels. + y_est : np.ndarray + Predictions or probabilities (can be 2D, 3D, or 4D). + is_multiclass : bool + Whether the task is multiclass. + is_proba : bool + Whether y_est contains probabilities. + + Returns + ------- + float or np.ndarray + Single score or temporal score matrix/vector. + """ + # 1. Configuration Setup + kwargs = {"multi_class": "ovr"} if (is_proba and is_multiclass) else {} + + # 2. Handle Binary Probability Slicing (Positive class only) + if ( + is_proba + and not is_multiclass + and y_est is not None + and y_est.ndim >= 2 + and y_est.shape[1] == 2 + ): + y_est = y_est[:, 1, ...] + + # 3. Shape Analysis + if y_est is None: + return np.nan + + slice_ndim = 2 if (is_proba and is_multiclass) else 1 + n_temporal_dims = y_est.ndim - slice_ndim + + # 4. Temporal Dispatch + if n_temporal_dims == 0: # Standard decoding + return float(scorer(y_true, y_est, **kwargs)) + + if n_temporal_dims == 1: # Sliding decoding (n_times,) + if name == "accuracy" and not is_proba: + return np.mean(y_true[:, None] == y_est, axis=0) + + return np.array( + [ + float(scorer(y_true, y_est[..., t], **kwargs)) + for t in range(y_est.shape[-1]) + ] + ) + + if n_temporal_dims == 2: # Generalizing decoding + if name == "accuracy" and not is_proba: + return np.mean(y_true[:, None, None] == y_est, axis=0) + + n_tr, n_te = y_est.shape[-2], y_est.shape[-1] + results = np.zeros((n_tr, n_te)) + for tr in range(n_tr): + for te in range(n_te): + results[tr, te] = float(scorer(y_true, y_est[..., tr, te], **kwargs)) + return results + + raise ValueError(f"Unsupported y_est shape for scoring: {y_est.shape}") + + +def extract_metadata( + estimator: BaseEstimator, + spec: Any, + feature_selection_config: Any, + search_enabled: bool = False, + feature_names: Optional[list[str]] = None, +) -> Dict[str, Any]: + """ + Metadata extraction. + + Parameters + ---------- + estimator : BaseEstimator + The fitted estimator or pipeline. + spec : EstimatorSpec + Model registry entry. + feature_selection_config : Any + FS configuration. + search_enabled : bool + Whether search wrapper is active. + feature_names : list of str, optional + Original feature names. + + Returns + ------- + Dict[str, Any] + Aggregated fold metadata. + """ + meta = {} + + # 1. Search Diagnostics + if search_enabled: + meta["best_params"] = getattr(estimator, "best_params_", {}) + meta["best_score"] = getattr(estimator, "best_score_", np.nan) + meta["best_index"] = getattr(estimator, "best_index_", -1) + if hasattr(estimator, "cv_results_"): + meta["search_results"] = compact_search_results(estimator) + estimator = getattr(estimator, "best_estimator_", estimator) + + # 2. Pipeline / Feature Selection Diagnostics + if feature_selection_config.enabled: + if hasattr(estimator, "named_steps"): + fs_step = estimator.named_steps.get("fs") + clf_step = estimator.named_steps.get("clf") + + if fs_step is not None: + mask = fs_step.get_support() + indices = np.flatnonzero(mask) + n_feat = len(mask) + + actual_names = ( + feature_names + if (feature_names and len(feature_names) == n_feat) + else [f"feature_{i}" for i in range(n_feat)] + ) + + meta.update( + { + "feature_selection_method": feature_selection_config.method, + "selected_features": mask, + "selected_feature_indices": indices, + "selected_feature_names": [actual_names[i] for i in indices], + "feature_names": actual_names, + } + ) + + if feature_selection_config.method == "sfs" and hasattr( + fs_step, "ranking_" + ): + meta["selection_order"] = fs_step.ranking_ + elif feature_selection_config.method == "k_best" and hasattr( + fs_step, "scores_" + ): + meta["feature_scores"] = fs_step.scores_ + + if clf_step is not None: + estimator = clf_step + + # 3. Custom Artifacts (Structural Type Check via Protocol) + if isinstance(estimator, NeuralTrainable): + meta["artifacts"] = estimator.get_artifact_metadata() + + return meta + + +def compact_search_results(estimator: BaseEstimator) -> list[Dict[str, Any]]: + """ + Return compact search diagnostics with pre-standardized arrays. + + Converts Scikit-learn's cv_results_ to a serializable list of records. + + Parameters + ---------- + estimator : BaseEstimator + The estimator (e.g., GridSearchCV) containing cv_results_. + + Returns + ------- + list[dict[str, Any]] + A list of dictionaries, where each dictionary represents a single + hyperparameter combination and its scores. + """ + cv_res = estimator.cv_results_ + params = cv_res["params"] + + ranks = ( + np.asarray(cv_res["rank_test_score"]) if "rank_test_score" in cv_res else None + ) + means = ( + np.asarray(cv_res["mean_test_score"], dtype=float) + if "mean_test_score" in cv_res + else None + ) + stds = ( + np.asarray(cv_res["std_test_score"], dtype=float) + if "std_test_score" in cv_res + else None + ) + + results = [] + for idx, p in enumerate(params): + row = {"candidate": idx, "params": dict(p)} + if ranks is not None: + row["rank"] = int(ranks[idx]) + if means is not None: + row["mean"] = float(means[idx]) + if stds is not None: + row["std"] = float(stds[idx]) + results.append(row) + + return results + + +def metadata_slice( + metadata_dict: Optional[Dict[str, np.ndarray]], + indices: np.ndarray, +) -> Optional[Dict[str, list[Any]]]: + """ + Slicing of pre-converted metadata. + + Parameters + ---------- + metadata_dict : dict, optional + Dictionary of numpy arrays. + indices : np.ndarray + Indices to select. + + Returns + ------- + dict, optional + Slicing result as a serializable dictionary of lists. + """ + if metadata_dict is None: + return None + return {k: v[indices].tolist() for k, v in metadata_dict.items()} + + +def warning_records_to_dict( + stage: str, warning_records: Sequence[Any] +) -> list[Dict[str, Any]]: + """ + Return serializable warning records captured in one fold stage. + """ + return [ + { + "stage": stage, + "category": record.category.__name__, + "message": str(record.message), + } + for record in warning_records + ] diff --git a/coco_pipe/decoding/_metrics.py b/coco_pipe/decoding/_metrics.py new file mode 100644 index 0000000..103c691 --- /dev/null +++ b/coco_pipe/decoding/_metrics.py @@ -0,0 +1,497 @@ +""" +Internal Metric Registry for Decoding. +===================================== + +This module defines the registry of available metrics for classification +and regression tasks. It provides metadata about each metric, such as +the required estimator response (predict vs predict_proba) and whether +higher values are better. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import numpy as np +from sklearn.metrics import ( + accuracy_score, + auc, + average_precision_score, + balanced_accuracy_score, + brier_score_loss, + cohen_kappa_score, + explained_variance_score, + f1_score, + hamming_loss, + jaccard_score, + log_loss, + matthews_corrcoef, + mean_absolute_error, + mean_squared_error, + precision_recall_curve, + precision_score, + r2_score, + recall_score, + roc_auc_score, + zero_one_loss, +) + +from ._constants import MetricFamily, MetricTask, ResponseMethod + + +@dataclass(frozen=True) +class MetricSpec: + """ + Metadata specification for a decoding metric. + + This container stores all necessary information to resolve and compute + a metric, including its task type, family for reporting, and the + required estimator response method. + + Parameters + ---------- + name : str + The unique identifier for the metric. + task : MetricTask + The type of task ("classification" or "regression"). + scorer : Callable + The callable with signature ``(y_true, y_pred) -> float``. + response_method : ResponseMethod, default="predict" + The estimator method required to produce ``y_pred`` (e.g., "proba"). + family : MetricFamily, default="label" + The category of the metric for reporting and visualization. + greater_is_better : bool, default=True + Whether a higher score indicates better performance. + + Examples + -------- + >>> from sklearn.metrics import accuracy_score + >>> spec = MetricSpec("accuracy", "classification", accuracy_score) + >>> spec.name + 'accuracy' + """ + + name: str + task: MetricTask + scorer: Callable + response_method: ResponseMethod = "predict" + family: MetricFamily = "label" + greater_is_better: bool = True + + +def _specificity_score(y_true, y_pred) -> float: + """ + Compute specificity (recall of the negative class). + + In many clinical or brain-decoding tasks, the ability to correctly + identify non-targets (negatives) is as important as identifying + targets. Specificity measures the true negative rate ($TN / (TN + FP)$). + + Parameters + ---------- + y_true : array-like + Ground truth labels. + y_pred : array-like + Predicted labels. + + Returns + ------- + score : float + The calculated specificity. + + Examples + -------- + >>> _specificity_score([0, 1], [0, 1]) + 1.0 + + See Also + -------- + _sensitivity_score : True positive rate. + """ + return recall_score(y_true, y_pred, pos_label=0, zero_division=0) + + +def _sensitivity_score(y_true, y_pred) -> float: + """ + Compute sensitivity (recall of the positive class). + + This implementation ensures binary-only enforcement and robust + zero-division handling for unbalanced neuroimaging datasets. It + measures the true positive rate ($TP / (TP + FN)$). + + Parameters + ---------- + y_true : array-like + Ground truth labels. + y_pred : array-like + Predicted labels. + + Returns + ------- + score : float + The calculated sensitivity. + + Raises + ------ + ValueError + If the input labels contain more than 2 classes. + + Examples + -------- + >>> _sensitivity_score([0, 1], [0, 1]) + 1.0 + + See Also + -------- + _specificity_score : True negative rate. + """ + labels = np.unique(y_true) + if len(labels) > 2: + raise ValueError("Sensitivity is only defined for binary classification.") + return recall_score(y_true, y_pred, pos_label=1, zero_division=0) + + +def _pr_auc_score(y_true, probas_pred) -> float: + """ + Compute Area Under the Precision-Recall Curve via trapezoidal integration. + + PR-AUC computed via trapezoidal integration can differ from Average + Precision (AP) on imbalanced datasets. This method directly + approximates the integral of the PR curve, providing a more direct + estimate of the curve's area for scientific reporting. + + Parameters + ---------- + y_true : array-like + Ground truth labels. + probas_pred : array-like + Predicted probabilities for the positive class. + + Returns + ------- + score : float + The calculated PR-AUC. + + Examples + -------- + >>> _pr_auc_score([0, 1], [0.1, 0.9]) + 1.0 + + See Also + -------- + sklearn.metrics.average_precision_score : Standard AP implementation. + """ + # Note: precision_recall_curve returns precision, recall, thresholds + # where recall is in descending order; auc() handles this. + precision, recall, _ = precision_recall_curve(y_true, probas_pred) + return float(auc(recall, precision)) + + +METRIC_REGISTRY: dict[str, MetricSpec] = { + # Classification from hard predictions (family="label" or "confusion") + "accuracy": MetricSpec("accuracy", "classification", accuracy_score), + "balanced_accuracy": MetricSpec( + "balanced_accuracy", + "classification", + balanced_accuracy_score, + family="confusion", + ), + "f1": MetricSpec( + "f1", + "classification", + lambda y, p: f1_score(y, p, average="weighted"), + family="confusion", + ), + "f1_macro": MetricSpec( + "f1_macro", + "classification", + lambda y, p: f1_score(y, p, average="macro"), + family="confusion", + ), + "f1_micro": MetricSpec( + "f1_micro", + "classification", + lambda y, p: f1_score(y, p, average="micro"), + family="confusion", + ), + "precision": MetricSpec( + "precision", + "classification", + lambda y, p: precision_score(y, p, average="weighted", zero_division=0), + family="confusion", + ), + "recall": MetricSpec( + "recall", + "classification", + lambda y, p: recall_score(y, p, average="weighted", zero_division=0), + family="confusion", + ), + "sensitivity": MetricSpec( + "sensitivity", + "classification", + _sensitivity_score, + family="confusion", + ), + "specificity": MetricSpec( + "specificity", + "classification", + _specificity_score, + family="confusion", + ), + "zero_one_loss": MetricSpec( + "zero_one_loss", + "classification", + zero_one_loss, + family="label", + greater_is_better=False, + ), + "hamming_loss": MetricSpec( + "hamming_loss", + "classification", + hamming_loss, + family="label", + greater_is_better=False, + ), + "jaccard": MetricSpec( + "jaccard", + "classification", + lambda y, p: jaccard_score(y, p, average="weighted"), + family="confusion", + ), + "matthews_corrcoef": MetricSpec( + "matthews_corrcoef", + "classification", + matthews_corrcoef, + family="confusion", + ), + "cohen_kappa": MetricSpec( + "cohen_kappa", + "classification", + cohen_kappa_score, + family="confusion", + ), + # Classification from probabilities (family="threshold_sweep") + "roc_auc": MetricSpec( + "roc_auc", + "classification", + roc_auc_score, + "proba_or_score", + family="threshold_sweep", + ), + "roc_auc_ovr_weighted": MetricSpec( + "roc_auc_ovr_weighted", + "classification", + lambda y, p: roc_auc_score(y, p, multi_class="ovr", average="weighted"), + "proba", + family="threshold_sweep", + ), + "average_precision": MetricSpec( + "average_precision", + "classification", + average_precision_score, + "proba_or_score", + family="threshold_sweep", + ), + "pr_auc": MetricSpec( + "pr_auc", + "classification", + _pr_auc_score, + "proba_or_score", + family="threshold_sweep", + ), + "log_loss": MetricSpec( + "log_loss", + "classification", + log_loss, + "proba", + family="score_probability", + greater_is_better=False, + ), + "brier_score": MetricSpec( + "brier_score", + "classification", + brier_score_loss, + "proba", + family="calibration", + greater_is_better=False, + ), + # Regression (family="regression") + "r2": MetricSpec("r2", "regression", r2_score, family="regression"), + "neg_mean_squared_error": MetricSpec( + "neg_mean_squared_error", + "regression", + lambda y, p: -mean_squared_error(y, p), + family="regression", + greater_is_better=True, + ), + "neg_mean_absolute_error": MetricSpec( + "neg_mean_absolute_error", + "regression", + lambda y, p: -mean_absolute_error(y, p), + family="regression", + greater_is_better=True, + ), + "explained_variance": MetricSpec( + "explained_variance", + "regression", + explained_variance_score, + family="regression", + ), + "neg_root_mean_squared_error": MetricSpec( + "neg_root_mean_squared_error", + "regression", + lambda y, p: -float(np.sqrt(mean_squared_error(y, p))), + family="regression", + greater_is_better=True, + ), +} + + +def get_scorer(name: str) -> Callable: + """ + Retrieve the callable scoring function for a given metric. + + Returns the bare function required by scikit-learn's `fit` and + `score` APIs. + + Parameters + ---------- + name : str + Metric name, for example ``accuracy`` or ``neg_mean_squared_error``. + + Returns + ------- + scorer : Callable + Metric function with signature ``(y_true, y_pred) -> float``. + + Raises + ------ + ValueError + If the metric name is not found in the registry. + + Examples + -------- + >>> scorer = get_scorer("accuracy") + >>> scorer([0, 1], [0, 1]) + 1.0 + + See Also + -------- + get_metric_spec : Retrieve the full metadata object. + """ + return get_metric_spec(name).scorer + + +def get_metric_spec(name: str) -> MetricSpec: + """ + Return the full metadata specification for a given metric. + + The specification includes the task type, response method, and + reporting family for the metric. + + Parameters + ---------- + name : str + The unique name of the metric to look up. + + Returns + ------- + spec : MetricSpec + The metadata object for the requested metric. + + Raises + ------ + ValueError + If the metric name is not found in the registry. + + Examples + -------- + >>> spec = get_metric_spec("roc_auc") + >>> spec.response_method + 'proba_or_score' + + See Also + -------- + get_scorer : Retrieve the bare scoring function. + """ + if name not in METRIC_REGISTRY: + raise ValueError( + f"Unknown metric '{name}'. Available: " + f"{sorted(list(METRIC_REGISTRY.keys()))}" + ) + return METRIC_REGISTRY[name] + + +def get_metric_names( + task: MetricTask | None = None, + family: MetricFamily | None = None, +) -> list[str]: + """ + Return a list of known metric names, optionally filtered. + + Enables dynamic discovery of metrics based on the current decoding + task or the desired reporting family. + + Parameters + ---------- + task : MetricTask, optional + Filter by task type ("classification" or "regression"). + family : MetricFamily, optional + Filter by metric family (e.g., "confusion"). + + Returns + ------- + names : list of str + The sorted list of matching metric names. + + Examples + -------- + >>> get_metric_names(task="regression") + ['explained_variance', ..., 'r2'] + + See Also + -------- + get_metric_families : Discover available reporting families. + """ + return sorted( + name + for name, spec in METRIC_REGISTRY.items() + if (task is None or spec.task == task) + and (family is None or spec.family == family) + ) + + +def get_metric_families(task: MetricTask | None = None) -> list[MetricFamily]: + """ + Return a list of known metric families, optionally filtered. + + Returns the categories of metrics available, useful for generating + comprehensive reporting dashboards. + + Parameters + ---------- + task : MetricTask, optional + Filter families by the task type they support. + + Returns + ------- + families : list of MetricFamily + The sorted list of matching metric families. + + Examples + -------- + >>> get_metric_families(task="classification") + ['calibration', 'confusion', 'label', 'score_probability', 'threshold_sweep'] + + See Also + -------- + get_metric_names : Retrieve the individual metric identifiers. + """ + return sorted( + { + spec.family + for spec in METRIC_REGISTRY.values() + if task is None or spec.task == task + } + ) diff --git a/coco_pipe/decoding/_specs.py b/coco_pipe/decoding/_specs.py new file mode 100644 index 0000000..b4b7286 --- /dev/null +++ b/coco_pipe/decoding/_specs.py @@ -0,0 +1,660 @@ +""" +Estimator Specifications and Capability Metadata +=============================================== + +Internal module containing the static database of estimator metadata +and the dataclasses used to represent them. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any + +from ._constants import ( + CalibrationSupport, + DependencyGroup, + EstimatorFamily, + FeatureSelectionSupport, + GroupedMetadata, + ImportanceSupport, + InputKind, + InputRank, + MetricTask, + PredictionInterface, + TemporalSupport, +) + + +@dataclass(frozen=True) +class EstimatorCapabilities: + """Machine-readable capabilities for a decoding estimator.""" + + method: str + tasks: tuple[MetricTask, ...] + input_ranks: tuple[InputRank, ...] = ("2d",) + prediction_interfaces: tuple[PredictionInterface, ...] = ("predict",) + grouped_metadata: tuple[GroupedMetadata, ...] = ("none",) + feature_selection: tuple[FeatureSelectionSupport, ...] = ("univariate", "sfs") + calibration: CalibrationSupport = "eligible" + importance: tuple[ImportanceSupport, ...] = ("unavailable",) + temporal: TemporalSupport = "none" + dependencies: tuple[DependencyGroup, ...] = ("core",) + + def to_dict(self) -> dict[str, Any]: + """ + Return a JSON-friendly capability dictionary. + + This ensures that the capability metadata can be safely serialized + for reporting or API responses. + + Returns + ------- + caps : dict[str, Any] + The capability dictionary. + """ + return asdict(self) + + def supports_task(self, task: str) -> bool: + """ + Check if the estimator supports the given task type. + + Parameters + ---------- + task : str + The task type to check (e.g., "classification"). + + Returns + ------- + available : bool + True if the task is supported. + """ + return task in self.tasks + + def has_response(self, response: str) -> bool: + """ + Check if the estimator supports the given prediction interface. + + Parameters + ---------- + response : str + The interface to check (e.g., "predict_proba"). + + Returns + ------- + available : bool + True if the response interface is available. + """ + return response in self.prediction_interfaces + + +@dataclass(frozen=True) +class EstimatorSpec: + """Typed registry entry for a decoding estimator.""" + + name: str + import_path: str + family: EstimatorFamily + task: tuple[MetricTask, ...] + input_kinds: tuple[InputKind, ...] = ("tabular_2d",) + supports_groups: bool = False + supports_proba: bool = False + supports_decision_function: bool = False + supports_calibration: bool = True + supports_feature_names: bool = True + dependency_extra: DependencyGroup = "core" + fit_smoke_required: bool = True + default_search_space: dict[str, list[Any]] = field(default_factory=dict) + feature_selection: tuple[FeatureSelectionSupport, ...] = ("univariate", "sfs") + importance: tuple[ImportanceSupport, ...] = ("unavailable",) + """Type of feature importance available (coefficients, importance).""" + + temporal: TemporalSupport = "none" + """Temporal decoding wrapper type (sliding or generalizing).""" + + calibration: CalibrationSupport = "eligible" + """Whether the model is eligible for probability calibration.""" + + supports_random_state: bool = False + """Whether the estimator class accepts a random_state parameter.""" + + is_sparse_capable: bool = False + """Whether the model can produce sparse coefficients (e.g. L1 regularization).""" + + @property + def module_path(self) -> str: + """ + Return the module part of the import path. + + Returns + ------- + path : str + The module path (e.g., "sklearn.linear_model"). + """ + return self.import_path.split(":")[0] + + @property + def class_name(self) -> str: + """ + Return the class name to be imported. + + Returns + ------- + name : str + The class name (e.g., "LogisticRegression"). + """ + if ":" in self.import_path: + return self.import_path.split(":", 1)[1] + return self.name + + def to_dict(self) -> dict[str, Any]: + """ + Return a JSON-friendly spec dictionary. + + Returns + ------- + spec : dict[str, Any] + The full specification as a dictionary. + """ + return asdict(self) + + def to_capabilities(self) -> EstimatorCapabilities: + """ + Derive lightweight capability metadata from this spec. + + This conversion distills the full specification into a format + used by the engine for runtime validation and capability-based + routing. + + Returns + ------- + caps : EstimatorCapabilities + The derived capability metadata. + + See Also + -------- + EstimatorCapabilities : The destination capability container. + """ + responses = ["predict"] + if self.supports_proba: + responses.append("predict_proba") + if self.supports_decision_function: + responses.append("decision_function") + + input_ranks = [] + for kind in self.input_kinds: + if kind in {"temporal_3d", "epoched"}: + rank = "3d_temporal" + elif kind == "tokens": + rank = "tokens" + else: + rank = "2d" + if rank not in input_ranks: + input_ranks.append(rank) + + return EstimatorCapabilities( + method=self.name, + tasks=self.task, + input_ranks=tuple(input_ranks), + prediction_interfaces=tuple(responses), + grouped_metadata=("search_cv",) if self.supports_groups else ("none",), + feature_selection=self.feature_selection, + calibration=self.calibration, + importance=self.importance, + temporal=self.temporal, + dependencies=(self.dependency_extra,), + ) + + +@dataclass(frozen=True) +class SelectorCapabilities: + """Machine-readable capabilities for a decoding feature selector.""" + + method: str + input_ranks: tuple[InputRank, ...] + support: tuple[FeatureSelectionSupport, ...] + grouped_metadata: tuple[GroupedMetadata, ...] = ("none",) + + def to_dict(self) -> dict[str, Any]: + """ + Return a JSON-friendly dictionary representation. + + Returns + ------- + caps : dict[str, Any] + The selector capabilities dictionary. + """ + return asdict(self) + + +# Shared Task Tuples +_CLASSIFICATION = ("classification",) +_REGRESSION = ("regression",) +_BOTH_TASKS = ("classification", "regression") + +# Shared Importance Tuples +_COEF = ("coefficients",) +_TREE_IMPORTANCE = ("feature_importances",) + + +def _spec( + name: str, + import_path: str, + family: EstimatorFamily, + task: tuple[MetricTask, ...], + **kwargs: Any, +) -> EstimatorSpec: + """Helper to create an EstimatorSpec directly.""" + return EstimatorSpec( + name=name, import_path=import_path, family=family, task=task, **kwargs + ) + + +ESTIMATOR_SPECS: dict[str, EstimatorSpec] = { + # --- Classifiers --- + "LogisticRegression": _spec( + "LogisticRegression", + "sklearn.linear_model", + "linear", + _CLASSIFICATION, + supports_proba=True, + supports_decision_function=True, + importance=_COEF, + supports_random_state=True, + is_sparse_capable=True, + default_search_space={"C": [0.1, 1.0, 10.0]}, + ), + "RandomForestClassifier": _spec( + "RandomForestClassifier", + "sklearn.ensemble", + "ensemble", + _CLASSIFICATION, + supports_proba=True, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={ + "n_estimators": [100, 300], + "max_depth": [None, 5, 10], + }, + ), + "ExtraTreesClassifier": _spec( + "ExtraTreesClassifier", + "sklearn.ensemble", + "ensemble", + _CLASSIFICATION, + supports_proba=True, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={ + "n_estimators": [100, 300], + "max_depth": [None, 5, 10], + }, + ), + "SVC": _spec( + "SVC", + "sklearn.svm", + "svm", + _CLASSIFICATION, + supports_proba=True, + supports_decision_function=True, + supports_random_state=True, + default_search_space={"C": [0.1, 1.0, 10.0]}, + ), + "LinearSVC": _spec( + "LinearSVC", + "sklearn.svm", + "svm", + _CLASSIFICATION, + supports_proba=False, + supports_decision_function=True, + importance=_COEF, + supports_random_state=True, + is_sparse_capable=True, + default_search_space={"C": [0.1, 1.0, 10.0]}, + ), + "KNeighborsClassifier": _spec( + "KNeighborsClassifier", + "sklearn.neighbors", + "neighbors", + _CLASSIFICATION, + supports_proba=True, + supports_random_state=False, + default_search_space={"n_neighbors": [3, 5, 7]}, + ), + "GradientBoostingClassifier": _spec( + "GradientBoostingClassifier", + "sklearn.ensemble", + "ensemble", + _CLASSIFICATION, + supports_proba=True, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={"n_estimators": [100, 300], "learning_rate": [0.03, 0.1]}, + ), + "HistGradientBoostingClassifier": _spec( + "HistGradientBoostingClassifier", + "sklearn.ensemble", + "ensemble", + _CLASSIFICATION, + supports_proba=True, + supports_random_state=True, + default_search_space={"max_iter": [100, 300], "learning_rate": [0.03, 0.1]}, + ), + "SGDClassifier": _spec( + "SGDClassifier", + "sklearn.linear_model", + "linear", + _CLASSIFICATION, + supports_decision_function=True, + importance=_COEF, + supports_random_state=True, + is_sparse_capable=True, + default_search_space={"alpha": [0.0001, 0.001, 0.01]}, + ), + "MLPClassifier": _spec( + "MLPClassifier", + "sklearn.neural_network", + "neural", + _CLASSIFICATION, + supports_proba=True, + supports_random_state=True, + default_search_space={"alpha": [0.0001, 0.001]}, + ), + "GaussianNB": _spec( + "GaussianNB", + "sklearn.naive_bayes", + "bayes", + _CLASSIFICATION, + supports_proba=True, + calibration="eligible", + supports_random_state=False, + # Note: GaussianNB is probabilistic but benefits from calibration + # when priors are misspecified or features are non-independent. + default_search_space={"var_smoothing": [1e-9, 1e-8, 1e-7]}, + ), + "LinearDiscriminantAnalysis": _spec( + "LinearDiscriminantAnalysis", + "sklearn.discriminant_analysis", + "linear", + _CLASSIFICATION, + supports_proba=True, + importance=_COEF, + supports_random_state=False, + ), + "AdaBoostClassifier": _spec( + "AdaBoostClassifier", + "sklearn.ensemble", + "ensemble", + _CLASSIFICATION, + supports_proba=True, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={"n_estimators": [50, 100], "learning_rate": [0.5, 1.0]}, + ), + "DummyClassifier": _spec( + "DummyClassifier", + "sklearn.dummy", + "dummy", + _CLASSIFICATION, + supports_proba=True, + supports_calibration=False, + calibration="unsupported", + ), + # --- Regressors --- + "LinearRegression": _spec( + "LinearRegression", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=False, + ), + "Ridge": _spec( + "Ridge", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=True, + default_search_space={"alpha": [0.1, 1.0, 10.0]}, + ), + "Lasso": _spec( + "Lasso", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=True, + is_sparse_capable=True, + default_search_space={"alpha": [0.001, 0.01, 0.1, 1.0]}, + ), + "ElasticNet": _spec( + "ElasticNet", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=True, + is_sparse_capable=True, + default_search_space={"alpha": [0.001, 0.01, 0.1], "l1_ratio": [0.2, 0.5, 0.8]}, + ), + "RandomForestRegressor": _spec( + "RandomForestRegressor", + "sklearn.ensemble", + "ensemble", + _REGRESSION, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={ + "n_estimators": [100, 300], + "max_depth": [None, 5, 10], + }, + ), + "ExtraTreesRegressor": _spec( + "ExtraTreesRegressor", + "sklearn.ensemble", + "ensemble", + _REGRESSION, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={ + "n_estimators": [100, 300], + "max_depth": [None, 5, 10], + }, + ), + "SVR": _spec( + "SVR", + "sklearn.svm", + "svm", + _REGRESSION, + supports_random_state=False, + default_search_space={"C": [0.1, 1.0, 10.0]}, + ), + "GradientBoostingRegressor": _spec( + "GradientBoostingRegressor", + "sklearn.ensemble", + "ensemble", + _REGRESSION, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={"n_estimators": [100, 300], "learning_rate": [0.03, 0.1]}, + ), + "HistGradientBoostingRegressor": _spec( + "HistGradientBoostingRegressor", + "sklearn.ensemble", + "ensemble", + _REGRESSION, + supports_random_state=True, + default_search_space={"max_iter": [100, 300], "learning_rate": [0.03, 0.1]}, + ), + "SGDRegressor": _spec( + "SGDRegressor", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=True, + default_search_space={"alpha": [0.0001, 0.001, 0.01]}, + ), + "MLPRegressor": _spec( + "MLPRegressor", + "sklearn.neural_network", + "neural", + _REGRESSION, + supports_random_state=True, + default_search_space={"alpha": [0.0001, 0.001]}, + ), + "DummyRegressor": _spec( + "DummyRegressor", + "sklearn.dummy", + "dummy", + _REGRESSION, + supports_calibration=False, + calibration="unsupported", + ), + "DecisionTreeRegressor": _spec( + "DecisionTreeRegressor", + "sklearn.tree", + "tree", + _REGRESSION, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={"max_depth": [None, 5, 10]}, + ), + "KNeighborsRegressor": _spec( + "KNeighborsRegressor", + "sklearn.neighbors", + "neighbors", + _REGRESSION, + supports_random_state=False, + default_search_space={"n_neighbors": [3, 5, 7]}, + ), + "AdaBoostRegressor": _spec( + "AdaBoostRegressor", + "sklearn.ensemble", + "ensemble", + _REGRESSION, + importance=_TREE_IMPORTANCE, + supports_random_state=True, + default_search_space={"n_estimators": [50, 100], "learning_rate": [0.5, 1.0]}, + ), + "BayesianRidge": _spec( + "BayesianRidge", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=False, + default_search_space={"alpha_1": [1e-7, 1e-6]}, + ), + "ARDRegression": _spec( + "ARDRegression", + "sklearn.linear_model", + "linear", + _REGRESSION, + importance=_COEF, + supports_random_state=False, + default_search_space={"alpha_1": [1e-7, 1e-6]}, + ), + # --- Custom Wrappers --- + "SlidingEstimator": _spec( + "SlidingEstimator", + "mne.decoding", + "temporal", + _BOTH_TASKS, + input_kinds=("temporal_3d",), + dependency_extra="mne", + fit_smoke_required=False, + feature_selection=("disabled",), + temporal="sliding", + ), + "GeneralizingEstimator": _spec( + "GeneralizingEstimator", + "mne.decoding", + "temporal", + _BOTH_TASKS, + input_kinds=("temporal_3d",), + dependency_extra="mne", + fit_smoke_required=False, + feature_selection=("disabled",), + temporal="generalizing", + ), + # --- Foundation Models --- + "reve": _spec( + "REVEModel", + "coco_pipe.decoding.fm_hub:REVEModel", + "foundation", + _BOTH_TASKS, + input_kinds=("epoched",), + supports_calibration=False, + fit_smoke_required=False, + feature_selection=("disabled",), + dependency_extra="torch", + ), + "cbramod": _spec( + "CBraModModel", + "coco_pipe.decoding.fm_hub:CBraModModel", + "foundation", + _BOTH_TASKS, + input_kinds=("epoched",), + supports_calibration=False, + fit_smoke_required=False, + feature_selection=("disabled",), + dependency_extra="torch", + ), +} + +SELECTOR_CAPABILITIES: dict[str, SelectorCapabilities] = { + "k_best": SelectorCapabilities( + "k_best", input_ranks=("2d",), support=("univariate",) + ), + "sfs": SelectorCapabilities( + "sfs", + input_ranks=("2d",), + support=("sfs",), + grouped_metadata=("sfs_metadata_routing",), + ), +} + + +def canonical_estimator_name(name: str) -> str: + """ + Map common model aliases to their canonical registry names. + + This ensures that user-friendly strings like 'lda' or 'logistic_regression' + resolve correctly to the internal 'LinearDiscriminantAnalysis' and + 'LogisticRegression' specifications. + + Parameters + ---------- + name : str + The input name or alias (e.g., 'lda', 'logistic_regression'). + + Returns + ------- + str + The canonical name used in the ESTIMATOR_SPECS registry. + + Examples + -------- + >>> canonical_estimator_name("lda") + 'LinearDiscriminantAnalysis' + + See Also + -------- + ESTIMATOR_SPECS : The registry of estimator specifications. + """ + aliases = { + "logistic_regression": "LogisticRegression", + "random_forest_classifier": "RandomForestClassifier", + "extra_trees_classifier": "ExtraTreesClassifier", + "linear_svc": "LinearSVC", + "lda": "LinearDiscriminantAnalysis", + "dummy_classifier": "DummyClassifier", + "ridge": "Ridge", + "random_forest_regressor": "RandomForestRegressor", + "extra_trees_regressor": "ExtraTreesRegressor", + "hist_gradient_boosting_classifier": "HistGradientBoostingClassifier", + "hist_gradient_boosting_regressor": "HistGradientBoostingRegressor", + } + return aliases.get(name, name) diff --git a/coco_pipe/decoding/_splitters.py b/coco_pipe/decoding/_splitters.py new file mode 100644 index 0000000..7c7189c --- /dev/null +++ b/coco_pipe/decoding/_splitters.py @@ -0,0 +1,380 @@ +""" +Decoding Splitters +================== + +Internal cross-validation splitters for the decoding module. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence, Union + +import numpy as np +import pandas as pd +from sklearn.model_selection import ( + BaseCrossValidator, + GroupKFold, + GroupShuffleSplit, + KFold, + LeaveOneGroupOut, + LeavePGroupsOut, + StratifiedGroupKFold, + StratifiedKFold, + TimeSeriesSplit, + train_test_split, +) + +from ._constants import GROUP_CV_STRATEGIES, MetricTask +from .configs import CVConfig + + +class _CVWithGroups(BaseCrossValidator): + """ + Bind fixed groups to a cross-validator. + + This wrapper ensures that the same group array is supplied whenever + ``split`` or ``get_n_splits`` is called. It is particularly useful + for scientific validation where group identities (e.g., Subject IDs) + must be strictly preserved across nested CV folds to prevent leakage. + + Parameters + ---------- + cv : BaseCrossValidator + The underlying scikit-learn splitter to wrap. + groups : Any + The group labels (e.g., Subject IDs) to bind to the splitter. + """ + + def __init__(self, cv: BaseCrossValidator, groups: Any): + """ + Initialize the group wrapper. + + Parameters + ---------- + cv : BaseCrossValidator + The underlying scikit-learn splitter to wrap. + groups : Any + The group labels (e.g., Subject IDs) to bind to the splitter. + """ + self.cv = cv + self.groups = ( + groups + if isinstance(groups, (np.ndarray, pd.Series)) + else np.asarray(groups) + ) + + def _get_effective_groups(self, X: Any, groups: Any = None) -> Any: + """ + Return groups aligned to ``X`` or fail before leakage can occur. + + Scientific rationale: In nested cross-validation, `X` is often a + subset (fold) of the original data. This method ensures that the + groups being passed to the inner splitter match the rows of `X`. + + Parameters + ---------- + X : array-like + The data being split. + groups : array-like, optional + Explicit groups to use. If provided, these take precedence. + + Returns + ------- + aligned_groups : array-like + The groups array aligned to `X`. + + Raises + ------ + ValueError + If group lengths do not match `X`. + """ + if groups is not None: + groups_arr = ( + groups + if isinstance(groups, (np.ndarray, pd.Series)) + else np.asarray(groups) + ) + if X is not None and len(groups_arr) != len(X): + raise ValueError( + "Explicit groups length does not match X for CV splitting: " + f"{len(groups_arr)} != {len(X)}." + ) + return groups_arr + + if X is None: + return self.groups + + if len(X) == len(self.groups): + return self.groups + + if hasattr(X, "index") and hasattr(self.groups, "loc"): + try: + aligned = self.groups.loc[X.index] + if len(aligned) == len(X): + return aligned + except Exception: + pass + + raise ValueError( + "Bound groups length does not match X for CV splitting: " + f"{len(self.groups)} != {len(X)}. Pass a groups array aligned to " + "the X being split instead of reusing a full-array group binding " + "inside nested CV." + ) + + def split(self, X: Any, y: Any = None, groups: Any = None): + """Generate indices to split data into training and test set.""" + return self.cv.split(X, y, self._get_effective_groups(X, groups)) + + def get_n_splits(self, X: Any = None, y: Any = None, groups: Any = None) -> int: + """Return the number of splitting iterations in the cross-validator.""" + if X is not None: + effective_groups = self._get_effective_groups(X, groups) + else: + effective_groups = groups if groups is not None else self.groups + return self.cv.get_n_splits(X, y, effective_groups) + + def __sklearn_tags__(self) -> Dict[str, Any]: + """Sklearn 1.6+ compatibility for estimator tags.""" + tags = getattr(self.cv, "__sklearn_tags__", lambda: {})() + if not tags: + tags = getattr(self.cv, "_get_tags", lambda: {})() + return {**tags, "non_deterministic": tags.get("non_deterministic", False)} + + def _get_tags(self) -> Dict[str, Any]: + """Legacy sklearn tag support.""" + return self.__sklearn_tags__() + + def get_params(self, deep: bool = True) -> Dict[str, Any]: + """Get parameters for this estimator.""" + return {"cv": self.cv, "groups": self.groups} + + def __repr__(self) -> str: + return f"_CVWithGroups(cv={self.cv!r})" + + +def cv_uses_groups(cv: Any) -> bool: + """Return True for runtime CV splitter objects that consume groups.""" + return isinstance( + cv, + ( + _CVWithGroups, + GroupKFold, + StratifiedGroupKFold, + LeaveOneGroupOut, + LeavePGroupsOut, + GroupShuffleSplit, + ), + ) + + +class SimpleSplit(BaseCrossValidator): + """ + One-shot train/test split using ``train_test_split``. + + This provides a scikit-learn compatible interface for a single hold-out + validation set. It is often used for final model evaluation or when + the dataset is large enough that K-Fold CV is computationally + prohibitive. + + Parameters + ---------- + test_size : float, default=0.2 + The proportion of the dataset to include in the test split. + shuffle : bool, default=True + Whether to shuffle the data before splitting. + random_state : int, optional + Random seed for reproducibility. + stratify : bool or array-like, optional + If True, use 'y' for stratification. If an array, use it directly. + """ + + def __init__( + self, + test_size: float = 0.2, + shuffle: bool = True, + random_state: Optional[int] = None, + stratify: Optional[Union[bool, pd.Series, np.ndarray]] = None, + ): + """ + Initialize the hold-out splitter. + + Parameters + ---------- + test_size : float, default=0.2 + The proportion of the dataset to include in the test split. + shuffle : bool, default=True + Whether to shuffle the data before splitting. + random_state : int, optional + Random seed for reproducibility. + stratify : bool or array-like, optional + If True, use 'y' for stratification. If an array, use it directly. + """ + if not (0 < test_size < 1): + raise ValueError("test_size must be between 0 and 1.") + self.test_size = test_size + self.shuffle = shuffle + self.random_state = random_state + self.stratify = stratify + + def split( + self, + X: Union[pd.DataFrame, np.ndarray], + y: Optional[Union[pd.Series, np.ndarray]] = None, + groups: Optional[Sequence] = None, + ): + """Generate indices for a single hold-out split.""" + idx = np.arange(len(X)) + if self.stratify is True: + strat = y + elif self.stratify is False: + strat = None + else: + strat = self.stratify + + train_idx, test_idx = train_test_split( + idx, + test_size=self.test_size, + shuffle=self.shuffle, + random_state=self.random_state if self.shuffle else None, + stratify=strat, + ) + yield train_idx, test_idx + + def get_n_splits( + self, + X: Any = None, + y: Any = None, + groups: Any = None, + ) -> int: + """Return the number of splits (always 1).""" + return 1 + + def __sklearn_tags__(self) -> Dict[str, Any]: + """Sklearn 1.6+ compatibility for estimator tags.""" + return {"non_deterministic": self.shuffle} + + def _get_tags(self) -> Dict[str, Any]: + """Legacy sklearn tag support.""" + return self.__sklearn_tags__() + + def __repr__(self) -> str: + return f"SimpleSplit(test_size={self.test_size}, shuffle={self.shuffle})" + + +def get_cv_splitter( + config: CVConfig, + groups: Optional[Sequence[Any]] = None, + y: Optional[Sequence[Any]] = None, + task: Optional[MetricTask] = None, + require_groups: bool = True, +) -> BaseCrossValidator: + """ + Create a scikit-learn cross-validator from ``CVConfig``. + + This factory handles the mapping from strategy names to scikit-learn + splitter objects and performs validation to ensure the strategy is + scientifically sound for the given task (e.g., preventing + stratification for regression). + + Parameters + ---------- + config : CVConfig + The cross-validation configuration. + groups : Sequence, optional + Grouping labels for the samples. Required for group-based strategies. + y : Sequence, optional + Target labels. Used for stratified strategies. + task : MetricTask, optional + The task type ('classification' or 'regression'). Used to validate + stratification requests. + require_groups : bool, default=True + Whether to raise an error if groups are missing for group-based strategies. + + Returns + ------- + splitter : BaseCrossValidator + A configured scikit-learn compatible splitter. + + Raises + ------ + ValueError + If the strategy is unknown, if groups are missing for a group strategy, + or if stratification is requested for a regression task. + + Examples + -------- + >>> from coco_pipe.decoding.configs import CVConfig + >>> cfg = CVConfig(strategy="stratified", n_splits=5) + >>> splitter = get_cv_splitter(cfg) + + See Also + -------- + SimpleSplit : One-shot holdout splitter. + _CVWithGroups : Wrapper for binding groups to splitters. + """ + strat = config.strategy.lower() + + # 1. Scientific Validation + if strat in GROUP_CV_STRATEGIES and require_groups and groups is None: + raise ValueError(f"CV strategy '{config.strategy}' requires groups.") + + if task == "regression" and "stratified" in strat: + raise ValueError( + f"Stratified CV strategy '{config.strategy}' is not supported for " + "regression tasks. Use 'kfold' or 'group_kfold' instead." + ) + + common_kwargs = {} + if strat not in ["leave_one_group_out", "leave_p_out", "split", "timeseries"]: + common_kwargs["n_splits"] = config.n_splits + + if strat in ["stratified", "kfold", "stratified_group_kfold", "split"]: + common_kwargs["shuffle"] = config.shuffle + common_kwargs["random_state"] = config.random_state if config.shuffle else None + + # 2. Factory Logic + if strat == "stratified": + splitter = StratifiedKFold(**common_kwargs) + elif strat == "kfold": + splitter = KFold(**common_kwargs) + elif strat == "group_kfold": + splitter = GroupKFold(n_splits=config.n_splits) + elif strat == "stratified_group_kfold": + splitter = StratifiedGroupKFold(**common_kwargs) + elif strat == "leave_p_out": + splitter = LeavePGroupsOut(n_groups=config.n_splits) + elif strat == "leave_one_group_out": + splitter = LeaveOneGroupOut() + elif strat == "group_shuffle_split": + splitter = GroupShuffleSplit( + n_splits=config.n_splits, + test_size=config.test_size, + random_state=config.random_state, + ) + elif strat == "timeseries": + splitter = TimeSeriesSplit(n_splits=config.n_splits) + elif strat == "split": + splitter = SimpleSplit( + test_size=config.test_size, + shuffle=config.shuffle, + random_state=config.random_state, + stratify=y if config.stratify and y is not None else config.stratify, + ) + else: + raise ValueError(f"Unknown CV strategy: {config.strategy}") + + # 3. Contextual Wrapping + if groups is not None: + if strat not in GROUP_CV_STRATEGIES: + import warnings + + warnings.warn( + f"CV groups were provided but the strategy '{config.strategy}' is not " + "group-aware. Groups will be bound for technical compatibility but " + "ignored during splitting logic.", + UserWarning, + ) + splitter = _CVWithGroups(splitter, groups) + + return splitter diff --git a/coco_pipe/decoding/configs.py b/coco_pipe/decoding/configs.py index df982a4..1354a08 100644 --- a/coco_pipe/decoding/configs.py +++ b/coco_pipe/decoding/configs.py @@ -3,16 +3,18 @@ ======================= Comprehensive Pydantic models for strict validation of Decoding/ML experiments. - -Key Components: -- ModelConfigs: extensive hyperparameters for each estimator. -- ExperimentConfig: Top-level configuration for the entire analysis workflow. +These models ensure that all parameters are scientifically sound before +any computation begins. """ +from __future__ import annotations + from pathlib import Path from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ._constants import MetricTask # --- Base Schemas --- @@ -21,9 +23,12 @@ class BaseEstimatorConfig(BaseModel): """Base configuration for any estimator.""" model_config = ConfigDict(extra="forbid") - random_state: Optional[int] = Field( - 42, description="Random seed for reproducibility." - ) + + +class ClassicalEstimatorConfig(BaseEstimatorConfig): + """Base for scikit-learn compatible classical estimators.""" + + kind: Literal["classical"] = "classical" # --- Mixins --- @@ -37,14 +42,17 @@ class LinearMixin(BaseModel): n_jobs: Optional[int] = None -class RegularizedLinearMixin(LinearMixin): +class RegularizedLinearMixin(BaseModel): """Parameters for regularized linear models.""" + fit_intercept: bool = True + copy_X: bool = True tol: float = 1e-3 max_iter: Optional[int] = None - solver: str = "auto" - warm_start: bool = False positive: bool = False + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) class TreeMixin(BaseModel): @@ -60,6 +68,9 @@ class TreeMixin(BaseModel): min_impurity_decrease: float = 0.0 ccp_alpha: float = 0.0 n_jobs: Optional[int] = None + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) verbose: int = 0 warm_start: bool = False @@ -80,8 +91,10 @@ class SupportVectorMixin(BaseModel): class SGDMixin(BaseModel): + """Stochastic Gradient Descent parameters.""" + loss: str = "hinge" - penalty: Literal["l2", "l1", "elasticnet", "null"] = "l2" + penalty: Optional[Literal["l2", "l1", "elasticnet"]] = "l2" alpha: float = 0.0001 l1_ratio: float = 0.15 fit_intercept: bool = True @@ -90,23 +103,76 @@ class SGDMixin(BaseModel): shuffle: bool = True verbose: int = 0 epsilon: float = 0.1 - n_jobs: Optional[int] = None learning_rate: str = "optimal" - eta0: float = 0.0 + eta0: float = 0.01 power_t: float = 0.5 early_stopping: bool = False + + +class MLPMixin(BaseModel): + """Common parameters for Multi-layer Perceptron models.""" + + hidden_layer_sizes: tuple = (100,) + activation: Literal["identity", "logistic", "tanh", "relu"] = "relu" + solver: Literal["lbfgs", "sgd", "adam"] = "adam" + alpha: float = 0.0001 + batch_size: Union[int, str] = "auto" + learning_rate: Literal["constant", "invscaling", "adaptive"] = "constant" + learning_rate_init: float = 0.001 + power_t: float = 0.5 + max_iter: int = 200 + shuffle: bool = True + tol: float = 1e-4 + verbose: bool = False + warm_start: bool = False + momentum: float = 0.9 + nesterovs_momentum: bool = True + early_stopping: bool = False validation_fraction: float = 0.1 - n_iter_no_change: int = 5 + beta_1: float = 0.9 + beta_2: float = 0.999 + epsilon: float = 1e-8 + n_iter_no_change: int = 10 + max_fun: int = 15000 + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + + +class GradientBoostingMixin(BaseModel): + """Common parameters for Gradient Boosting models.""" + + learning_rate: float = 0.1 + n_estimators: int = 100 + subsample: float = 1.0 + criterion: Literal["friedman_mse", "squared_error"] = "friedman_mse" + min_samples_split: Union[int, float] = 2 + min_samples_leaf: Union[int, float] = 1 + min_weight_fraction_leaf: float = 0.0 + max_depth: int = 3 + min_impurity_decrease: float = 0.0 + init: Optional[str] = None + max_features: Union[str, int, float, None] = None + verbose: int = 0 + max_leaf_nodes: Optional[int] = None warm_start: bool = False - average: bool = False + validation_fraction: float = 0.1 + n_iter_no_change: int = 5 + tol: float = 1e-4 + ccp_alpha: float = 0.0 + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) # --- Classifiers --- -class LogisticRegressionConfig(BaseEstimatorConfig): +class LogisticRegressionConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.linear_model.LogisticRegression.""" + method: Literal["LogisticRegression"] = "LogisticRegression" - penalty: Literal["l1", "l2", "elasticnet", "none", None] = "l2" + penalty: Literal["l1", "l2", "elasticnet", None] = "l2" dual: bool = False tol: float = 1e-4 C: float = Field(1.0, gt=0.0) @@ -115,14 +181,18 @@ class LogisticRegressionConfig(BaseEstimatorConfig): class_weight: Optional[Union[Dict, str]] = None solver: Literal["newton-cg", "lbfgs", "liblinear", "sag", "saga"] = "lbfgs" max_iter: int = 100 - multiclass: Literal["auto", "ovr", "multinomial"] = "auto" verbose: int = 0 warm_start: bool = False n_jobs: Optional[int] = None l1_ratio: Optional[float] = None + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + +class RandomForestClassifierConfig(ClassicalEstimatorConfig, TreeMixin): + """Configuration for sklearn.ensemble.RandomForestClassifier.""" -class RandomForestClassifierConfig(BaseEstimatorConfig, TreeMixin): method: Literal["RandomForestClassifier"] = "RandomForestClassifier" criterion: Literal["gini", "entropy", "log_loss"] = "gini" bootstrap: bool = True @@ -131,15 +201,42 @@ class RandomForestClassifierConfig(BaseEstimatorConfig, TreeMixin): max_samples: Optional[Union[int, float]] = None -class SVCConfig(BaseEstimatorConfig, SupportVectorMixin): +class SVCConfig(ClassicalEstimatorConfig, SupportVectorMixin): + """Configuration for sklearn.svm.SVC.""" + method: Literal["SVC"] = "SVC" probability: bool = True # Default to True for metrics requiring proba class_weight: Optional[Union[Dict, str]] = None decision_function_shape: Literal["ovo", "ovr"] = "ovr" break_ties: bool = False + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + + +class LinearSVCConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.svm.LinearSVC.""" + + method: Literal["LinearSVC"] = "LinearSVC" + penalty: Literal["l1", "l2"] = "l2" + loss: Literal["hinge", "squared_hinge"] = "squared_hinge" + dual: Union[bool, Literal["auto"]] = "auto" + tol: float = 1e-4 + C: float = Field(1.0, gt=0.0) + multi_class: Literal["ovr", "crammer_singer"] = "ovr" + fit_intercept: bool = True + intercept_scaling: float = 1.0 + class_weight: Optional[Union[Dict, str]] = None + verbose: int = 0 + max_iter: int = 1000 + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + +class KNeighborsClassifierConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.neighbors.KNeighborsClassifier.""" -class KNeighborsClassifierConfig(BaseEstimatorConfig): method: Literal["KNeighborsClassifier"] = "KNeighborsClassifier" n_neighbors: int = Field(5, ge=1) weights: Literal["uniform", "distance"] = "uniform" @@ -151,67 +248,61 @@ class KNeighborsClassifierConfig(BaseEstimatorConfig): n_jobs: Optional[int] = None -class GradientBoostingClassifierConfig(BaseEstimatorConfig): +class GradientBoostingClassifierConfig(ClassicalEstimatorConfig, GradientBoostingMixin): + """Configuration for sklearn.ensemble.GradientBoostingClassifier.""" + method: Literal["GradientBoostingClassifier"] = "GradientBoostingClassifier" loss: Literal["log_loss", "exponential"] = "log_loss" + + +class HistGradientBoostingClassifierConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.ensemble.HistGradientBoostingClassifier.""" + + method: Literal["HistGradientBoostingClassifier"] = "HistGradientBoostingClassifier" learning_rate: float = 0.1 - n_estimators: int = 100 - subsample: float = 1.0 - criterion: Literal["friedman_mse", "squared_error"] = "friedman_mse" - min_samples_split: Union[int, float] = 2 - min_samples_leaf: Union[int, float] = 1 - min_weight_fraction_leaf: float = 0.0 - max_depth: int = 3 - min_impurity_decrease: float = 0.0 - init: Optional[str] = None - max_features: Union[str, int, float, None] = None - verbose: int = 0 - max_leaf_nodes: Optional[int] = None + max_iter: int = 100 + max_leaf_nodes: int = 31 + max_depth: Optional[int] = None + min_samples_leaf: int = 20 + l2_regularization: float = 0.0 + max_bins: int = 255 + categorical_features: Optional[Union[List[int], List[str], List[bool]]] = None + monotonic_cst: Optional[Any] = None + interaction_cst: Optional[Any] = None warm_start: bool = False + early_stopping: Union[bool, Literal["auto"]] = "auto" + scoring: Optional[str] = "loss" validation_fraction: float = 0.1 - n_iter_no_change: Optional[int] = None - tol: float = 1e-4 - ccp_alpha: float = 0.0 + n_iter_no_change: int = 10 + tol: float = 1e-7 + verbose: int = 0 + random_state: Optional[int] = None + +class SGDClassifierConfig(ClassicalEstimatorConfig, SGDMixin): + """Configuration for sklearn.linear_model.SGDClassifier.""" -class SGDClassifierConfig(BaseEstimatorConfig, SGDMixin): method: Literal["SGDClassifier"] = "SGDClassifier" class_weight: Optional[Union[Dict, str]] = None -class MLPClassifierConfig(BaseEstimatorConfig): +class MLPClassifierConfig(ClassicalEstimatorConfig, MLPMixin): + """Configuration for sklearn.neural_network.MLPClassifier.""" + method: Literal["MLPClassifier"] = "MLPClassifier" - hidden_layer_sizes: tuple = (100,) - activation: Literal["identity", "logistic", "tanh", "relu"] = "relu" - solver: Literal["lbfgs", "sgd", "adam"] = "adam" - alpha: float = 0.0001 - batch_size: Union[int, str] = "auto" - learning_rate: Literal["constant", "invscaling", "adaptive"] = "constant" - learning_rate_init: float = 0.001 - power_t: float = 0.5 - max_iter: int = 200 - shuffle: bool = True - tol: float = 1e-4 - verbose: bool = False - warm_start: bool = False - momentum: float = 0.9 - nesterovs_momentum: bool = True - early_stopping: bool = False - validation_fraction: float = 0.1 - beta_1: float = 0.9 - beta_2: float = 0.999 - epsilon: float = 1e-8 - n_iter_no_change: int = 10 - max_fun: int = 15000 -class GaussianNBConfig(BaseEstimatorConfig): +class GaussianNBConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.naive_bayes.GaussianNB.""" + method: Literal["GaussianNB"] = "GaussianNB" priors: Optional[List[float]] = None var_smoothing: float = 1e-9 -class LDAConfig(BaseEstimatorConfig): +class LDAConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.discriminant_analysis.LinearDiscriminantAnalysis.""" + method: Literal["LinearDiscriminantAnalysis"] = "LinearDiscriminantAnalysis" solver: Literal["svd", "lsqr", "eigen"] = "svd" shrinkage: Optional[Union[str, float]] = None @@ -221,17 +312,26 @@ class LDAConfig(BaseEstimatorConfig): tol: float = 1e-4 -class AdaBoostClassifierConfig(BaseEstimatorConfig): +class AdaBoostClassifierConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.ensemble.AdaBoostClassifier.""" + method: Literal["AdaBoostClassifier"] = "AdaBoostClassifier" n_estimators: int = 50 learning_rate: float = 1.0 - algorithm: Literal["SAMME", "SAMME.R"] = "SAMME.R" + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + +class DummyClassifierConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.dummy.DummyClassifier.""" -class DummyClassifierConfig(BaseEstimatorConfig): method: Literal["DummyClassifier"] = "DummyClassifier" strategy: Literal["stratified", "most_frequent", "prior", "uniform"] = "prior" constant: Optional[Any] = None + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) # --- Deep Learning / Foundation Models --- @@ -273,12 +373,12 @@ class SkorchClassifierConfig(BaseEstimatorConfig): class SlidingEstimatorConfig(BaseEstimatorConfig): """ - Configuration for MNE's SlidingEstimator. + Configuration for MNE-style SlidingEstimator. Fits a separate estimator for each time point. """ method: Literal["SlidingEstimator"] = "SlidingEstimator" - base_estimator: "EstimatorConfigType" + base_estimator: EstimatorConfigType scoring: Optional[Union[str, Callable]] = None n_jobs: Optional[int] = 1 position: Optional[float] = 0 @@ -288,12 +388,12 @@ class SlidingEstimatorConfig(BaseEstimatorConfig): class GeneralizingEstimatorConfig(BaseEstimatorConfig): """ - Configuration for MNE's GeneralizingEstimator. + Configuration for MNE-style GeneralizingEstimator. Fits an estimator on each time point and tests on all other time points. """ method: Literal["GeneralizingEstimator"] = "GeneralizingEstimator" - base_estimator: "EstimatorConfigType" + base_estimator: EstimatorConfigType scoring: Optional[Union[str, Callable]] = None n_jobs: Optional[int] = 1 position: Optional[float] = 0 @@ -304,28 +404,38 @@ class GeneralizingEstimatorConfig(BaseEstimatorConfig): # --- Regressors --- -class LinearRegressionConfig(BaseEstimatorConfig, LinearMixin): +class LinearRegressionConfig(ClassicalEstimatorConfig, LinearMixin): + """Configuration for sklearn.linear_model.LinearRegression.""" + method: Literal["LinearRegression"] = "LinearRegression" positive: bool = False -class RidgeConfig(BaseEstimatorConfig, RegularizedLinearMixin): +class RidgeConfig(ClassicalEstimatorConfig, RegularizedLinearMixin): + """Configuration for sklearn.linear_model.Ridge.""" + method: Literal["Ridge"] = "Ridge" alpha: float = 1.0 fit_intercept: bool = True copy_X: bool = True + solver: str = "auto" + +class LassoConfig(ClassicalEstimatorConfig, RegularizedLinearMixin): + """Configuration for sklearn.linear_model.Lasso.""" -class LassoConfig(BaseEstimatorConfig, RegularizedLinearMixin): method: Literal["Lasso"] = "Lasso" alpha: float = 1.0 precompute: Union[bool, List] = False fit_intercept: bool = True copy_X: bool = True selection: Literal["cyclic", "random"] = "cyclic" + warm_start: bool = False -class ElasticNetConfig(BaseEstimatorConfig, RegularizedLinearMixin): +class ElasticNetConfig(ClassicalEstimatorConfig, RegularizedLinearMixin): + """Configuration for sklearn.linear_model.ElasticNet.""" + method: Literal["ElasticNet"] = "ElasticNet" alpha: float = 1.0 l1_ratio: float = 0.5 @@ -333,9 +443,12 @@ class ElasticNetConfig(BaseEstimatorConfig, RegularizedLinearMixin): fit_intercept: bool = True copy_X: bool = True selection: Literal["cyclic", "random"] = "cyclic" + warm_start: bool = False -class RandomForestRegressorConfig(BaseEstimatorConfig, TreeMixin): +class RandomForestRegressorConfig(ClassicalEstimatorConfig, TreeMixin): + """Configuration for sklearn.ensemble.RandomForestRegressor.""" + method: Literal["RandomForestRegressor"] = "RandomForestRegressor" criterion: Literal["squared_error", "absolute_error", "friedman_mse", "poisson"] = ( "squared_error" @@ -345,75 +458,48 @@ class RandomForestRegressorConfig(BaseEstimatorConfig, TreeMixin): max_samples: Optional[Union[int, float]] = None -class SVRConfig(BaseEstimatorConfig, SupportVectorMixin): +class SVRConfig(ClassicalEstimatorConfig, SupportVectorMixin): + """Configuration for sklearn.svm.SVR.""" + method: Literal["SVR"] = "SVR" epsilon: float = 0.1 -class GradientBoostingRegressorConfig(BaseEstimatorConfig): +class GradientBoostingRegressorConfig(ClassicalEstimatorConfig, GradientBoostingMixin): + """Configuration for sklearn.ensemble.GradientBoostingRegressor.""" + method: Literal["GradientBoostingRegressor"] = "GradientBoostingRegressor" loss: Literal["squared_error", "absolute_error", "huber", "quantile"] = ( "squared_error" ) - learning_rate: float = 0.1 - n_estimators: int = 100 - subsample: float = 1.0 - criterion: Literal["friedman_mse", "squared_error"] = "friedman_mse" - min_samples_split: Union[int, float] = 2 - min_samples_leaf: Union[int, float] = 1 - min_weight_fraction_leaf: float = 0.0 - max_depth: int = 3 - min_impurity_decrease: float = 0.0 - init: Optional[str] = None - max_features: Union[str, int, float, None] = None alpha: float = 0.9 - verbose: int = 0 - max_leaf_nodes: Optional[int] = None - warm_start: bool = False - validation_fraction: float = 0.1 - n_iter_no_change: Optional[int] = None - tol: float = 1e-4 - ccp_alpha: float = 0.0 -class SGDRegressorConfig(BaseEstimatorConfig, SGDMixin): +class SGDRegressorConfig(ClassicalEstimatorConfig, SGDMixin): + """Configuration for sklearn.linear_model.SGDRegressor.""" + method: Literal["SGDRegressor"] = "SGDRegressor" loss: str = "squared_error" -class MLPRegressorConfig(BaseEstimatorConfig): +class MLPRegressorConfig(ClassicalEstimatorConfig, MLPMixin): + """Configuration for sklearn.neural_network.MLPRegressor.""" + method: Literal["MLPRegressor"] = "MLPRegressor" - hidden_layer_sizes: tuple = (100,) - activation: Literal["identity", "logistic", "tanh", "relu"] = "relu" - alpha: float = 0.0001 - batch_size: Union[int, str] = "auto" - learning_rate: Literal["constant", "invscaling", "adaptive"] = "constant" - learning_rate_init: float = 0.001 - power_t: float = 0.5 - max_iter: int = 200 - shuffle: bool = True - tol: float = 1e-4 - verbose: bool = False - warm_start: bool = False - momentum: float = 0.9 - nesterovs_momentum: bool = True - early_stopping: bool = False - validation_fraction: float = 0.1 - beta_1: float = 0.9 - beta_2: float = 0.999 - epsilon: float = 1e-8 - n_iter_no_change: int = 10 - max_fun: int = 15000 -class DummyRegressorConfig(BaseEstimatorConfig): +class DummyRegressorConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.dummy.DummyRegressor.""" + method: Literal["DummyRegressor"] = "DummyRegressor" strategy: Literal["mean", "median", "quantile", "constant"] = "mean" constant: Optional[Union[int, float, List]] = None quantile: Optional[float] = None -class DecisionTreeRegressorConfig(BaseEstimatorConfig): +class DecisionTreeRegressorConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.tree.DecisionTreeRegressor.""" + method: Literal["DecisionTreeRegressor"] = "DecisionTreeRegressor" criterion: Literal["squared_error", "friedman_mse", "absolute_error", "poisson"] = ( "squared_error" @@ -430,7 +516,9 @@ class DecisionTreeRegressorConfig(BaseEstimatorConfig): ccp_alpha: float = 0.0 -class KNeighborsRegressorConfig(BaseEstimatorConfig): +class KNeighborsRegressorConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.neighbors.KNeighborsRegressor.""" + method: Literal["KNeighborsRegressor"] = "KNeighborsRegressor" n_neighbors: int = Field(5, ge=1) weights: Literal["uniform", "distance"] = "uniform" @@ -442,14 +530,18 @@ class KNeighborsRegressorConfig(BaseEstimatorConfig): n_jobs: Optional[int] = None -class ExtraTreesRegressorConfig(BaseEstimatorConfig, TreeMixin): +class ExtraTreesRegressorConfig(ClassicalEstimatorConfig, TreeMixin): + """Configuration for sklearn.ensemble.ExtraTreesRegressor.""" + method: Literal["ExtraTreesRegressor"] = "ExtraTreesRegressor" bootstrap: bool = False oob_score: bool = False max_samples: Optional[Union[int, float]] = None -class HistGradientBoostingRegressorConfig(BaseEstimatorConfig): +class HistGradientBoostingRegressorConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.ensemble.HistGradientBoostingRegressor.""" + method: Literal["HistGradientBoostingRegressor"] = "HistGradientBoostingRegressor" loss: Literal["squared_error", "absolute_error", "poisson", "quantile"] = ( "squared_error" @@ -465,7 +557,7 @@ class HistGradientBoostingRegressorConfig(BaseEstimatorConfig): monotonic_cst: Optional[Any] = None interaction_cst: Optional[Any] = None warm_start: bool = False - early_stopping: str = "auto" + early_stopping: Union[bool, Literal["auto"]] = "auto" scoring: Optional[str] = "loss" validation_fraction: float = 0.1 n_iter_no_change: int = 10 @@ -474,16 +566,23 @@ class HistGradientBoostingRegressorConfig(BaseEstimatorConfig): random_state: Optional[int] = None -class AdaBoostRegressorConfig(BaseEstimatorConfig): +class AdaBoostRegressorConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.ensemble.AdaBoostRegressor.""" + method: Literal["AdaBoostRegressor"] = "AdaBoostRegressor" n_estimators: int = 50 learning_rate: float = 1.0 loss: Literal["linear", "square", "exponential"] = "linear" + random_state: Optional[int] = Field( + 42, description="Random seed for reproducibility." + ) + +class BayesianRidgeConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.linear_model.BayesianRidge.""" -class BayesianRidgeConfig(BaseEstimatorConfig): method: Literal["BayesianRidge"] = "BayesianRidge" - n_iter: int = 300 + max_iter: int = 300 tol: float = 1e-3 alpha_1: float = 1e-6 alpha_2: float = 1e-6 @@ -497,9 +596,11 @@ class BayesianRidgeConfig(BaseEstimatorConfig): verbose: bool = False -class ARDRegressionConfig(BaseEstimatorConfig): +class ARDRegressionConfig(ClassicalEstimatorConfig): + """Configuration for sklearn.linear_model.ARDRegression.""" + method: Literal["ARDRegression"] = "ARDRegression" - n_iter: int = 300 + max_iter: int = 300 tol: float = 1e-3 alpha_1: float = 1e-6 alpha_2: float = 1e-6 @@ -515,21 +616,145 @@ class ARDRegressionConfig(BaseEstimatorConfig): # --- Recursive Unions Implementation --- -# 1. Define Atomic Unions (Terminal nodes) +class ClassicalModelConfig(ClassicalEstimatorConfig): + """Configuration for standard scikit-learn estimators.""" + + method: Literal["ClassicalModel"] = "ClassicalModel" + estimator: str + params: Dict[str, Any] = Field(default_factory=dict) + input_kind: Literal["tabular", "embeddings"] = "tabular" + + +class FoundationEmbeddingModelConfig(BaseEstimatorConfig): + """Configuration for pretrained feature extraction backbones.""" + + kind: Literal["foundation_embedding"] = "foundation_embedding" + provider: Literal["dummy", "braindecode", "huggingface", "reve"] = "dummy" + model_name: str = "dummy" + input_kind: Literal["tabular", "temporal", "epoched", "embeddings", "tokens"] = ( + "epoched" + ) + pooling: Literal["mean", "flatten", "last"] = "mean" + normalize_embeddings: bool = True + cache_embeddings: bool = True + embedding_dim: Optional[int] = None + + +class LoRAConfig(BaseModel): + """Low-Rank Adaptation (LoRA) configuration.""" + + model_config = ConfigDict(extra="forbid") + + r: int = Field(16, ge=1) + alpha: int = Field(32, ge=1) + dropout: float = Field(0.0, ge=0.0, le=1.0) + target_modules: Union[str, List[str]] = "all-linear" + + +class QuantizationConfig(BaseModel): + """Model quantization settings.""" + + model_config = ConfigDict(extra="forbid") + + enabled: bool = False + load_in_4bit: bool = True + quant_type: Literal["nf4", "fp4"] = "nf4" + compute_dtype: Literal["bf16", "fp16", "fp32"] = "bf16" + + +class DeviceConfig(BaseModel): + """Compute device and precision policy.""" + + model_config = ConfigDict(extra="forbid") + + device: Literal["auto", "cpu", "cuda", "mps"] = "auto" + precision: Literal["fp32", "fp16", "bf16"] = "fp32" + + +class CheckpointConfig(BaseModel): + """Neural training checkpoint policy.""" + + model_config = ConfigDict(extra="forbid") + + save: Literal["none", "best", "last", "all"] = "best" + monitor: str = "val_loss" + output_dir: Optional[Path] = None + + +class TrainerConfig(BaseModel): + """Neural training loop configuration.""" + + model_config = ConfigDict(extra="forbid") + + max_epochs: int = Field(10, ge=1) + early_stopping_patience: Optional[int] = Field(None, ge=1) + batch_size: int = Field(32, ge=1) + validation_fraction: float = Field(0.2, ge=0.0, lt=1.0) + + +class TrainStageConfig(BaseModel): + """Single stage in a multi-stage training schedule.""" + + model_config = ConfigDict(extra="forbid") + + name: str + epochs: int = Field(..., ge=1) + train_backbone: bool = False + train_head: bool = True + + +class FrozenBackboneDecoderConfig(BaseEstimatorConfig): + """Config for a frozen backbone followed by a classical decoding head.""" + + kind: Literal["frozen_backbone"] = "frozen_backbone" + backbone: FoundationEmbeddingModelConfig + head: ClassicalModelConfig + + +class NeuralFineTuneConfig(BaseEstimatorConfig): + """Configuration for end-to-end neural fine-tuning.""" + + kind: Literal["neural_finetune"] = "neural_finetune" + provider: Literal["dummy", "braindecode", "huggingface"] = "dummy" + model_name: str = "dummy" + input_kind: Literal["temporal", "epoched", "tokens"] = "epoched" + train_mode: Literal["full", "frozen", "linear_probe", "lora", "qlora"] = "full" + optimizer: Dict[str, Any] = Field(default_factory=lambda: {"name": "adamw"}) + trainer: TrainerConfig = Field(default_factory=TrainerConfig) + device: DeviceConfig = Field(default_factory=DeviceConfig) + checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig) + lora: Optional[LoRAConfig] = None + quantization: Optional[QuantizationConfig] = None + stages: List[TrainStageConfig] = Field(default_factory=list) + + +class TemporalDecoderConfig(BaseEstimatorConfig): + """Config for MNE-style temporal meta-estimators.""" + + kind: Literal["temporal"] = "temporal" + wrapper: Literal["sliding", "generalizing"] = "sliding" + base: ClassicalModelConfig + scoring: Optional[Union[str, Callable]] = None + n_jobs: Optional[int] = 1 + position: Optional[float] = 0 + allow_2d: bool = False + verbose: Optional[Union[bool, str, int]] = None + + AtomicEstimator = Union[ LogisticRegressionConfig, RandomForestClassifierConfig, SVCConfig, + LinearSVCConfig, KNeighborsClassifierConfig, GradientBoostingClassifierConfig, + HistGradientBoostingClassifierConfig, SGDClassifierConfig, MLPClassifierConfig, GaussianNBConfig, LDAConfig, AdaBoostClassifierConfig, DummyClassifierConfig, - LPFTConfig, - SkorchClassifierConfig, # Regressors LinearRegressionConfig, RidgeConfig, @@ -550,26 +775,44 @@ class ARDRegressionConfig(BaseEstimatorConfig): ARDRegressionConfig, ] -# 2. Define the Recursive Union using Annotated Discriminator -# This allows Pydantic to choose the correct class based on 'method' field EstimatorConfigType = Annotated[ Union[AtomicEstimator, SlidingEstimatorConfig, GeneralizingEstimatorConfig], Field(discriminator="method"), ] +ClassicalModelType = Annotated[ + Union[ClassicalModelConfig, AtomicEstimator], Field(discriminator="method") +] -# --- Experiment Config --- - - -class TemporalConfig(BaseModel): - """Configuration for temporal decoding (Sliding/Generalizing).""" - enabled: bool = False - window_interaction: Literal["sliding", "generalizing"] = "sliding" +# --- Experiment Config --- class CVConfig(BaseModel): - """Cross-validation settings.""" + """ + Cross-validation configuration settings. + + Parameters + ---------- + strategy : str, default="stratified" + The splitting strategy. Note that 'stratified' strategies require + classification tasks. + n_splits : int, default=5 + Number of folds. Must be at least 2. + shuffle : bool, default=True + Whether to shuffle data before splitting. + random_state : int, default=42 + Random seed for the splitter. + test_size : float, default=0.2 + The proportion of the dataset to include in the test split for + strategy='split'. + stratify : bool, default=False + Whether to use stratified sampling for strategy='split'. + group_key : str, optional + The column name in sample_metadata to use for group-aware strategies. + """ + + model_config = ConfigDict(extra="forbid") strategy: Literal[ "stratified", @@ -577,51 +820,185 @@ class CVConfig(BaseModel): "group_kfold", "stratified_group_kfold", "leave_p_out", - "leave_one_out", + "leave_one_group_out", "timeseries", "split", + "group_shuffle_split", ] = "stratified" - n_splits: int = Field(5, ge=2) + n_splits: int = Field(5, ge=1) shuffle: bool = True random_state: int = 42 + test_size: float = Field( + 0.2, gt=0.0, lt=1.0, description="Holdout size for strategy='split'." + ) + stratify: bool = Field( + False, description="Whether strategy='split' should stratify by y." + ) + group_key: Optional[str] = Field( + None, description="sample_metadata column used by grouped CV strategies." + ) class TuningConfig(BaseModel): """ Hyperparameter Tuning Configuration. - Use this to define HOW to search (random vs grid). - The WHAT (the grid itself) is passed in ExperimentConfig.grids. """ enabled: bool = False search_type: Literal["grid", "random"] = "grid" - n_iter: int = Field(10, description="Number of iterations for random search") + n_iter: int = Field(10, ge=1, description="Number of iterations for random search") scoring: Optional[str] = None # Metric to optimize (defaults to first in list) n_jobs: int = -1 + random_state: Optional[int] = Field( + 42, description="Random seed used by RandomizedSearchCV." + ) + cv: Optional[CVConfig] = Field( + None, + description=( + "Inner CV used for model selection. Defaults to the outer CV family." + ), + ) + allow_nongroup_inner_cv: bool = Field( + False, + description=( + "Allow a non-grouped tuning CV under grouped outer CV. This explicitly " + "acknowledges the leakage/generalization trade-off." + ), + ) class FeatureSelectionConfig(BaseModel): - """Configuration for Sequential Feature Selection.""" + """Feature selection settings.""" enabled: bool = False method: Literal["k_best", "sfs"] = "sfs" - n_features: Optional[int] = Field(None, description="Number of features to select.") + n_features: Optional[int] = Field(None, gt=0, description="Number of features.") direction: Literal["forward", "backward"] = "forward" - cv: Optional[int] = Field(None, description="Inner CV splits. Required for SFS.") + cv: Optional[CVConfig] = Field( + None, + description=( + "Inner CV used by SequentialFeatureSelector. Defaults to tuning.cv " + "when available, otherwise the outer CV family." + ), + ) scoring: Optional[str] = None + allow_nongroup_inner_cv: bool = Field( + False, + description=( + "Allow a non-grouped SFS CV under grouped outer CV. This explicitly " + "acknowledges the leakage/generalization trade-off." + ), + ) + + +class CalibrationConfig(BaseModel): + """Probability calibration settings for classification estimators.""" + + enabled: bool = False + method: Literal["sigmoid", "isotonic"] = "sigmoid" + cv: Optional[CVConfig] = Field( + None, + description=( + "Inner CV used by CalibratedClassifierCV. Defaults to the outer CV family." + ), + ) + n_jobs: Optional[int] = None + allow_nongroup_inner_cv: bool = Field( + False, + description=("Allow a non-grouped calibration CV under grouped outer CV."), + ) + + +class ConfidenceIntervalConfig(BaseModel): + """Analytical confidence interval settings.""" + + model_config = ConfigDict(extra="forbid") + + alpha: float = Field(0.05, gt=0.0, lt=1.0) + method: Literal["wilson", "clopper_pearson"] = "wilson" + + +class ChanceAssessmentConfig(BaseModel): + """Null-hypothesis (chance level) assessment settings.""" + + model_config = ConfigDict(extra="forbid") + + method: Literal["permutation", "binomial", "auto"] = "permutation" + n_permutations: int = Field(1000, ge=1) + p0: Union[float, Literal["auto"], None] = Field( + "auto", + description="Chance level for binomial test (e.g., 0.5 for binary).", + ) + temporal_correction: Literal["max_stat", "fdr_bh", "none"] = Field( + "max_stat", description="Method to correct for multiple comparisons." + ) + store_null_distribution: bool = False + + +class StatisticalAssessmentConfig(BaseModel): + """ + Settings for finite-sample statistical inference and uncertainty estimation. + """ + + model_config = ConfigDict(extra="forbid") + + enabled: bool = False + random_state: Optional[int] = 42 + metrics: Optional[List[str]] = Field( + None, description="Subset of experiment metrics to run assessment for." + ) + + chance: ChanceAssessmentConfig = Field(default_factory=ChanceAssessmentConfig) + confidence_intervals: ConfidenceIntervalConfig = Field( + default_factory=ConfidenceIntervalConfig + ) + + unit_of_inference: Optional[ + Literal["sample", "group_mean", "group_majority", "custom"] + ] = Field( + None, + description="Independent unit for label permutation or binomial counts.", + ) + custom_unit_column: Optional[str] = None + custom_aggregation: Literal["mean", "majority"] = "mean" + + +ModelConfigType = Annotated[ + Union[ + ClassicalModelType, + FoundationEmbeddingModelConfig, + FrozenBackboneDecoderConfig, + NeuralFineTuneConfig, + TemporalDecoderConfig, + ], + Field(discriminator="kind"), +] class ExperimentConfig(BaseModel): """ Master configuration for a Decoding Experiment. + + This model serves as the single source of truth for an entire analysis, + including data handling, model selection, hyperparameter tuning, + feature selection, and statistical inference. """ - task: Literal["classification", "regression"] = "classification" + model_config = ConfigDict(extra="forbid") + + task: MetricTask = "classification" output_dir: Optional[Path] = None tag: str = "experiment" + random_state: Optional[int] = Field( + None, + description=( + "Master random seed. If set, it is used to derive seeds for all " + "components (CV, Tuning, Models, etc.) ensuring global reproducibility." + ), + ) # Map of Friendly Name -> Polymorphic Config Object - models: Dict[str, EstimatorConfigType] + models: Dict[str, ModelConfigType] # Map of Friendly Name -> Parameter Grid (Search Space) grids: Optional[Dict[str, Dict[str, List[Any]]]] = None @@ -631,16 +1008,97 @@ class ExperimentConfig(BaseModel): feature_selection: FeatureSelectionConfig = Field( default_factory=FeatureSelectionConfig ) + calibration: CalibrationConfig = Field(default_factory=CalibrationConfig) + evaluation: StatisticalAssessmentConfig = Field( + default_factory=StatisticalAssessmentConfig + ) metrics: List[str] = Field( default_factory=lambda: ["accuracy", "roc_auc"], description="List of metrics to compute.", ) - temporal: TemporalConfig = Field(default_factory=TemporalConfig) - use_scaler: bool = Field( True, description="Whether to scalar normalize features upstream." ) n_jobs: int = -1 verbose: bool = True + + def get_all_evaluation_metrics(self) -> list[str]: + """Union of primary experiment metrics and stats-specific metrics.""" + primary = list(self.metrics) + eval_metrics = self.evaluation.metrics or [] + return sorted(set(primary + list(eval_metrics))) + + @model_validator(mode="after") + def _validate_task_consistency(self) -> ExperimentConfig: + """Ensure CV strategies and metrics match the selected task.""" + # 1. Validate Metrics + from ._metrics import get_metric_names, get_metric_spec + + registered_metrics = get_metric_names() + + for metric in self.get_all_evaluation_metrics(): + if metric in registered_metrics: + spec = get_metric_spec(metric) + if spec.task != self.task: + raise ValueError( + f"Metric '{metric}' is for {spec.task} but experiment task " + f"is {self.task}." + ) + # Metrics not in the registry are assumed to be custom callables or + # user-defined strings handled by the execution engine at runtime. + + # 2. Validate CV Strategy + stratified_strategies = {"stratified", "stratified_group_kfold"} + if self.task == "regression" and self.cv.strategy in stratified_strategies: + raise ValueError( + f"CV strategy '{self.cv.strategy}' is not valid for regression." + ) + + # 3. Validate Calibration task + if self.calibration.enabled and self.task != "classification": + raise ValueError("calibration is only available for classification.") + + # 4. Validate Tuning Metrics + if self.tuning.enabled and self.tuning.scoring: + if self.tuning.scoring in registered_metrics: + spec = get_metric_spec(self.tuning.scoring) + if spec.task != self.task: + raise ValueError( + f"Tuning metric '{self.tuning.scoring}' is for {spec.task} " + f"but task is {self.task}." + ) + + # 4. Validate Tuning CV + if self.tuning.enabled and self.tuning.cv is None: + if self.tuning.allow_nongroup_inner_cv: + self.tuning.cv = self.cv + else: + raise ValueError( + "Tuning is enabled but tuning.cv is not defined. " + "You must explicitly define an inner CV strategy to ensure " + "scientific validity and acknowledge computational cost." + ) + + # 5. Validate FS CV (MANDATORY if SFS) + if ( + self.feature_selection.enabled + and self.feature_selection.method == "sfs" + and self.feature_selection.cv is None + ): + if self.feature_selection.allow_nongroup_inner_cv: + self.feature_selection.cv = self.cv + else: + raise ValueError( + "Sequential Feature Selection (SFS) is enabled but " + "feature_selection.cv is not defined." + ) + + # 6. Validate Calibration CV (MANDATORY) + if self.calibration.enabled and self.calibration.cv is None: + raise ValueError( + "Calibration is enabled but calibration.cv is not defined." + ) + + return self diff --git a/coco_pipe/decoding/core.py b/coco_pipe/decoding/core.py deleted file mode 100644 index 1eb3ed8..0000000 --- a/coco_pipe/decoding/core.py +++ /dev/null @@ -1,1080 +0,0 @@ -""" -Decoding Core -============= -This module is responsible for: -1. Orchestrating the Cross-Validation loop. -2. Managing Estimator lifecycles (instantiation, fitting, prediction). -3. Computing metrics dynamically based on task type. -4. Aggregating results for downstream analysis. -""" - -import atexit -import logging -import time -from collections import defaultdict -from datetime import datetime -from pathlib import Path -from shutil import rmtree -from tempfile import mkdtemp -from typing import Any, Dict, Optional, Union - -import joblib -import numpy as np -import pandas as pd -from sklearn.base import BaseEstimator, clone -from sklearn.feature_selection import ( - SelectKBest, - SequentialFeatureSelector, - f_classif, - f_regression, -) -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler -from sklearn.utils.multiclass import type_of_target - -from ..report.provenance import get_package_version -from .configs import ExperimentConfig -from .registry import get_estimator_cls -from .utils import get_cv_splitter, get_scorer - -logger = logging.getLogger(__name__) - - -class Experiment: - """ - Main executor for decoding experiments. - - Parameters - ---------- - config : ExperimentConfig - The complete configuration for the experiment. - """ - - def __init__(self, config: ExperimentConfig): - self.config = config - self.results: Dict[str, Any] = {} - self._validate_config() - - def _validate_config(self): - """ - Perform comprehensive runtime validation of the configuration. - - Logic - ----- - 1. **Tuning Consistency**: Warns if `tuning.enabled` but no `grids` - are provided. - 2. **Task vs Metrics**: Checks if metrics match the task (e.g. no 'accuracy' - for regression). Raises ValueError if incompatible. - 3. **Task vs CV**: Checks if CV strategy matches task (e.g. no 'stratified' - for regression). Raises ValueError if incompatible. - 4. **Task vs Model**: Heuristic check for model type (e.g. no Regressor for - Classification). Raises ValueError if incompatible. - - Raises - ------ - ValueError - If configuration contains incompatible settings. - """ - task = self.config.task - - # 1. Tuning Consistency - if self.config.tuning.enabled and not self.config.grids: - logger.warning( - "Hyperparameter tuning is enabled but no 'grids' are defined in the " - "config." - ) - - # 2. Task vs Metrics - # Define forbidden substrings for each task - forbidden_metrics = { - "classification": ["r2", "squared_error", "absolute_error"], - "regression": [ - "accuracy", - "roc_auc", - "f1", - "precision", - "recall", - "log_loss", - ], - } - - for metric in self.config.metrics: - # Check internal sklearn/scorer names - if any(bad in metric for bad in forbidden_metrics.get(task, [])): - suggestions = ( - forbidden_metrics["regression"] - if task == "classification" - else forbidden_metrics["classification"] - ) - raise ValueError( - f"Metric '{metric}' is incompatible with task '{task}'. " - f"Please choose suitable metrics (e.g., {suggestions}...)" - ) - - # 3. Task vs CV Strategy - if task == "regression": - if "stratified" in self.config.cv.strategy: - raise ValueError( - f"CV strategy '{self.config.cv.strategy}' implies stratification, " - f"which is invalid for regression tasks." - ) - - # 4. Task vs Model Type - # We infer type from the config class name or method string - for name, model_cfg in self.config.models.items(): - method_name = model_cfg.method.lower() - - if task == "classification": - if "regressor" in method_name or "regression" in method_name: - # Exception: LogisticRegression is a classifier - if "logistic" not in method_name: - raise ValueError( - f"Model '{name}' ({model_cfg.method}) appears to be a " - f"Regressor, but task is 'classification'." - ) - - elif task == "regression": - if ( - "classifier" in method_name - or "svc" in method_name - or "logistic" in method_name - ): - # SVR is valid, SVC is not (usually) - raise ValueError( - f"Model '{name}' ({model_cfg.method}) appears to be a " - f"Classifier, but task is 'regression'." - ) - - def _prepare_estimator(self, model_name: str, model_config: Any) -> BaseEstimator: - """ - Orchestrate the creation of the full Estimator Pipeline. - - Steps - ----- - 1. **Instantiation**: Calls `_instantiate_model` to get the base estimator - (handling recursion). - 2. **Scaling**: If `use_scaler=True`, prepends a StandardScaler. - 3. **Feature Selection**: If enabled, prepends the FS step (Filter or Wrapper). - 4. **Pipeline Construction**: wraps steps in `sklearn.pipeline.Pipeline`. - - Enables caching if FS + Tuning are both active. - 5. **Tuning Wrapper**: If tuning is enabled for this model, wraps the Pipeline - in GridSearchCV/RandomizedSearchCV via `_wrap_with_tuning`. - - Parameters - ---------- - model_name : str - Friendly name from config (used for grid lookup). - model_config : Any - Pydantic configuration object for the model. - - Returns - ------- - BaseEstimator - Final ready-to-run estimator (Pipeline or SearchCV). - """ - # 1. Instantiate the Core Estimator - full_est = self._instantiate_model(model_name, model_config) - - # 2. Build Pipeline Steps - steps = [] - - # Scaling - if self.config.use_scaler: - steps.append(("scaler", StandardScaler())) - - # Feature Selection - if self.config.feature_selection.enabled: - fs_step = self._create_fs_step(full_est) - if fs_step: - steps.append(fs_step) - - # Final Estimator - steps.append(("clf", full_est)) - - # 3. Create Pipeline with Caching if needed - if ( - self.config.feature_selection.enabled - and self.config.tuning.enabled - and self.config.grids - ): - cachedir = mkdtemp() - atexit.register(lambda: rmtree(cachedir, ignore_errors=True)) - est = Pipeline(steps, memory=cachedir) - else: - est = Pipeline(steps) - - # 4. Wrap with Tuning if enabled - if ( - self.config.tuning.enabled - and self.config.grids - and model_name in self.config.grids - ): - est = self._wrap_with_tuning(est, model_name) - - return est - - def _instantiate_model(self, name: str, config: Any) -> BaseEstimator: - """ - Instantiate a raw estimator from its configuration object. - - Logic - ----- - 1. **Registry Lookup**: Resolves class from `config.method`. - 2. **Recursion**: If config implies a meta-estimator (has `base_estimator`), - recursively calls `_prepare_estimator` for the child. - 3. **Parameter Injection**: passed config fields as kwargs to `__init__`. - - Automatically filters out invalid parameters if `TypeError` occurs - (robustness for mismatched config/class versions). - - Returns - ------- - BaseEstimator - The instantiated model (e.g., LogisticRegression instance) without pipeline - wrappers. - """ - # 1. Get Class - est_cls = get_estimator_cls(config.method) - - # 2. Extract Params - params = config.model_dump(exclude={"method"}) - - # 3. Recursive Preparation (for Sliding/Generalizing internal 'base_estimator') - if "base_estimator" in params: - base_conf = params["base_estimator"] - params["base_estimator"] = self._prepare_estimator( - f"{name}_base", base_conf - ) - - # 4. Instantiate with Parameter Filtering - try: - return est_cls(**params) - except TypeError: - # Fallback: Filter invalid params (e.g. metadata fields in config) - valid_params = est_cls().get_params().keys() - filtered = {k: v for k, v in params.items() if k in valid_params} - dropped = set(params) - set(filtered) - if dropped: - logger.debug(f"[{name}] Dropping invalid params: {dropped}") - return est_cls(**filtered) - - def _create_fs_step(self, estimator: BaseEstimator) -> Optional[tuple]: - """ - Create a Feature Selection step for the pipeline. - - Logic - ----- - - **Filter (k_best)**: Fast. selected before training the classifier based on - statistical test. No inner CV loop required. - - **Wrapper (sfs)**: Slow but accurate. Wraps the estimator in a - SequentialFeatureSelector. This runs an **Inner CV Loop** - (size = config.feature_selection.cv) to validate feature subsets. - - If used inside Hyperparameter Tuning, this step is part of the Pipeline, - ensuring features are re-selected for every fold and every parameter - combination (Nested Simplification). - - Returns - ------- - tuple or None - ("fs", Transformer) step for sklearn Pipeline. - """ - fs_conf = self.config.feature_selection - - if fs_conf.method == "k_best": - score_func = ( - f_classif if self.config.task == "classification" else f_regression - ) - return ( - "fs", - SelectKBest(score_func=score_func, k=fs_conf.n_features or 10), - ) - - elif fs_conf.method == "sfs": - inner_cv = fs_conf.cv or 3 - return ( - "fs", - SequentialFeatureSelector( - estimator=clone(estimator), - n_features_to_select=fs_conf.n_features, - direction=fs_conf.direction, - cv=inner_cv, - n_jobs=self.config.n_jobs, - ), - ) - return None - - def _wrap_with_tuning(self, estimator: BaseEstimator, name: str) -> BaseEstimator: - """ - Wrap the estimator (or pipeline) in a Hyperparameter Search object. - - This implements **Nested Cross-Validation** (Middle Loop): - 1. **Input**: A Pipeline (Scaler + FS + Classifier). - 2. **Search**: Creates a GridSearchCV / RandomizedSearchCV. - 3. **Process**: - - For each fold of the *tuning* CV (defined by config.cv): - - Train the Pipeline (including FS!) on the tuning train set. - - Evaluate on the tuning validation set. - - Finds the best (Hyperparameters + Features) combination. - - Refits on the entire training set provided by the Outer Loop. - - This ensures simultaneous optimization of Preprocessing (FS) and Modeling - parameters. - """ - from sklearn.model_selection import GridSearchCV, RandomizedSearchCV - - grid = self.config.grids[name] - - new_grid = {} - for k, v in grid.items(): - if "__" in k: - new_grid[k] = v # trusted user input - else: - new_grid[f"clf__{k}"] = v - grid = new_grid - - cv_splitter = get_cv_splitter(self.config.cv) - # Note: We don't pass groups here; they are passed to fit() - - search_kwargs = { - "estimator": estimator, - "param_grid" - if self.config.tuning.search_type == "grid" - else "param_distributions": grid, - "cv": cv_splitter, - "scoring": self.config.tuning.scoring or self.config.metrics[0], - "n_jobs": self.config.tuning.n_jobs, - "refit": True, - } - - if self.config.tuning.search_type == "grid": - return GridSearchCV(**search_kwargs) - else: - return RandomizedSearchCV(n_iter=self.config.tuning.n_iter, **search_kwargs) - - def run( - self, - X: Union[pd.DataFrame, np.ndarray], - y: Union[pd.Series, np.ndarray], - groups: Optional[Union[pd.Series, np.ndarray]] = None, - ) -> "ExperimentResult": - """ - Execute the full experiment pipeline. - - This is the main entry point. It orchestrates: - 1. **Data Validation**: Checks input shapes and types. - 2. **Model Loop**: Iterates through all configured models. - 3. **Preparation**: Instantiates models -> Builds Pipelines (Scaler/FS) -> - Wraps in Tuning. - 4. **Validation**: Runs the Outer Cross-Validation loop (optionally - parallelized). - 5. **Aggregation**: Collects scores, predictions, and importances. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Training data (2D) or Time-Series data (3D). - y : array-like of shape (n_samples,) or (n_samples, n_targets) - Target labels or values. - groups : array-like of shape (n_samples,), optional - Group labels for splitting (e.g., subject-specific splits). - - Returns - ------- - ExperimentResult - Object containing results with methods to export to Tidy DataFrames. - """ - start_time = time.time() - logger.info(f"Starting Experiment: Task={self.config.task}") - - # 1. Validate Inputs - X = np.asarray(X) - y = np.asarray(y) - if len(X) == 0: - raise ValueError("Input X is empty.") - if len(y) != len(X): - raise ValueError( - f"Length mismatch: X has {len(X)} samples, y has {len(y)}." - ) - - if groups is not None: - groups = np.asarray(groups) - if len(groups) != len(X): - raise ValueError( - f"Length mismatch: groups has {len(groups)}, X has {len(X)}." - ) - - # 2. Check Task Consistency (Classification vs Regression) - target_type = type_of_target(y) - if self.config.task == "classification" and target_type == "continuous": - raise ValueError( - f"Task is 'classification' but target type is '{target_type}'. " - "Please check your labels or switch task to 'regression'." - ) - - # 3. Main Loop over Configured Models - for friendly_name, model_cfg in self.config.models.items(): - logger.info(f"Evaluating Model: {friendly_name} ({model_cfg.method})") - - try: - # A. Prepare (Instantiate + Scale + FS + Tune Wrapper) - estimator = self._prepare_estimator(friendly_name, model_cfg) - - # B. Execute Cross-Validation - # Note: Parallelism is handled inside _cross_validate if - # config.n_jobs > 1 - cv_results = self._cross_validate(estimator, X, y, groups) - - # C. Store Results - self.results[friendly_name] = cv_results - - except Exception as e: - logger.error( - f"Failed to evaluate model '{friendly_name}': {e}", exc_info=True - ) - self.results[friendly_name] = {"error": str(e), "status": "failed"} - - total_time = time.time() - start_time - logger.info(f"Experiment Completed in {total_time:.2f}s") - - return ExperimentResult(self.results) - - def save_results(self, path: Optional[Union[str, Path]] = None): - """ - Serialize results, configuration, and metadata to disk. - - Parameters - ---------- - path : str or Path, optional - Path to save the results. If None, uses config.output_dir. - If both are None, raises ValueError. - """ - if path is None: - path = self.config.output_dir - if path is None: - raise ValueError("No output path specified in config or arguments.") - - path = Path(path) - - # 1. Prepare Metadata - meta = { - "timestamp": datetime.now().isoformat(), - "tag": self.config.tag, - "coco_pipe_version": get_package_version("coco-pipe"), - } - - # 2. Bundle - payload = { - "config": self.config.model_dump(), - "results": self.results, - "meta": meta, - } - - # 3. Create Directory - # If path is a directory (no extension), append filename - if path.suffix == "": - path.mkdir(parents=True, exist_ok=True) - filename = ( - f"{self.config.tag}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl" - ) - target = path / filename - else: - path.parent.mkdir(parents=True, exist_ok=True) - target = path - - # 4. Save - logger.info(f"Saving results to {target}") - joblib.dump(payload, target) - return target - - @staticmethod - def load_results(path: Union[str, Path]) -> "ExperimentResult": - """ - Load a saved experiment payload and wrap it in ExperimentResult. - - Returns - ------- - ExperimentResult - The loaded results wrapper. - """ - path = Path(path) - if not path.exists(): - raise FileNotFoundError(f"Result file not found: {path}") - - payload = joblib.load(path) - # Handle backward compatibility or raw load - results = payload.get("results", payload) - return ExperimentResult(results) - - def _cross_validate( - self, - estimator: BaseEstimator, - X: np.ndarray, - y: np.ndarray, - groups: Optional[np.ndarray], - ) -> Dict[str, Any]: - """ - Execute the Outer Cross-Validation Loop (Evaluation). - - This is the **Level 1 (Top Level)** Splits: - - Splits the entire dataset into K folds (defined by config.cv). - - For each fold: - 1. **Training Data**: 80% (if 5-fold). Passed to `estimator.fit()`. - - If `estimator` is a GridSearch (Tuning Enabled), it will internally split - this 80% again for validation (Level 2 Split). - 2. **Test Data**: 20%. Used strictly for final `estimator.predict()` - evaluation. - - Parallelization - --------------- - If `config.n_jobs > 1`, these folds run in parallel processes to speed up - execution. - """ - cv = get_cv_splitter(self.config.cv, groups=groups) - - # Prepare CV iterator - splits = list(cv.split(X, y, groups)) - - # Parallel Execution - # We use n_jobs from config. - n_jobs_outer = self.config.n_jobs - - # OVERSUBSCRIPTION PROTECTION - # If outer loop is parallel, force inner estimators to run sequentially. - # Otherwise, we might spawn N_outer * N_inner threads, crashing the system. - parallel_estimator = clone(estimator) - if n_jobs_outer != 1: - parallel_estimator = self._force_serial_execution(parallel_estimator) - - parallel = joblib.Parallel(n_jobs=n_jobs_outer, verbose=self.config.verbose) - - results = parallel( - joblib.delayed(self._fit_and_score_fold)( - clone(parallel_estimator), X, y, train_idx, test_idx - ) - for train_idx, test_idx in splits - ) - - # Unpack Results - fold_scores = defaultdict(list) - fold_preds = [] - fold_indices = [] - fold_importances = [] - fold_metadata = [] - - for res in results: - fold_indices.append(res["test_idx"]) - fold_preds.append(res["preds"]) - fold_importances.append(res["importance"]) - fold_metadata.append(res.get("metadata", {})) - - for m, s in res["scores"].items(): - fold_scores[m].append(s) - - # Aggregate Metrics - metrics_summary = { - m: {"mean": np.nanmean(s), "std": np.nanstd(s), "folds": s} - for m, s in fold_scores.items() - } - - # Aggregate Importances - valid_imps = [f for f in fold_importances if f is not None] - aggregated_importances = None - if valid_imps: - try: - # Check consistency - first_shape = valid_imps[0].shape - if all(imp.shape == first_shape for imp in valid_imps): - stack = np.vstack(valid_imps) - aggregated_importances = { - "mean": np.mean(stack, axis=0), - "std": np.std(stack, axis=0), - "raw": stack, - } - except Exception: - pass - - return { - "metrics": metrics_summary, - "predictions": fold_preds, - "indices": fold_indices, - "importances": aggregated_importances, - "metadata": fold_metadata, - } - - def _fit_and_score_fold( - self, - estimator: BaseEstimator, - X: np.ndarray, - y: np.ndarray, - train_idx: np.ndarray, - test_idx: np.ndarray, - ) -> Dict[str, Any]: - """ - Execute a single Cross-Validation fold: Fit, Predict, and Score. - - Optimized for: - - **Standard Estimators**: (N, F) input -> (N,) output. - - **Sliding Estimators**: (N, F, T) input -> (N, T) output (Diagonal Decoding). - - Returns - ------- - dict - Contains 'test_idx', 'preds' (y_pred, y_true, y_proba), - 'scores' (dict of metric values), and 'importance'. - """ - X_train, X_test = X[train_idx], X[test_idx] - y_train, y_test = y[train_idx], y[test_idx] - - # 1. Fit - estimator.fit(X_train, y_train) - - # 2. Predict (Standard or Temporal) - y_pred = estimator.predict(X_test) - fold_data = {"y_true": y_test, "y_pred": y_pred} - - # 3. Predict Proba (if available and needed) - # Optimization: We always check/compute this if available, as 'roc_auc' - # is common. - if hasattr(estimator, "predict_proba"): - try: - fold_data["y_proba"] = estimator.predict_proba(X_test) - except Exception: - pass # Some estimators have the method but fail if not calibrated - # or supported correctly - - # 4. Extract Feature Importances - imp = None - try: - imp = self._extract_feature_importances(estimator) - except Exception: - pass - - # 5. Compute Metrics - scores = {} - is_multiclass = type_of_target(y_test) == "multiclass" - - for metric_name in self.config.metrics: - scorer = get_scorer(metric_name) - try: - # Determine if we should use Proba or Predictions - use_proba = ( - metric_name in ["roc_auc", "log_loss"] and "y_proba" in fold_data - ) - - if use_proba: - val = self._compute_metric_safe( - scorer, - y_test, - fold_data["y_proba"], - is_multiclass, - is_proba=True, - ) - else: - val = self._compute_metric_safe( - scorer, y_test, y_pred, is_multiclass, is_proba=False - ) - - scores[metric_name] = val - except Exception as e: - logger.warning(f"Metric '{metric_name}' failed in CV fold: {e}") - scores[metric_name] = np.nan - - # 6. Extract Metadata (Best Params, Selected Features) - meta = {} - try: - meta = self._extract_metadata(estimator) - except Exception as e: - logger.warning(f"Failed to extract metadata: {e}") - - return { - "test_idx": test_idx, - "preds": fold_data, - "scores": scores, - "importance": imp, - "metadata": meta, - } - - @staticmethod - def _extract_metadata(estimator: BaseEstimator) -> Dict[str, Any]: - """ - Extract training metadata like best Hyperparameters and Selected Features. - """ - meta = {} - - # 1. Best Params (from GridSearchCV/RandomizedSearchCV) - if hasattr(estimator, "best_params_"): - meta["best_params"] = estimator.best_params_ - # Unwrap best estimator for feature selection - search_best = estimator.best_estimator_ - else: - search_best = estimator - - # 2. Selected Features (from Pipeline step 'fs') - if isinstance(search_best, Pipeline): - fs_step = search_best.named_steps.get("fs") - if fs_step and hasattr(fs_step, "get_support"): - meta["selected_features"] = fs_step.get_support() - - return meta - - @staticmethod - def _compute_metric_safe(scorer, y_true, y_est, is_multiclass, is_proba=False): - """ - Compute metric handling standard and temporal (diagonal) shapes. - - Shapes Handled - -------------- - - **Standard**: y_est is (N,) or (N, C) - - **Generalizing (Matrix)**: - - y_pred: (N, T_train, T_test) -> Score each (T_train, T_test) pair. - - y_proba: (N, C, T_train, T_test) -> Score each (T_train, T_test) pair. - """ - # 1. Temporal / Sliding Case (Extra Dimension) - # Check for (N, T) predictions or (N, C, T) probabilities - is_temporal = (y_est.ndim == 2 and not is_proba and y_true.ndim == 1) or ( - y_est.ndim == 3 - ) - - if is_temporal: - # Case A: Binary/Regression Predictions (N, T) - if y_est.ndim == 2: - # Iterate over time (dim 1) - return np.array( - [scorer(y_true, y_est[:, t]) for t in range(y_est.shape[1])] - ) - - # Case B: Probabilities (N, C, T) or Generalizing (N, T_train, T_test) - if y_est.ndim == 3: - # Logic: - # - If input is NOT proba, (N, T, T) implies Generalizing Predictions. - # - If input IS proba, (N, C, T) implies Sliding Probabilities. - - if not is_proba: - # Generalizing Predictions (N, T_train, T_test) - n_train = y_est.shape[1] - n_test = y_est.shape[2] - matrix_scores = np.zeros((n_train, n_test)) - - for t_tr in range(n_train): - for t_te in range(n_test): - y_slice = y_est[:, t_tr, t_te] - matrix_scores[t_tr, t_te] = scorer(y_true, y_slice) - return matrix_scores - - # Sliding Probabilities (N, C, T) - n_times = y_est.shape[2] - scores = [] - for t in range(n_times): - slice_y = y_est[:, :, t] # (N, C) - - if not is_multiclass: - if slice_y.shape[1] == 2: - slice_y = slice_y[:, 1] - - kwargs = {"multi_class": "ovr"} if is_multiclass else {} - scores.append(scorer(y_true, slice_y, **kwargs)) - return np.array(scores) - - # Case C: GenEst Probabilities (N, C, T_train, T_test) -> 4D - if y_est.ndim == 4: - n_train = y_est.shape[2] - n_test = y_est.shape[3] - matrix_scores = np.zeros((n_train, n_test)) - - for t_tr in range(n_train): - for t_te in range(n_test): - slice_y = y_est[:, :, t_tr, t_te] # (N, C) - - if not is_multiclass: - if slice_y.shape[1] == 2: - slice_y = slice_y[:, 1] - - kwargs = {"multi_class": "ovr"} if is_multiclass else {} - matrix_scores[t_tr, t_te] = scorer(y_true, slice_y, **kwargs) - return matrix_scores - - # 2. Standard Case (N,) or (N, C) - kwargs = {} - if is_proba: - if is_multiclass: - kwargs = {"multi_class": "ovr"} - elif y_est.ndim == 2 and y_est.shape[1] == 2: - # Standard Binary Probabilities -> Take Positive Class - y_est = y_est[:, 1] - - return scorer(y_true, y_est, **kwargs) - - def _force_serial_execution(self, estimator: BaseEstimator) -> BaseEstimator: - """ - Recursively set n_jobs=1 for the estimator and its sub-components. - Used when the outer loop is already parallelized to avoid oversubscription. - """ - # 1. Get all parameters - params = estimator.get_params() - - # 2. Identify keys ending in 'n_jobs' - updates = {} - for key, value in params.items(): - if key.endswith("n_jobs") and value is not None and value != 1: - updates[key] = 1 - - # 3. Apply updates - if updates: - estimator.set_params(**updates) - - return estimator - - @staticmethod - def _extract_feature_importances(estimator: BaseEstimator) -> Optional[np.ndarray]: - """ - Extract feature importances or coefficients from a fitted estimator. - Handles Pipelines and Feature Selection. - """ - # 1. Unwrap Pipeline - if isinstance(estimator, Pipeline): - # Check for FS step - fs_step = estimator.named_steps.get("fs") - clf_step = estimator.named_steps.get("clf") - - # Get raw importances from classifier - raw_imp = Experiment._extract_feature_importances(clf_step) - if raw_imp is None: - return None - - # Map back if FS was used - if fs_step: - support = fs_step.get_support() # bool mask of selected features - # We need to reconstruct the full importance array with zeros (or NaNs) - # for unselected - full_imp = np.zeros_like(support, dtype=float) - full_imp[support] = raw_imp - return full_imp - - return raw_imp - - # 2. Extract from Base Estimator - if hasattr(estimator, "feature_importances_"): - return estimator.feature_importances_ - if hasattr(estimator, "coef_"): - # Handle multi-class coefs (n_classes, n_features) -> take magnitude/mean? - # For strict "importance", usually mean of abs(coefs) across classes - if estimator.coef_.ndim > 1: - return np.mean(np.abs(estimator.coef_), axis=0) - return np.abs(estimator.coef_) - - return None - - -class ExperimentResult: - """ - Unified Container for Experiment Results. - Provides Tidy Data views for easier analysis. - """ - - def __init__(self, raw_results: Dict[str, Any]): - self.raw = raw_results - - def summary(self) -> pd.DataFrame: - """ - Get a high-level summary of performance (Mean/Std across folds). - - Returns - ------- - pd.DataFrame - Index: Model Name - Columns: Metric Mean/Std - """ - rows = [] - for model, res in self.raw.items(): - if "error" in res: - continue - - row = {"Model": model} - for metric, stats in res["metrics"].items(): - row[f"{metric}_mean"] = stats["mean"] - row[f"{metric}_std"] = stats["std"] - rows.append(row) - - return pd.DataFrame(rows).set_index("Model") - - def get_detailed_scores(self) -> pd.DataFrame: - """ - Get fold-level scores for all models in long format. - - Returns - ------- - pd.DataFrame - Columns: Model, Fold, Metric, Value - """ - rows = [] - for model, res in self.raw.items(): - if "error" in res: - continue - - metrics_data = res["metrics"] - # Assume all metrics have same number of folds - n_folds = len(next(iter(metrics_data.values()))["folds"]) - - for fold_idx in range(n_folds): - for metric, stats in metrics_data.items(): - rows.append( - { - "Model": model, - "Fold": fold_idx, - "Metric": metric, - "Value": stats["folds"][fold_idx], - } - ) - return pd.DataFrame(rows) - - def get_predictions(self) -> pd.DataFrame: - """ - Get concatenated predictions for all models. - - Returns - ------- - pd.DataFrame - Columns: Model, Fold, y_true, y_pred, (y_proba if available) - """ - dfs = [] - for model, res in self.raw.items(): - if "error" in res: - continue - - for fold_idx, preds in enumerate(res["predictions"]): - # preds is dict: y_true, y_pred, y_proba - df = pd.DataFrame( - {"y_true": preds["y_true"], "y_pred": preds["y_pred"]} - ) - df["Model"] = model - df["Fold"] = fold_idx - - if "y_proba" in preds: - # Handle proba columns (might be multi-class) - proba = preds["y_proba"] - if proba.ndim == 1: - df["y_proba"] = proba - elif proba.ndim == 2: - for c in range(proba.shape[1]): - df[f"y_proba_{c}"] = proba[:, c] - - dfs.append(df) - - if not dfs: - return pd.DataFrame() - - return pd.concat(dfs, ignore_index=True) - - def get_best_params(self) -> pd.DataFrame: - """ - Get the best hyperparameters selected per fold (if Tuning was enabled). - - Returns - ------- - pd.DataFrame - Columns: Model, Fold, Param, Value - """ - rows = [] - for model_name, res in self.raw.items(): - if "error" in res: - continue - - # Check if metadata exists (handling backward compatibility) - if "metadata" in res: - for fold_idx, meta in enumerate(res["metadata"]): - if "best_params" in meta: - for p_name, p_val in meta["best_params"].items(): - rows.append( - { - "Model": model_name, - "Fold": fold_idx, - "Param": p_name, - "Value": p_val, - } - ) - - return pd.DataFrame(rows) - - def get_feature_stability(self) -> pd.DataFrame: - """ - Analyze feature selection stability across folds. - - Returns - ------- - pd.DataFrame - Index: Feature Index/Name - Columns: Selection Frequency (0.0 - 1.0) - """ - rows = [] - for model_name, res in self.raw.items(): - if "error" in res: - continue - - if "metadata" in res: - # Collect masks - masks = [] - for meta in res["metadata"]: - if "selected_features" in meta: - masks.append(meta["selected_features"]) - - if masks: - # Stack: (n_folds, n_features) - stack = np.vstack(masks) - stability = np.mean(stack, axis=0) # 0 to 1 - - for feat_idx, freq in enumerate(stability): - rows.append( - { - "Model": model_name, - "Feature": feat_idx, - "Frequency": freq, - } - ) - - if not rows: - return pd.DataFrame() - - return pd.DataFrame(rows) - - def get_generalization_matrix(self, metric: str = None) -> pd.DataFrame: - """ - Get Generalization Matrix (Train Time x Test Time) averaged across folds. - - Parameters - ---------- - metric : str, optional - The metric to retrieve (e.g., 'accuracy', 'roc_auc'). - Defaults to the first metric found in results. - - Returns - ------- - pd.DataFrame - Index: Train Time - Columns: Test Time - Values: Average Score - """ - # 1. Collect all matrices for the metric - for model_name, res in self.raw.items(): - if "error" in res: - continue - - metrics_data = res["metrics"] - if metric is None: - metric = list(metrics_data.keys())[0] - - if metric not in metrics_data: - continue - - fold_scores = metrics_data[metric]["folds"] - # Check if scores are matrices (2D arrays) - valid_matrices = [ - s for s in fold_scores if isinstance(s, np.ndarray) and s.ndim == 2 - ] - - if valid_matrices: - # Stack and Mean -> (n_folds, n_train, n_test) -> (n_train, n_test) - stack = np.stack(valid_matrices) - mean_matrix = np.mean(stack, axis=0) - return pd.DataFrame(mean_matrix) - - return pd.DataFrame() diff --git a/coco_pipe/decoding/experiment.py b/coco_pipe/decoding/experiment.py new file mode 100644 index 0000000..0e16912 --- /dev/null +++ b/coco_pipe/decoding/experiment.py @@ -0,0 +1,871 @@ +""" +Decoding Experiment +=================== +Main executor for decoding experiments. +""" + +import atexit +import logging +import time +from collections import defaultdict +from shutil import rmtree +from tempfile import mkdtemp +from typing import Any, Dict, Optional, Sequence, Union + +import joblib +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, clone +from sklearn.feature_selection import ( + SelectKBest, + f_classif, + f_regression, +) +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.multiclass import type_of_target + +from ..report.provenance import get_environment_info +from ._constants import CLASSICAL_FAMILIES, GROUP_CV_STRATEGIES, RESULT_SCHEMA_VERSION +from ._engine import GroupedSequentialFeatureSelector, fit_and_score_fold +from ._metrics import get_metric_spec +from ._splitters import get_cv_splitter +from .configs import ExperimentConfig +from .registry import ( + _get_val, + get_estimator_cls, + get_selector_capabilities, + resolve_estimator_spec, +) +from .result import ExperimentResult + +logger = logging.getLogger(__name__) + + +class Experiment: + """ + Main executor for decoding experiments. + + Parameters + ---------- + config : ExperimentConfig + The complete configuration for the experiment. + + Examples + -------- + >>> from coco_pipe.decoding import Experiment, ExperimentConfig + >>> config = ExperimentConfig( + ... task="classification", models={"lr": {"kind": "classical"}} + ... ) + >>> exp = Experiment(config) + + See Also + -------- + ExperimentResult : Container for experiment outputs. + get_cv_splitter : Factory for cross-validation splitters. + """ + + def __init__(self, config: ExperimentConfig): + self.config = config + self.results: Dict[str, Any] = {} + self.result_: Optional[ExperimentResult] = None + self._model_specs: Dict[str, Any] = {} + self._model_capabilities: Dict[str, Any] = {} + self._propagate_random_state() + self._validate_config() + + def _validate_config(self): + """Perform comprehensive runtime validation of the configuration.""" + task = self.config.task + if self.config.tuning.enabled and not self.config.grids: + logger.warning( + "Hyperparameter tuning is enabled but no 'grids' are defined " + "in the config." + ) + + if self.config.calibration.enabled and task != "classification": + raise ValueError("calibration is only available for classification.") + + if task == "regression" and "stratified" in self.config.cv.strategy: + raise ValueError( + f"CV strategy '{self.config.cv.strategy}' is invalid " + "for regression tasks." + ) + + # 1. Inner CV Grouping Validation (Leakage Guard) + if self.config.cv.strategy in GROUP_CV_STRATEGIES: + targets = [] + if self.config.tuning.enabled: + targets.append( + ( + "tuning.cv", + self.config.tuning.cv, + self.config.tuning.allow_nongroup_inner_cv, + ) + ) + fs = self.config.feature_selection + if fs.enabled and fs.method == "sfs": + targets.append( + ("feature_selection.cv", fs.cv, fs.allow_nongroup_inner_cv) + ) + cal = self.config.calibration + if cal.enabled: + targets.append(("calibration.cv", cal.cv, cal.allow_nongroup_inner_cv)) + + for name, cv_cfg, allowed in targets: + if ( + cv_cfg + and cv_cfg.strategy not in GROUP_CV_STRATEGIES + and not allowed + ): + raise ValueError( + f"Outer CV strategy is group-based, but {name} strategy " + f"'{cv_cfg.strategy}' is not. This leads to data leakage. " + "Set allow_nongroup_inner_cv=True if this is intentional." + ) + + # Validate FS scoring if explicitly set + fs_scoring = self.config.feature_selection.scoring + if fs_scoring and get_metric_spec(fs_scoring).task != task: + raise ValueError( + f"Feature selection scoring '{fs_scoring}' is incompatible with " + f"task '{task}'." + ) + + for name, model_cfg in self.config.models.items(): + spec = resolve_estimator_spec(model_cfg) + caps = spec.to_capabilities() + self._model_specs[name] = spec + self._model_capabilities[name] = caps + if not caps.supports_task(task): + raise ValueError(f"Model '{name}' does not support task '{task}'.") + + # 2. Metric Compatibility Check + has_proba = ( + caps.has_response("predict_proba") or self.config.calibration.enabled + ) + has_score = caps.has_response("decision_function") + for metric in self.config.get_all_evaluation_metrics(): + m_spec = get_metric_spec(metric) + if m_spec.task != task: + raise ValueError( + f"Metric '{metric}' is for {m_spec.task} but experiment task " + f"is {task}." + ) + if m_spec.response_method == "proba" and not has_proba: + raise ValueError( + f"Metric '{metric}' requires probabilities, but model " + f"'{name}' doesn't provide them (and calibration is disabled)." + ) + if m_spec.response_method == "proba_or_score" and not ( + has_proba or has_score + ): + raise ValueError( + f"Metric '{metric}' requires probabilities or decision scores, " + f"but model '{name}' provides neither." + ) + + def _prepare_estimator(self, model_name: str, model_config: Any) -> BaseEstimator: + """Orchestrate the creation of the full Estimator Pipeline.""" + + full_est = self._instantiate_model(model_name, model_config) + spec = self._model_specs.get(model_name) or resolve_estimator_spec(model_config) + steps = [] + + # Classical models on tabular/embedding data support standard preprocessing + allow_prep = spec.family in CLASSICAL_FAMILIES and any( + k in {"tabular_2d", "embedding_2d", "tabular", "embeddings"} + for k in spec.input_kinds + ) + + if self.config.use_scaler and allow_prep: + steps.append(("scaler", StandardScaler())) + + if self.config.feature_selection.enabled and allow_prep: + fs_step = self._create_fs_step(full_est) + if fs_step: + steps.append(fs_step) + elif self.config.feature_selection.enabled and not allow_prep: + raise ValueError( + f"Feature selection is only valid for classical 2D tabular " + f"inputs. Model '{model_name}' uses {spec.input_kinds} data." + ) + + steps.append(("clf", full_est)) + + if ( + self.config.feature_selection.enabled + and self.config.tuning.enabled + and self.config.grids + ): + cachedir = mkdtemp() + atexit.register(lambda: rmtree(cachedir, ignore_errors=True)) + est = Pipeline(steps, memory=cachedir) + else: + est = Pipeline(steps) + + if ( + self.config.tuning.enabled + and self.config.grids + and model_name in self.config.grids + ): + est = self._wrap_with_tuning(model_name, est) + + if self.config.calibration.enabled: + from sklearn.calibration import CalibratedClassifierCV + + cal_cv = get_cv_splitter(self.config.calibration.cv, require_groups=False) + est = CalibratedClassifierCV( + estimator=est, + method=self.config.calibration.method, + cv=cal_cv, + n_jobs=self.config.calibration.n_jobs or self.config.n_jobs, + ) + return est + + def _propagate_random_state(self): + """Ensure the global random state is distributed to all config sub-objects.""" + seed = self.config.random_state + if seed is None: + return + + self._inject_seed(self.config.cv, seed) + self._inject_seed(self.config.feature_selection, seed + 1) + self._inject_seed(self.config.tuning, seed + 2) + self._inject_seed(self.config.calibration, seed + 3) + + # 2. Model seeds + model_names = sorted(self.config.models.keys()) + from numpy.random import SeedSequence + + ss = SeedSequence(seed + 4) + model_seeds = ss.spawn(len(model_names)) + for name, m_ss in zip(model_names, model_seeds): + self._inject_seed(self.config.models[name], int(m_ss.generate_state(1)[0])) + + def _inject_seed(self, cfg: Any, seed: int): + """Safely inject a random seed into a config object if it supports it.""" + if hasattr(cfg, "random_state"): + cfg.random_state = seed + if hasattr(cfg, "cv") and cfg.cv and hasattr(cfg.cv, "random_state"): + cfg.cv.random_state = seed + + # Classical parameters dictionary + if getattr(cfg, "kind", None) == "classical" and hasattr(cfg, "params"): + if resolve_estimator_spec(cfg).supports_random_state: + cfg.params["random_state"] = seed + + # Recursion into sub-components (backbone, head, base) + for attr in ("backbone", "head", "base"): + if hasattr(cfg, attr): + self._inject_seed(getattr(cfg, attr), seed) + + def _instantiate_model(self, model_name: str, config: Any) -> BaseEstimator: + """Create a concrete scikit-learn estimator instance from a model config.""" + # 1. Use the pre-resolved spec for explicit dispatch + spec = self._model_specs.get(model_name) or resolve_estimator_spec(config) + + if spec.family in CLASSICAL_FAMILIES: + est_cls = get_estimator_cls(spec.name) + if hasattr(config, "params"): + params = config.params + elif isinstance(config, dict): + params = config.get("params", {}) + else: + params = config.model_dump(exclude={"method", "kind"}) + try: + return est_cls(**params) + except Exception as e: + raise ValueError( + f"Failed to instantiate model '{model_name}': {e}" + ) from e + + if spec.family == "foundation": + from .fm_hub import build_foundation_model + + return build_foundation_model(config) + + if spec.family == "temporal": + # wrapper is 'sliding' or 'generalizing' + wrapper = _get_val(config, "wrapper") + method = ( + "SlidingEstimator" if wrapper == "sliding" else "GeneralizingEstimator" + ) + est_cls = get_estimator_cls(method) + + if hasattr(config, "model_dump"): + params = config.model_dump(exclude={"kind", "wrapper", "base"}) + elif isinstance(config, dict): + params = { + k: v + for k, v in config.items() + if k not in {"kind", "wrapper", "base"} + } + else: + params = {} + + params["base_estimator"] = self._prepare_estimator( + f"{model_name}_base", _get_val(config, "base") + ) + try: + return est_cls(**params) + except Exception as e: + raise ValueError( + f"Failed to instantiate model '{model_name}': {e}" + ) from e + + # Fallback for other registry-based estimators + method = _get_val(config, "method") + est_cls = get_estimator_cls(method) + if hasattr(config, "model_dump"): + params = config.model_dump(exclude={"method", "kind"}) + elif isinstance(config, dict): + params = {k: v for k, v in config.items() if k not in {"method", "kind"}} + else: + params = {} + + if "base_estimator" in params: + params["base_estimator"] = self._prepare_estimator( + f"{model_name}_base", _get_val(config, "base") + ) + return est_cls(**params) + + def _create_fs_step(self, estimator: BaseEstimator) -> Optional[tuple]: + """Create a feature selection step compatible with the chosen model.""" + fs_conf = self.config.feature_selection + if fs_conf.method == "k_best": + score_func = ( + f_classif if self.config.task == "classification" else f_regression + ) + return ( + "fs", + SelectKBest(score_func=score_func, k=fs_conf.n_features or "all"), + ) + if fs_conf.method == "sfs": + cv = get_cv_splitter(fs_conf.cv, require_groups=False) + scoring = ( + fs_conf.scoring or self.config.tuning.scoring or self.config.metrics[0] + ) + sfs = GroupedSequentialFeatureSelector( + estimator=clone(estimator), + n_features_to_select=fs_conf.n_features, + direction=fs_conf.direction, + cv=cv, + scoring=scoring, + n_jobs=self.config.n_jobs, + ) + return ("fs", sfs) + return None + + def _wrap_with_tuning(self, name: str, estimator: BaseEstimator) -> BaseEstimator: + """Wrap an estimator with hyperparameter search (GridSearch/RandomSearch).""" + from sklearn.model_selection import GridSearchCV, RandomizedSearchCV + + grid = self.config.grids[name] + mapped = {k if "__" in k else f"clf__{k}": v for k, v in grid.items()} + valid = estimator.get_params(deep=True) + invals = [k for k in mapped if k not in valid] + if invals: + raise ValueError(f"Invalid tuning keys for '{name}': {invals}") + cv = get_cv_splitter(self.config.tuning.cv, require_groups=False) + kwargs = { + "estimator": estimator, + "cv": cv, + "scoring": self.config.tuning.scoring or self.config.metrics[0], + "n_jobs": self.config.tuning.n_jobs, + "refit": True, + } + if self.config.tuning.search_type == "grid": + return GridSearchCV(param_grid=mapped, **kwargs) + return RandomizedSearchCV( + param_distributions=mapped, + n_iter=self.config.tuning.n_iter, + random_state=self.config.tuning.random_state, + **kwargs, + ) + + def run( + self, + X: np.ndarray, + y: Union[pd.Series, np.ndarray], + groups: Optional[Union[pd.Series, np.ndarray]] = None, + feature_names: Optional[Sequence[str]] = None, + sample_ids: Optional[Sequence[Any]] = None, + sample_metadata: Optional[Union[pd.DataFrame, Dict[str, Sequence[Any]]]] = None, + observation_level: str = "sample", + inferential_unit: Optional[str] = None, + time_axis: Optional[Sequence[Any]] = None, + ) -> ExperimentResult: + """ + Execute the complete decoding experiment pipeline. + + This method orchestrates the full scientific workflow: + 1. Resolves and validates data hierarchy (metadata, groups, sample IDs). + 2. Aligns temporal dimensions and feature names. + 3. Performs model-by-model evaluation using stratified/grouped cross-validation. + 4. Aggregates results into a unified ExperimentResult object. + 5. Performs statistical assessment (permutations/bootstrapping) if enabled. + + Parameters + ---------- + X : np.ndarray + The input data. Can be 2D (samples x features) or 3D temporal + (samples x sensors x time). + y : Union[pd.Series, np.ndarray] + The target labels or values. + groups : Union[pd.Series, np.ndarray], optional + Grouping labels for grouped cross-validation (e.g., subject IDs). + feature_names : Sequence[str], optional + Human-readable names for features. Auto-generated if None. + sample_ids : Sequence[Any], optional + Unique identifiers for each sample. Auto-generated if None. + sample_metadata : Union[pd.DataFrame, Dict], optional + Additional scientific context for each sample (BIDS-like). + observation_level : str, default='sample' + Level of the input rows ('sample' or 'epoch'). + inferential_unit : str, optional + The level of statistical independence ('sample' or 'subject'). + time_axis : Sequence[Any], optional + Scientific time points for 3D temporal data. + + Returns + ------- + ExperimentResult + A container holding all raw results, metrics, and diagnostics. + + Raises + ------ + ValueError + If input lengths mismatch, data is empty, or configuration + is scientifically invalid (e.g. regression with stratification). + + Examples + -------- + >>> from coco_pipe.decoding import Experiment, ExperimentConfig + >>> config = ExperimentConfig( + ... task="classification", models={"lr": {"kind": "classical"}} + ... ) + >>> result = Experiment(config).run(X, y) + + See Also + -------- + ExperimentResult.get_predictions : Tidy prediction accessor. + """ + start_time = time.time() + logger.info(f"Starting Experiment: Task={self.config.task}") + X, y = np.asarray(X), np.asarray(y) + if len(X) == 0: + raise ValueError("X is empty.") + if len(y) != len(X): + raise ValueError("Length mismatch between X and y.") + + # 1. Scientific Guard: Double-Normalization Warning + if self.config.use_scaler and X.ndim == 2: + # Simple heuristic: if means are near 0 and stds are near 1, warn. + means = np.nanmean(X, axis=0) + stds = np.nanstd(X, axis=0) + if np.all(np.abs(means) < 1e-3) and np.all(np.abs(stds - 1.0) < 1e-2): + logger.warning( + "Input data X appears to be already normalized " + "(means ~0, stds ~1). Enabling 'use_scaler' will " + "result in redundant double-normalization." + ) + + self._feature_names = self._resolve_feature_names(X, feature_names) + self._sample_ids = self._resolve_sample_ids(len(X), sample_ids) + if observation_level not in {"sample", "epoch"}: + raise ValueError("observation_level must be 'sample' or 'epoch'.") + self._observation_level = observation_level + self._sample_metadata, groups = self._resolve_metadata_and_groups( + len(X), sample_metadata, groups + ) + + # 3. Resolve Inferential Unit (Level of statistical independence) + if inferential_unit is not None: + self._inferential_unit = inferential_unit + elif ( + observation_level == "epoch" and "Subject" in self._sample_metadata.columns + ): + self._inferential_unit = "subject" + else: + self._inferential_unit = "sample" + + # 4. Resolve Time Axis + if X.ndim == 3: + if time_axis is None: + self._time_axis = np.arange(X.shape[-1]) + else: + self._time_axis = np.asarray(time_axis) + if len(self._time_axis) != X.shape[-1]: + raise ValueError( + f"time_axis length mismatch: {len(self._time_axis)} vs " + f"{X.shape[-1]}" + ) + else: + self._time_axis = np.asarray(time_axis) if time_axis is not None else None + + # 5. Input Rank Capability Guard + rank = "3d_temporal" if X.ndim == 3 else "2d" + + # 6. Group Validation: Early check before entering model loop + if groups is None: + from ._constants import GROUP_CV_STRATEGIES + + cv_configs = [self.config.cv] + if self.config.tuning.enabled and self.config.tuning.cv: + cv_configs.append(self.config.tuning.cv) + if ( + self.config.feature_selection.enabled + and self.config.feature_selection.cv + ): + cv_configs.append(self.config.feature_selection.cv) + if self.config.calibration.enabled and self.config.calibration.cv: + cv_configs.append(self.config.calibration.cv) + + for cv_conf in cv_configs: + if cv_conf.strategy in GROUP_CV_STRATEGIES: + raise ValueError( + f"Strategy '{cv_conf.strategy}' requires groups, but none " + "were provided." + ) + + for name, caps in self._model_capabilities.items(): + if rank not in caps.input_ranks: + raise ValueError(f"Model '{name}' doesn't support rank '{rank}'.") + if self.config.feature_selection.enabled: + sel = get_selector_capabilities(self.config.feature_selection.method) + if rank not in sel.input_ranks: + raise ValueError( + f"FS method '{sel.method}' doesn't support rank '{rank}'." + ) + if self.config.task == "classification" and type_of_target(y) == "continuous": + raise ValueError("Task is 'classification' but target is 'continuous'.") + + for name, cfg in self.config.models.items(): + label = getattr(cfg, "method", getattr(cfg, "kind", "Unknown")) + logger.info(f"Evaluating Model: {name} ({label})") + try: + # 1. Resolve Spec & Capabilities + from .registry import resolve_estimator_spec + + spec = resolve_estimator_spec(cfg) + + # 2. Parallelism Safety + is_fm = spec.family in {"foundation", "neural"} + model_n_jobs = 1 if is_fm else self.config.n_jobs + + est = self._prepare_estimator(name, cfg) + self.results[name] = self._cross_validate( + est, + X, + y, + groups, + self._sample_ids, + self._sample_metadata, + n_jobs=model_n_jobs, + spec=spec, + model_name=name, + ) + except Exception as e: + logger.error(f"Failed model '{name}': {e}", exc_info=True) + self.results[name] = {"error": str(e), "status": "failed"} + + logger.info(f"Experiment Completed in {time.time() - start_time:.2f}s") + res_obj = ExperimentResult( + self.results, + config=self.config.model_dump(), + meta=self._build_result_meta(X, self._time_axis), + ) + + if self.config.evaluation.enabled: + from .stats import run_statistical_assessment + + assessment = run_statistical_assessment( + res_obj, + self.config, + X, + y, + groups, + self._sample_ids, + self._sample_metadata, + self._feature_names, + self._time_axis, + self._observation_level, + self._inferential_unit, + ) + res_obj.meta["statistical_assessment"] = assessment["meta"] + for m_name, m_res in res_obj.raw.items(): + if "error" in m_res: + continue + m_res["statistical_assessment"] = [ + r for r in assessment["rows"] if r.get("Model") == m_name + ] + if m_name in assessment["nulls"]: + m_res["statistical_nulls"] = assessment["nulls"][m_name] + return res_obj + + def _cross_validate( + self, + estimator: BaseEstimator, + X: np.ndarray, + y: np.ndarray, + groups: Optional[np.ndarray], + sample_ids: np.ndarray, + sample_metadata: pd.DataFrame, + n_jobs: int = 1, + spec: Optional[Any] = None, + model_name: Optional[str] = None, + ) -> Dict[str, Any]: + """Perform parallel cross-validation for a single estimator.""" + cv = get_cv_splitter(self.config.cv, groups=groups, y=y) + splits = list(cv.split(X, y, groups)) + + for train_idx, test_idx in splits: + _validate_fold_integrity(y[train_idx], y[test_idx], spec.task) + + est = clone(estimator) + + meta_dict = None + if sample_metadata is not None: + meta_dict = { + col: sample_metadata[col].values for col in sample_metadata.columns + } + + # 4. Execute Parallel CV + parallel = joblib.Parallel(n_jobs=n_jobs, verbose=self.config.verbose) + results = parallel( + joblib.delayed(fit_and_score_fold)( + clone(est), + X, + y, + groups, + sample_ids, + meta_dict, + train_idx, + test_idx, + metrics=self.config.get_all_evaluation_metrics(), + feature_selection_config=self.config.feature_selection, + calibration_config=self.config.calibration, + spec=spec, + tuning_config=self.config.tuning, + feature_names=self._feature_names, + search_enabled=( + self.config.tuning.enabled + and self.config.grids is not None + and model_name in self.config.grids + ), + force_serial=(n_jobs == 1), + ) + for train_idx, test_idx in splits + ) + + fold_scores = defaultdict(list) + f_idx, f_preds, f_imps, f_meta, f_splits, f_diags = [], [], [], [], [], [] + for res in results: + f_idx.append(res["test_idx"]) + f_preds.append(res["preds"]) + f_imps.append(res["importance"]) + f_meta.append(res.get("metadata", {})) + f_splits.append(res["split"]) + f_diags.append(res.get("diagnostics", {})) + for m, s in res["scores"].items(): + fold_scores[m].append(s) + + metrics = {} + for m, s in fold_scores.items(): + if np.isnan(s).any(): + logger.warning( + f"NaN score detected in one or more folds for metric '{m}'. " + "This usually indicates a model failure or degenerate test fold." + ) + metrics[m] = {"mean": np.nanmean(s), "std": np.nanstd(s), "folds": s} + valid_imps = [f for f in f_imps if f is not None] + agg_imp = None + if valid_imps and all(imp.shape == valid_imps[0].shape for imp in valid_imps): + stack = np.vstack(valid_imps) + agg_imp = { + "mean": np.mean(stack, axis=0), + "std": np.std(stack, axis=0), + "raw": stack, + "feature_names": self._feature_names + if len(self._feature_names) == stack.shape[1] + else [f"feature_{idx}" for idx in range(stack.shape[1])], + } + + return { + "status": "success", + "metrics": metrics, + "predictions": f_preds, + "indices": f_idx, + "importances": agg_imp, + "metadata": f_meta, + "splits": f_splits, + "diagnostics": f_diags, + } + + @staticmethod + def _resolve_sample_ids(n: int, ids: Optional[Sequence[Any]]) -> np.ndarray: + """Ensure sample IDs are provided and have correct length.""" + if ids is None: + return np.arange(n) + ids = np.asarray(ids) + if len(ids) != n: + raise ValueError(f"sample_ids length mismatch: {len(ids)} vs {n}") + if len(pd.unique(ids)) != n: + raise ValueError("sample_ids must be unique.") + return ids + + def _resolve_metadata_and_groups( + self, + n: int, + meta_in: Optional[Union[pd.DataFrame, Dict[str, Sequence[Any]]]], + groups_in: Optional[np.ndarray], + ) -> tuple[pd.DataFrame, Optional[np.ndarray]]: + """Validate metadata and extract cross-validation groups if required.""" + # 1. Standardize Metadata to DataFrame + if meta_in is None: + meta = pd.DataFrame(index=range(n)) + else: + meta = pd.DataFrame(meta_in).reset_index(drop=True) + meta.columns = [str(c).capitalize() for c in meta.columns] + if len(meta) != n: + raise ValueError(f"sample_metadata length mismatch: {len(meta)} vs {n}") + + # 2. Scientific Guard: Metadata Requirements + # We must track subject and session to ensure independent validation + # and prevent pseudoreplication, especially for epoch-level data. + if meta_in is not None: + missing = [c for c in ["Subject", "Session"] if c not in meta.columns] + if missing: + raise ValueError( + f"sample_metadata must include Subject and Session for " + f"proper independence tracking. Missing: {missing}" + ) + + # 2. Resolve Groups + gv = None + key = self.config.cv.group_key + has_grouped_cv = ( + self.config.cv.strategy in GROUP_CV_STRATEGIES + or ( + self.config.tuning.enabled + and self.config.tuning.cv.strategy in GROUP_CV_STRATEGIES + ) + or ( + self.config.calibration.enabled + and self.config.calibration.cv.strategy in GROUP_CV_STRATEGIES + ) + ) + + # Case A: Explicit groups array provided + if groups_in is not None: + gv = np.asarray(groups_in) + if len(gv) != n: + raise ValueError(f"groups length mismatch: {len(gv)} vs {n}") + if key is not None: + meta[key] = gv + + # Case B: Extract from metadata using group_key + elif key is not None: + if key not in meta.columns: + raise ValueError(f"group_key '{key}' not found in sample_metadata.") + gv = meta[key].to_numpy() + + # Validation: Grouped strategies with only 1 group will fail in sklearn + if has_grouped_cv: + if gv is None: + raise ValueError( + f"CV strategy '{self.config.cv.strategy}' requires groups " + "via 'groups' parameter or 'group_key' in config." + ) + unique_groups = len(pd.unique(gv)) + if unique_groups < 2: + raise ValueError( + f"Grouped CV requires at least 2 unique groups, but found " + f"only {unique_groups}. Check your sample_metadata or " + "groups array." + ) + return meta, gv + + def _build_result_meta( + self, X: np.ndarray, t_axis: Optional[np.ndarray] + ) -> Dict[str, Any]: + meta = get_environment_info() + meta.update( + { + "tag": self.config.tag, + "task": self.config.task, + "n_samples": int(X.shape[0]), + "n_features": int(X.shape[1]) if X.ndim > 1 else 1, + "observation_level": self._observation_level, + "inferential_unit": self._inferential_unit, + "sample_metadata_columns": self._sample_metadata.columns.tolist(), + "run_manifest": { + "schema_version": RESULT_SCHEMA_VERSION, + "model_names": list(self.config.models), + "cv_strategy": self.config.cv.strategy, + "metrics": self.config.get_all_evaluation_metrics(), + }, + "hardware_provenance": {"n_jobs": self.config.n_jobs}, + "capabilities": self._capability_payload(), + } + ) + if t_axis is not None: + meta["time_axis"] = t_axis.tolist() + return meta + + def _capability_payload(self) -> Dict[str, Any]: + sels = {} + if self.config.feature_selection.enabled: + sels[self.config.feature_selection.method] = get_selector_capabilities( + self.config.feature_selection.method + ).to_dict() + return { + "models": {n: c.to_dict() for n, c in self._model_capabilities.items()}, + "estimator_specs": {n: s.to_dict() for n, s in self._model_specs.items()}, + "feature_selectors": sels, + "metrics": { + m: { + "task": get_metric_spec(m).task, + "response_method": get_metric_spec(m).response_method, + "family": get_metric_spec(m).family, + } + for m in self.config.get_all_evaluation_metrics() + }, + } + + def _resolve_feature_names( + self, X: np.ndarray, names: Optional[Sequence[str]] + ) -> list[str]: + exp = 1 if X.ndim < 2 else X.shape[1] + if names is not None: + if len(names) != exp: + raise ValueError( + f"feature_names length mismatch: {len(names)} vs {exp}" + ) + return [str(n) for n in names] + return ( + ["feature_0"] + if X.ndim < 2 + else [f"feature_{idx}" for idx in range(X.shape[1])] + ) + + +def _validate_fold_integrity( + y_train: np.ndarray, y_test: np.ndarray, tasks: tuple +) -> None: + """Check if CV folds are degenerate before fitting.""" + if y_train.size == 0 or y_test.size == 0: + raise ValueError("Empty fold detected.") + + if np.min(y_train) == np.max(y_train): + raise ValueError( + f"Degenerate Train Fold: Only one value found ({y_train[0]}). " + "Scoring metrics are undefined for constant targets." + ) + + if "classification" in tasks and np.min(y_test) == np.max(y_test): + raise ValueError( + f"Degenerate Test Fold: Only one class found ({y_test[0]}). " + "Metrics like ROC-AUC are undefined for single-class test sets." + ) diff --git a/coco_pipe/decoding/fm_hub/__init__.py b/coco_pipe/decoding/fm_hub/__init__.py new file mode 100644 index 0000000..341c219 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/__init__.py @@ -0,0 +1,18 @@ +""" +Foundation Model Hub (fm_hub) +============================ +Unified access to foundation models for both extraction and fine-tuning. +""" + +from ._factory import build_foundation_model +from .base import BaseFoundationModel, EmbeddingInfo +from .cbramod import CBraModModel +from .reve import REVEModel + +__all__ = [ + "BaseFoundationModel", + "EmbeddingInfo", + "build_foundation_model", + "REVEModel", + "CBraModModel", +] diff --git a/coco_pipe/decoding/fm_hub/_factory.py b/coco_pipe/decoding/fm_hub/_factory.py new file mode 100644 index 0000000..68d9ae3 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/_factory.py @@ -0,0 +1,57 @@ +""" +Foundation Model Factory +======================== +Handles instantiation of foundation models with lazy loading of providers. +""" + +from typing import Any + +_PROVIDER_MAP = { + "reve": (".reve", "REVEModel"), + "cbramod": (".cbramod", "CBraModModel"), + "custom": (".custom", "CustomNeuralModel"), +} + + +def build_foundation_model(config: Any) -> Any: + """ + Instantiate a foundation model from a config object. + + This function uses lazy loading to avoid importing heavy dependencies + (torch, transformers) until a model is actually requested. + """ + provider = getattr(config, "provider", None) + if provider not in _PROVIDER_MAP: + supported = list(_PROVIDER_MAP.keys()) + raise ValueError( + f"Unknown foundation model provider '{provider}'. " + f"Supported providers: {supported}" + ) + + # Lazy import of the provider class + from importlib import import_module + + module_path, class_name = _PROVIDER_MAP[provider] + + # Resolve relative import + module = import_module(module_path, package="coco_pipe.decoding.fm_hub") + model_cls = getattr(module, class_name) + + # Extract common parameters from config + params = { + "model_name": getattr(config, "model_name", None), + "electrode_names": getattr(config, "electrode_names", None), + "sfreq": getattr(config, "sfreq", 200.0), + "train_mode": getattr(config, "train_mode", "frozen"), + "pooling": getattr(config, "pooling", "mean"), + "device": getattr(config, "device", None), + "token": getattr(config, "token", None), + "task": getattr(config, "task", "classification"), + "patch_size": getattr(config, "patch_size", 200), + "seq_len": getattr(config, "seq_len", 4), + } + + # Filter out None values to allow class defaults to take over + params = {k: v for k, v in params.items() if v is not None} + + return model_cls(**params) diff --git a/coco_pipe/decoding/fm_hub/base.py b/coco_pipe/decoding/fm_hub/base.py new file mode 100644 index 0000000..79eb165 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/base.py @@ -0,0 +1,231 @@ +""" +Base classes for the Foundation Model Hub (fm_hub). +================================================= +Unified API for foundation models using skorch for robust Scikit-Learn +compatibility and training orchestration. +""" + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import numpy as np +import torch +import torch.nn as nn +from sklearn.base import BaseEstimator, TransformerMixin + + +@dataclass +class EmbeddingInfo: + """ + Metadata about the embeddings produced by a foundation model. + + Parameters + ---------- + n_embeddings : int + The dimensionality of the embedding vector. + embedding_name : str + A descriptive name for the embedding (e.g., "CLS token", "Mean Pool"). + provider : str + The model provider (e.g., "reve", "neuroai"). + model_name : str + The specific model identifier (e.g., "brain-bzh/reve-large"). + sfreq : float + The sampling frequency the model expects. + """ + + n_embeddings: int + embedding_name: str + provider: str + model_name: str + sfreq: float + + +class BaseTransformerModule(nn.Module): + """ + Shared PyTorch base for Transformer models (REVE, CBRAMOD, etc.). + + This class centralizes the complex logic of Parameter-Efficient Fine-Tuning + (PEFT), Quantization, and parameter freezing, ensuring consistent behavior + across different foundation model architectures. + """ + + def load_backbone( + self, model_name: str, train_mode: str, token: Optional[str] = None, **kwargs + ) -> nn.Module: + """ + Initialize and configure a HuggingFace Transformer backbone. + + Parameters + ---------- + model_name : str + HuggingFace model ID. + train_mode : str + One of {"frozen", "full", "lora", "qlora"}. + token : str, optional + HuggingFace API token for private models. + **kwargs : dict + Additional parameters passed to LoRA/Quantization configs. + + Returns + ------- + nn.Module + The configured backbone (potentially wrapped with LoRA/Quantization). + """ + from transformers import AutoModel, BitsAndBytesConfig + + hf_kwargs = {"trust_remote_code": True, "token": token} + + # 1. Handle 4-bit Quantization (QLoRA) + if train_mode == "qlora": + hf_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + + backbone = AutoModel.from_pretrained(model_name, **hf_kwargs) + + # 2. Handle PEFT (LoRA) + if train_mode in {"lora", "qlora"}: + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + + if train_mode == "qlora": + backbone = prepare_model_for_kbit_training(backbone) + + lora_config = LoraConfig( + r=kwargs.get("lora_r", 16), + lora_alpha=kwargs.get("lora_alpha", 32), + target_modules=kwargs.get("lora_target_modules", ["query", "value"]), + lora_dropout=kwargs.get("lora_dropout", 0.05), + bias="none", + task_type="FEATURE_EXTRACTION", + ) + backbone = get_peft_model(backbone, lora_config) + + # 3. Handle Parameter Freezing (Linear Probing) + if train_mode == "frozen": + for param in backbone.parameters(): + param.requires_grad = False + backbone.eval() + + return backbone + + +class BaseFoundationModel(BaseEstimator, TransformerMixin): + """ + Abstract base class for all foundation models in fm_hub. + + Provides a standardized Scikit-Learn wrapper around skorch, enabling + automated training history collection, hardware-agnostic device management, + and diagnostic reporting. + """ + + def __init__( + self, + model_name: str, + train_mode: str = "frozen", + task: str = "classification", + device: Optional[str] = None, + **kwargs, + ): + valid_modes = {"frozen", "full", "lora", "qlora"} + if train_mode not in valid_modes: + raise ValueError( + f"train_mode must be one of {valid_modes}, got {train_mode}" + ) + + self.model_name = model_name + self.train_mode = train_mode + self.task = task + self.device = device + self.kwargs = kwargs + self.net_ = None + + @abstractmethod + def get_module_cls(self) -> Type[nn.Module]: + """Return the PyTorch Module class to be instantiated by skorch.""" + pass + + def fit( + self, X: np.ndarray, y: Optional[np.ndarray] = None + ) -> "BaseFoundationModel": + """ + Fit the foundation model or its head using skorch. + """ + from skorch import NeuralNetClassifier, NeuralNetRegressor + + self.device_ = self._resolve_device() + + # Infer output dimensionality + if self.task == "classification" and y is not None: + self.out_dim_ = len(np.unique(y)) + else: + self.out_dim_ = 1 + + net_cls = ( + NeuralNetClassifier if self.task == "classification" else NeuralNetRegressor + ) + + # Configure skorch Net + self.net_ = net_cls( + module=self.get_module_cls(), + module__model_name=self.model_name, + module__train_mode=self.train_mode, + module__output_dim=self.out_dim_, + device=self.device_, + **self._get_net_params(), + ) + + self.net_.fit(X, y) + return self + + def transform(self, X: np.ndarray) -> np.ndarray: + """ + Extract high-dimensional embeddings from the model. + """ + if self.net_ is None: + raise RuntimeError("Model must be fitted before transform.") + + return self.net_.forward(X, training=False, return_embeddings=True) + + def predict(self, X: np.ndarray) -> np.ndarray: + """Perform task-specific inference.""" + if self.net_ is None: + raise RuntimeError("Model must be fitted before predict.") + return self.net_.predict(X) + + def get_training_history(self) -> List[Dict[str, Any]]: + """Return the skorch training history.""" + return self.net_.history if self.net_ else [] + + def get_artifact_metadata(self) -> Dict[str, Any]: + """Diagnostic reporting for the training engine.""" + return { + "model_type": "foundation", + "model_card": { + "model_name": self.model_name, + "train_mode": self.train_mode, + "task": self.task, + }, + "history": self.get_training_history(), + } + + def _get_net_params(self) -> Dict[str, Any]: + """Package standard skorch parameters.""" + return { + "max_epochs": self.kwargs.get("max_epochs", 10), + "lr": self.kwargs.get("lr", 0.001), + "batch_size": self.kwargs.get("batch_size", 32), + } + + def _resolve_device(self) -> str: + """Determine the best available compute device.""" + if self.device: + return self.device + if torch.cuda.is_available(): + return "cuda" + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return "mps" + return "cpu" diff --git a/coco_pipe/decoding/fm_hub/cbramod.py b/coco_pipe/decoding/fm_hub/cbramod.py new file mode 100644 index 0000000..38a6e61 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/cbramod.py @@ -0,0 +1,141 @@ +""" +CBraMod Foundation Model Provider +=================================== +Implementation for the CBraMod EEG foundation model. +""" + +from typing import Optional, Type + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download + +from .base import BaseFoundationModel, BaseTransformerModule, EmbeddingInfo +from .cbramod_src.cbramod import CBraMod as RawCBraMod + + +class CBraModModule(BaseTransformerModule): + """ + CBraMod wrapper module. + """ + + def __init__( + self, + model_name: str = "braindecode/cbramod-pretrained", + output_dim: int = 2, + train_mode: str = "frozen", + pooling: str = "mean", + token: Optional[str] = None, + patch_size: int = 200, + seq_len: int = 4, + **kwargs, + ): + super().__init__() + + self.pooling = pooling + self.train_mode = train_mode + self.patch_size = patch_size + self.seq_len = seq_len + + # 1. Instantiate raw CBraMod + self.backbone = RawCBraMod( + in_dim=self.patch_size, + out_dim=200, + d_model=200, + dim_feedforward=800, + seq_len=self.seq_len, + n_layer=12, + nhead=8, + ) + + # 2. Download and load weights from HuggingFace Hub + weights_path = hf_hub_download( + repo_id=model_name, filename="pytorch_model.bin", token=token + ) + state_dict = torch.load(weights_path, map_location="cpu") + self.backbone.load_state_dict(state_dict, strict=False) + + # 3. Handle Parameter Freezing (Linear Probing) + if self.train_mode == "frozen": + for param in self.backbone.parameters(): + param.requires_grad = False + self.backbone.eval() + + # 4. Task-Specific Head + self.backbone.proj_out = nn.Identity() + feat_dim = 200 # CBraMod's d_model dimension + self.head = nn.Linear(feat_dim, output_dim) + + def forward( + self, X: torch.Tensor, return_embeddings: bool = False, **kwargs + ) -> torch.Tensor: + """ + Forward pass for CBraMod. + + Parameters + ---------- + X : torch.Tensor + Input EEG data of shape (batch, channels, time). + return_embeddings : bool + If True, returns the pooled representation instead of logits. + """ + bz, ch_num, time = X.shape + + # Reshape to (batch, channels, patch_num, patch_size) + patch_num = time // self.patch_size + X_reshaped = X.contiguous().view(bz, ch_num, patch_num, self.patch_size) + + # Transform backbone pass + # hidden shape: (batch, channels, patch_num, d_model) + hidden = self.backbone(X_reshaped) + + # Spatial and Temporal Pooling + if self.pooling == "mean": + # Pool over channels (dim 1) and patches (dim 2) + pooled = hidden.mean(dim=(1, 2)) + else: + # First token pooling (mean over channels, first patch) + pooled = hidden.mean(dim=1)[:, 0, :] + + if return_embeddings: + return pooled + + return self.head(pooled) + + +class CBraModModel(BaseFoundationModel): + """ + Unified CBraMod wrapper for the coco-pipe decoding engine. + """ + + def __init__(self, **kwargs): + model_name = kwargs.pop("model_name", "braindecode/cbramod-pretrained") + super().__init__(model_name=model_name, **kwargs) + self.provider = "cbramod" + + def get_module_cls(self) -> Type[nn.Module]: + """Return the CBraModModule class.""" + return CBraModModule + + def _get_net_params(self) -> dict: + """Add CBraMod-specific parameters for skorch initialization.""" + params = super()._get_net_params() + params.update( + { + "module__pooling": self.kwargs.get("pooling", "mean"), + "module__token": self.kwargs.get("token"), + "module__patch_size": self.kwargs.get("patch_size", 200), + "module__seq_len": self.kwargs.get("seq_len", 4), + } + ) + return params + + def get_embedding_info(self) -> EmbeddingInfo: + """Return metadata about CBraMod embeddings.""" + return EmbeddingInfo( + n_embeddings=200, + embedding_name=f"CBraMod ({self.kwargs.get('pooling', 'mean')})", + provider="cbramod", + model_name=self.model_name, + sfreq=self.kwargs.get("sfreq", 200.0), + ) diff --git a/coco_pipe/decoding/fm_hub/cbramod_src/__init__.py b/coco_pipe/decoding/fm_hub/cbramod_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/coco_pipe/decoding/fm_hub/cbramod_src/cbramod.py b/coco_pipe/decoding/fm_hub/cbramod_src/cbramod.py new file mode 100644 index 0000000..f520b16 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/cbramod_src/cbramod.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .criss_cross_transformer import TransformerEncoderLayer, TransformerEncoder + + +class CBraMod(nn.Module): + def __init__(self, in_dim=200, out_dim=200, d_model=200, dim_feedforward=800, seq_len=30, n_layer=12, + nhead=8): + super().__init__() + self.patch_embedding = PatchEmbedding(in_dim, out_dim, d_model, seq_len) + encoder_layer = TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, norm_first=True, + activation=F.gelu + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers=n_layer, enable_nested_tensor=False) + self.proj_out = nn.Sequential( + # nn.Linear(d_model, d_model*2), + # nn.GELU(), + # nn.Linear(d_model*2, d_model), + # nn.GELU(), + nn.Linear(d_model, out_dim), + ) + self.apply(_weights_init) + + def forward(self, x, mask=None): + patch_emb = self.patch_embedding(x, mask) + feats = self.encoder(patch_emb) + + out = self.proj_out(feats) + + return out + +class PatchEmbedding(nn.Module): + def __init__(self, in_dim, out_dim, d_model, seq_len): + super().__init__() + self.d_model = d_model + self.positional_encoding = nn.Sequential( + nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(19, 7), stride=(1, 1), padding=(9, 3), + groups=d_model), + ) + self.mask_encoding = nn.Parameter(torch.zeros(in_dim), requires_grad=False) + # self.mask_encoding = nn.Parameter(torch.randn(in_dim), requires_grad=True) + + self.proj_in = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)), + nn.GroupNorm(5, 25), + nn.GELU(), + + nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(5, 25), + nn.GELU(), + + nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(5, 25), + nn.GELU(), + ) + self.spectral_proj = nn.Sequential( + nn.Linear(101, d_model), + nn.Dropout(0.1), + # nn.LayerNorm(d_model, eps=1e-5), + ) + # self.norm1 = nn.LayerNorm(d_model, eps=1e-5) + # self.norm2 = nn.LayerNorm(d_model, eps=1e-5) + # self.proj_in = nn.Sequential( + # nn.Linear(in_dim, d_model, bias=False), + # ) + + + def forward(self, x, mask=None): + bz, ch_num, patch_num, patch_size = x.shape + if mask == None: + mask_x = x + else: + mask_x = x.clone() + mask_x[mask == 1] = self.mask_encoding + + mask_x = mask_x.contiguous().view(bz, 1, ch_num * patch_num, patch_size) + patch_emb = self.proj_in(mask_x) + patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model) + + mask_x = mask_x.contiguous().view(bz*ch_num*patch_num, patch_size) + spectral = torch.fft.rfft(mask_x, dim=-1, norm='forward') + spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101) + spectral_emb = self.spectral_proj(spectral) + # print(patch_emb[5, 5, 5, :]) + # print(spectral_emb[5, 5, 5, :]) + patch_emb = patch_emb + spectral_emb + + positional_embedding = self.positional_encoding(patch_emb.permute(0, 3, 1, 2)) + positional_embedding = positional_embedding.permute(0, 2, 3, 1) + + patch_emb = patch_emb + positional_embedding + + return patch_emb + + +def _weights_init(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + +if __name__ == '__main__': + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = CBraMod(in_dim=200, out_dim=200, d_model=200, dim_feedforward=800, seq_len=30, n_layer=12, + nhead=8).to(device) + model.load_state_dict(torch.load('pretrained_weights/pretrained_weights.pth', + map_location=device)) + a = torch.randn((8, 16, 10, 200)).cuda() + b = model(a) + print(a.shape, b.shape) diff --git a/coco_pipe/decoding/fm_hub/cbramod_src/criss_cross_transformer.py b/coco_pipe/decoding/fm_hub/cbramod_src/criss_cross_transformer.py new file mode 100644 index 0000000..8f76374 --- /dev/null +++ b/coco_pipe/decoding/fm_hub/cbramod_src/criss_cross_transformer.py @@ -0,0 +1,219 @@ +import copy +from typing import Optional, Any, Union, Callable + +import torch +import torch.nn as nn +# import torch.nn.functional as F +import warnings +from torch import Tensor +from torch.nn import functional as F + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True): + super().__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None) -> Tensor: + + output = src + for mod in self.layers: + output = mod(output, src_mask=mask) + if self.norm is not None: + output = self.norm(output) + return output + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ['norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, + bias: bool = True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn_s = nn.MultiheadAttention(d_model//2, nhead // 2, dropout=dropout, + bias=bias, batch_first=batch_first, + **factory_kwargs) + self.self_attn_t = nn.MultiheadAttention(d_model//2, nhead // 2, dropout=dropout, + bias=bias, batch_first=batch_first, + **factory_kwargs) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is F.relu or isinstance(activation, torch.nn.ReLU): + self.activation_relu_or_gelu = 1 + elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'activation'): + self.activation = F.relu + + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False) -> Tensor: + + x = src + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) + x = x + self._ff_block(self.norm2(x)) + return x + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: + bz, ch_num, patch_num, patch_size = x.shape + xs = x[:, :, :, :patch_size // 2] + xt = x[:, :, :, patch_size // 2:] + xs = xs.transpose(1, 2).contiguous().view(bz*patch_num, ch_num, patch_size // 2) + xt = xt.contiguous().view(bz*ch_num, patch_num, patch_size // 2) + xs = self.self_attn_s(xs, xs, xs, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + xs = xs.contiguous().view(bz, patch_num, ch_num, patch_size//2).transpose(1, 2) + xt = self.self_attn_t(xt, xt, xt, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + xt = xt.contiguous().view(bz, ch_num, patch_num, patch_size//2) + x = torch.concat((xs, xt), dim=3) + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError(f"activation should be relu/gelu, not {activation}") + +def _get_clones(module, N): + # FIXME: copy.deepcopy() is not defined on nn.module + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_seq_len( + src: Tensor, + batch_first: bool +) -> Optional[int]: + + if src.is_nested: + return None + else: + src_size = src.size() + if len(src_size) == 2: + # unbatched: S, E + return src_size[0] + else: + # batched: B, S, E if batch_first else S, B, E + seq_len_pos = 1 if batch_first else 0 + return src_size[seq_len_pos] + + +def _detect_is_causal_mask( + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, +) -> bool: + """Return whether the given attention mask is causal. + + Warning: + If ``is_causal`` is not ``None``, its value will be returned as is. If a + user supplies an incorrect ``is_causal`` hint, + + ``is_causal=False`` when the mask is in fact a causal attention.mask + may lead to reduced performance relative to what would be achievable + with ``is_causal=True``; + ``is_causal=True`` when the mask is in fact not a causal attention.mask + may lead to incorrect and unpredictable execution - in some scenarios, + a causal mask may be applied based on the hint, in other execution + scenarios the specified mask may be used. The choice may not appear + to be deterministic, in that a number of factors like alignment, + hardware SKU, etc influence the decision whether to use a mask or + rely on the hint. + ``size`` if not None, check whether the mask is a causal mask of the provided size + Otherwise, checks for any causal mask. + """ + # Prevent type refinement + make_causal = (is_causal is True) + + if is_causal is None and mask is not None: + sz = size if size is not None else mask.size(-2) + causal_comparison = _generate_square_subsequent_mask( + sz, device=mask.device, dtype=mask.dtype) + + # Do not use `torch.equal` so we handle batched masks by + # broadcasting the comparison. + if mask.size() == causal_comparison.size(): + make_causal = bool((mask == causal_comparison).all()) + else: + make_causal = False + + return make_causal + + +def _generate_square_subsequent_mask( + sz: int, + device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'), + dtype: torch.dtype = torch.get_default_dtype(), +) -> Tensor: + r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + return torch.triu( + torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), + diagonal=1, + ) + + +if __name__ == '__main__': + encoder_layer = TransformerEncoderLayer( + d_model=256, nhead=4, dim_feedforward=1024, batch_first=True, norm_first=True, + activation=F.gelu + ) + encoder = TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False) + encoder = encoder.cuda() + + a = torch.randn((4, 19, 30, 256)).cuda() + b = encoder(a) + print(a.shape, b.shape) \ No newline at end of file diff --git a/coco_pipe/decoding/fm_hub/reve.py b/coco_pipe/decoding/fm_hub/reve.py new file mode 100644 index 0000000..0d29c5b --- /dev/null +++ b/coco_pipe/decoding/fm_hub/reve.py @@ -0,0 +1,130 @@ +""" +REVE Foundation Model Provider +============================== +Implementation for the Representation for EEG with Versatile Embeddings (REVE). +""" + +from typing import List, Optional, Type + +import torch +import torch.nn as nn + +from .base import BaseFoundationModel, BaseTransformerModule, EmbeddingInfo + + +class REVEModule(BaseTransformerModule): + """ + Pure PyTorch implementation of the REVE architecture. + + This module handles EEG-specific positional encoding using a position bank + and performs feature extraction through a Transformer backbone. + """ + + def __init__( + self, + model_name: str = "brain-bzh/reve-large", + output_dim: int = 2, + electrode_names: Optional[List[str]] = None, + train_mode: str = "frozen", + pooling: str = "mean", + token: Optional[str] = None, + **kwargs, + ): + super().__init__() + from transformers import AutoModel + + self.pooling = pooling + self.electrode_names = electrode_names + + # 1. Initialize Backbone (Generic Transformer logic) + self.backbone = self.load_backbone(model_name, train_mode, token, **kwargs) + + # 2. REVE-Specific: Position Bank + # The position bank maps electrode names to spatial embeddings + hf_kwargs = {"trust_remote_code": True, "token": token} + self.pos_bank = AutoModel.from_pretrained( + "brain-bzh/reve-positions", **hf_kwargs + ) + + # 3. Task-Specific Head + feat_dim = ( + self.backbone.config.hidden_size + if hasattr(self.backbone, "config") + else 1024 + ) + self.head = nn.Linear(feat_dim, output_dim) + + def forward( + self, X: torch.Tensor, return_embeddings: bool = False, **kwargs + ) -> torch.Tensor: + """ + Forward pass for REVE. + + Parameters + ---------- + X : torch.Tensor + Input EEG data of shape (batch, channels, time). + return_embeddings : bool + If True, returns the 1024D pooled representation instead of logits. + """ + # Resolve electrode names for positional encoding + n_channels = X.shape[1] + elec = self.electrode_names or [f"e{i}" for i in range(n_channels)] + + # Lookup spatial positions and expand to batch size + pos = self.pos_bank(elec).unsqueeze(0).expand(len(X), -1, -1) + + # Transform backbone pass + out = self.backbone(X, pos) + hidden = out.last_hidden_state + + # Temporal Pooling + if self.pooling == "mean": + pooled = hidden.mean(dim=1) + else: + pooled = hidden[:, 0, :] # CLS-style pooling + + if return_embeddings: + return pooled + + return self.head(pooled) + + +class REVEModel(BaseFoundationModel): + """ + Unified REVE Model wrapper for the coco-pipe decoding engine. + + Provides a Scikit-Learn compatible interface for REVE-large, supporting + frozen extraction, linear probing, and parameter-efficient fine-tuning (LoRA). + """ + + def __init__(self, **kwargs): + model_name = kwargs.pop("model_name", "brain-bzh/reve-large") + super().__init__(model_name=model_name, **kwargs) + self.provider = "reve" + + def get_module_cls(self) -> Type[nn.Module]: + """Return the REVEModule class.""" + return REVEModule + + def _get_net_params(self) -> dict: + """Add REVE-specific parameters for skorch initialization.""" + params = super()._get_net_params() + params.update( + { + "module__electrode_names": self.kwargs.get("electrode_names"), + "module__pooling": self.kwargs.get("pooling", "mean"), + "module__token": self.kwargs.get("token"), + } + ) + return params + + def get_embedding_info(self) -> EmbeddingInfo: + """Return metadata about REVE embeddings.""" + return EmbeddingInfo( + n_embeddings=1024, + embedding_name=f"REVE ({self.kwargs.get('pooling', 'mean')})", + provider="reve", + model_name=self.model_name, + sfreq=self.kwargs.get("sfreq", 200.0), + ) diff --git a/coco_pipe/decoding/interfaces.py b/coco_pipe/decoding/interfaces.py new file mode 100644 index 0000000..dd89eb6 --- /dev/null +++ b/coco_pipe/decoding/interfaces.py @@ -0,0 +1,233 @@ +""" +Lightweight public interfaces for decoding estimator families. +============================================================ + +These protocols define the structural contracts for models and extractors +used within the decoding pipeline. Since they are runtime-checkable, +the pipeline can verify capabilities without strict inheritance from +scikit-learn base classes. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class DecoderEstimator(Protocol): + """ + Protocol for scikit-learn-compatible decoding estimators. + + This interface defines the minimal set of methods required for an + estimator to be integrated into the cross-validation engine. + """ + + def fit(self, X: Any, y: Any = None, **fit_params: Any) -> DecoderEstimator: + """ + Fit the estimator to the provided data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training vector, where n_samples is the number of samples and + n_features is the number of features. + y : array-like of shape (n_samples,), optional + Target values (class labels in classification, real numbers in + regression). + **fit_params : dict + Parameters to pass to the underlying fit method. + + Returns + ------- + self : DecoderEstimator + The fitted estimator. + """ + ... # pragma: no cover + + def predict(self, X: Any) -> Any: + """ + Predict targets for the provided data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Samples to predict. + + Returns + ------- + y_pred : array-like of shape (n_samples,) + Predicted target values per sample. + """ + ... # pragma: no cover + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """ + Get parameters for this estimator. + + Parameters + ---------- + deep : bool, default=True + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + ... # pragma: no cover + + def set_params(self, **params: Any) -> DecoderEstimator: + """ + Set the parameters of this estimator. + + Parameters + ---------- + **params : dict + Estimator parameters. + + Returns + ------- + self : DecoderEstimator + The estimator instance. + """ + ... # pragma: no cover + + +@runtime_checkable +class EmbeddingExtractor(Protocol): + """ + Interface for pretrained or frozen feature extraction backbones. + + Embedding extractors typically represent foundation models or frozen + neural networks that transform raw data into a fixed-dimensional + vector space before classical decoding. + """ + + def transform(self, X: Any) -> Any: + """ + Extract features from the provided data. + + Parameters + ---------- + X : array-like + The raw data to be transformed. + + Returns + ------- + embeddings : array-like + The extracted feature vectors. + """ + ... # pragma: no cover + + def get_embedding_info(self) -> dict[str, Any]: + """ + Return technical metadata about the extractor and its output space. + + Returns + ------- + info : dict + A dictionary containing provider name, model name, pooling + strategy, and output dimensionality. + """ + ... # pragma: no cover + + +@runtime_checkable +class NeuralTrainable(Protocol): + """ + Interface for trainable neural estimators with diagnostic metadata. + + This protocol exposes internal training states and histories for + reporting and verification purposes. + """ + + def get_training_history(self) -> list[dict[str, Any]]: + """ + Get the step-by-step training history (e.g., loss per epoch). + + Returns + ------- + history : list of dict + A list of diagnostic records, one per training iteration. + """ + ... # pragma: no cover + + def get_checkpoint_manifest(self) -> dict[str, Any]: + """ + Get information about saved model checkpoints. + + Returns + ------- + manifest : dict + Metadata including checkpoint paths and best-epoch indices. + """ + ... # pragma: no cover + + def get_model_card_info(self) -> dict[str, Any]: + """ + Get high-level model card metadata for the artifact registry. + + Returns + ------- + info : dict + Information about model architecture, training configuration, + and hyperparameters. + """ + ... # pragma: no cover + + def get_failure_diagnostics(self) -> dict[str, Any]: + """ + Get technical diagnostics if training failed or diverged. + + Returns + ------- + diagnostics : dict + Information about gradients, NaN detection, or hardware state. + """ + ... # pragma: no cover + + def get_artifact_metadata(self) -> dict[str, Any]: + """ + Aggregate all diagnostic metadata into a single dictionary. + + Returns + ------- + metadata : dict + A serializable dictionary containing history, model card, and checkpoints. + """ + ... # pragma: no cover + + +@runtime_checkable +class StagedTrainable(Protocol): + """ + Interface for estimators that support multi-stage training schedules. + """ + + def set_train_stage(self, stage: str) -> StagedTrainable: + """ + Configure the active training stage (e.g., 'pretrain', 'finetune'). + + Parameters + ---------- + stage : str + The name of the training stage to activate. + + Returns + ------- + self : StagedTrainable + The estimator instance. + """ + ... # pragma: no cover + + def get_train_stage(self) -> str: + """ + Get the name of the currently active training stage. + + Returns + ------- + stage : str + The active stage name. + """ + ... # pragma: no cover diff --git a/coco_pipe/decoding/registry.py b/coco_pipe/decoding/registry.py index 71d550d..73af924 100644 --- a/coco_pipe/decoding/registry.py +++ b/coco_pipe/decoding/registry.py @@ -1,102 +1,92 @@ """ -Decoding Registry -================= - -Central registry for decoding estimators (classifiers, regressors, and FMs). -This allows instantiating models from string names in configuration files, -avoiding circular imports and simplifying the config layer. - -Usage ------ ->>> from coco_pipe.decoding.registry import register_estimator, get_estimator_cls ->>> ->>> @register_estimator("MyModel") ->>> class MyModel: ... ->>> ->>> cls = get_estimator_cls("MyModel") +Decoding Registry Engine +======================== + +Central engine for resolving, registering, and lazy-loading decoding +estimators. This allows configurations to refer to models by name +without triggering eager imports of heavyweight dependencies. """ +from __future__ import annotations + +import difflib import importlib import pkgutil +import threading import warnings -from importlib.metadata import entry_points -from typing import Callable, Dict, Type - -# Registry Storage -# Maps string alias -> class object +from dataclasses import replace +from typing import Any, Callable, Dict, Type + +from ._specs import ( + ESTIMATOR_SPECS, + SELECTOR_CAPABILITIES, + EstimatorCapabilities, + EstimatorSpec, + SelectorCapabilities, + canonical_estimator_name, +) + +# Runtime class cache _ESTIMATOR_REGISTRY: Dict[str, Type] = {} +_INTERNAL_SCANNED = False +_REGISTRY_LOCK = threading.Lock() + + +class EstimatorNotFoundError(KeyError, ValueError): + """Raised when an estimator is not found in the registry.""" + pass + + +def _discover_entry_points(): # pragma: no cover + """Import 'coco_pipe.estimators' entry points.""" + try: + from importlib.metadata import entry_points + + try: + eps = entry_points(group="coco_pipe.estimators") + except TypeError: + eps = entry_points().get("coco_pipe.estimators", []) + except Exception: + return -_LAZY_MODULES = { - # MNE - "SlidingEstimator": "mne.decoding", - "GeneralizingEstimator": "mne.decoding", - # Classifiers - "LogisticRegression": "sklearn.linear_model", - "RandomForestClassifier": "sklearn.ensemble", - "SVC": "sklearn.svm", - "KNeighborsClassifier": "sklearn.neighbors", - "GradientBoostingClassifier": "sklearn.ensemble", - "SGDClassifier": "sklearn.linear_model", - "MLPClassifier": "sklearn.neural_network", - "GaussianNB": "sklearn.naive_bayes", - "LDA": "sklearn.discriminant_analysis", - "AdaBoostClassifier": "sklearn.ensemble", - "DummyClassifier": "sklearn.dummy", - # Regressors - "LinearRegression": "sklearn.linear_model", - "Ridge": "sklearn.linear_model", - "Lasso": "sklearn.linear_model", - "ElasticNet": "sklearn.linear_model", - "RandomForestRegressor": "sklearn.ensemble", - "SVR": "sklearn.svm", - "ARDRegression": "sklearn.linear_model", -} - - -def _discover_entry_points(): - """ - Populate _LAZY_MODULES from 'coco_pipe.estimators' entry points. - This allows plugins to register estimators without modifying code. - """ - eps = entry_points(group="coco_pipe.estimators") for ep in eps: - if ep.name not in _LAZY_MODULES: - _LAZY_MODULES[ep.name] = ep.value + try: + ep.load() + except Exception: + warnings.warn(f"Could not load estimator entry point '{ep.name}'") -def _discover_internal_modules(): - """ - Walk through the 'coco_pipe.decoding' subpackage and import all modules. - This triggers the @register_estimator decorators. - """ +def _discover_internal_modules(): # pragma: no cover + """Import all internal decoding submodules to trigger decorators.""" package = importlib.import_module("coco_pipe.decoding") if not hasattr(package, "__path__"): return - for _, name, ispkg in pkgutil.walk_packages( - package.__path__, package.__name__ + "." - ): + for _, name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): try: importlib.import_module(name) except ImportError: - # warn but continue - we don't want to crash if deep learning libs are - # missing pass -# 1. Load Entry Points on startup (lazy map update only) +# Lazy entry point discovery _discover_entry_points() def register_estimator(name: str) -> Callable[[Type], Type]: """ - Decorator to register an estimator class under a specific name. + Decorator to register a custom estimator class under a specific name. Parameters ---------- name : str - The unique alias for the estimator (e.g., "RandomForestClassifier"). + The unique name to register the estimator under. + + Returns + ------- + Callable[[Type], Type] + A decorator that adds the class to the internal registry. """ def decorator(cls: Type) -> Type: @@ -108,76 +98,330 @@ def decorator(cls: Type) -> Type: return decorator +def register_estimator_spec(spec: EstimatorSpec) -> EstimatorSpec: + """ + Register or replace an estimator spec in the global specs registry. + + This allows the execution engine to support new model types by name, + defining their capabilities, required input formats, and importance + extraction logic. + + Parameters + ---------- + spec : EstimatorSpec + The typed specification object for the estimator. + + Returns + ------- + EstimatorSpec + The registered specification object. + + See Also + -------- + get_estimator_spec : Retrieve a registered specification. + """ + if spec.name in ESTIMATOR_SPECS: + warnings.warn(f"Overwriting existing estimator spec for '{spec.name}'") + ESTIMATOR_SPECS[spec.name] = spec + return spec + + def get_estimator_cls(name: str) -> Type: """ - Retrieve an estimator class by name. + Retrieve an estimator class by name, triggering lazy loading if needed. Parameters ---------- name : str - Name of the estimator. + The canonical name of the estimator (e.g., 'LogisticRegression'). Returns ------- Type - The class object. + The uninstantiated estimator class. Raises ------ - ValueError - If name is not found. + EstimatorNotFoundError + If the estimator is unknown. + ImportError + If the underlying module cannot be imported. + + Examples + -------- + >>> from coco_pipe.decoding.registry import get_estimator_cls + >>> cls = get_estimator_cls('LogisticRegression') + + See Also + -------- + get_estimator_spec : Retrieve the metadata for an estimator. """ - # 1. Check if already loaded if name in _ESTIMATOR_REGISTRY: return _ESTIMATOR_REGISTRY[name] - # 2. Try Lazy Loading Map - if name in _LAZY_MODULES: + # Try lazy import from spec + spec = ESTIMATOR_SPECS.get(name) + if spec is not None: try: - mod_path = _LAZY_MODULES[name] - if ":" in mod_path: - mod_path = mod_path.split(":")[0] - - module = importlib.import_module(mod_path) + module = importlib.import_module(spec.module_path) except ImportError as e: raise ImportError( - f"Could not load estimator '{name}' from '{_LAZY_MODULES[name]}'. " + f"Could not load estimator '{name}' from '{spec.import_path}'. " f"Ensure optional dependencies are installed." ) from e - if hasattr(module, name): - cls = getattr(module, name) + if hasattr(module, spec.class_name): + cls = getattr(module, spec.class_name) _ESTIMATOR_REGISTRY[name] = cls return cls - # Check if the import triggered a decorator registration - if name in _ESTIMATOR_REGISTRY: + if name in _ESTIMATOR_REGISTRY: # pragma: no cover return _ESTIMATOR_REGISTRY[name] - # 3. Last Ditch: Internal Discovery - if not getattr(get_estimator_cls, "_internal_scanned", False): - _discover_internal_modules() - setattr(get_estimator_cls, "_internal_scanned", True) - if name in _ESTIMATOR_REGISTRY: + # Try internal discovery + global _INTERNAL_SCANNED + if not _INTERNAL_SCANNED: + with _REGISTRY_LOCK: + if not _INTERNAL_SCANNED: + _discover_internal_modules() # pragma: no cover + _INTERNAL_SCANNED = True # pragma: no cover + if name in _ESTIMATOR_REGISTRY: # pragma: no cover return _ESTIMATOR_REGISTRY[name] if name not in _ESTIMATOR_REGISTRY: - # Generate helpful error - available = sorted(list(_ESTIMATOR_REGISTRY.keys())) - raise ValueError( - f"Estimator '{name}' not found in registry.\n" - f"Available estimators: {available}\n" - f"Tip: Ensure the containing module is imported or registered via " - f"entry points." + available = sorted(set(_ESTIMATOR_REGISTRY) | set(ESTIMATOR_SPECS)) + matches = difflib.get_close_matches(name, available, n=3, cutoff=0.6) + msg = f"Estimator '{name}' not found in registry." + if matches: + msg += f" Did you mean: {matches}?" + msg += f"\nAvailable estimators: {available[:10]}... (Total: {len(available)})" + raise EstimatorNotFoundError(msg) + + return _ESTIMATOR_REGISTRY[name] # pragma: no cover + + +def get_estimator_spec(name: str) -> EstimatorSpec: + """ + Return the typed estimator spec for a given name. + + Parameters + ---------- + name : str + The canonical name of the estimator (e.g., 'LogisticRegression'). + + Returns + ------- + EstimatorSpec + The registered specification object. + + Raises + ------ + ValueError + If no specification is registered for the given name. + + See Also + -------- + get_capabilities : Retrieve derived lightweight capabilities. + """ + if name not in ESTIMATOR_SPECS: + raise ValueError(f"No decoding estimator spec registered for '{name}'.") + return ESTIMATOR_SPECS[name] + + +def get_capabilities(name: str) -> EstimatorCapabilities: + """ + Return machine-readable capability metadata for a given estimator. + + Parameters + ---------- + name : str + The canonical name of the estimator. + + Returns + ------- + EstimatorCapabilities + Lightweight metadata summary for validation. + """ + return get_estimator_spec(name).to_capabilities() + + +def list_capabilities() -> Dict[str, EstimatorCapabilities]: + """ + Return capability metadata for all registered estimators. + + Returns + ------- + Dict[str, EstimatorCapabilities] + A dictionary mapping estimator names to their capability objects. + + See Also + -------- + get_capabilities : Retrieve capabilities for a single estimator. + """ + return {name: spec.to_capabilities() for name, spec in ESTIMATOR_SPECS.items()} + + +def list_estimator_specs() -> Dict[str, EstimatorSpec]: + """ + Return all registered estimator specs. + + Returns + ------- + Dict[str, EstimatorSpec] + A dictionary mapping estimator names to their full specification objects. + + See Also + -------- + get_estimator_spec : Retrieve a single estimator specification. + """ + return dict(ESTIMATOR_SPECS) + + +def _get_val(obj: Any, key: str, default: Any = None) -> Any: + """ + Retrieve a value from a configuration object or dictionary. + + This helper provides a unified interface for accessing parameters from + both Pydantic models and raw dictionaries, which is common when + handling polymorphic or nested estimator configurations. + + Parameters + ---------- + obj : Any + The configuration source (dict or object). + key : str + The parameter name to retrieve. + default : Any, optional + The default value if not found. + + Returns + ------- + Any + The retrieved value. + """ + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def resolve_estimator_spec(config: Any) -> EstimatorSpec: + """ + Resolve a hydrated EstimatorSpec from a model configuration. + + This function handles polymorphic model types (Foundation, Temporal, + Neural) and applies runtime parameter fixups based on the user's + specific configuration (e.g., handling SVC probability flags). + + Parameters + ---------- + config : Any + A configuration object (typically a Pydantic model) containing + the 'kind' and specific estimator parameters. + + Returns + ------- + EstimatorSpec + A hydrated spec object containing accurate flags for the training engine. + + Examples + -------- + >>> from coco_pipe.decoding.configs import LogisticRegressionConfig + >>> from coco_pipe.decoding.registry import resolve_estimator_spec + >>> config = LogisticRegressionConfig() + >>> spec = resolve_estimator_spec(config) + + See Also + -------- + resolve_estimator_capabilities : Lightweight capability summary. + """ + # 1. Temporal Wrapper logic (Needs special handling for base spec) + kind = _get_val(config, "kind", "classical") + if kind == "temporal": + base_spec = resolve_estimator_spec(_get_val(config, "base")) + wrapper = _get_val(config, "wrapper", "sliding") + wrapper_name = ( + "SlidingEstimator" if wrapper == "sliding" else "GeneralizingEstimator" + ) + return replace( + get_estimator_spec(wrapper_name), + task=base_spec.task, + supports_proba=base_spec.supports_proba, + supports_decision_function=base_spec.supports_decision_function, + supports_calibration=base_spec.supports_calibration, + supports_feature_names=False, ) - return _ESTIMATOR_REGISTRY[name] + if kind == "classical": + name_val = _get_val(config, "method") + if name_val == "ClassicalModel": + name_val = _get_val(config, "estimator") + spec_name = canonical_estimator_name(name_val or str(config)) + elif kind == "foundation_embedding": + spec_name = _get_val(config, "provider", kind) + else: + spec_name = kind + + spec = get_estimator_spec(spec_name or str(config)) + + # Runtime Fixups (Original logic) + if spec.name == "SVC" and not _get_val(config, "probability", True): + spec = replace(spec, supports_proba=False, supports_decision_function=True) + + if spec.name == "SGDClassifier" and _get_val(config, "loss") in { + "log_loss", + "modified_huber", + }: + spec = replace(spec, supports_proba=True, supports_decision_function=True) + + return spec + + +def resolve_estimator_capabilities(config: Any) -> EstimatorCapabilities: + """ + Resolve lightweight capabilities derived from ``resolve_estimator_spec``. + + This is a convenience wrapper that first resolves the full spec (handling + polymorphism and runtime fixups) and then converts it to the capability + summary used by the engine for validation. + + Parameters + ---------- + config : Any + The model configuration object. + + Returns + ------- + EstimatorCapabilities + The resolved capability metadata. + + See Also + -------- + resolve_estimator_spec : The underlying spec resolution logic. + """ + return resolve_estimator_spec(config).to_capabilities() + + +def get_selector_capabilities(method: str) -> SelectorCapabilities: + """ + Return feature-selector capabilities for a given name. + + Parameters + ---------- + method : str + The name of the feature selection method (e.g., 'k_best', 'sfs'). + Returns + ------- + SelectorCapabilities + The registered capability metadata for the selector. -def list_estimators() -> Dict[str, Type]: - """Return a copy of the current registry.""" - # Ensure everything is discovered before listing - if not getattr(get_estimator_cls, "_internal_scanned", False): - _discover_internal_modules() - setattr(get_estimator_cls, "_internal_scanned", True) - return dict(_ESTIMATOR_REGISTRY) + Raises + ------ + ValueError + If the method name is not found in the SELECTOR_CAPABILITIES registry. + """ + if method not in SELECTOR_CAPABILITIES: + raise ValueError( + f"No decoding capabilities registered for selector '{method}'." + ) + return SELECTOR_CAPABILITIES[method] diff --git a/coco_pipe/decoding/result.py b/coco_pipe/decoding/result.py new file mode 100644 index 0000000..1a7254f --- /dev/null +++ b/coco_pipe/decoding/result.py @@ -0,0 +1,2159 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Union + +import numpy as np +import pandas as pd + +from ._constants import RESULT_SCHEMA_VERSION +from ._diagnostics import ( + confusion_matrix_frame, + curve_score_groups, + prediction_rows, + proba_matrix, + scalar_prediction_frame, + score_rows, +) + +logger = logging.getLogger(__name__) + + +class ExperimentResult: + """ + Unified Container for Experiment Results. + + Provides tidy data views for easier analysis, visualization, and + statistical assessment of decoding performance across multiple models, + folds, and temporal coordinates. + + Examples + -------- + >>> result = Experiment(config).run(X, y) + >>> summary_df = result.summary() + >>> preds_df = result.get_predictions() + """ + + def __init__( + self, + raw_results: Dict[str, Any], + config: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + time_axis: Optional[Sequence[Any]] = None, + schema_version: str = RESULT_SCHEMA_VERSION, + ): + """ + Initialize the ExperimentResult container. + + This object serves as the primary interface for exploring, visualizing, + and validating decoding results. It encapsulates raw metrics, + predictions, and cross-validation splits into a unified structure. + + Parameters + ---------- + raw_results : dict + The raw results dictionary returned by the Experiment engine, + keyed by model name. + config : dict, optional + The configuration dictionary used for the experiment. + meta : dict, optional + Additional metadata (e.g., sample IDs, unit of inference, versions). + time_axis : sequence, optional + The scientific time points (e.g., in seconds or ms) corresponding + to the temporal coordinates in the results. + schema_version : str, optional + The version of the result schema for forward compatibility. + """ + self.raw = raw_results + self.config = config or {} + self.meta = meta or {} + self.schema_version = schema_version + + # Explicit time axis resolution + self._time_axis_cache = None + if time_axis is not None: + self._time_axis_cache = list(time_axis) + elif "time_axis" in self.meta: + t = self.meta["time_axis"] + self._time_axis_cache = list(t) if t is not None else None + + @property + def time_axis(self) -> Optional[list[Any]]: + """The scientific time points for temporal decoding results.""" + return self._time_axis_cache + + def to_payload(self, serializable: bool = False) -> Dict[str, Any]: + """ + Return the result payload for persistence or transmission. + + Converts the internal state into a dictionary containing the schema + version, configuration, metadata, and raw model results. + + Parameters + ---------- + serializable : bool, default=False + If True, recursively converts all NumPy arrays, integers, floats, + and booleans into standard Python primitives (lists, ints, etc.) + suitable for JSON serialization. + + Returns + ------- + payload : dict + The consolidated result payload. + + See Also + -------- + ExperimentResult.save : Persist results to disk. + """ + payload = { + "schema_version": self.schema_version, + "config": self.config, + "meta": self.meta, + "results": self.raw, + } + return make_serializable(payload) if serializable else payload + + def save(self, path: Optional[Union[str, Path, Any]] = None, indent: int = 2): + """ + Save results to a file, auto-detecting the format from the extension. + + Supports both binary formats (via joblib) for speed and disk space, + and JSON format for interoperability and human-readability. + + Parameters + ---------- + path : str or Path, optional + The destination path. + - If None, uses 'output_dir' from the experiment config. + - If a directory, generates a timestamped filename with a '.pkl' extension. + - If a file path ending in '.json', performs JSON serialization. + - Otherwise, uses joblib binary serialization. + indent : int, default=2 + JSON indentation level (only applicable for .json files). + + Returns + ------- + path : Path + The path where the results were saved. + + See Also + -------- + ExperimentResult.load : Load results from disk. + Experiment.save_results : Experiment-level wrapper. + """ + from datetime import datetime + + if path is None: + path = self.config.get("output_dir", ".") + + path = Path(path) + if path.suffix == "" or path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + tag = self.config.get("tag", "result") + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + path = path / f"{tag}_{ts}.pkl" + else: + path.parent.mkdir(parents=True, exist_ok=True) + + if path.suffix == ".json": + import json + + payload = self.to_payload(serializable=True) + with open(path, "w") as f: + json.dump(payload, f, indent=indent) + else: + import joblib + + joblib.dump(self.to_payload(), path) + + return path + + @classmethod + def load(cls, path: Union[str, Path, Any]) -> "ExperimentResult": + """ + Load results from a file (auto-detects JSON or Pickle). + + Reconstructs an ExperimentResult instance from a previously saved + payload on disk. + + Parameters + ---------- + path : str or Path + The path to the result file. + + Returns + ------- + result : ExperimentResult + The rehydrated result container. + + Raises + ------ + FileNotFoundError + If the specified path does not exist. + ValueError + If the file format is unrecognized or corrupted. + + See Also + -------- + ExperimentResult.save : Persist results to disk. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Result file not found: {path}") + + if path.suffix == ".json": + import json + + with open(path, "r") as f: + payload = json.load(f) + else: + import joblib + + payload = joblib.load(path) + + return cls( + raw_results=payload["results"], + config=payload.get("config"), + meta=payload.get("meta"), + schema_version=payload.get("schema_version", RESULT_SCHEMA_VERSION), + ) + + def summary(self) -> pd.DataFrame: + """ + Get a high-level summary of performance (Mean/Std and Stats). + + Aggregates results across all models and folds into a single + benchmarking table. + + Scientific Rationale + -------------------- + A summary table provides a concise overview of model performance + expectations and their reliability. By including standard deviations + and p-values alongside means, it allows for immediate identification + of significant decoding effects and model stability. + + Returns + ------- + summary_df : pd.DataFrame + DataFrame with models as index and scalar metrics as columns. + Includes p-values and significance markers ('*') if statistical + assessments were executed. + + Examples + -------- + >>> # df = result.summary() + >>> # print(df[['accuracy_mean', 'accuracy_p_val']]) + + See Also + -------- + ExperimentResult.get_detailed_scores : Get fold-level results. + ExperimentResult.get_temporal_score_summary : Temporal-resolved summary. + """ + rows = [] + for model, res in self.raw.items(): + if "error" in res: + continue + + # 1. Performance Metrics + row = {"Model": model} + for metric, stats in res.get("metrics", {}).items(): + mean = np.asarray(stats["mean"]) + std = np.asarray(stats["std"]) + if mean.ndim == 0 and std.ndim == 0: + row[f"{metric}_mean"] = float(mean) + row[f"{metric}_std"] = float(std) + + # 2. Statistical Assessment (if available) + stats_rows = res.get("statistical_assessment", []) + for s in stats_rows: + # Only include scalar stats (where Time/TrainTime are None) + if s.get("Time") is None and s.get("TrainTime") is None: + m = s.get("Metric") + p_val = s.get("PValue") + if p_val is not None: + row[f"{m}_p_val"] = p_val + # Add significance marker + if s.get("Significant"): + row[f"{m}_sig"] = "*" + + if len(row) > 1: + rows.append(row) + + if not rows: + return pd.DataFrame() + + df = pd.DataFrame(rows).set_index("Model") + # Sort columns to group Mean/Std/P-val together for each metric + cols = sorted(df.columns) + return df[cols] + + def get_detailed_scores(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get fold-level scores for all models or a specific model in long format. + + Expands results into a 'tidy' format where each row represents a + single score for one fold, model, and metric. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + scores_df : pd.DataFrame + Tidy DataFrame with columns: Model, Fold, Metric, and Value. + Includes temporal coordinates if the data is time-resolved. + + See Also + -------- + ExperimentResult.summary : Mean/Std aggregate view. + """ + rows = [] + target_models = [model] if model is not None else self.raw.keys() + for m_name in target_models: + res = self.raw.get(m_name, {}) + if "error" in res: + continue + metrics_data = res["metrics"] + n_folds = len(next(iter(metrics_data.values()))["folds"]) + for fold_idx in range(n_folds): + for metric, stats in metrics_data.items(): + rows.extend( + score_rows( + m_name, + fold_idx, + metric, + stats["folds"][fold_idx], + time_axis=self.time_axis, + ) + ) + return pd.DataFrame(rows) + + def get_temporal_score_summary(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get temporal metric means/stds and significance across folds. + + Averages performance metrics across cross-validation folds for each + temporal coordinate (Time or TrainTime/TestTime pair). + + Scientific Rationale + -------------------- + Temporal decoding and time-generalization analysis yield multi-dimensional + performance arrays. Aggregating these across folds provides an estimate of + the central tendency and variance of the model's ability to decode at + specific latency points. Integrating p-values into this view allows for + the identification of 'significant' time windows. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + summary_df : pd.DataFrame + DataFrame in long format with Model, Metric, Time coordinates, + Mean, Std, and PValue/Significant columns if statistical + assessments were executed. + + See Also + -------- + ExperimentResult.summary : Scalar-only summary view. + ExperimentResult.get_generalization_matrix : 2D matrix view of TG results. + """ + rows = [] + columns = [ + "Model", + "Metric", + "Time", + "TrainTime", + "TestTime", + "Mean", + "Std", + "PValue", + "Significant", + ] + + target_models = [model] if model is not None else self.raw.keys() + for m_name in target_models: + res = self.raw.get(m_name, {}) + if "error" in res: + continue + + # 1. Base Metrics (Mean/Std across folds) + for metric, stats in res.get("metrics", {}).items(): + folds = [np.asarray(fold) for fold in stats.get("folds", [])] + if not folds or folds[0].ndim == 0: + continue + stack = np.stack(folds) + mean = np.nanmean(stack, axis=0) + std = np.nanstd(stack, axis=0) + + # Prepare stats lookup for this model/metric + # Using a dict for O(1) lookup: (Time) or (TrainTime, TestTime) -> row + stats_rows = res.get("statistical_assessment", []) + stats_lookup = {} + for s in stats_rows: + if s.get("Metric") == metric: + key = (s.get("Time"), s.get("TrainTime"), s.get("TestTime")) + stats_lookup[key] = s + + if mean.ndim == 1: + for t_idx, val in enumerate(mean): + t_val = self._time_value(t_idx) + s_row = stats_lookup.get((t_val, None, None), {}) + rows.append( + { + "Model": m_name, + "Metric": metric, + "Time": t_val, + "Mean": val, + "Std": std[t_idx], + "PValue": s_row.get("PValue"), + "Significant": s_row.get("Significant", False), + } + ) + elif mean.ndim == 2: + for t_tr in range(mean.shape[0]): + tr_val = self._time_value(t_tr) + for t_te in range(mean.shape[1]): + te_val = self._time_value(t_te) + s_row = stats_lookup.get((None, tr_val, te_val), {}) + rows.append( + { + "Model": m_name, + "Metric": metric, + "TrainTime": tr_val, + "TestTime": te_val, + "Mean": mean[t_tr, t_te], + "Std": std[t_tr, t_te], + "PValue": s_row.get("PValue"), + "Significant": s_row.get("Significant", False), + } + ) + return pd.DataFrame(rows, columns=columns) + + def get_predictions(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get concatenated predictions for all models or a specific model. + + Converts nested prediction dictionaries from all folds into a single + flattened DataFrame. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + predictions_df : pd.DataFrame + Tidy DataFrame of predictions. Includes SampleID, y_true, y_pred, + y_score, and probability columns if available. + + See Also + -------- + ExperimentResult.get_splits : Membership of samples in each fold. + """ + rows = [] + time_axis = self.time_axis + target_models = [model] if model is not None else self.raw.keys() + + for m_name in target_models: + res = self.raw.get(m_name, {}) + if "error" in res or "predictions" not in res: + continue + for fold_idx, preds in enumerate(res["predictions"]): + rows.extend( + prediction_rows(m_name, fold_idx, preds, time_axis=time_axis) + ) + return pd.DataFrame(rows) + + def get_splits(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get outer-CV train/test membership in long format for all models. + + Tracks which samples were used for training and testing in each fold + of the cross-validation procedure. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + splits_df : pd.DataFrame + DataFrame with Model, Fold, Set (train/test), SampleIndex, SampleID, + and associated metadata columns (e.g., Subject, Session). + + See Also + -------- + ExperimentResult.get_predictions : Link predictions to splits via SampleID. + """ + from ._diagnostics import optional_values + + all_rows = [] + target_models = [model] if model is not None else self.raw.keys() + for m_name in target_models: + res = self.raw.get(m_name, {}) + if "error" in res: + continue + for fold_idx, split in enumerate(res.get("splits", [])): + for set_name, idx_key, id_key, group_key, meta_key in [ + ( + "train", + "train_idx", + "train_sample_id", + "train_group", + "train_metadata", + ), + ( + "test", + "test_idx", + "test_sample_id", + "test_group", + "test_metadata", + ), + ]: + indices = np.asarray(split[idx_key]) + n = len(indices) + if n == 0: + continue + + ids = np.asarray(split[id_key]) + groups = optional_values(split.get(group_key), n) + metadata = split.get(meta_key) or {} + + # Flatten metadata into columns + meta_arrays = { + k: np.asarray(v, dtype=object)[:n] for k, v in metadata.items() + } + + for i in range(n): + row = { + "Model": m_name, + "Fold": fold_idx, + "Set": set_name, + "SampleIndex": indices[i], + "SampleID": ids[i], + "Group": groups[i], + } + # Add metadata columns + for k, v_arr in meta_arrays.items(): + row[k] = v_arr[i] + all_rows.append(row) + + if not all_rows: + return pd.DataFrame() + + return pd.DataFrame(all_rows) + + def get_fit_diagnostics(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get fold-level timing and warning diagnostics for all models. + + Aggregates operational metrics such as execution time and runtime + warnings encountered during the model fit and predict stages. + + Scientific Rationale + -------------------- + Runtime diagnostics are critical for identifying computational + bottlenecks and ensuring model validity. Long training times may + suggest the need for dimensionality reduction, while consistent + warnings (e.g., convergence failures) can signal that model + hyperparameters are poorly suited to the dataset. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + diagnostics_df : pd.DataFrame + DataFrame with Model, Fold, and timing columns (FitTime, PredictTime, + ScoreTime, TotalTime). If warnings were captured, includes Stage, + WarningCategory, and WarningMessage columns. + + Examples + -------- + >>> diagnostics = result.get_fit_diagnostics() + >>> # Identify the slowest model + >>> slow_model = diagnostics.groupby("Model")["TotalTime"].mean().idxmax() + + See Also + -------- + ExperimentResult.summary : General performance summary. + """ + rows = [] + columns = [ + "Model", + "Fold", + "FitTime", + "PredictTime", + "ScoreTime", + "TotalTime", + "Stage", + "WarningCategory", + "WarningMessage", + ] + target_models = [model] if model is not None else self.raw.keys() + for m_name in target_models: + res = self.raw.get(m_name, {}) + if "error" in res: + continue + for fold_idx, diag in enumerate(res.get("diagnostics", [])): + base = { + "Model": m_name, + "Fold": fold_idx, + "FitTime": diag.get("fit_time"), + "PredictTime": diag.get("predict_time"), + "ScoreTime": diag.get("score_time"), + "TotalTime": diag.get("total_time"), + } + warnings_ = diag.get("warnings") or [] + if not warnings_: + rows.append( + { + **base, + "Stage": None, + "WarningCategory": None, + "WarningMessage": None, + } + ) + continue + for w in warnings_: + rows.append( + { + **base, + "Stage": w.get("stage"), + "WarningCategory": w.get("category"), + "WarningMessage": w.get("message"), + } + ) + return pd.DataFrame(rows, columns=columns) + + def _build_confusion_df( + self, + model: Optional[str], + labels: Optional[Sequence[Any]], + normalize: Optional[str], + group_cols: list[str], + ) -> pd.DataFrame: + """Shared logic for building confusion matrix DataFrames.""" + preds = scalar_prediction_frame(self.get_predictions(model=model)) + if preds.empty: + cols = group_cols + ["TrueLabel", "PredictedLabel", "Value"] + return pd.DataFrame(columns=cols) + + if labels is None: + labels = sorted( + pd.unique(pd.concat([preds["y_true"], preds["y_pred"]])).tolist() + ) + + return confusion_matrix_frame(preds, labels, normalize, group_cols=group_cols) + + def get_confusion_matrices( + self, + model: Optional[str] = None, + labels: Optional[Sequence[Any]] = None, + normalize: Optional[str] = None, + ) -> pd.DataFrame: + """ + Get fold-level confusion matrices in long format. + + Computes the confusion between true and predicted labels for each + cross-validation fold. + + Scientific Rationale + -------------------- + Confusion matrices provide a granular view of model errors, identifying + specific classes that are frequently misidentified. Analyzing these + per-fold allows for assessing the consistency of error patterns across + different data splits. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + labels : sequence of any, optional + The list of labels to use for the matrix axes. If None, uses all + labels present in the predictions. + normalize : {'true', 'pred', 'all'}, optional + Normalization strategy: + - 'true': Normalize by true labels (rows). + - 'pred': Normalize by predicted labels (columns). + - 'all': Normalize by total number of samples. + + Returns + ------- + confusion_df : pd.DataFrame + Tidy DataFrame with Model, Fold, TrueLabel, PredictedLabel, and Value. + + See Also + -------- + ExperimentResult.get_pooled_confusion_matrix : Aggregate across folds. + """ + return self._build_confusion_df( + model=model, + labels=labels, + normalize=normalize, + group_cols=["Model", "Fold"], + ) + + def get_confusion_counts( + self, model: Optional[str] = None, labels: Optional[Sequence[Any]] = None + ) -> pd.DataFrame: + """ + Get unnormalized per-fold confusion counts. + + Equivalent to `get_confusion_matrices(normalize=None)`. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. + labels : sequence of any, optional + The list of labels to use. + + Returns + ------- + counts_df : pd.DataFrame + Unnormalized confusion counts. + """ + return self.get_confusion_matrices(model=model, labels=labels, normalize=None) + + def get_pooled_confusion_matrix( + self, + model: Optional[str] = None, + labels: Optional[Sequence[Any]] = None, + normalize: Optional[str] = None, + ) -> pd.DataFrame: + """ + Get pooled out-of-fold confusion matrices in long format. + + Aggregates predictions from all cross-validation folds before + calculating the confusion matrix. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. + labels : sequence of any, optional + The list of labels to use. + normalize : {'true', 'pred', 'all'}, optional + Normalization strategy. + + Returns + ------- + confusion_df : pd.DataFrame + Pooled confusion matrix with Model, TrueLabel, PredictedLabel, + and Value. + + See Also + -------- + ExperimentResult.get_confusion_matrices : Fold-level view. + """ + return self._build_confusion_df( + model=model, labels=labels, normalize=normalize, group_cols=["Model"] + ) + + def get_roc_curve( + self, model: Optional[str] = None, pos_label: Optional[Any] = None + ) -> pd.DataFrame: + """ + Get binary or one-vs-rest ROC curve coordinates. + + Calculates False Positive Rate (FPR) and True Positive Rate (TPR) at + various thresholds for each fold. For multiclass problems, computes + One-vs-Rest (OvR) curves for each class. + + Scientific Rationale + -------------------- + Receiver Operating Characteristic (ROC) curves illustrate the + diagnostic ability of a classifier as its discrimination threshold is + varied. Analyzing the spread of these curves across folds helps in + assessing the robustness of the model's probabilistic rankings. + + Parameters + ---------- + model : str, optional + The model name to filter by. + pos_label : any, optional + The label to treat as the positive class in binary cases. If None, + uses the second class in alphabetical order. + + Returns + ------- + roc_df : pd.DataFrame + DataFrame with Model, Fold, Class, Threshold, FPR, and TPR. + + See Also + -------- + ExperimentResult.get_roc_auc_summary : Aggregate AUC metrics. + """ + from sklearn.metrics import roc_curve + + frames = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + for m_name, f_idx, label, y_binary, y_score in curve_score_groups( + preds, model=model, pos_label=pos_label + ): + fpr, tpr, thresholds = roc_curve(y_binary, y_score, pos_label=True) + df_c = pd.DataFrame( + { + "Model": m_name, + "Fold": f_idx, + "Class": label, + "Threshold": thresholds, + "FPR": fpr, + "TPR": tpr, + } + ) + frames.append(df_c) + + if not frames: + return pd.DataFrame( + columns=["Model", "Fold", "Class", "Threshold", "FPR", "TPR"] + ) + return pd.concat(frames, ignore_index=True) + + def get_pr_curve( + self, model: Optional[str] = None, pos_label: Optional[Any] = None + ) -> pd.DataFrame: + """ + Get binary or one-vs-rest precision-recall curve coordinates. + + Calculates Precision and Recall at various thresholds for each fold. + For multiclass problems, computes One-vs-Rest (OvR) curves for each class. + + Scientific Rationale + -------------------- + Precision-Recall (PR) curves are often more informative than ROC + curves for imbalanced datasets, as they focus on the model's + performance on the minority (positive) class. + + Parameters + ---------- + model : str, optional + The model name to filter by. + pos_label : any, optional + The label to treat as positive. + + Returns + ------- + pr_df : pd.DataFrame + DataFrame with Model, Fold, Class, Threshold, Precision, and Recall. + + See Also + -------- + ExperimentResult.get_pr_auc_summary : Average Precision summary. + """ + from sklearn.metrics import precision_recall_curve + + frames = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + for m_name, f_idx, label, y_binary, y_score in curve_score_groups( + preds, model=model, pos_label=pos_label + ): + precision, recall, thresholds = precision_recall_curve( + y_binary, y_score, pos_label=True + ) + # thresholds is 1 element shorter than precision/recall + thresh_vals = np.append(thresholds, np.nan) + df_c = pd.DataFrame( + { + "Model": m_name, + "Fold": f_idx, + "Class": label, + "Threshold": thresh_vals, + "Precision": precision, + "Recall": recall, + } + ) + frames.append(df_c) + + if not frames: + return pd.DataFrame( + columns=["Model", "Fold", "Class", "Threshold", "Precision", "Recall"] + ) + return pd.concat(frames, ignore_index=True) + + def get_roc_auc_summary(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get summary ROC-AUC metrics across models and folds. + + Calculates the Area Under the ROC Curve for each fold, using macro- and + weighted-averaging for multiclass tasks. + + Parameters + ---------- + model : str, optional + Model name to filter by. + + Returns + ------- + auc_df : pd.DataFrame + Summary with Model, Fold, MacroROCAUC, and WeightedROCAUC. + + See Also + -------- + ExperimentResult.get_roc_curve : Detailed curve coordinates. + """ + from sklearn.metrics import roc_auc_score + from sklearn.preprocessing import LabelBinarizer + + rows = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + if preds.empty: + return pd.DataFrame( + columns=["Model", "Fold", "MacroROCAUC", "WeightedROCAUC"] + ) + + proba_cols = sorted( + [col for col in preds.columns if col.startswith("y_proba_")], + key=lambda v: int(v.rsplit("_", 1)[-1]), + ) + if not proba_cols: + return pd.DataFrame( + columns=["Model", "Fold", "MacroROCAUC", "WeightedROCAUC"] + ) + + for (m_name, f_idx), group in preds.groupby(["Model", "Fold"]): + y_true = group["y_true"].to_numpy() + y_proba = group[proba_cols].to_numpy() + + lb = LabelBinarizer() + y_true_bin = lb.fit_transform(y_true) + if y_true_bin.shape[1] == 1: + score = roc_auc_score(y_true_bin, y_proba[:, -1]) + macro = weighted = score + else: + macro = roc_auc_score( + y_true_bin, y_proba, multi_class="ovr", average="macro" + ) + weighted = roc_auc_score( + y_true_bin, y_proba, multi_class="ovr", average="weighted" + ) + + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "MacroROCAUC": float(macro), + "WeightedROCAUC": float(weighted), + } + ) + return pd.DataFrame(rows) + + def get_pr_auc_summary(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get summary PR-AUC (Average Precision) metrics across models and folds. + + Calculates the Area Under the Precision-Recall Curve for each fold. + + Parameters + ---------- + model : str, optional + Model name to filter by. + + Returns + ------- + auc_df : pd.DataFrame + Summary with Model, Fold, MacroPRAUC, and WeightedPRAUC. + + See Also + -------- + ExperimentResult.get_pr_curve : Detailed curve coordinates. + """ + from sklearn.metrics import average_precision_score + from sklearn.preprocessing import LabelBinarizer + + rows = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + if preds.empty: + return pd.DataFrame( + columns=["Model", "Fold", "MacroPRAUC", "WeightedPRAUC"] + ) + + proba_cols = sorted( + [col for col in preds.columns if col.startswith("y_proba_")], + key=lambda v: int(v.rsplit("_", 1)[-1]), + ) + if not proba_cols: + return pd.DataFrame( + columns=["Model", "Fold", "MacroPRAUC", "WeightedPRAUC"] + ) + + for (m_name, f_idx), group in preds.groupby(["Model", "Fold"]): + y_true = group["y_true"].to_numpy() + y_proba = group[proba_cols].to_numpy() + + lb = LabelBinarizer() + y_true_bin = lb.fit_transform(y_true) + if y_true_bin.shape[1] == 1: + score = average_precision_score(y_true_bin, y_proba[:, -1]) + macro = weighted = score + else: + macro = average_precision_score(y_true_bin, y_proba, average="macro") + weighted = average_precision_score( + y_true_bin, y_proba, average="weighted" + ) + + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "MacroPRAUC": float(macro), + "WeightedPRAUC": float(weighted), + } + ) + return pd.DataFrame(rows) + + def get_calibration_curve( + self, + model: Optional[str] = None, + n_bins: int = 5, + pos_label: Optional[Any] = None, + strategy: str = "uniform", + ) -> pd.DataFrame: + """ + Get binary reliability/calibration curve coordinates. + + Calculates the fraction of positive samples vs. mean predicted + probabilities for each probability bin. + + Scientific Rationale + -------------------- + A well-calibrated classifier provides probabilistic outputs that + reflect the true likelihood of the predicted event. Calibration + curves (reliability diagrams) are essential for assessing whether + predicted probabilities can be interpreted as confidence levels. + + Parameters + ---------- + model : str, optional + The model name to filter by. + n_bins : int, default=5 + Number of bins to use for the calibration curve. + pos_label : any, optional + The label to treat as positive. + strategy : {'uniform', 'quantile'}, default='uniform' + Strategy used to define the widths of the bins. + - 'uniform': Bins have identical widths. + - 'quantile': Bins have the same number of samples. + + Returns + ------- + calibration_df : pd.DataFrame + DataFrame with Model, Fold, Class, MeanPredictedProbability, + and FractionPositive. + + See Also + -------- + ExperimentResult.get_probability_diagnostics : Brier score and Log Loss. + """ + from sklearn.calibration import calibration_curve + + frames = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + for m_name, f_idx, label, y_binary, y_score in curve_score_groups( + preds, model=model, require_probability=True, pos_label=pos_label + ): + p_true, p_pred = calibration_curve( + y_binary.astype(int), y_score, n_bins=n_bins, strategy=strategy + ) + df_c = pd.DataFrame( + { + "Model": m_name, + "Fold": f_idx, + "Class": label, + "MeanPredictedProbability": p_pred, + "FractionPositive": p_true, + } + ) + frames.append(df_c) + + if not frames: + return pd.DataFrame( + columns=[ + "Model", + "Fold", + "Class", + "MeanPredictedProbability", + "FractionPositive", + ] + ) + return pd.concat(frames, ignore_index=True) + + def get_probability_diagnostics(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get fold-level log-loss and Brier summaries when probabilities exist. + + Computes summary metrics that penalize poor probability calibration + and high-uncertainty predictions. + + Parameters + ---------- + model : str, optional + The model name to filter by. + + Returns + ------- + diagnostics_df : pd.DataFrame + DataFrame in long format with Model, Fold, Metric, Class, and Value. + Metrics include 'log_loss', 'brier_score_ovr', and 'brier_score_macro'. + + See Also + -------- + ExperimentResult.get_calibration_curve : Visual calibration view. + """ + from sklearn.metrics import log_loss + + rows = [] + preds = scalar_prediction_frame(self.get_predictions(model=model)) + if preds.empty: + return pd.DataFrame(columns=["Model", "Fold", "Metric", "Class", "Value"]) + + for (m_name, f_idx), group in preds.groupby(["Model", "Fold"]): + y_true = group["y_true"].to_numpy() + unique_labels = sorted(pd.unique(y_true).tolist()) + y_proba = proba_matrix(group, len(unique_labels)) + if y_proba is None: + continue + + # 1. Log Loss (Overall) + try: + ll = log_loss(y_true, y_proba, labels=unique_labels) + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "Metric": "log_loss", + "Class": None, + "Value": float(ll), + } + ) + except Exception as e: # noqa: BLE001 + logger.debug(f"log_loss skipped for {m_name} fold {f_idx}: {e}") + + # 2. Brier Scores (Vectorized) + # Create binary matrix [n_samples, n_classes] + y_binary = (y_true[:, None] == np.array(unique_labels)).astype(float) + # Brier Score = Mean Squared Error per class + brier_ovr = np.mean((y_binary - y_proba) ** 2, axis=0) + + for c_idx, label in enumerate(unique_labels): + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "Metric": "brier_score_ovr", + "Class": label, + "Value": float(brier_ovr[c_idx]), + } + ) + + # 3. Macro Brier Score + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "Metric": "brier_score_macro", + "Class": None, + "Value": float(np.mean(brier_ovr)), + } + ) + + return pd.DataFrame(rows) + + def get_statistical_assessment( + self, + lightweight: bool = False, + metric: str = "accuracy", + unit: Optional[str] = None, + n_permutations: int = 1000, + random_state: Optional[int] = None, + ) -> pd.DataFrame: + """ + Get finite-sample statistical assessment rows in long form. + + Returns p-values and significance markers for model performance, + supporting both full-pipeline and post-hoc permutation methods. + + Scientific Rationale + -------------------- + Statistical significance in decoding ensures that observed performance + deltas are not due to chance fluctuations. This method allows + accessing pre-calculated results from the full experimental pipeline + (the gold standard) or running a faster post-hoc permutation test + directly on stored predictions. + + Parameters + ---------- + lightweight : bool, default=False + If True, perform a post-hoc label permutation on out-of-fold + predictions. Fast but does not account for pipeline leakage + (e.g., in tuning). + If False, returns results from the full-pipeline assessment if + they were computed during the experiment. + metric : str, default='accuracy' + Metric to use for the assessment. + unit : str, optional + The level of independence for the permutation test (e.g., 'subject'). + n_permutations : int, default=1000 + Number of permutations for the lightweight assessment. + random_state : int, optional + Seed for reproducible permutations. + + Returns + ------- + stats_df : pd.DataFrame + Tidy DataFrame with Model, Metric, Observed, PValue, and Significance. + + See Also + -------- + coco_pipe.decoding.stats.run_statistical_assessment : Underlying engine. + """ + u_type = self._resolve_inference_unit(unit) + rows = [] + + # 1. Pull pre-calculated results from the full pipeline if requested + if not lightweight: + for model, res in self.raw.items(): + if "error" in res: + continue + stats_rows = res.get("statistical_assessment", []) + for s in stats_rows: + row = dict(s) + row["Model"] = model + rows.append(row) + + # 2. Fallback to lightweight if no stats found + if not rows and not lightweight: + logger.info( + "Full-pipeline statistical assessment results not found. " + "Falling back to lightweight post-hoc assessment." + ) + lightweight = True + + # 3. Lightweight post-hoc permutation assessment + if lightweight: + from .stats import assess_post_hoc_permutation + + for model, res in self.raw.items(): + if "error" in res: + continue + try: + df_l = assess_post_hoc_permutation( + res, + metric=metric, + unit=u_type, + n_permutations=n_permutations, + random_state=random_state, + ) + df_l["Model"] = model + rows.extend(df_l.to_dict("records")) + except Exception as e: + logger.warning(f"Lightweight assessment failed for {model}: {e}") + + if not rows: + return pd.DataFrame() + + # Consistent column ordering + cols = [ + "Model", + "Metric", + "Observed", + "PValue", + "Significant", + "NullMethod", + "NPermutations", + "InferentialUnit", + "Time", + "TrainTime", + "TestTime", + "NullLower", + "NullUpper", + ] + df = pd.DataFrame(rows) + # Sort and filter columns to match what's present + present_cols = [c for c in cols if c in df.columns] + return df[present_cols] + + def get_statistical_nulls(self, model: Optional[str] = None) -> Dict[str, Any]: + """ + Return stored statistical null distributions, when configured. + + Accesses the empirical null distributions (e.g., from permutation + tests) stored during the experiment. + + Parameters + ---------- + model : str, optional + Model name to filter by. Default is None (all models). + + Returns + ------- + nulls : dict + Dictionary mapping model names to their null distribution payloads, + containing coordinates and permuted score arrays. + + See Also + -------- + ExperimentResult.get_statistical_assessment : P-values derived from these nulls. + """ + nulls = {} + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + if "statistical_nulls" in res: + nulls[m_name] = res["statistical_nulls"] + return nulls + + def get_model_artifacts(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Return fold-level model artifact metadata in long form. + + Accesses non-metric outputs stored by models, such as learned + coefficients, intercept values, or class labels. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + + Returns + ------- + artifacts_df : pd.DataFrame + DataFrame with Model, Fold, ArtifactType, Key, and Value. + + See Also + -------- + ExperimentResult.get_feature_importances : Specifically for importances. + """ + rows = [] + cols = ["Model", "Fold", "ArtifactType", "Key", "Value"] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + for f_idx, m in enumerate(res.get("metadata", [])): + artifacts = m.get("artifacts", {}) + for a_type, payload in artifacts.items(): + if isinstance(payload, dict): + for k, v in payload.items(): + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "ArtifactType": a_type, + "Key": k, + "Value": v, + } + ) + else: + rows.append( + { + "Model": model, + "Fold": f_idx, + "ArtifactType": "model", + "Key": a_type, + "Value": payload, + } + ) + return pd.DataFrame(rows, columns=cols) + + def get_bootstrap_confidence_intervals( + self, + metric: str = "accuracy", + model: Optional[str] = None, + unit: Optional[str] = None, + n_bootstraps: int = 1000, + ci: float = 0.95, + random_state: Optional[int] = None, + ) -> pd.DataFrame: + """ + Bootstrap metric confidence intervals over configured inference units. + + Estimates the uncertainty of a performance metric by resampling + independent units (e.g., subjects) with replacement. + + Scientific Rationale + -------------------- + Bootstrapping provides a non-parametric estimate of the sampling + distribution of a metric. By resampling at the 'unit' level, we + account for within-unit correlations (e.g., multiple trials from the + same subject) and provide more realistic uncertainty bounds than + sample-level analytical methods. + + Parameters + ---------- + metric : str, default='accuracy' + The metric to estimate uncertainty for. + model : str, optional + The model name to filter by. + unit : str, optional + The level of independence for resampling (e.g., 'subject'). + n_bootstraps : int, default=1000 + Number of bootstrap iterations. + ci : float, default=0.95 + Confidence interval level (e.g., 0.95 for 95% CI). + random_state : int, optional + Random seed for reproducibility. + + Returns + ------- + bootstrap_df : pd.DataFrame + DataFrame with Model, Metric, Estimate (observed), CILower, and CIUpper. + + See Also + -------- + coco_pipe.decoding.stats.assess_bootstrap_ci : Underlying engine. + """ + from .stats import assess_bootstrap_ci + + u_type = self._resolve_inference_unit(unit) + rows = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + + try: + df_b = assess_bootstrap_ci( + res, + metric=metric, + unit=u_type, + n_bootstraps=n_bootstraps, + ci=ci, + random_state=random_state, + ) + df_b["Model"] = m_name + rows.extend(df_b.to_dict("records")) + except Exception as e: + logger.warning(f"Bootstrap failed for {m_name}: {e}") + + if not rows: + return pd.DataFrame() + + cols = [ + "Model", + "Metric", + "Estimate", + "CILower", + "CIUpper", + "Unit", + "NUnits", + "NBootstraps", + ] + return pd.DataFrame(rows)[cols] + + def compare_models( + self, + models: Optional[Sequence[str]] = None, + metric: str = "accuracy", + unit: Optional[str] = None, + n_permutations: int = 1000, + correction: str = "fdr_bh", + random_state: Optional[int] = None, + ) -> pd.DataFrame: + """ + Perform exhaustive pairwise comparisons between multiple models. + + Automatically applies p-value correction (e.g., FDR) for the multiple + comparisons performed across all pairs of models. + + Scientific Rationale + -------------------- + Benchmarking multiple models requires controlling for the 'multiple + comparisons problem'—the increased risk of Type I errors (false + positives) when testing many hypotheses. This method automates the + pairwise testing and subsequent error-rate control. + + Parameters + ---------- + models : list of str, optional + List of model names to compare. Default is all models in the result. + metric : str, default='accuracy' + Metric to use for comparison. + unit : str, optional + Level of independence for permutation testing (e.g., 'subject'). + n_permutations : int, default=1000 + Number of permutations for each paired test. + correction : str, default='fdr_bh' + Multiple comparison correction method (e.g., 'bonferroni', 'fdr_bh'). + random_state : int, optional + Random seed for permutations. + + Returns + ------- + comparison_df : pd.DataFrame + DataFrame containing ModelA, ModelB, Difference, and corrected + PValue (PValueCorrected). + + See Also + -------- + ExperimentResult.compare_models_paired : Underlying paired test. + """ + from itertools import combinations + + if models is None: + models = sorted(self.raw.keys()) + + if len(models) < 2: + raise ValueError("Need at least two models to perform a comparison.") + + all_results = [] + # 1. Run all pairwise comparisons + for m_a, m_b in combinations(models, 2): + try: + res = self.compare_models_paired( + model_a=m_a, + model_b=m_b, + metric=metric, + unit=unit, + n_permutations=n_permutations, + random_state=random_state, + ) + all_results.append(res) + except Exception as e: + logger.warning(f"Comparison failed for {m_a} vs {m_b}: {e}") + + if not all_results: + return pd.DataFrame() + + df = pd.concat(all_results, ignore_index=True) + + # 2. Apply multiple comparison correction + from .stats import apply_multiple_comparison_correction + + return apply_multiple_comparison_correction(df, method=correction) + + def compare_models_paired( + self, + model_a: str, + model_b: str, + metric: str = "accuracy", + unit: Optional[str] = None, + n_permutations: int = 1000, + random_state: Optional[int] = None, + ) -> pd.DataFrame: + """ + Paired model comparison using outer-fold predictions on shared samples. + + Performs a within-unit permutation test (e.g., swapping model labels + per subject) to determine if the performance difference is significant. + + Scientific Rationale + -------------------- + Paired tests are generally more powerful than independent-sample tests + because they control for unit-specific baseline variance (e.g., a subject + who is overall 'harder' to decode). By aligning predictions at the + sample level across models, we ensure a valid paired comparison. + + Parameters + ---------- + model_a, model_b : str + The names of the two models to compare. + metric : str, default='accuracy' + Metric to use for comparison. + unit : str, optional + Level of independence (e.g., 'subject'). + n_permutations : int, default=1000 + Number of permutations for the test. + random_state : int, optional + Random seed for reproducible permutations. + + Returns + ------- + paired_df : pd.DataFrame + DataFrame with ModelA, ModelB, ScoreA, ScoreB, Difference, and PValue. + + See Also + -------- + coco_pipe.decoding.stats.assess_paired_comparison : Underlying engine. + """ + from .stats import assess_paired_comparison + + u_type = self._resolve_inference_unit(unit) + preds = scalar_prediction_frame(self.get_predictions()) + a = preds[preds["Model"] == model_a] + b = preds[preds["Model"] == model_b] + + if a.empty or b.empty: + raise ValueError(f"One or both models not found: {model_a}, {model_b}") + + # Merge to find shared samples across all relevant coordinates + merge_cols = ["SampleID", "y_true", "Fold"] + for col in ["Time", "TrainTime", "TestTime"]: + if col in a and col in b: + merge_cols.append(col) + + merged = a.merge(b, on=merge_cols, suffixes=("_A", "_B")) + if merged.empty: + raise ValueError("No overlapping samples found between the two models.") + + df_res = assess_paired_comparison( + merged, + metric=metric, + unit=u_type, + n_permutations=n_permutations, + random_state=random_state, + ) + + df_res["ModelA"] = model_a + df_res["ModelB"] = model_b + df_res["Unit"] = u_type + + # Reorder columns + cols = [ + "ModelA", + "ModelB", + "Metric", + "Unit", + "NUnits", + "ScoreA", + "ScoreB", + "Difference", + "PValue", + "Significant", + ] + # Preserve temporal columns if present + for c in ["Time", "TrainTime", "TestTime"]: + if c in df_res.columns: + cols.insert(4, c) + + return df_res[cols] + + def get_feature_importances( + self, model: Optional[str] = None, fold_level: bool = False + ) -> pd.DataFrame: + """ + Get feature importances in long format. + + Aggregates relative feature contributions (e.g., coefficients, + Gini importance) across all folds. + + Scientific Rationale + -------------------- + Feature importances identify the data dimensions that drive the model's + predictions. Analyzing these across folds ensures that identified + features are robust and not artifacts of a specific data split. Ranking + features provides a prioritized list for subsequent biological + interpretation. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + fold_level : bool, default=False + - If True: Returns importance for each fold individually. + - If False: Returns the mean and standard deviation across folds. + + Returns + ------- + importances_df : pd.DataFrame + DataFrame with Model, FeatureName, and Importance (or Mean/Std). + Includes a 'Rank' column based on the importance magnitude. + + See Also + -------- + ExperimentResult.get_selected_features : If feature selection was used. + """ + + frames = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + + imp = res.get("importances") + if not imp: + continue + + if fold_level: + raw = np.asarray(imp.get("raw", []), dtype=float) + if raw.ndim != 2: + continue + n_feats = raw.shape[1] + f_names = imp.get("feature_names") + if f_names is None or len(f_names) != n_feats: + f_names = [f"feature_{i}" for i in range(n_feats)] + + df_m = pd.DataFrame(raw, columns=f_names) + df_m.index.name = "Fold" + df_m = df_m.reset_index().melt( + id_vars="Fold", var_name="FeatureName", value_name="Importance" + ) + df_m["Model"] = m_name + # Reconstruct Feature Index + name_to_idx = {name: i for i, name in enumerate(f_names)} + df_m["Feature"] = df_m["FeatureName"].map(name_to_idx) + frames.append(df_m) + else: + means = np.asarray(imp.get("mean", []), dtype=float).ravel() + if len(means) == 0: + continue + stds = np.asarray(imp.get("std", []), dtype=float).ravel() + if len(stds) != len(means): + stds = np.full(len(means), np.nan) + + f_names = imp.get("feature_names") + if f_names is None or len(f_names) != len(means): + f_names = [f"feature_{i}" for i in range(len(means))] + df_m = pd.DataFrame( + { + "Model": m_name, + "Feature": np.arange(len(means)), + "FeatureName": f_names, + "Mean": means, + "Std": stds, + } + ) + frames.append(df_m) + + if not frames: + return pd.DataFrame() + + df = pd.concat(frames, ignore_index=True) + + # Add ranks + group_cols = ["Model"] + if fold_level: + group_cols.append("Fold") + val_col = "Importance" + else: + val_col = "Mean" + + df["Rank"] = df.groupby(group_cols)[val_col].rank(ascending=False, method="min") + + # Ensure stable column order for API consistency + if fold_level: + ordered_cols = [ + "Model", + "Fold", + "Feature", + "FeatureName", + "Importance", + "Rank", + ] + else: + ordered_cols = ["Model", "Feature", "FeatureName", "Mean", "Std", "Rank"] + + return df[ordered_cols] + + def _resolve_inference_unit(self, unit: Optional[str]) -> str: + if unit is not None: + return unit + return self.meta.get("inferential_unit") or "sample" + + def _time_value(self, index: int) -> Any: + from ._diagnostics import time_value as get_time_val + + return get_time_val(index, self.time_axis) + + def get_best_params(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get the best hyperparameters selected per fold. + + If hyperparameter tuning was enabled, returns the optimal + configuration found in each cross-validation outer fold. + + Parameters + ---------- + model : str, optional + The name of the model to filter by. Default is None (all models). + + Returns + ------- + params_df : pd.DataFrame + DataFrame with Model, Fold, Param, and Value. + + See Also + -------- + ExperimentResult.get_search_results : Detailed tuning diagnostics. + """ + rows = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + if "metadata" in res: + for f_idx, meta in enumerate(res["metadata"]): + if "best_params" in meta: + for p_name, p_val in meta["best_params"].items(): + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "Param": p_name, + "Value": p_val, + } + ) + return pd.DataFrame(rows) + + def get_search_results(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get compact hyperparameter-search diagnostics in long form. + + Provides a summary of all candidate configurations evaluated during + tuning, including their mean performance and ranking. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + + Returns + ------- + search_df : pd.DataFrame + DataFrame with Model, Fold, Candidate, Rank, MeanTestScore, + StdTestScore, and Params. + + See Also + -------- + ExperimentResult.get_best_params : Just the winner per fold. + """ + rows = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + for f_idx, meta in enumerate(res.get("metadata", [])): + for s_row in meta.get("search_results", []): + rows.append( + { + "Model": m_name, + "Fold": f_idx, + "Candidate": s_row.get("candidate"), + "Rank": s_row.get("rank_test_score"), + "MeanTestScore": s_row.get("mean_test_score"), + "StdTestScore": s_row.get("std_test_score"), + "Params": s_row.get("params"), + } + ) + cols = [ + "Model", + "Fold", + "Candidate", + "Rank", + "MeanTestScore", + "StdTestScore", + "Params", + ] + return pd.DataFrame(rows, columns=cols) + + def get_selected_features(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get fold-level selected feature masks in long format. + + Returns a boolean mask indicating which features were retained by + automated feature selection in each fold. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + + Returns + ------- + features_df : pd.DataFrame + DataFrame with Model, Fold, FeatureName, and Selected status. + Includes an 'Order' column if recursive or sequential selection + was used. + + See Also + -------- + ExperimentResult.get_feature_stability : Cross-fold selection consistency. + """ + frames = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + for f_idx, meta in enumerate(res.get("metadata", [])): + if "selected_features" not in meta: + continue + + mask = np.asarray(meta["selected_features"], dtype=bool) + order = meta.get("selection_order") + f_names = meta.get("feature_names") + if f_names is None or len(f_names) != len(mask): + f_names = [f"feature_{idx}" for idx in range(len(mask))] + + df_f = pd.DataFrame( + { + "Model": m_name, + "Fold": f_idx, + "Feature": np.arange(len(mask)), + "FeatureName": f_names, + "Selected": mask, + } + ) + + if order is not None: + # Resolve order (Rank) + order_arr = np.asarray(order) + if len(order_arr) == len(mask): + df_f["Order"] = order_arr + elif len(order_arr) < len(mask): + # Order is a list of selected indices + order_map = {idx: i + 1 for i, idx in enumerate(order_arr)} + df_f["Order"] = df_f["Feature"].map(order_map) + else: + df_f["Order"] = np.nan + + frames.append(df_f) + + if not frames: + return pd.DataFrame() + return pd.concat(frames, ignore_index=True) + + def get_feature_scores(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Get fold-level feature-selection scores. + + Accesses raw univariate or multivariate scores (e.g., F-values, + p-values, or internal selector scores) used for feature ranking. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + + Returns + ------- + scores_df : pd.DataFrame + DataFrame with Model, Fold, FeatureName, Score, and PValue + (if available). + + See Also + -------- + ExperimentResult.get_selected_features : Final binary selection mask. + """ + frames = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + for f_idx, meta in enumerate(res.get("metadata", [])): + if "feature_scores" not in meta: + continue + + scores = np.asarray(meta["feature_scores"], dtype=float) + pvals = meta.get("feature_pvalues") + f_names = meta.get("feature_names") + if f_names is None or len(f_names) != len(scores): + f_names = [f"feature_{idx}" for idx in range(len(scores))] + + sel = meta.get("selected_features") + + df_f = pd.DataFrame( + { + "Model": m_name, + "Fold": f_idx, + "Feature": np.arange(len(scores)), + "FeatureName": f_names, + "Selector": meta.get("feature_selection_method"), + "Score": scores, + } + ) + + if pvals is not None and len(pvals) == len(scores): + df_f["PValue"] = np.asarray(pvals, dtype=float) + else: + df_f["PValue"] = np.nan + + if sel is not None and len(sel) == len(scores): + df_f["Selected"] = np.asarray(sel, dtype=bool) + else: + df_f["Selected"] = np.nan + + frames.append(df_f) + + if not frames: + return pd.DataFrame() + return pd.concat(frames, ignore_index=True) + + def get_feature_stability(self, model: Optional[str] = None) -> pd.DataFrame: + """ + Analyze feature selection stability across folds. + + Calculates the frequency with which each feature was selected across + the cross-validation procedure. + + Scientific Rationale + -------------------- + Stability analysis helps distinguish robust predictors from features + that are selected due to noise in specific data splits. High stability + (e.g., > 90% of folds) provides strong evidence for the relevance of + a feature to the decoding task. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + + Returns + ------- + stability_df : pd.DataFrame + DataFrame with Model, FeatureName, SelectionFrequency (0.0 to 1.0), + and NFolds. + + See Also + -------- + ExperimentResult.get_selected_features : Fold-level selection data. + """ + frames = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + if "metadata" in res: + masks = [] + f_names = None + for meta in res["metadata"]: + if "selected_features" in meta: + masks.append(meta["selected_features"]) + if f_names is None and "feature_names" in meta: + f_names = meta["feature_names"] + + if not masks: + continue + + stack = np.vstack(masks) # [n_folds, n_features] + n_folds = stack.shape[0] + freq = np.mean(stack, axis=0) + + if f_names is None or len(f_names) != stack.shape[1]: + f_names = [f"feature_{i}" for i in range(stack.shape[1])] + + df_m = pd.DataFrame( + { + "Model": m_name, + "Feature": np.arange(len(freq)), + "FeatureName": f_names, + "SelectionFrequency": freq, + "NFolds": n_folds, + } + ) + frames.append(df_m) + + if not frames: + return pd.DataFrame() + return pd.concat(frames, ignore_index=True) + + def get_generalization_matrix( + self, model: Optional[str] = None, metric: str = "accuracy" + ) -> pd.DataFrame: + """ + Get Generalization Matrix (Train Time x Test Time) averaged across folds. + + Computes the cross-temporal performance matrix for Time-Generalization + analysis. + + Scientific Rationale + -------------------- + Temporal Generalization (TG) analysis reveals the dynamics of neural + representations. By training a classifier at one time point and + testing it across all others, we can identify whether a neural + pattern is transient, sustained, or reoccurring. + + Parameters + ---------- + model : str, optional + The model name to filter by. Default is None (all models). + If None, returns a long-format DataFrame suitable for plotting. + metric : str, default='accuracy' + The metric to retrieve. + + Returns + ------- + gen_df : pd.DataFrame + - If model is specified: A square matrix (2D DataFrame) with + TrainTime as index and TestTime as columns. + - If model is None: A tidy long-format DataFrame with Model, + Metric, TrainTime, TestTime, and Value. + + See Also + -------- + ExperimentResult.get_temporal_score_summary : Linear temporal summary. + """ + frames = [] + for m_name, res in self.raw.items(): + if model is not None and m_name != model: + continue + if "error" in res: + continue + + metrics_data = res.get("metrics", {}) + if metric not in metrics_data: + # Fallback to first available metric if requested one is missing + if not metrics_data: + continue + metric = next(iter(metrics_data.keys())) + + fold_scores = metrics_data[metric].get("folds", []) + valid_matrices = [ + s for s in fold_scores if isinstance(s, np.ndarray) and s.ndim == 2 + ] + + if valid_matrices: + stack = np.stack(valid_matrices) + mean = np.nanmean(stack, axis=0) + + labels = self.time_axis + if labels is None or len(labels) != mean.shape[0]: + labels = list(range(mean.shape[0])) + + df_gen = pd.DataFrame(mean, index=labels, columns=labels) + df_gen.index.name = "TrainTime" + df_gen.columns.name = "TestTime" + + if model is not None: + return df_gen + + df_gen = df_gen.stack().reset_index(name="Value") + df_gen["Model"] = m_name + df_gen["Metric"] = metric + frames.append(df_gen) + + if not frames: + return pd.DataFrame() + + return pd.concat(frames, ignore_index=True) + + +def make_serializable(obj: Any) -> Any: + """Recursively convert NumPy types to JSON-safe Python primitives.""" + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.int64, np.int32, np.int16, np.integer)): + return int(obj) + if isinstance(obj, (np.float64, np.float32, np.floating)): + return float(obj) + if isinstance(obj, (np.bool_, bool)): + return bool(obj) + if isinstance(obj, dict): + return {str(k): make_serializable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [make_serializable(v) for v in obj] + if isinstance(obj, Path): + return str(obj) + if hasattr(obj, "model_dump"): + return obj.model_dump() + return obj diff --git a/coco_pipe/decoding/stats.py b/coco_pipe/decoding/stats.py new file mode 100644 index 0000000..e668ba1 --- /dev/null +++ b/coco_pipe/decoding/stats.py @@ -0,0 +1,1637 @@ +""" +Finite-sample statistical assessment for decoding results. + +This module separates descriptive performance from inferential claims. The +default inferential path reruns the complete decoding experiment under label +permutations so learned preprocessing, feature selection, tuning, and +calibration remain inside each null pipeline. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence + +if TYPE_CHECKING: + from .result import ExperimentResult + +import numpy as np +import pandas as pd +from scipy.stats import beta, binom, false_discovery_control, norm + +from ._metrics import get_metric_spec +from .configs import StatisticalAssessmentConfig + +logger = logging.getLogger(__name__) + +TEMPORAL_COLUMNS = ["Time", "TrainTime", "TestTime"] + + +def aggregate_predictions_for_inference( + predictions: pd.DataFrame, + metric: str, + task: str = "classification", + unit_of_inference: str = "sample", + custom_unit_column: Optional[str] = None, + custom_aggregation: str = "mean", + require_single_prediction: bool = False, +) -> pd.DataFrame: + """ + Aggregate prediction rows to independent units for inference. + + This ensures that each independent unit (e.g., a subject or a specific trial) + contributes exactly one prediction per temporal coordinate to the statistical + test, preventing pseudoreplication. + + Scientific Rationale + -------------------- + Inferential statistics assume independence between observations. In EEG/MEG, + multiple epochs from the same subject are correlated. By aggregating + predictions to the 'subject' level before calculating p-values, we ensure + the degrees of freedom in the test reflect the number of independent + biological units rather than the number of recorded segments. + + Parameters + ---------- + predictions : pd.DataFrame + Raw predictions from the experiment. + metric : str + The metric to optimize aggregation for (e.g., 'accuracy'). + task : str, default='classification' + Task type ('classification' or 'regression'). + unit_of_inference : str, default='sample' + The level at which independence is assumed ('sample', 'subject', or 'custom'). + custom_unit_column : str, optional + Column name in metadata to use as the independence unit if + unit_of_inference is 'custom'. + custom_aggregation : str, default='mean' + Aggregation mode ('mean' or 'majority'). + require_single_prediction : bool, default=False + If True, ensures that each unit has exactly one prediction per coordinate. + + Returns + ------- + aggregated_df : pd.DataFrame + Aggregated predictions with an 'InferentialUnitID' column. + + Raises + ------ + ValueError + If the unit column is missing or aggregation is incompatible with the task. + + Examples + -------- + >>> import pandas as pd + >>> from coco_pipe.decoding.stats import aggregate_predictions_for_inference + >>> df = pd.DataFrame({ + ... 'Subject': ['S1', 'S1'], 'y_true': [1, 1], 'y_pred': [1, 0], + ... 'SampleID': [0, 1] + ... }) + >>> res = aggregate_predictions_for_inference( + ... df, 'accuracy', unit_of_inference='Subject' + ... ) + + See Also + -------- + ExperimentResult.get_predictions : Tidy prediction accessor. + """ + if predictions.empty: + return predictions.copy() + + frame = predictions + temporal_cols = [ + col for col in TEMPORAL_COLUMNS if col in frame and frame[col].notna().any() + ] + # 1. Resolve Unit Column (Explicitly) + if unit_of_inference == "sample": + unit_col, aggregation = "SampleID", "identity" + else: + unit_col = ( + custom_unit_column if unit_of_inference == "custom" else unit_of_inference + ) + if unit_col not in frame.columns: + raise ValueError( + f"Inference unit '{unit_col}' not found in result columns. " + f"Available: {list(frame.columns)}" + ) + aggregation = custom_aggregation + + if unit_of_inference == "sample": + # Fast path: No aggregation needed + return frame.rename(columns={unit_col: "InferentialUnitID"}) + + # 2. Perform Aggregation + if task != "classification" and aggregation == "majority": + raise ValueError("majority aggregation is only valid for classification.") + + group_cols = [unit_col, *temporal_cols] + proba_cols = sorted( + [col for col in frame.columns if col.startswith("y_proba_")], + key=lambda value: int(value.rsplit("_", 1)[-1]), + ) + + agg_dict = {"y_true": "first"} + if task == "classification": + if aggregation == "mean": + if not proba_cols: + raise ValueError("mean aggregation requires probability columns.") + for col in proba_cols: + agg_dict[col] = "mean" + elif aggregation == "majority": + # Avoid slow lambda/mode: count occurrences and pick first mode + agg_dict["y_pred"] = lambda x: x.value_counts().index[0] + if proba_cols: + for col in proba_cols: + agg_dict[col] = "mean" + else: # regression + agg_dict["y_pred"] = "mean" + + # Execute Aggregation + res = frame.groupby(group_cols, dropna=False).agg(agg_dict).reset_index() + res = res.rename(columns={unit_col: "InferentialUnitID"}) + + # 3. Resolve y_pred for classification mean-aggregation + if task == "classification" and aggregation == "mean": + labels = sorted(pd.unique(frame["y_true"]).tolist()) + probs = res[proba_cols].to_numpy() + # Fast vectorized label assignment + res["y_pred"] = np.array(labels)[np.argmax(probs, axis=1)] + + return res + + +def binomial_accuracy_test( + y_true: Sequence[Any], + y_pred: Sequence[Any], + p0: Optional[float], + alpha: float = 0.05, + ci_method: str = "wilson", +) -> dict[str, Any]: + """ + Exact upper-tail binomial test for top-1 classification accuracy. + + This test computes the probability of obtaining at least the observed number + of correct predictions under the null hypothesis (theoretical chance). + + Scientific Rationale + -------------------- + For classification tasks with a known number of categories, the number of + correct predictions follows a Binomial distribution B(n, p0) under the null + hypothesis. This exact test is more robust than z-tests for small sample + sizes and provides a rigorous bound for 'chance-level' performance. + + Parameters + ---------- + y_true : Sequence[Any] + Actual ground-truth labels. + y_pred : Sequence[Any] + Predicted labels. + p0 : float + The theoretical chance level (e.g., 0.5 for binary classification). + alpha : float, default=0.05 + Significance level for p-values and confidence intervals. + ci_method : str, default='wilson' + Method for calculating confidence intervals ('wilson' or 'clopper_pearson'). + + Returns + ------- + result : dict + Dictionary containing 'observed' accuracy, 'p_value', 'n_eff', + 'chance_threshold', and confidence intervals. + + Raises + ------ + ValueError + If p0 is missing or input arrays are empty. + + Examples + -------- + >>> from coco_pipe.decoding.stats import binomial_accuracy_test + >>> res = binomial_accuracy_test([1, 0, 1], [1, 1, 1], p0=0.5) + >>> print(res['p_value']) + + See Also + -------- + run_statistical_assessment : Full-pipeline assessment driver. + """ + if p0 is None: + raise ValueError("Analytical binomial testing requires an explicit p0.") + + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + n_eff = len(y_true) + if n_eff == 0: + raise ValueError("Cannot run a binomial test with zero predictions.") + + correct = y_true == y_pred + k_correct = int(np.sum(correct)) + observed = k_correct / n_eff + + # Upper-tail (Is accuracy > p0?) + p_upper = float(binom.sf(k_correct - 1, n_eff, p0)) + + # Chance Threshold (Smallest k such that P(X >= k) <= alpha) + k_alpha = int(binom.isf(alpha, n_eff, p0)) + 1 + if binom.sf(k_alpha - 2, n_eff, p0) <= alpha: + k_alpha -= 1 + + ci_lower, ci_upper = _accuracy_ci(k_correct, n_eff, alpha, ci_method) + return { + "observed": observed, + "k_correct": k_correct, + "n_eff": n_eff, + "p_value": p_upper, + "chance_threshold": k_alpha / n_eff, + "chance_threshold_count": k_alpha, + "ci_lower": ci_lower, + "ci_upper": ci_upper, + } + + +def run_statistical_assessment( + observed_result: Any, + experiment_config: Any, + X: np.ndarray, + y: np.ndarray, + groups: Optional[np.ndarray], + sample_ids: np.ndarray, + sample_metadata: Optional[pd.DataFrame], + feature_names: Optional[Sequence[str]], + time_axis: Optional[np.ndarray], + observation_level: str, + inferential_unit: str, +) -> dict[str, Any]: + """ + Orchestrate the statistical assessment of experiment results. + + Resolves the chosen statistical method (binomial or permutation) and + dispatches analysis for each model and metric. + + Scientific Rationale + -------------------- + Statistical significance in decoding is often non-trivial due to temporal + autocorrelations and multiple comparisons. This orchestrator handles + either analytical binomial tests (fast, theoretical chance) or full-pipeline + permutation tests (rigorous, empirical null) to provide scientifically + grounded inferential claims about model performance. + + Parameters + ---------- + observed_result : ExperimentResult + The result of the actual experiment run. + experiment_config : ExperimentConfig + The full configuration of the experiment. + X, y : np.ndarray + The raw features and targets. + groups : np.ndarray, optional + CV grouping vector. + sample_ids : np.ndarray + Unique identifiers for samples. + sample_metadata : pd.DataFrame, optional + Metadata for unit resolution. + feature_names : list of str, optional + Names of input features. + time_axis : np.ndarray, optional + Time coordinates for temporal data. + observation_level : str + Level of input rows ('sample' or 'epoch'). + inferential_unit : str + Definition of statistical independence ('sample' or 'subject'). + + Returns + ------- + assessment_payload : dict + Summary containing 'rows' (standardized results) and 'nulls'. + + Examples + -------- + >>> # Internal use within Experiment.run() + >>> # res = run_statistical_assessment(observed, config, X, y, ...) + + See Also + -------- + binomial_accuracy_test : Core analytical test. + assess_post_hoc_permutation : Fast post-hoc alternative. + """ + stats_config = experiment_config.evaluation + unit = inferential_unit + metrics = experiment_config.get_all_evaluation_metrics() + rows: list[dict[str, Any]] = [] + nulls: dict[str, dict[str, Any]] = {} + + for model in observed_result.raw: + if "error" in observed_result.raw[model]: + continue + model_predictions = observed_result.get_predictions() + model_predictions = model_predictions[model_predictions["Model"] == model] + for metric in metrics: + method = stats_config.chance.method + if method == "auto": + method = ( + "binomial" + if (metric == "accuracy" and stats_config.chance.p0) + else "permutation" + ) + + if method == "binomial": + rows.extend( + _run_binomial_assessment( + model, + metric, + model_predictions, + experiment_config.task, + stats_config, + unit, + ) + ) + continue + + perm_rows, perm_null = _run_permutation_assessment( + model, + metric, + observed_result, + experiment_config, + X, + y, + groups, + sample_ids, + sample_metadata, + feature_names, + time_axis, + observation_level, + inferential_unit, + stats_config, + unit, + ) + rows.extend(perm_rows) + if stats_config.chance.store_null_distribution and perm_null is not None: + nulls.setdefault(model, {})[metric] = perm_null + + return { + "rows": rows, + "nulls": nulls, + "meta": { + "enabled": True, + "method": stats_config.chance.method, + "resolved_unit_of_inference": unit, + "metrics": metrics, + "n_permutations": stats_config.chance.n_permutations, + "alpha": stats_config.confidence_intervals.alpha, + "temporal_correction": stats_config.chance.temporal_correction, + }, + } + + +def _run_binomial_assessment( + model: str, + metric: str, + predictions: pd.DataFrame, + task: str, + config: StatisticalAssessmentConfig, + unit: str, +) -> list[dict[str, Any]]: + """ + Internal driver for analytical binomial significance testing. + + Examples + -------- + >>> # rows = _run_binomial_assessment( + >>> # "LR", "accuracy", preds, "classification", config, "subject" + >>> # ) + """ + if task != "classification" or metric != "accuracy": + raise ValueError( + "Analytical binomial testing only supports classification accuracy." + ) + + aggregated = aggregate_predictions_for_inference( + predictions, + metric=metric, + task=task, + unit_of_inference=unit, + custom_unit_column=config.custom_unit_column, + custom_aggregation=config.custom_aggregation, + require_single_prediction=True, + ) + p0 = config.chance.p0 + if p0 == "auto": + n_classes = len(pd.unique(aggregated["y_true"])) + p0 = 1.0 / n_classes + + temporal_cols = [ + col + for col in TEMPORAL_COLUMNS + if col in aggregated and aggregated[col].notna().any() + ] + + n_units = aggregated["InferentialUnitID"].nunique() + + if not temporal_cols: + result = binomial_accuracy_test( + aggregated["y_true"], + aggregated["y_pred"], + p0=p0, + alpha=config.confidence_intervals.alpha, + ci_method=config.confidence_intervals.method, + ) + return [ + _build_binomial_row( + model, metric, result, unit, p0, n_units, config, (), () + ) + ] + + rows = [] + for key, group in aggregated.groupby(temporal_cols, dropna=False): + coord_key = (key,) if not isinstance(key, tuple) else key + result = binomial_accuracy_test( + group["y_true"], + group["y_pred"], + p0=p0, + alpha=config.confidence_intervals.alpha, + ci_method=config.confidence_intervals.method, + ) + rows.append( + _build_binomial_row( + model, + metric, + result, + unit, + p0, + n_units, + config, + coord_key, + temporal_cols, + ) + ) + return rows + + +def _build_binomial_row( + model: str, + metric: str, + result: dict[str, Any], + unit: str, + p0: float, + n_units: int, + config: StatisticalAssessmentConfig, + key: tuple, + temporal_cols: list[str], +) -> dict[str, Any]: + """Format a binomial test result into a standardized result row.""" + coord = _coord_dict(key, temporal_cols) + return { + "Model": model, + "Metric": metric, + "Observed": result["observed"], + "InferentialUnit": unit, + "NEff": n_units, + "NullMethod": "binomial", + "NPermutations": None, + "P0": p0, + "PValue": result["p_value"], + "CILower": result["ci_lower"], + "CIUpper": result["ci_upper"], + "CorrectionMethod": "none", + "ChanceThreshold": result["chance_threshold"], + "Significant": result["p_value"] <= config.confidence_intervals.alpha, + "Assumptions": "classification accuracy; one prediction per unit", + "Caveat": f"Independence assumed at the '{unit}' level.", + **coord, + } + + +def _run_permutation_assessment( + model: str, + metric: str, + observed_result: Any, + experiment_config: Any, + X: np.ndarray, + y: np.ndarray, + groups: Optional[np.ndarray], + sample_ids: np.ndarray, + sample_metadata: Optional[pd.DataFrame], + feature_names: Optional[Sequence[str]], + time_axis: Optional[np.ndarray], + observation_level: str, + inferential_unit: str, + config: StatisticalAssessmentConfig, + unit: str, +) -> tuple[list[dict[str, Any]], Optional[dict[str, Any]]]: + """ + Internal driver for full-pipeline permutation testing. + + Examples + -------- + >>> # rows, nulls = _run_permutation_assessment( + >>> # "LR", "accuracy", res, cfg, X, y, ... + >>> # ) + """ + observed_predictions = observed_result.get_predictions() + observed_predictions = observed_predictions[observed_predictions["Model"] == model] + observed_agg = aggregate_predictions_for_inference( + observed_predictions, + metric=metric, + task=experiment_config.task, + unit_of_inference=unit, + custom_unit_column=config.custom_unit_column, + custom_aggregation=config.custom_aggregation, + ) + observed_scores = _score_by_coordinates(observed_agg, metric) + score_keys = list(observed_scores.keys()) + n_units = observed_agg["InferentialUnitID"].nunique() + temporal_cols = [ + col + for col in TEMPORAL_COLUMNS + if col in observed_agg and observed_agg[col].notna().any() + ] + + null_array = _run_permutation_loop( + model=model, + metric=metric, + score_keys=score_keys, + experiment_config=experiment_config, + X=X, + y=y, + groups=groups, + sample_ids=sample_ids, + sample_metadata=sample_metadata, + feature_names=feature_names, + time_axis=time_axis, + observation_level=observation_level, + inferential_unit=inferential_unit, + config=config, + unit=unit, + ) + + # Bootstrap Observed Scores for Confidence Intervals + boot_array = _bootstrap_scores( + observed_agg, + metric=metric, + score_keys=score_keys, + n_bootstraps=1000, + random_state=config.random_state, + ) + + obs_ci_lower = np.nanpercentile(boot_array, 2.5, axis=0) + obs_ci_upper = np.nanpercentile(boot_array, 97.5, axis=0) + + observed_array = np.asarray([observed_scores[key] for key in score_keys]) + rows = _build_permutation_rows( + model=model, + metric=metric, + observed_array=observed_array, + null_array=null_array, + obs_ci_lower=obs_ci_lower, + obs_ci_upper=obs_ci_upper, + score_keys=score_keys, + temporal_cols=temporal_cols, + unit=unit, + n_units=n_units, + config=config, + task=experiment_config.task, + ) + + null_payload = None + if config.chance.store_null_distribution: + null_payload = { + "metric": metric, + "unit": unit, + "coordinates": [_coord_dict(key, temporal_cols) for key in score_keys], + "values": null_array, + } + return rows, null_payload + + +def _run_permutation_loop( + model: str, + metric: str, + score_keys: list[tuple], + experiment_config: Any, + X: np.ndarray, + y: np.ndarray, + groups: Optional[np.ndarray], + sample_ids: np.ndarray, + sample_metadata: Optional[pd.DataFrame], + feature_names: Optional[Sequence[str]], + time_axis: Optional[np.ndarray], + observation_level: str, + inferential_unit: str, + config: StatisticalAssessmentConfig, + unit: str, +) -> np.ndarray: + """ + Execute the core permutation loop using parallel processing. + """ + from .experiment import Experiment + + rng = np.random.default_rng(config.random_state) + if unit == "sample": + unit_values = np.arange(len(y)) + elif sample_metadata is not None and unit in sample_metadata.columns: + unit_values = sample_metadata[unit].to_numpy() + elif groups is not None and unit in {"Group", "group"}: + unit_values = np.asarray(groups) + else: + target_col = config.custom_unit_column if unit == "custom" else unit + if sample_metadata is not None and target_col in sample_metadata.columns: + unit_values = sample_metadata[target_col].to_numpy() + else: + raise ValueError(f"Could not resolve unit values for '{unit}'.") + unique_units, unit_map_idx = np.unique(unit_values, return_inverse=True) + n_unique = len(unique_units) + + unit_labels_orig = [] + for u in unique_units: + mask = unit_values == u + u_y = y[mask] + if experiment_config.task == "classification": + unit_labels_orig.append(u_y[0]) + else: + unit_labels_orig.append(np.mean(u_y)) + unit_labels_orig = np.array(unit_labels_orig) + + import joblib + + parallel = joblib.Parallel(n_jobs=experiment_config.n_jobs) + + def _run_single_permutation(p_idx, seed): + local_rng = np.random.default_rng(seed) + perm_idx = local_rng.permutation(n_unique) + y_perm = unit_labels_orig[perm_idx][unit_map_idx] + perm_config = experiment_config.model_copy() + p_res = Experiment(perm_config).run( + X, + y_perm, + groups=groups, + feature_names=feature_names, + sample_ids=sample_ids, + sample_metadata=sample_metadata, + observation_level=observation_level, + inferential_unit=inferential_unit, + time_axis=time_axis, + ) + p_preds = p_res.get_predictions() + p_preds = p_preds[p_preds["Model"] == model] + p_agg = aggregate_predictions_for_inference( + p_preds, + metric=metric, + task=experiment_config.task, + unit_of_inference=unit, + custom_unit_column=config.custom_unit_column, + custom_aggregation=config.custom_aggregation, + ) + p_scores = _score_by_coordinates(p_agg, metric) + return [p_scores[key] for key in score_keys] + + seeds = rng.integers(0, 2**32, size=config.chance.n_permutations) + results = parallel( + joblib.delayed(_run_single_permutation)(i, seeds[i]) + for i in range(config.chance.n_permutations) + ) + + return np.array(results) + + +def _build_permutation_rows( + model: str, + metric: str, + observed_array: np.ndarray, + null_array: np.ndarray, + obs_ci_lower: np.ndarray, + obs_ci_upper: np.ndarray, + score_keys: list[tuple[Any, ...]], + temporal_cols: list[str], + unit: str, + n_units: int, + config: StatisticalAssessmentConfig, + task: str, +) -> list[dict[str, Any]]: + """Format a permutation test result into a standardized result row.""" + metric_spec = get_metric_spec(metric) + greater_is_better = metric_spec.greater_is_better + + p_values = _empirical_p_values( + observed_array, + null_array, + greater_is_better, + ) + + corrected = _correct_p_values( + observed_array, + null_array, + p_values, + config.chance.temporal_correction, + metric_spec.greater_is_better, + ) + + null_median = np.nanmedian(null_array, axis=0) + null_lower = np.nanpercentile(null_array, 2.5, axis=0) + null_upper = np.nanpercentile(null_array, 97.5, axis=0) + + rows = [] + for idx, key in enumerate(score_keys): + coord = _coord_dict(key, temporal_cols) + rows.append( + { + "Model": model, + "Metric": metric, + "Observed": observed_array[idx], + "InferentialUnit": unit, + "NEff": n_units, + "NullMethod": "permutation_full_pipeline", + "NPermutations": config.chance.n_permutations, + "P0": null_median[idx], + "PValue": p_values[idx], + "CILower": obs_ci_lower[idx], + "CIUpper": obs_ci_upper[idx], + "CorrectionMethod": config.chance.temporal_correction, + "CorrectedPValue": corrected[idx], + "ChanceThreshold": np.nanpercentile( + null_array[:, idx], 95 if greater_is_better else 5 + ), + "NullMedian": null_median[idx], + "NullLower": null_lower[idx], + "NullUpper": null_upper[idx], + "Significant": corrected[idx] <= config.confidence_intervals.alpha, + "Assumptions": ( + "full outer-CV pipeline rerun under label permutations; " + "regression targets averaged by unit" + if task == "regression" and unit != "sample" + else "full outer-CV pipeline rerun under label permutations" + ), + "Caveat": f"Independence assumed at the '{unit}' level.", + **coord, + } + ) + return rows + + +def run_paired_permutation_assessment( + results_a: "ExperimentResult", + results_b: "ExperimentResult", + model: str, + metric: str, + config: StatisticalAssessmentConfig, +) -> pd.DataFrame: + """ + Run a paired permutation test to compare two experimental results. + + Tests the null hypothesis that the difference between two models is zero + by randomly swapping model labels within each independent unit. + + Scientific Rationale + -------------------- + This function performs a rigorous comparison of two experimental pipelines + by aligning predictions at the 'SampleID' level and performing + within-unit label swaps. This ensures that the comparison is not biased + by unit-specific performance baselines and correctly estimates the + p-value for the observed performance delta. + + Parameters + ---------- + results_a, results_b : ExperimentResult + The results of the two experiments to compare. + model : str + The name of the model to compare. + metric : str + Metric to use for the comparison. + config : StatisticalAssessmentConfig + Configuration for permutations and significance. + + Returns + ------- + paired_df : pd.DataFrame + DataFrame with Difference, PValue, and confidence intervals. + + Examples + -------- + >>> # diff = run_paired_permutation_assessment(res1, res2, 'LR', 'accuracy', config) + + See Also + -------- + assess_paired_comparison : Fast post-hoc alternative. + + Examples + -------- + >>> # paired_df = run_paired_permutation_assessment( + >>> # res_a, res_b, "LR", "accuracy", config + >>> # ) + """ + from ._diagnostics import score_frame + + preds_a = results_a.get_predictions() + preds_b = results_b.get_predictions() + preds_a = preds_a[preds_a["Model"] == model] + preds_b = preds_b[preds_b["Model"] == model] + + merge_cols = ["SampleID", "Fold"] + temporal_cols = [c for c in TEMPORAL_COLUMNS if c in preds_a] + merge_cols.extend(temporal_cols) + + unit_col = ( + config.unit_of_inference if config.unit_of_inference != "sample" else "SampleID" + ) + if unit_col in preds_a and unit_col not in merge_cols: + merge_cols.append(unit_col) + + merged = pd.merge(preds_a, preds_b, on=merge_cols, suffixes=("_A", "_B")) + if merged.empty: + raise ValueError("Could not align predictions for paired test.") + + def get_diff(group: pd.DataFrame) -> float: + s_a = score_frame( + group.filter(regex=".*_A$|SampleID|y_true").rename( + columns=lambda x: x[:-2] if x.endswith("_A") else x + ), + metric, + ) + s_b = score_frame( + group.filter(regex=".*_B$|SampleID|y_true").rename( + columns=lambda x: x[:-2] if x.endswith("_B") else x + ), + metric, + ) + return s_a - s_b + + obs_scores_dummy = _score_by_coordinates(preds_a, metric) + score_keys = list(obs_scores_dummy.keys()) + + observed_diff_array = np.zeros(len(score_keys)) + for idx, key in enumerate(score_keys): + m = np.ones(len(merged), dtype=bool) + for i, c in enumerate(temporal_cols): + m &= merged[c] == key[i] + observed_diff_array[idx] = get_diff(merged[m]) + + boot_results = _bootstrap_scores_paired( + merged, + metric=metric, + score_keys=score_keys, + temporal_cols=temporal_cols, + unit_col=unit_col, + n_bootstraps=1000, + random_state=config.random_state, + ) + + obs_ci_lower = np.nanpercentile(boot_results, 2.5, axis=0) + obs_ci_upper = np.nanpercentile(boot_results, 97.5, axis=0) + + import joblib + + n_perm = config.chance.n_permutations + perm_rng = np.random.default_rng(config.random_state + 1) + unique_units = merged[unit_col].unique() + n_units = len(unique_units) + + def _run_single_perm(seed): + local_rng = np.random.default_rng(seed) + swaps = local_rng.choice([False, True], size=n_units) + swap_units = unique_units[swaps] + + perm_merged = merged.copy() + if np.any(swaps): + mask = merged[unit_col].isin(swap_units) + cols_a = [c for c in merged.columns if c.endswith("_A")] + for c_a in cols_a: + c_b = c_a[:-2] + "_B" + a_vals = merged.loc[mask, c_a].copy() + perm_merged.loc[mask, c_a] = merged.loc[mask, c_b] + perm_merged.loc[mask, c_b] = a_vals + + p_diffs = np.empty(len(score_keys)) + for idx, key in enumerate(score_keys): + m = np.ones(len(perm_merged), dtype=bool) + for j, c in enumerate(temporal_cols): + m &= perm_merged[c] == key[j] + p_diffs[idx] = get_diff(perm_merged[m]) + return p_diffs + + seeds = perm_rng.integers(0, 2**32, size=n_perm) + n_jobs = getattr(config, "n_jobs", 1) + null_results = joblib.Parallel(n_jobs=n_jobs)( + joblib.delayed(_run_single_perm)(s) for s in seeds + ) + null_array = np.array(null_results) + + p_values = _empirical_p_values( + observed_diff_array, null_array, greater_is_better=True, two_sided=True + ) + corrected = _correct_p_values( + observed_diff_array, + null_array, + p_values, + config.chance.temporal_correction, + greater_is_better=True, + ) + + null_median = np.nanmedian(null_array, axis=0) + null_lower = np.nanpercentile(null_array, 2.5, axis=0) + null_upper = np.nanpercentile(null_array, 97.5, axis=0) + + rows = [] + for idx, key in enumerate(score_keys): + coord = _coord_dict(key, temporal_cols) + rows.append( + { + "Model": model, + "Metric": metric, + "Comparison": "Paired Difference (A-B)", + "Observed": observed_diff_array[idx], + "InferentialUnit": unit_col, + "NEff": n_units, + "NullMethod": "paired_permutation", + "NPermutations": n_perm, + "P0": null_median[idx], + "PValue": p_values[idx], + "CILower": obs_ci_lower[idx], + "CIUpper": obs_ci_upper[idx], + "CorrectionMethod": config.chance.temporal_correction, + "CorrectedPValue": corrected[idx], + "NullMedian": null_median[idx], + "NullLower": null_lower[idx], + "NullUpper": null_upper[idx], + "Significant": corrected[idx] <= config.confidence_intervals.alpha, + "Caveat": f"Independence assumed at the '{unit_col}' level.", + **coord, + } + ) + + return pd.DataFrame(rows) + + +def _score_by_coordinates( + frame: pd.DataFrame, metric: str +) -> dict[tuple[Any, ...], float]: + """Score predictions across all temporal coordinates.""" + from ._diagnostics import score_frame + + temporal_cols = [ + col for col in TEMPORAL_COLUMNS if col in frame and frame[col].notna().any() + ] + if not temporal_cols: + return {(): score_frame(frame, metric)} + + m_spec = get_metric_spec(metric) + if m_spec.response_method == "predict" and metric in { + "accuracy", + "zero_one_loss", + "hamming_loss", + }: + y_true_mat = frame.pivot( + index="InferentialUnitID", columns=temporal_cols, values="y_true" + ) + y_pred_mat = frame.pivot( + index="InferentialUnitID", columns=temporal_cols, values="y_pred" + ) + + if metric == "accuracy": + scores_array = (y_true_mat.values == y_pred_mat.values).mean(axis=0) + else: + scores_array = (y_true_mat.values != y_pred_mat.values).mean(axis=0) + + return dict(zip(y_true_mat.columns, scores_array)) + + scores = {} + for key, group in frame.groupby(temporal_cols, dropna=False): + coord_key = (key,) if not isinstance(key, tuple) else key + scores[coord_key] = score_frame(group, metric) + return scores + + +def _bootstrap_engine( + units: np.ndarray, + unit_map: dict[Any, pd.DataFrame], + score_func: Callable[[pd.DataFrame], np.ndarray], + n_bootstraps: int = 1000, + random_state: Optional[int] = None, +) -> np.ndarray: + """ + Core engine for unit-based bootstrap resampling. + + Examples + -------- + >>> # boot_dist = _bootstrap_engine(units, unit_map, score_func, n_bootstraps=100) + """ + rng = np.random.default_rng(random_state) + n_units = len(units) + results = [] + for _ in range(n_bootstraps): + boot_units = rng.choice(units, size=n_units, replace=True) + boot_frame = pd.concat([unit_map[u] for u in boot_units]) + results.append(score_func(boot_frame)) + return np.array(results) + + +def _bootstrap_scores( + frame: pd.DataFrame, + metric: str, + score_keys: list[tuple], + n_bootstraps: int = 1000, + random_state: Optional[int] = None, +) -> np.ndarray: + """Resample independent units with replacement and re-score.""" + unique_units = frame["InferentialUnitID"].unique() + unit_map = {u: frame[frame["InferentialUnitID"] == u] for u in unique_units} + + def score_func(df: pd.DataFrame) -> np.ndarray: + boot_scores = _score_by_coordinates(df, metric) + return np.array([boot_scores.get(key, np.nan) for key in score_keys]) + + return _bootstrap_engine( + unique_units, unit_map, score_func, n_bootstraps, random_state + ) + + +def _bootstrap_scores_paired( + merged: pd.DataFrame, + metric: str, + score_keys: list[tuple], + temporal_cols: list[str], + unit_col: str, + n_bootstraps: int = 1000, + random_state: Optional[int] = None, +) -> np.ndarray: + """Resample independent units for paired differences.""" + from ._diagnostics import score_frame + + unique_units = merged[unit_col].unique() + unit_map = {u: merged[merged[unit_col] == u] for u in unique_units} + + def get_diff(group: pd.DataFrame) -> float: + s_a = score_frame( + group.filter(regex=".*_A$|SampleID|y_true").rename( + columns=lambda x: x[:-2] if x.endswith("_A") else x + ), + metric, + ) + s_b = score_frame( + group.filter(regex=".*_B$|SampleID|y_true").rename( + columns=lambda x: x[:-2] if x.endswith("_B") else x + ), + metric, + ) + return s_a - s_b + + def score_func(df: pd.DataFrame) -> np.ndarray: + res = np.empty(len(score_keys)) + for idx, key in enumerate(score_keys): + m = np.ones(len(df), dtype=bool) + for j, c in enumerate(temporal_cols): + m &= df[c] == key[j] + res[idx] = get_diff(df[m]) + return res + + return _bootstrap_engine( + unique_units, unit_map, score_func, n_bootstraps, random_state + ) + + +def _empirical_p_values( + observed: np.ndarray, + null: np.ndarray, + greater_is_better: bool, + two_sided: bool = False, +) -> np.ndarray: + """ + Calculate empirical p-values from a null distribution. + + Uses the recommended (k+1)/(n+1) estimator to avoid p=0. + Supports both one-sided and asymmetric two-sided calculations. + + Parameters + ---------- + observed : np.ndarray + The observed scores. + null : np.ndarray + Null distribution array (permutations, coordinates). + greater_is_better : bool + Metric directionality. + two_sided : bool, default=False + Whether to compute a two-sided p-value. + + Returns + ------- + p_values : np.ndarray + Empirical p-values for each coordinate. + """ + if two_sided: + p_upper = (np.sum(null >= observed, axis=0) + 1) / (null.shape[0] + 1) + p_lower = (np.sum(null <= observed, axis=0) + 1) / (null.shape[0] + 1) + return np.minimum(1.0, 2 * np.minimum(p_upper, p_lower)) + + if greater_is_better: + return (np.sum(null >= observed, axis=0) + 1) / (null.shape[0] + 1) + return (np.sum(null <= observed, axis=0) + 1) / (null.shape[0] + 1) + + +def _correct_p_values( + observed: np.ndarray, + null: np.ndarray, + p_values: np.ndarray, + method: str, + greater_is_better: bool, +) -> np.ndarray: + """ + Apply multiple-comparison correction across temporal coordinates. + + Supported methods include standard corrections (Bonferroni, FDR) and + permutation-based Max-Stat (recommended for cluster-based temporal data). + + Parameters + ---------- + observed : np.ndarray + The observed scores. + null : np.ndarray + Null distribution array. + p_values : np.ndarray + Raw p-values. + method : str + Correction method name. + greater_is_better : bool + Metric directionality. + + Returns + ------- + corrected_p : np.ndarray + Corrected p-values. + """ + if method == "none" or observed.size == 1: + return p_values + if method == "bonferroni": + return np.minimum(1.0, p_values * len(p_values)) + if method == "fdr_bh": + return false_discovery_control(p_values, method="bh") + if method == "fdr_by": + return false_discovery_control(p_values, method="by") + if method == "max_stat": + if greater_is_better: + max_null = np.nanmax(null, axis=1) + return (np.sum(max_null[:, None] >= observed[None, :], axis=0) + 1) / ( + null.shape[0] + 1 + ) + min_null = np.nanmin(null, axis=1) + return (np.sum(min_null[:, None] <= observed[None, :], axis=0) + 1) / ( + null.shape[0] + 1 + ) + raise ValueError(f"Unknown temporal correction: {method}.") + + +def _accuracy_ci( + k_correct: np.ndarray, + n_eff: np.ndarray, + alpha: float, + method: str, +) -> tuple[np.ndarray, np.ndarray]: + """ + Calculate confidence intervals for accuracy, vectorized. + + Parameters + ---------- + k_correct : np.ndarray + Number of correct predictions. + n_eff : np.ndarray + Effective sample size. + alpha : float + Significance level. + method : str + CI method ('wilson' or 'clopper_pearson'). + + Returns + ------- + lower, upper : tuple of np.ndarray + Lower and upper CI bounds. + """ + if method == "clopper_pearson": + lower = beta.ppf(alpha / 2, k_correct, n_eff - k_correct + 1) + lower = np.where(k_correct == 0, 0.0, lower) + upper = beta.ppf(1 - alpha / 2, k_correct + 1, n_eff - k_correct) + upper = np.where(k_correct == n_eff, 1.0, upper) + return lower, upper + + if method != "wilson": + raise ValueError("ci_method must be 'wilson' or 'clopper_pearson'.") + + z = norm.ppf(1 - alpha / 2) + phat = k_correct / n_eff + denom = 1 + z**2 / n_eff + center = (phat + z**2 / (2 * n_eff)) / denom + half = z * np.sqrt((phat * (1 - phat) + z**2 / (4 * n_eff)) / n_eff) / denom + return np.maximum(0.0, center - half), np.minimum(1.0, center + half) + + +def _coord_dict(key: tuple[Any, ...], names: list[str]) -> dict[str, Any]: + """Map a coordinate tuple to its dimension names.""" + result = {"Time": None, "TrainTime": None, "TestTime": None} + if not names or len(key) == 0: + return result + + for i, name in enumerate(names): + if i < len(key): + result[name] = key[i] + return result + + +def assess_post_hoc_permutation( + res: dict[str, Any], + metric: str = "accuracy", + unit: Optional[str] = None, + n_permutations: int = 1000, + random_state: Optional[int] = None, +) -> pd.DataFrame: + """ + Perform a post-hoc label permutation assessment on out-of-fold predictions. + + Shuffles labels relative to fixed predictions to estimate the null + distribution under exchangeability. + + Scientific Rationale + -------------------- + Unlike full-pipeline permutations, post-hoc permutations do not rerun + feature selection or tuning. This makes them significantly faster but + potentially over-optimistic if those steps 'leaked' label information. + However, if the independence unit (e.g., subject) is respected during + the shuffle, it provides a valid test of whether the model's predictions + are significantly associated with the labels beyond chance. + + Parameters + ---------- + res : dict + Result dictionary from ExperimentResult.raw. + metric : str, default='accuracy' + The metric to evaluate. + unit : str, optional + Level of independence (e.g., 'subject'). + n_permutations : int, default=1000 + Number of null permutations. + random_state : int, optional + Seed for reproducibility. + + Returns + ------- + posthoc_df : pd.DataFrame + DataFrame with Observed score, PValue, and Significant status. + + Examples + -------- + >>> # posthoc = assess_post_hoc_permutation(res.raw['LR'], metric='accuracy') + + See Also + -------- + run_statistical_assessment : Full-pipeline assessment driver. + """ + from ._diagnostics import prediction_rows, score_frame + + preds = [] + for fold_idx, p_list in enumerate(res.get("predictions", [])): + rows = prediction_rows(model="temp", fold_idx=fold_idx, preds=p_list) + preds.extend([r for r in rows if r.get("Time") is None]) + + if not preds: + raise ValueError("No scalar predictions found for post-hoc assessment.") + + df = pd.DataFrame(preds) + y_true = df["y_true"].to_numpy() + obs_score = score_frame(df, metric) + + rng = np.random.default_rng(random_state) + null_scores = np.zeros(n_permutations) + + u_col = unit if unit != "sample" else None + if u_col is None and "Group" in df.columns: + u_col = "Group" + + if u_col is not None and u_col in df.columns: + unique_units = df[u_col].unique() + label_map = df.groupby(u_col)["y_true"].apply(list).to_dict() + for i in range(n_permutations): + shuffled_units = rng.permutation(unique_units) + unit_map = dict(zip(unique_units, shuffled_units)) + new_labels = [] + for val in df[u_col]: + target_u = unit_map[val] + new_labels.append(rng.choice(label_map[target_u])) + new_df = df.copy() + new_df["y_true"] = new_labels + null_scores[i] = score_frame(new_df, metric) + else: + for i in range(n_permutations): + new_labels = rng.permutation(y_true) + new_df = df.copy() + new_df["y_true"] = new_labels + null_scores[i] = score_frame(new_df, metric) + + spec = get_metric_spec(metric) + if spec.greater_is_better: + p_val = (np.sum(null_scores >= obs_score) + 1) / (n_permutations + 1) + else: + p_val = (np.sum(null_scores <= obs_score) + 1) / (n_permutations + 1) + + return pd.DataFrame( + [ + { + "Metric": metric, + "Observed": obs_score, + "PValue": float(p_val), + "Significant": p_val < 0.05, + "NullMethod": "posthoc_label_permutation", + "NPermutations": n_permutations, + "InferentialUnit": unit or "sample", + "NullLower": float(np.quantile(null_scores, 0.025)), + "NullUpper": float(np.quantile(null_scores, 0.975)), + } + ] + ) + + +def assess_paired_comparison( + merged: pd.DataFrame, + metric: str = "accuracy", + unit: Optional[str] = None, + n_permutations: int = 1000, + random_state: Optional[int] = None, +) -> pd.DataFrame: + """ + Perform a paired permutation test between two models. + + Tests the null hypothesis that the difference between two models is zero + by randomly swapping model labels within each independent unit. + + Scientific Rationale + -------------------- + To compare two models (A and B), we test if the observed difference in + performance is greater than what would be expected by chance if the labels + 'A' and 'B' were interchangeable. By swapping labels within units (e.g., + within subject), we control for subject-specific performance baselines and + focus on the model-driven variance. + + Parameters + ---------- + merged : pd.DataFrame + Merged prediction frame with suffixes '_A' and '_B'. + metric : str, default='accuracy' + Metric to evaluate. + unit : str, optional + Level of independence (e.g., 'subject'). + n_permutations : int, default=1000 + Number of permutations. + random_state : int, optional + Seed for reproducibility. + + Returns + ------- + comparison_df : pd.DataFrame + DataFrame with ScoreA, ScoreB, Difference, and PValue. + + Examples + -------- + >>> # comp = assess_paired_comparison(merged_df, metric='accuracy') + + See Also + -------- + run_paired_permutation_assessment : Full-pipeline paired comparison. + """ + + coord_cols = [c for c in ["Time", "TrainTime", "TestTime"] if c in merged] + if coord_cols: + results = [] + for coords, group in merged.groupby(coord_cols): + res = _assess_paired_comparison_internal( + group, metric, unit, n_permutations, random_state + ) + # Add coordinates back + if len(coord_cols) == 1: + res[coord_cols[0]] = coords + else: + for i, col in enumerate(coord_cols): + res[col] = coords[i] + results.append(res) + return pd.concat(results, ignore_index=True) + + return _assess_paired_comparison_internal( + merged, metric, unit, n_permutations, random_state + ) + + +def _assess_paired_comparison_internal( + merged: pd.DataFrame, + metric: str, + unit: Optional[str], + n_permutations: int, + random_state: Optional[int], +) -> pd.DataFrame: + """Internal core for paired comparison on a single coordinate.""" + from ._diagnostics import paired_unit_indices, score_frame + + frame_a = merged.copy() + for col in merged.columns: + if col.endswith("_A"): + frame_a[col[:-2]] = merged[col] + + frame_b = merged.copy() + for col in merged.columns: + if col.endswith("_B"): + frame_b[col[:-2]] = merged[col] + + score_a = score_frame(frame_a, metric) + score_b = score_frame(frame_b, metric) + obs_diff = score_a - score_b + + u_indices = paired_unit_indices(merged, unit or "sample") + n_units = len(u_indices) + rng = np.random.default_rng(random_state) + + null_diffs = np.zeros(n_permutations) + for i in range(n_permutations): + swaps = rng.choice([True, False], size=n_units) + perm_a = frame_a.copy() + perm_b = frame_b.copy() + + for unit_idx, should_swap in enumerate(swaps): + if should_swap: + idx = u_indices[unit_idx] + swap_cols = ["y_pred"] + if "y_score" in frame_a.columns: + swap_cols.append("y_score") + swap_cols.extend( + [c for c in frame_a.columns if c.startswith("y_proba_")] + ) + for col in swap_cols: + temp = perm_a.iloc[idx, perm_a.columns.get_loc(col)].copy() + perm_a.iloc[idx, perm_a.columns.get_loc(col)] = perm_b.iloc[ + idx, perm_b.columns.get_loc(col) + ] + perm_b.iloc[idx, perm_b.columns.get_loc(col)] = temp + + null_diffs[i] = score_frame(perm_a, metric) - score_frame(perm_b, metric) + + p_val = (np.sum(np.abs(null_diffs) >= np.abs(obs_diff)) + 1) / (n_permutations + 1) + + return pd.DataFrame( + [ + { + "Metric": metric, + "ScoreA": score_a, + "ScoreB": score_b, + "Difference": obs_diff, + "PValue": float(p_val), + "Significant": p_val < 0.05, + "NUnits": n_units, + "NPermutations": n_permutations, + } + ] + ) + + +def assess_bootstrap_ci( + res: dict[str, Any], + metric: str = "accuracy", + unit: Optional[str] = None, + n_bootstraps: int = 1000, + ci: float = 0.95, + random_state: Optional[int] = None, +) -> pd.DataFrame: + """ + Estimate uncertainty of a metric via bootstrapping over units. + + This function computes the observed metric on the provided results + and then generates a distribution of scores by resampling independent + units with replacement. + + Scientific Rationale + -------------------- + Bootstrapping provides a non-parametric estimate of the sampling + distribution of the metric. By resampling at the 'unit' level (e.g., + subjects rather than individual trials), we account for within-unit + correlations and avoid pseudoreplication, ensuring that the confidence + intervals accurately reflect the uncertainty at the intended level of + inference. + + Parameters + ---------- + res : dict + Result dictionary for a single model from ExperimentResult.raw. + metric : str, default='accuracy' + Metric to evaluate. + unit : str, optional + The level of independence (e.g., 'subject'). + n_bootstraps : int, default=1000 + Number of bootstrap iterations. + ci : float, default=0.95 + Confidence level (0.95 for 95% intervals). + random_state : int, optional + Seed for reproducibility. + + Returns + ------- + bootstrap_df : pd.DataFrame + DataFrame with estimate, CILower, and CIUpper. + + Examples + -------- + >>> # ci_df = assess_bootstrap_ci(res.raw['LR'], unit='subject') + + See Also + -------- + binomial_accuracy_test : Analytical CI alternative. + """ + from ._diagnostics import prediction_rows, score_frame, unit_indices + + preds = [] + for fold_idx, p_list in enumerate(res.get("predictions", [])): + rows = prediction_rows(model="temp", fold_idx=fold_idx, preds=p_list) + preds.extend([r for r in rows if r.get("Time") is None]) + + if not preds: + raise ValueError("No scalar predictions found for bootstrap.") + + df = pd.DataFrame(preds) + obs_score = score_frame(df, metric) + + u_indices = unit_indices(df, unit) + unit_map = {i: df.iloc[idx] for i, idx in enumerate(u_indices)} + unique_units = np.arange(len(u_indices)) + + def score_func(sample_df: pd.DataFrame) -> np.ndarray: + try: + return np.array([score_frame(sample_df, metric)]) + except Exception: + return np.array([np.nan]) + + boot_scores = _bootstrap_engine( + unique_units, unit_map, score_func, n_bootstraps, random_state + ).flatten() + + alpha = (1 - ci) / 2 + + return pd.DataFrame( + [ + { + "Metric": metric, + "Estimate": obs_score, + "CILower": float(np.nanquantile(boot_scores, alpha)), + "CIUpper": float(np.nanquantile(boot_scores, 1 - alpha)), + "Unit": unit or "sample", + "NUnits": len(u_indices), + "NBootstraps": n_bootstraps, + } + ] + ) + + +def apply_multiple_comparison_correction( + df: pd.DataFrame, + p_col: str = "PValue", + method: str = "fdr_bh", + alpha: float = 0.05, +) -> pd.DataFrame: + """ + Apply multiple comparison correction to a DataFrame of results. + + Scientific Rationale + -------------------- + When testing multiple hypotheses (e.g., across many timepoints or models), + the probability of a Type I error (false positive) increases. This + utility applies standard corrections like Bonferroni (strict) or False + Discovery Rate (FDR; Benjamini-Hochberg) to control the family-wise error + rate or the expected proportion of false discoveries. + + Parameters + ---------- + df : pd.DataFrame + Results DataFrame containing p-values. + p_col : str, default='PValue' + Name of the column containing raw p-values. + method : str, default='fdr_bh' + Correction method (e.g., 'fdr_bh', 'bonferroni'). + alpha : float, default=0.05 + Significance level. + + Returns + ------- + corrected_df : pd.DataFrame + The DataFrame with updated 'PValueCorrected' and 'Significant' columns. + + Examples + -------- + >>> # corrected = apply_multiple_comparison_correction(results_df, method='fdr_bh') + + See Also + -------- + statsmodels.stats.multitest.multipletests : Underlying implementation. + """ + from statsmodels.stats.multitest import multipletests + + if df.empty or len(df) < 2 or not method: + if "Significant" not in df.columns and not df.empty: + df["Significant"] = df[p_col] < alpha + return df + + reject, corrected, _, _ = multipletests( + df[p_col].to_numpy(), alpha=alpha, method=method + ) + + df = df.copy() + df["PValueCorrected"] = corrected + df["Significant"] = reject + df["CorrectionMethod"] = method + return df diff --git a/coco_pipe/decoding/utils.py b/coco_pipe/decoding/utils.py deleted file mode 100644 index b660042..0000000 --- a/coco_pipe/decoding/utils.py +++ /dev/null @@ -1,343 +0,0 @@ -""" -Decoding Utilities -================== - -Helper functions and classes for the decoding module, primarily focused on -Cross-Validation (CV) strategy management. - -This module provides: -- `get_cv_splitter`: A factory function to instantiate Scikit-Learn cross-validators - from a Pydantic `CVConfig`. -- `SimpleSplit`: A custom validator for a single train/test split. -- `_CVWithGroups`: A wrapper to ensure group constraints are respected even when - Scikit-Learn's `cross_val_score` internals might obscure them. -""" - -from typing import Any, Callable, Optional, Sequence, Union - -import numpy as np -import pandas as pd -from sklearn.base import BaseEstimator, clone -from sklearn.metrics import ( - accuracy_score, - balanced_accuracy_score, - explained_variance_score, - f1_score, - mean_absolute_error, - mean_squared_error, - precision_score, - r2_score, - recall_score, - roc_auc_score, -) -from sklearn.model_selection import ( - BaseCrossValidator, - GroupKFold, - KFold, - LeaveOneGroupOut, - LeavePGroupsOut, - StratifiedGroupKFold, - StratifiedKFold, - train_test_split, -) -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler - -from .configs import CVConfig - - -class _CVWithGroups(BaseCrossValidator): - """ - Internal wrapper to bind specific groups to a CV splitter. - - This ensures that `.split(X, y)` always uses the strict `groups` provided - at initialization, ignoring any groups passed at runtime. This is critical - for preventing data leakage when complex grouping logic is defined upstream. - - Parameters - ---------- - cv : BaseCrossValidator - The underlying Scikit-Learn cross-validator (e.g., GroupKFold). - groups : array-like - The group labels to enforce for all splits. - """ - - def __init__(self, cv, groups): - self.cv = cv - self.groups = groups - - def split(self, X, y=None, groups=None): - # ignore incoming groups, always use our stored one - return self.cv.split(X, y, self.groups) - - def get_n_splits(self, X=None, y=None, groups=None): - return self.cv.get_n_splits(X, y, self.groups) - - -class SimpleSplit(BaseCrossValidator): - """ - A unified 1-fold CV strategy wrapping `train_test_split`. - - This allows "hold-out" validation to be treated as a Cross-Validation - strategy with `n_splits=1`, integrating seamlessly into loops that - expect a generator of indices. - - Parameters - ---------- - test_size : float, default=0.2 - Proportion of the dataset to include in the test split. - shuffle : bool, default=True - Whether to shuffle the data before splitting. - random_state : int, optional - Controls the shuffling applied to the data before applying the split. - stratify : array-like, optional - If not None, data is split in a stratified fashion, using this array - as the class labels. - """ - - def __init__( - self, - test_size: float = 0.2, - shuffle: bool = True, - random_state: Optional[int] = None, - stratify: Optional[Union[pd.Series, np.ndarray]] = None, - ): - if not (0 < test_size < 1): - raise ValueError("test_size must be between 0 and 1.") - self.test_size = test_size - self.shuffle = shuffle - self.random_state = random_state - self.stratify = stratify - - def split( - self, - X: Union[pd.DataFrame, np.ndarray], - y: Optional[Union[pd.Series, np.ndarray]] = None, - groups: Optional[Sequence] = None, - ): - """ - Yield a single (train_index, test_index) tuple. - """ - idx = np.arange(len(X)) - strat = self.stratify if self.stratify is not None else None - train_idx, test_idx = train_test_split( - idx, - test_size=self.test_size, - shuffle=self.shuffle, - random_state=self.random_state if self.shuffle else None, - stratify=strat, - ) - yield train_idx, test_idx - - def get_n_splits( - self, - X: Any = None, - y: Any = None, - groups: Any = None, - ) -> int: - """Always returns 1 split.""" - return 1 - - -def get_cv_splitter( - config: CVConfig, groups: Optional[Sequence] = None -) -> BaseCrossValidator: - """ - Factory function to create a Scikit-Learn compliant cross-validator. - - Constructs the appropriate splitter based on the provided `CVConfig` strategy. - If `groups` are provided, they are bound to the splitter using `_CVWithGroups` - to guarantee consistent grouping across pipeline steps. - - Parameters - ---------- - config : CVConfig - Validated configuration object specifying: - - strategy: 'stratified', 'kfold', 'group_kfold', 'leave_p_out', etc. - - n_splits: Number of folds (where applicable). - - shuffle: Whether to shuffle data (where applicable). - - random_state: Seed for reproducibility. - groups : sequence, optional - Group labels for the samples. Required for 'group_kfold', 'leave_p_out', - and 'stratified_group_kfold'. - If provided, the returned validator will ignore any groups passed to its - `.split()` method and use these instead. - - Returns - ------- - BaseCrossValidator - An initialized cross-validator instance. - - Raises - ------ - ValueError - If an unknown CV strategy is specified or if required parameters (like - n_groups for leave_p_out) are missing from the configuration. - """ - strat = config.strategy.lower() - - # Common arguments - common_kwargs = {} - if strat not in ["leave_one_out", "leave_p_out", "split"]: - common_kwargs["n_splits"] = config.n_splits - - if strat in ["stratified", "kfold", "stratified_group_kfold", "split"]: - common_kwargs["shuffle"] = config.shuffle - common_kwargs["random_state"] = config.random_state if config.shuffle else None - - # Strategy Selection - if strat == "stratified": - splitter = StratifiedKFold(**common_kwargs) - - elif strat == "kfold": - splitter = KFold(**common_kwargs) - - elif strat == "group_kfold": - # GroupKFold doesn't take shuffle/random_state - splitter = GroupKFold(n_splits=config.n_splits) - - elif strat == "stratified_group_kfold": - splitter = StratifiedGroupKFold(**common_kwargs) - - elif strat == "leave_p_out": - splitter = LeavePGroupsOut(n_groups=config.n_splits) - - elif strat == "leave_one_out": - splitter = LeaveOneGroupOut() - - elif strat == "split": - splitter = SimpleSplit( - test_size=0.2, - shuffle=config.shuffle, - random_state=config.random_state, - ) - - else: - raise ValueError(f"Unknown CV strategy: {config.strategy}") - - # if the user provided groups, wrap the splitter so .split always sees them - if groups is not None: - splitter = _CVWithGroups(splitter, groups) - - return splitter - - -def get_scorer(name: str) -> Callable: - """ - Retrieve or construct a Scikit-Learn compliant scorer by name. - - Parameters - ---------- - name : str - The name of the metric (e.g., 'accuracy', 'f1_macro', 'neg_mean_squared_error'). - - Returns - ------- - Callable - A scoring function with signature `(y_true, y_pred) -> float`. - - Raises - ------ - ValueError - If the metric name is unknown. - """ - metrics = { - # Classification - "accuracy": accuracy_score, - "balanced_accuracy": balanced_accuracy_score, - "roc_auc": roc_auc_score, - "f1": lambda y, p: f1_score(y, p, average="weighted"), - "f1_macro": lambda y, p: f1_score(y, p, average="macro"), - "f1_micro": lambda y, p: f1_score(y, p, average="micro"), - "precision": lambda y, p: precision_score( - y, p, average="weighted", zero_division=0 - ), - "recall": lambda y, p: recall_score(y, p, average="weighted", zero_division=0), - # Regression - "r2": r2_score, - "neg_mean_squared_error": lambda y, p: -mean_squared_error(y, p), - "neg_mean_absolute_error": lambda y, p: -mean_absolute_error(y, p), - "explained_variance": explained_variance_score, - } - - if name not in metrics: - raise ValueError( - f"Unknown metric '{name}'. Available: {sorted(list(metrics.keys()))}" - ) - return metrics[name] - - -def cross_validate_score( - estimator: BaseEstimator, - X: np.ndarray, - y: Sequence, - *, - groups: Optional[Sequence] = None, - cv_config: Optional[CVConfig] = None, - metric: str = "balanced_accuracy", - use_scaler: bool = False, -) -> float: - """ - Compute one mean cross-validated score for an estimator. - - Parameters - ---------- - estimator : BaseEstimator - Estimator to fit inside each fold. - X : np.ndarray - Input features with shape ``(n_samples, n_features)``. - y : sequence - Target labels aligned with ``X``. - groups : sequence, optional - Group labels aligned with ``X``. - cv_config : CVConfig, optional - Cross-validation configuration. Defaults to a 5-fold stratified - strategy, or 5-fold stratified-group strategy when groups are - provided. - metric : str, default="balanced_accuracy" - Metric name resolved through :func:`get_scorer`. - use_scaler : bool, default=False - When ``True``, wraps the estimator in a ``StandardScaler`` pipeline. - - Returns - ------- - float - Mean cross-validated score. - """ - X = np.asarray(X) - y = np.asarray(y).reshape(-1) - if len(X) != len(y): - raise ValueError("X and y must have matching sample counts.") - - group_values = None - if groups is not None: - group_values = np.asarray(groups).reshape(-1) - if len(group_values) != len(y): - raise ValueError("groups must align with X and y.") - - if cv_config is None: - cv_config = CVConfig( - strategy="stratified_group_kfold" - if group_values is not None - else "stratified", - n_splits=5, - shuffle=True, - random_state=42, - ) - - scorer = get_scorer(metric) - cv = get_cv_splitter(cv_config, groups=group_values) - base_estimator = estimator - if use_scaler: - base_estimator = Pipeline( - [("scaler", StandardScaler()), ("clf", clone(estimator))] - ) - - scores = [] - for train_idx, test_idx in cv.split(X, y, group_values): - model = clone(base_estimator) - model.fit(X[train_idx], y[train_idx]) - y_pred = model.predict(X[test_idx]) - scores.append(float(scorer(y[test_idx], y_pred))) - - return float(np.nanmean(scores)) if scores else float("nan") diff --git a/coco_pipe/dim_reduction/evaluation/_supervised.py b/coco_pipe/dim_reduction/evaluation/_supervised.py new file mode 100644 index 0000000..8f40653 --- /dev/null +++ b/coco_pipe/dim_reduction/evaluation/_supervised.py @@ -0,0 +1,110 @@ +""" +Private supervised scoring helpers for dim-reduction evaluation. + +These helpers are intentionally local to ``coco_pipe.dim_reduction``. They are +used to score embedding separability and are not part of the decoding API. +""" + +from __future__ import annotations + +from typing import Optional, Sequence + +import numpy as np +from sklearn.base import BaseEstimator, clone +from sklearn.metrics import balanced_accuracy_score +from sklearn.model_selection import GroupKFold, StratifiedGroupKFold, StratifiedKFold +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + + +def _cv_random_state(shuffle: bool, random_state: Optional[int]) -> Optional[int]: + return random_state if shuffle else None + + +def _make_splitter( + strategy: str, + *, + n_splits: int, + shuffle: bool, + random_state: Optional[int], + groups: Optional[np.ndarray], +): + if strategy == "stratified_group_kfold": + if groups is None: + raise ValueError("groups are required for stratified_group_kfold.") + return StratifiedGroupKFold( + n_splits=n_splits, + shuffle=shuffle, + random_state=_cv_random_state(shuffle, random_state), + ) + if strategy == "group_kfold": + if groups is None: + raise ValueError("groups are required for group_kfold.") + return GroupKFold(n_splits=n_splits) + if strategy == "stratified": + return StratifiedKFold( + n_splits=n_splits, + shuffle=shuffle, + random_state=_cv_random_state(shuffle, random_state), + ) + raise ValueError(f"Unsupported supervised scoring CV strategy: {strategy}.") + + +def _cross_validate_score( + estimator: BaseEstimator, + X: np.ndarray, + y: Sequence, + *, + groups: Optional[Sequence] = None, + cv_strategy: str = "stratified", + n_splits: int = 5, + shuffle: bool = True, + random_state: Optional[int] = 42, + metric: str = "balanced_accuracy", + use_scaler: bool = False, +) -> float: + """ + Compute a mean supervised CV score for dim-reduction separation metrics. + + This private helper is deliberately narrow. It currently exists for + ``separation_logreg_balanced_accuracy`` and supports only the CV strategies + used by dim-reduction evaluation. + """ + if metric != "balanced_accuracy": + raise ValueError( + "Dim-reduction supervised scoring currently supports only " + "'balanced_accuracy'." + ) + + X_values = np.asarray(X) + y_values = np.asarray(y).reshape(-1) + if len(X_values) != len(y_values): + raise ValueError("X and y must have matching sample counts.") + + group_values = None + if groups is not None: + group_values = np.asarray(groups).reshape(-1) + if len(group_values) != len(y_values): + raise ValueError("groups must align with X and y.") + + splitter = _make_splitter( + cv_strategy, + n_splits=n_splits, + shuffle=shuffle, + random_state=random_state, + groups=group_values, + ) + base_estimator = estimator + if use_scaler: + base_estimator = Pipeline( + [("scaler", StandardScaler()), ("clf", clone(estimator))] + ) + + scores = [] + for train_idx, test_idx in splitter.split(X_values, y_values, group_values): + model = clone(base_estimator) + model.fit(X_values[train_idx], y_values[train_idx]) + y_pred = model.predict(X_values[test_idx]) + scores.append(float(balanced_accuracy_score(y_values[test_idx], y_pred))) + + return float(np.nanmean(scores)) if scores else float("nan") diff --git a/coco_pipe/dim_reduction/evaluation/core.py b/coco_pipe/dim_reduction/evaluation/core.py index a97bd1b..021f53b 100644 --- a/coco_pipe/dim_reduction/evaluation/core.py +++ b/coco_pipe/dim_reduction/evaluation/core.py @@ -41,8 +41,7 @@ if TYPE_CHECKING: from ..core import DimReduction -from ...decoding.configs import CVConfig -from ...decoding.utils import cross_validate_score +from ._supervised import _cross_validate_score from .geometry import ( trajectory_acceleration, trajectory_curvature, @@ -620,17 +619,15 @@ def evaluate_embedding( f"`labels` and `groups` are required for " f"'{SEPARATION_LOGREG_BALANCED_ACCURACY}'." ) - separation_score = cross_validate_score( + separation_score = _cross_validate_score( LogisticRegression(max_iter=1000, class_weight="balanced"), X_emb, labels, groups=groups, - cv_config=CVConfig( - strategy="stratified_group_kfold", - n_splits=5, - shuffle=True, - random_state=42, - ), + cv_strategy="stratified_group_kfold", + n_splits=5, + shuffle=True, + random_state=42, metric="balanced_accuracy", use_scaler=True, ) diff --git a/coco_pipe/fm/__init__.py b/coco_pipe/fm/__init__.py index 0b05497..c196895 100644 --- a/coco_pipe/fm/__init__.py +++ b/coco_pipe/fm/__init__.py @@ -1,8 +1,3 @@ -"""Foundation model pipelines for CoCo Pipe.""" +"""Foundation-model integration namespace for CoCo Pipe.""" -from .cbramod import CBRAModRegressionPipeline, FoundationRegressor - -__all__ = [ - "CBRAModRegressionPipeline", - "FoundationRegressor", -] +__all__: list[str] = [] diff --git a/coco_pipe/fm/cbramod/__init__.py b/coco_pipe/fm/cbramod/__init__.py index bc61592..085a12f 100644 --- a/coco_pipe/fm/cbramod/__init__.py +++ b/coco_pipe/fm/cbramod/__init__.py @@ -1,12 +1,3 @@ -""" -Foundation model pipelines for CoCo Pipe. +"""CBRAMod integration namespace.""" -This module hosts pipelines built on top of the CBRAMod foundation model. -""" - -from .pipeline import CBRAModRegressionPipeline, FoundationRegressor - -__all__ = [ - "CBRAModRegressionPipeline", - "FoundationRegressor", -] +__all__: list[str] = [] diff --git a/coco_pipe/fm/cbramod/pipeline.py b/coco_pipe/fm/cbramod/pipeline.py deleted file mode 100644 index d5c1921..0000000 --- a/coco_pipe/fm/cbramod/pipeline.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Pipeline wrapper for running regression with the CBRAMod foundation model. - -This mirrors the public surface of the classical regression pipeline but swaps -the estimator for a foundation-model-backed regressor. The embedder is expected -to turn raw inputs into embeddings; a light downstream regressor is trained on -top of those embeddings. -""" - -from typing import Any, Callable, Dict, Optional, Sequence, Union - -import numpy as np -import pandas as pd -from sklearn.base import BaseEstimator, RegressorMixin -from sklearn.linear_model import Ridge -from sklearn.multioutput import MultiOutputRegressor - -from coco_pipe.ml.base import BasePipeline -from coco_pipe.ml.config import DEFAULT_CV, REGRESSION_METRICS - - -class FoundationRegressor(BaseEstimator, RegressorMixin): - """ - Thin sklearn-compatible wrapper around a foundation model embedder. - - Parameters - ---------- - embed_fn : callable - Callable that maps ``X`` to a 2D numpy array of embeddings. Either a - ``__call__`` or ``transform`` method will be used. - base_regressor : BaseEstimator, optional - Downstream regressor trained on the embeddings. Defaults to ``Ridge``. - multioutput : bool, optional - If True, wraps the base regressor in ``MultiOutputRegressor`` for - multi-target regression. - """ - - def __init__( - self, - embed_fn: Callable[[Any], np.ndarray], - base_regressor: Optional[BaseEstimator] = None, - multioutput: bool = False, - ): - self.embed_fn = embed_fn - self.base_regressor = base_regressor or Ridge(random_state=42) - self.multioutput = multioutput - - if self.multioutput: - self.base_regressor = MultiOutputRegressor(self.base_regressor) - - def _embed(self, X: Any) -> np.ndarray: - if hasattr(self.embed_fn, "transform"): - emb = self.embed_fn.transform(X) - else: - emb = self.embed_fn(X) - emb = np.asarray(emb) - if emb.ndim != 2: - raise ValueError( - f"Expected 2D embeddings, got shape {emb.shape} from embed_fn" - ) - return emb - - def fit(self, X: Any, y: Any): - X_emb = self._embed(X) - self.base_regressor.fit(X_emb, y) - return self - - def predict(self, X: Any) -> np.ndarray: - X_emb = self._embed(X) - return self.base_regressor.predict(X_emb) - - -class CBRAModRegressionPipeline(BasePipeline): - """ - Regression pipeline that routes features through the CBRAMod foundation - model before fitting a lightweight regressor. - - The interface mirrors ``RegressionPipeline`` for analysis_type dispatch but - exposes a required ``embed_fn`` to obtain embeddings from the foundation - model. - """ - - def __init__( - self, - X: Union[pd.DataFrame, np.ndarray], - y: Union[pd.Series, np.ndarray], - embed_fn: Callable[[Any], np.ndarray], - metrics: Union[str, Sequence[str], None] = None, - base_regressor: Optional[BaseEstimator] = None, - hp_search_params: Optional[Dict[str, Sequence[Any]]] = None, - use_scaler: bool = False, - random_state: int = 42, - n_jobs: int = -1, - cv_kwargs: Optional[Dict[str, Any]] = None, - groups: Optional[Union[pd.Series, np.ndarray]] = None, - verbose: bool = False, - ): - self.embed_fn = embed_fn - self.verbose = verbose - - # Determine if we need multi-output support for the downstream regressor - is_multi = hasattr(y, "ndim") and getattr(y, "ndim", 1) == 2 - - metric_funcs = REGRESSION_METRICS - default_metrics = [metrics] if isinstance(metrics, str) else (metrics or ["r2"]) - - foundation_estimator = FoundationRegressor( - embed_fn=embed_fn, - base_regressor=base_regressor, - multioutput=is_multi, - ) - - model_configs = { - "CBRAMod": { - "estimator": foundation_estimator, - "default_params": {}, - "hp_search_params": hp_search_params - or ({} if is_multi else {"base_regressor__alpha": [0.1, 1.0, 10.0]}), - } - } - - cv = dict(DEFAULT_CV) - # Regression should not stratify continuous targets; default to kfold - cv["cv_strategy"] = "kfold" - if cv_kwargs: - cv.update(cv_kwargs) - - super().__init__( - X=X, - y=y, - metric_funcs=metric_funcs, - model_configs=model_configs, - use_scaler=use_scaler, - default_metrics=default_metrics, - cv_kwargs=cv, - groups=groups, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - ) - - self.model_name = "CBRAMod" - - def run( - self, - analysis_type: str = "baseline", - n_features: Optional[int] = None, - direction: str = "forward", - search_type: str = "grid", - n_iter: int = 50, - scoring: Optional[str] = None, - ) -> Dict[str, Any]: - analysis_type = analysis_type.lower() - if analysis_type not in { - "baseline", - "feature_selection", - "hp_search", - "hp_search_fs", - }: - raise ValueError(f"Invalid analysis type: {analysis_type}") - - if analysis_type == "baseline": - return self.baseline_evaluation(self.model_name) - - if analysis_type == "feature_selection": - return self.feature_selection( - self.model_name, - n_features=n_features, - direction=direction, - scoring=scoring, - ) - - if analysis_type == "hp_search": - return self.hp_search( - self.model_name, - param_grid=None, - search_type=search_type, - n_iter=n_iter, - scoring=scoring, - ) - - return self.hp_search_fs( - self.model_name, - param_grid=None, - search_type=search_type, - n_features=n_features, - direction=direction, - n_iter=n_iter, - scoring=scoring, - ) diff --git a/coco_pipe/report/core.py b/coco_pipe/report/core.py index a67e8d9..0709b3d 100644 --- a/coco_pipe/report/core.py +++ b/coco_pipe/report/core.py @@ -11,6 +11,7 @@ import html import io import json +import logging import re import uuid from abc import ABC, abstractmethod @@ -31,6 +32,8 @@ check_outliers_zscore, ) +logger = logging.getLogger(__name__) + def _get_reducer_summary(reducer: Any) -> Dict[str, Any]: """Collect the strict summary payload from a reduction-like object.""" @@ -1121,8 +1124,8 @@ def add_raw_preview(self, data: Any, name: str = "Raw Data Inspector") -> "Repor res_outlier = check_outliers_zscore(sample_X) if res_outlier: sec.add_finding(res_outlier) - except Exception: - pass + except Exception as e: + logger.debug(f"Data quality checks failed: {e}") # Ensure 2D if hasattr(X, "ndim") and X.ndim == 1: @@ -1206,6 +1209,263 @@ def add_comparison( self.add_section(sec) return self + def add_decoding_temporal( + self, + result: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + name: str = "Temporal Decoding", + ) -> "Report": + """ + Add a compact temporal decoding section from an ExperimentResult. + """ + from coco_pipe.viz.decoding import ( + plot_temporal_generalization_matrix, + plot_temporal_score_curve, + ) + + if not hasattr(result, "get_temporal_score_summary"): + raise TypeError( + "result must provide get_temporal_score_summary() for temporal " + "decoding report sections." + ) + + summary = result.get_temporal_score_summary() + if metric is not None: + summary = summary[summary["Metric"] == metric] + if model is not None: + summary = summary[summary["Model"] == model] + if summary.empty: + raise ValueError("No temporal decoding scores available for report.") + + sec = Section(title=name, icon="📈") + sec.add_element(TableElement(summary, title="Temporal Score Summary")) + + if "Time" in summary and summary["Time"].notna().any(): + fig_curve = plot_temporal_score_curve( + summary, metric=metric, model=model, title="Temporal Score Curve" + ) + sec.add_element(ImageElement(fig_curve, caption="Temporal score curve")) + + if ( + {"TrainTime", "TestTime"}.issubset(summary.columns) + and summary["TrainTime"].notna().any() + and summary["TestTime"].notna().any() + ): + fig_matrix = plot_temporal_generalization_matrix( + summary, + metric=metric, + model=model, + title="Temporal Generalization Matrix", + ) + sec.add_element( + ImageElement(fig_matrix, caption="Temporal generalization matrix") + ) + + def add_decoding_summary( + self, + result: Any, + name: str = "Decoding Summary", + ) -> "Report": + """Add a summary performance table for all models.""" + if not hasattr(result, "summary"): + raise TypeError("result must provide a summary() method.") + + summary = result.summary() + if summary.empty: + return self + + sec = Section(title=name) + sec.add_element( + MetricsTableElement( + summary, + title="Model Performance Summary", + highlight_best=True, + ) + ) + self.add_section(sec) + return self + + def add_decoding_diagnostics( + self, + result: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + name: str = "Decoding Diagnostics", + ) -> "Report": + """Add shallow decoding diagnostics from an ExperimentResult.""" + from coco_pipe.viz.decoding import ( + plot_calibration_curve, + plot_confusion_matrix, + plot_fold_score_dispersion, + plot_pr_curve, + plot_roc_curve, + ) + + required = [ + "get_detailed_scores", + "get_fit_diagnostics", + "get_confusion_matrices", + ] + if not all(hasattr(result, attr) for attr in required): + raise TypeError("result must provide decoding diagnostic accessors.") + + sec = Section(title=name) + + meta = getattr(result, "meta", {}) or {} + if meta.get("observation_level") or meta.get("inferential_unit"): + inference_context = pd.DataFrame( + [ + { + "ObservationLevel": meta.get("observation_level", "sample"), + "InferentialUnit": meta.get("inferential_unit", "sample"), + } + ] + ) + sec.add_element(TableElement(inference_context, title="Inference Context")) + + scores = result.get_detailed_scores() + if metric is not None and "Metric" in scores: + scores = scores[scores["Metric"] == metric] + if model is not None and "Model" in scores: + scores = scores[scores["Model"] == model] + if not scores.empty: + sec.add_element(TableElement(scores, title="Fold Scores")) + try: + fig_scores = plot_fold_score_dispersion( + scores, + metric=metric, + model=model, + title="Fold Score Dispersion", + ) + sec.add_element( + ImageElement(fig_scores, caption="Fold score dispersion") + ) + except ValueError as e: + logger.debug(f"Could not plot fold score dispersion: {e}") + + diagnostics = result.get_fit_diagnostics() + if model is not None and "Model" in diagnostics: + diagnostics = diagnostics[diagnostics["Model"] == model] + if not diagnostics.empty: + # 1. Clean Timing Table + timing_cols = ["Model", "Fold", "FitTime", "PredictTime", "TotalTime"] + timing_cols = [c for c in timing_cols if c in diagnostics.columns] + timings = diagnostics[timing_cols].drop_duplicates() + sec.add_element(TableElement(timings, title="Fit Diagnostics")) + + # 2. Warnings Table (only if they exist) + warns = diagnostics[diagnostics["WarningMessage"].notna()] + if not warns.empty: + warn_cols = [ + "Model", + "Fold", + "Stage", + "WarningCategory", + "WarningMessage", + ] + warn_cols = [c for c in warn_cols if c in warns.columns] + sec.add_element( + TableElement(warns[warn_cols], title="Training Warnings") + ) + + confusion = result.get_confusion_matrices(model=model) + if not confusion.empty: + sec.add_element(TableElement(confusion, title="Confusion Matrix Data")) + try: + fig_confusion = plot_confusion_matrix( + confusion, + model=model, + title="Confusion Matrix", + ) + sec.add_element(ImageElement(fig_confusion, caption="Confusion matrix")) + except ValueError as e: + logger.debug(f"Could not plot confusion matrix: {e}") + + for title, plotter in [ + ("ROC Curve", plot_roc_curve), + ("Precision-Recall Curve", plot_pr_curve), + ("Calibration Curve", plot_calibration_curve), + ]: + try: + fig = plotter(result, model=model, title=title) + sec.add_element(ImageElement(fig, caption=title)) + except ValueError as e: + logger.debug(f"Could not plot curve '{title}': {e}") + + self.add_section(sec) + return self + + def add_decoding_statistical_assessment( + self, + result: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + name: str = "Statistical Assessment", + ) -> "Report": + """Add finite-sample decoding statistical assessment rows and plots.""" + from coco_pipe.viz.decoding import plot_temporal_statistical_assessment + + if not hasattr(result, "get_statistical_assessment"): + raise TypeError("result must provide get_statistical_assessment().") + + assessment = result.get_statistical_assessment() + if metric is not None and "Metric" in assessment: + assessment = assessment[assessment["Metric"] == metric] + if model is not None and "Model" in assessment: + assessment = assessment[assessment["Model"] == model] + if assessment.empty: + raise ValueError("No statistical assessment rows available for report.") + + sec = Section(title=name) + sec.add_element( + TableElement(assessment, title="Finite-Sample Statistical Assessment") + ) + + if "Time" in assessment and assessment["Time"].notna().any(): + try: + fig = plot_temporal_statistical_assessment( + assessment, + metric=metric, + model=model, + title="Temporal Statistical Assessment", + ) + sec.add_element( + ImageElement(fig, caption="Temporal statistical assessment") + ) + except ValueError as e: + logger.debug(f"Could not plot temporal statistical assessment: {e}") + + self.add_section(sec) + return self + + def add_decoding_neural_artifacts( + self, + result: Any, + model: Optional[str] = None, + name: str = "Neural Artifacts", + ) -> "Report": + """Add neural/foundation-model artifact metadata to the report.""" + from coco_pipe.viz.decoding import plot_training_history + + if not hasattr(result, "get_model_artifacts"): + raise TypeError("result must provide get_model_artifacts().") + artifacts = result.get_model_artifacts() + if model is not None and "Model" in artifacts: + artifacts = artifacts[artifacts["Model"] == model] + if artifacts.empty: + raise ValueError("No model artifacts available for report.") + + sec = Section(title=name) + sec.add_element(TableElement(artifacts, title="Model Artifacts")) + try: + fig = plot_training_history(artifacts, model=model) + sec.add_element(ImageElement(fig, caption="Training history")) + except ValueError: + pass + self.add_section(sec) + return self + def render(self) -> str: """ Render the full HTML report. diff --git a/coco_pipe/viz/__init__.py b/coco_pipe/viz/__init__.py index a265acc..d01cffb 100644 --- a/coco_pipe/viz/__init__.py +++ b/coco_pipe/viz/__init__.py @@ -1,6 +1,18 @@ #!/usr/bin/env python3 """Curated plotting helpers for coco_pipe.""" +from .decoding import ( + plot_calibration_curve, + plot_confusion_matrix, + plot_fold_score_dispersion, + plot_pr_curve, + plot_roc_curve, + plot_statistical_null_distribution, + plot_temporal_generalization_matrix, + plot_temporal_score_curve, + plot_temporal_statistical_assessment, + plot_training_history, +) from .dim_reduction import ( plot_eigenvalues, plot_embedding, @@ -33,6 +45,16 @@ "plot_interpretation", "plot_trajectory", "plot_trajectory_metric_series", + "plot_confusion_matrix", + "plot_roc_curve", + "plot_pr_curve", + "plot_calibration_curve", + "plot_fold_score_dispersion", + "plot_statistical_null_distribution", + "plot_training_history", + "plot_temporal_score_curve", + "plot_temporal_generalization_matrix", + "plot_temporal_statistical_assessment", "plot_local_metrics", "plot_channel_traces_interactive", ] diff --git a/coco_pipe/viz/decoding.py b/coco_pipe/viz/decoding.py new file mode 100644 index 0000000..e65d83b --- /dev/null +++ b/coco_pipe/viz/decoding.py @@ -0,0 +1,610 @@ +""" +Decoding Visualization +====================== + +Focused plotting helpers for decoding result tables. +""" + +from typing import Any, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def _temporal_summary_frame(result_or_scores: Any) -> pd.DataFrame: + if hasattr(result_or_scores, "get_temporal_score_summary"): + return result_or_scores.get_temporal_score_summary() + return pd.DataFrame(result_or_scores) + + +def _filter_temporal_summary( + summary: pd.DataFrame, + metric: Optional[str] = None, + model: Optional[str] = None, +) -> pd.DataFrame: + frame = summary.copy() + if metric is not None: + frame = frame[frame["Metric"] == metric] + if model is not None: + frame = frame[frame["Model"] == model] + return frame + + +def _result_frame(result_or_frame: Any, accessor: str) -> pd.DataFrame: + if hasattr(result_or_frame, accessor): + return getattr(result_or_frame, accessor)() + return pd.DataFrame(result_or_frame) + + +def _filter_frame( + frame: pd.DataFrame, + model: Optional[str] = None, + fold: Optional[int] = None, +) -> pd.DataFrame: + data = frame.copy() + if model is not None and "Model" in data: + data = data[data["Model"] == model] + if fold is not None and "Fold" in data: + data = data[data["Fold"] == fold] + return data + + +def plot_confusion_matrix( + result_or_matrix: Any, + model: Optional[str] = None, + fold: Optional[int] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot an aggregated confusion matrix from decoding diagnostics.""" + frame = _filter_frame( + _result_frame(result_or_matrix, "get_confusion_matrices"), + model=model, + fold=fold, + ) + if frame.empty: + raise ValueError("No confusion-matrix rows available to plot.") + + matrix = frame.pivot_table( + index="TrueLabel", + columns="PredictedLabel", + values="Value", + aggfunc="sum", + fill_value=0, + ) + + if ax is None: + fig, ax = plt.subplots(figsize=(6, 5)) + else: + fig = ax.get_figure() + + im = ax.imshow(np.asarray(matrix, dtype=float), cmap="Blues") + ax.set_xticks(np.arange(matrix.shape[1])) + ax.set_xticklabels([str(value) for value in matrix.columns], rotation=45) + ax.set_yticks(np.arange(matrix.shape[0])) + ax.set_yticklabels([str(value) for value in matrix.index]) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + ax.set_title(title or "Confusion Matrix") + for row_idx in range(matrix.shape[0]): + for col_idx in range(matrix.shape[1]): + ax.text( + col_idx, + row_idx, + f"{matrix.iloc[row_idx, col_idx]:.3g}", + ha="center", + va="center", + ) + fig.colorbar(im, ax=ax, label="Count") + return fig + + +def plot_roc_curve( + result_or_curve: Any, + model: Optional[str] = None, + fold: Optional[int] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, + mean_only: bool = False, +): + """Plot ROC curves from decoding curve diagnostics.""" + frame = _filter_frame( + _result_frame(result_or_curve, "get_roc_curve"), + model=model, + fold=fold, + ) + if frame.empty: + raise ValueError("No ROC curve rows available to plot.") + + if ax is None: + fig, ax = plt.subplots(figsize=(6, 5)) + else: + fig = ax.get_figure() + + group_cols = ["Model"] + (["Class"] if "Class" in frame else []) + for keys, group in frame.groupby(group_cols): + if not isinstance(keys, tuple): + keys = (keys,) + + if mean_only: + # Pivot to align FPR and interpolate TPR + from scipy import interpolate + + all_fpr = np.unique( + np.concatenate([g["FPR"].to_numpy() for _, g in group.groupby("Fold")]) + ) + all_fpr = np.sort(all_fpr) + tprs = [] + for _, fold_group in group.groupby("Fold"): + interp = interpolate.interp1d( + fold_group["FPR"], + fold_group["TPR"], + bounds_error=False, + fill_value=(0, 1), + ) + tprs.append(interp(all_fpr)) + mean_tpr = np.mean(tprs, axis=0) + std_tpr = np.std(tprs, axis=0) + label = f"{keys[0]}" + if len(keys) > 1: + label = f"{label} class {keys[1]}" + ax.plot(all_fpr, mean_tpr, label=f"{label} (mean)", linewidth=2) + ax.fill_between(all_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, alpha=0.15) + else: + for f_idx, fold_group in group.groupby("Fold"): + label = f"{keys[0]} fold {f_idx}" + if len(keys) > 1: + label = f"{label} class {keys[1]}" + ax.plot(fold_group["FPR"], fold_group["TPR"], label=label, alpha=0.5) + ax.plot([0, 1], [0, 1], linestyle="--", color="0.5", linewidth=1) + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + ax.set_title(title or "ROC Curve") + ax.legend(frameon=False) + ax.grid(True, linestyle="--", alpha=0.3) + return fig + + +def plot_pr_curve( + result_or_curve: Any, + model: Optional[str] = None, + fold: Optional[int] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, + mean_only: bool = False, +): + """Plot precision-recall curves from decoding diagnostics.""" + frame = _filter_frame( + _result_frame(result_or_curve, "get_pr_curve"), + model=model, + fold=fold, + ) + if frame.empty: + raise ValueError("No precision-recall curve rows available to plot.") + + if ax is None: + fig, ax = plt.subplots(figsize=(6, 5)) + else: + fig = ax.get_figure() + + group_cols = ["Model"] + (["Class"] if "Class" in frame else []) + for keys, group in frame.groupby(group_cols): + if not isinstance(keys, tuple): + keys = (keys,) + + if mean_only: + from scipy import interpolate + + all_recall = np.unique( + np.concatenate( + [g["Recall"].to_numpy() for _, g in group.groupby("Fold")] + ) + ) + all_recall = np.sort(all_recall) + precs = [] + for _, fold_group in group.groupby("Fold"): + # PR curves are not necessarily monotonic, but interpolation is + # standard for mean PR + interp = interpolate.interp1d( + fold_group["Recall"], + fold_group["Precision"], + bounds_error=False, + fill_value=(1, 0), + ) + precs.append(interp(all_recall)) + mean_prec = np.mean(precs, axis=0) + std_prec = np.std(precs, axis=0) + label = f"{keys[0]}" + if len(keys) > 1: + label = f"{label} class {keys[1]}" + ax.plot(all_recall, mean_prec, label=f"{label} (mean)", linewidth=2) + ax.fill_between( + all_recall, mean_prec - std_prec, mean_prec + std_prec, alpha=0.15 + ) + else: + for f_idx, fold_group in group.groupby("Fold"): + label = f"{keys[0]} fold {f_idx}" + if len(keys) > 1: + label = f"{label} class {keys[1]}" + ax.plot( + fold_group["Recall"], + fold_group["Precision"], + label=label, + alpha=0.5, + ) + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_title(title or "Precision-Recall Curve") + ax.legend(frameon=False) + ax.grid(True, linestyle="--", alpha=0.3) + return fig + + +def plot_calibration_curve( + result_or_curve: Any, + model: Optional[str] = None, + fold: Optional[int] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot reliability curves from decoding calibration diagnostics.""" + frame = _filter_frame( + _result_frame(result_or_curve, "get_calibration_curve"), + model=model, + fold=fold, + ) + if frame.empty: + raise ValueError("No calibration curve rows available to plot.") + + if ax is None: + fig, ax = plt.subplots(figsize=(6, 5)) + else: + fig = ax.get_figure() + + group_cols = ["Model", "Fold"] + (["Class"] if "Class" in frame else []) + for keys, group in frame.groupby(group_cols): + if not isinstance(keys, tuple): + keys = (keys,) + label = f"{keys[0]} fold {keys[1]}" + if len(keys) > 2: + label = f"{label} class {keys[2]}" + ax.plot( + group["MeanPredictedProbability"], + group["FractionPositive"], + marker="o", + label=label, + ) + ax.plot([0, 1], [0, 1], linestyle="--", color="0.5", linewidth=1) + ax.set_xlabel("Mean Predicted Probability") + ax.set_ylabel("Fraction Positive") + ax.set_title(title or "Calibration Curve") + ax.legend(frameon=False) + ax.grid(True, linestyle="--", alpha=0.3) + return fig + + +def plot_fold_score_dispersion( + result_or_scores: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot scalar fold-score dispersion by model and metric.""" + frame = _result_frame(result_or_scores, "get_detailed_scores") + if frame.empty: + raise ValueError("No fold score rows available to plot.") + if metric is not None: + frame = frame[frame["Metric"] == metric] + if model is not None: + frame = frame[frame["Model"] == model] + if "Value" not in frame: + raise ValueError("No scalar fold score rows available to plot.") + frame = frame[frame["Value"].notna()].copy() + for column in ["Time", "TrainTime", "TestTime"]: + if column in frame: + frame = frame[frame[column].isna()] + if frame.empty: + raise ValueError("No scalar fold score rows available to plot.") + + labels = [] + values = [] + for (model_name, metric_name), group in frame.groupby(["Model", "Metric"]): + labels.append(f"{model_name}\n{metric_name}") + values.append(group["Value"].astype(float).to_numpy()) + + if ax is None: + fig, ax = plt.subplots(figsize=figsize or (max(6, len(labels) * 1.5), 5)) + else: + fig = ax.get_figure() + + for idx, val_set in enumerate(values): + x = np.random.normal(idx + 1, 0.04, size=len(val_set)) + ax.scatter(x, val_set, alpha=0.5, color="black", zorder=3, s=15) + + if all(len(v) >= 8 for v in values) and len(values) > 0: + ax.violinplot(values, showmeans=True, showmedians=False) + ax.set_xticks(np.arange(1, len(labels) + 1)) + ax.set_xticklabels(labels) + else: + ax.boxplot(values, tick_labels=labels, showmeans=True) + ax.set_ylabel("Score") + ax.set_title(title or "Fold Score Dispersion") + ax.grid(True, axis="y", linestyle="--", alpha=0.3) + return fig + + +def plot_temporal_score_curve( + result_or_scores: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """ + Plot mean temporal decoding score curves. + + Parameters + ---------- + result_or_scores : ExperimentResult or DataFrame-like + Result object or output from ``get_temporal_score_summary()``. + metric, model : str, optional + Optional filters. + title : str, optional + Figure title. + ax : matplotlib.axes.Axes, optional + Existing axes to draw on. + """ + summary = _filter_temporal_summary( + _temporal_summary_frame(result_or_scores), metric=metric, model=model + ) + if summary.empty or "Time" not in summary: + raise ValueError("No 1D temporal score rows available to plot.") + + curve_data = summary[summary["Time"].notna()].copy() + if curve_data.empty: + raise ValueError("No 1D temporal score rows available to plot.") + + if ax is None: + fig, ax = plt.subplots(figsize=(10, 5)) + else: + fig = ax.get_figure() + + for (model_name, metric_name), group in curve_data.groupby(["Model", "Metric"]): + times = group["Time"].to_numpy() + numeric_times = pd.api.types.is_numeric_dtype(group["Time"]) + x_vals = times if numeric_times else np.arange(len(times)) + label = f"{model_name} / {metric_name}" + ax.plot(x_vals, group["Mean"], marker="o", linewidth=2, label=label) + if "Std" in group: + ax.fill_between( + x_vals, + group["Mean"] - group["Std"], + group["Mean"] + group["Std"], + alpha=0.15, + ) + if not numeric_times: + ax.set_xticks(x_vals) + ax.set_xticklabels([str(value) for value in times], rotation=45) + + ax.set_title(title or "Temporal Decoding Score") + ax.set_xlabel("Time") + ax.set_ylabel("Score") + ax.grid(True, linestyle="--", alpha=0.3) + ax.legend(frameon=False) + return fig + + +def plot_temporal_generalization_matrix( + result_or_scores: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """ + Plot a train-time by test-time temporal generalization heatmap. + + Parameters + ---------- + result_or_scores : ExperimentResult or DataFrame-like + Result object or output from ``get_temporal_score_summary()``. + metric, model : str, optional + Optional filters. When omitted and multiple matrices are present, the + first model/metric pair in the table is plotted. + title : str, optional + Figure title. + ax : matplotlib.axes.Axes, optional + Existing axes to draw on. + """ + summary = _filter_temporal_summary( + _temporal_summary_frame(result_or_scores), metric=metric, model=model + ) + required = {"TrainTime", "TestTime", "Mean"} + if summary.empty or not required.issubset(summary.columns): + raise ValueError("No temporal generalization matrix rows available to plot.") + + matrix_data = summary[ + summary["TrainTime"].notna() & summary["TestTime"].notna() + ].copy() + if matrix_data.empty: + raise ValueError("No temporal generalization matrix rows available to plot.") + + first = matrix_data.iloc[0] + matrix_data = matrix_data[ + (matrix_data["Model"] == first["Model"]) + & (matrix_data["Metric"] == first["Metric"]) + ] + train_order = pd.unique(matrix_data["TrainTime"]) + test_order = pd.unique(matrix_data["TestTime"]) + matrix = matrix_data.pivot(index="TrainTime", columns="TestTime", values="Mean") + matrix = matrix.reindex(index=train_order, columns=test_order) + + if ax is None: + fig, ax = plt.subplots(figsize=figsize or (7, 6)) + else: + fig = ax.get_figure() + + im = ax.imshow(np.asarray(matrix), aspect="auto", origin="lower", cmap="viridis") + ax.set_xticks(np.arange(matrix.shape[1])) + ax.set_xticklabels([str(value) for value in matrix.columns], rotation=45) + ax.set_yticks(np.arange(matrix.shape[0])) + ax.set_yticklabels([str(value) for value in matrix.index]) + ax.set_xlabel("Test Time") + ax.set_ylabel("Train Time") + ax.set_title(title or f"{first['Model']} / {first['Metric']}") + fig.colorbar(im, ax=ax, label="Score") + return fig + + +def plot_temporal_statistical_assessment( + result_or_assessment: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot temporal observed scores, null bands, and significant segments.""" + frame = _result_frame(result_or_assessment, "get_statistical_assessment") + if frame.empty: + raise ValueError("No statistical assessment rows available to plot.") + if metric is not None and "Metric" in frame: + frame = frame[frame["Metric"] == metric] + if model is not None and "Model" in frame: + frame = frame[frame["Model"] == model] + if "Time" not in frame: + raise ValueError("No temporal statistical assessment rows available.") + frame = frame[frame["Time"].notna()].copy() + if frame.empty: + raise ValueError("No temporal statistical assessment rows available.") + + first = frame.iloc[0] + frame = frame[ + (frame["Model"] == first["Model"]) & (frame["Metric"] == first["Metric"]) + ] + numeric_times = pd.api.types.is_numeric_dtype(frame["Time"]) + x_vals = frame["Time"].to_numpy() if numeric_times else np.arange(len(frame)) + + if ax is None: + fig, ax = plt.subplots(figsize=(10, 5)) + else: + fig = ax.get_figure() + + ax.plot(x_vals, frame["Observed"], marker="o", linewidth=2, label="Observed") + if {"NullLower", "NullUpper"}.issubset(frame.columns): + if frame["NullLower"].notna().any() and frame["NullUpper"].notna().any(): + ax.fill_between( + x_vals, + frame["NullLower"].astype(float), + frame["NullUpper"].astype(float), + alpha=0.2, + label="Permutation null band", + ) + if "Significant" in frame and frame["Significant"].fillna(False).any(): + sig = frame["Significant"].fillna(False).to_numpy(dtype=bool) + ax.scatter( + x_vals[sig], + frame["Observed"].to_numpy(dtype=float)[sig], + marker="s", + color="black", + label="Corrected significant", + zorder=3, + ) + if not numeric_times: + ax.set_xticks(x_vals) + ax.set_xticklabels([str(value) for value in frame["Time"]], rotation=45) + ax.set_xlabel("Time") + ax.set_ylabel("Score") + default_title = f"{first['Model']} / {first['Metric']} statistical assessment" + ax.set_title(title or default_title) + ax.grid(True, linestyle="--", alpha=0.3) + ax.legend(frameon=False) + return fig + + +def plot_statistical_null_distribution( + result_or_assessment: Any, + metric: Optional[str] = None, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot observed scores with available null interval summaries.""" + frame = _result_frame(result_or_assessment, "get_statistical_assessment") + if metric is not None and "Metric" in frame: + frame = frame[frame["Metric"] == metric] + if model is not None and "Model" in frame: + frame = frame[frame["Model"] == model] + frame = frame[frame["Observed"].notna()].copy() + if frame.empty: + raise ValueError("No statistical assessment rows available to plot.") + + labels = [f"{row.Model}\n{row.Metric}" for row in frame.itertuples()] + x_vals = np.arange(len(frame)) + if ax is None: + fig, ax = plt.subplots(figsize=figsize or (max(6, len(frame) * 1.2), 5)) + else: + fig = ax.get_figure() + ax.scatter(x_vals, frame["Observed"].astype(float), label="Observed") + if {"NullLower", "NullUpper"}.issubset(frame.columns): + lower = frame["NullLower"].astype(float) + upper = frame["NullUpper"].astype(float) + center = (lower + upper) / 2 + yerr = np.vstack([center - lower, upper - center]) + ax.errorbar(x_vals, center, yerr=yerr, fmt="o", label="Null band") + ax.set_xticks(x_vals) + ax.set_xticklabels(labels, rotation=45, ha="right") + ax.set_ylabel("Score") + ax.set_title(title or "Statistical Assessment") + ax.grid(True, axis="y", linestyle="--", alpha=0.3) + ax.legend(frameon=False) + return fig + + +def plot_training_history( + result_or_artifacts: Any, + model: Optional[str] = None, + title: Optional[str] = None, + ax: Optional[plt.Axes] = None, + figsize: Optional[tuple[float, float]] = None, +): + """Plot neural training and validation history from model artifacts.""" + artifacts = _result_frame(result_or_artifacts, "get_model_artifacts") + if model is not None and "Model" in artifacts: + artifacts = artifacts[artifacts["Model"] == model] + rows = artifacts[ + (artifacts["Key"].isin(["training", "validation"])) + | (artifacts["ArtifactType"] == "history") + ] + if rows.empty: + raise ValueError("No training history artifacts available to plot.") + if ax is None: + fig, ax = plt.subplots(figsize=figsize or (8, 5)) + else: + fig = ax.get_figure() + for row in rows.itertuples(): + history = row.Value or [] + if not history: + continue + frame = pd.DataFrame(history) + if "epoch" not in frame: + continue + value_cols = [col for col in frame.columns if col != "epoch"] + for col in value_cols: + ax.plot(frame["epoch"], frame[col], marker="o", label=f"{row.Model} {col}") + ax.set_xlabel("Epoch") + ax.set_ylabel("Value") + ax.set_title(title or "Training History") + ax.grid(True, linestyle="--", alpha=0.3) + ax.legend(frameon=False) + return fig diff --git a/configs/toy_ml_config.yml b/configs/toy_ml_config.yml deleted file mode 100644 index 4c0c6bd..0000000 --- a/configs/toy_ml_config.yml +++ /dev/null @@ -1,225 +0,0 @@ -# ----------------------------------------------------------------------------- -# Toy config for coco_pipe MLPipeline -# ----------------------------------------------------------------------------- - -# A unique identifier for this entire experiment -global_experiment_id: "toy_ml_config" - -# Path to your input CSV/Parquet/etc. -data_path: "./datasets/toy_dataset.csv" - -# Where to write all the per‐analysis and final results -results_dir: "./results" - -# Base filename (without extension) for per‐analysis outputs -results_file: "toy_ml_config" - -# ----------------------------------------------------------------------------- -# Defaults: values here are applied to every analysis unless overridden -# ----------------------------------------------------------------------------- -defaults: - # Random seed for reproducibility - random_state: 42 - - # Number of parallel jobs (−1 = all cores) - n_jobs: -1 - - # Cross‐validation settings; any can be overridden per‐analysis - cv_kwargs: - cv_strategy: "stratified" # Options: "kfold", "stratified", "group", etc. - n_splits: 5 # Number of folds - shuffle: true # Whether to shuffle before splitting - random_state: 42 # Seed for the splitter - - # Always‐include covariates (by column name) for all analyses - covariates: ["age"] - - # If you want to do group‐ or spatial‐CV, list your grouping columns here - spatial_units: ["regionX", "regionY"] - - # By default include all features; you can also set a list of column names - feature_names: "all" - - -# ----------------------------------------------------------------------------- -# Analyses to run: each entry here produces one folder/file in results_dir -# ----------------------------------------------------------------------------- -analyses: - - # 1) Classification Baseline (multivariate mode) - - id: "classification_baseline" - task: "classification" # "classification" or "regression" - mode: "multivariate" # "multivariate" (one pipeline on all outputs) - # or "univariate" (loop per target column) - analysis_type: "baseline" # baseline, feature_selection, hp_search, hp_search_fs - target_columns: ["target_class"] - # (optional) override defaults - # covariates: ["gender", "income"] - # feature_names: "all" - # spatial_units: ["regionX"] - # cv_kwargs: - # cv_strategy: "kfold" - # n_splits: 4 - - models: - - "Logistic Regression" # must match a key in BINARY_MODELS or MULTICLASS_MODELS - - "Random Forest" - - metrics: - - "accuracy" # keys in BINARY_METRICS / MULTICLASS_METRICS - - "roc_auc" - - - # 2) Classification + Feature Selection - - id: "classification_feature_selection" - task: "classification" - mode: "multivariate" - analysis_type: "feature_selection" - target_columns: ["target_class"] - - # If you only want to consider a subset of columns: - # feature_names: ["feat1","feat2","feat3"] - - row_filter: - - column: "age" # only keep samples with age < 40 - operator: "<" - values: 40 - - models: - - "SVC" - - metrics: - - "f1" - - # Feature‐selection parameters - n_features: 3 # how many features to pick - direction: "backward" # "forward", "backward", or "both" - scoring: "f1" # which metric to use inside SFS - - - # 3) Classification Hyperparameter Search - - id: "classification_hp_search" - task: "classification" - mode: "multivariate" - analysis_type: "hp_search" - target_columns: ["target_class"] - - models: "all" # run HP search on every model in your config - - metrics: - - "accuracy" - - "average_precision" - - # override cross‐val for this analysis - cv_kwargs: - cv_strategy: "kfold" - n_splits: 3 - - # Hyperparameter‐search parameters - search_type: "grid" # "grid" or "random" - n_iter: 20 # only used for random search - scoring: "accuracy" # metric to optimize - - # You can also supply your own grid if you like: - # param_grid: - # max_depth: [3,5,10] - # n_estimators: [50,100] - - - # 4) Classification + Combined FS & HP Search - - id: "classification_fs_hp_search" - task: "classification" - mode: "multivariate" - analysis_type: "hp_search_fs" - target_columns: ["target_class"] - - models: - - "Random Forest" - - metrics: - - "accuracy" - - # Feature‐selection parameters - n_features: 4 - direction: "forward" - - # Hyperparameter‐search parameters - search_type: "random" - n_iter: 50 - scoring: "roc_auc" - - - # 5) Regression Baseline (single‐output) - - id: "regression_baseline" - task: "regression" - mode: "multivariate" # for regression, if y has multiple columns you can also do univariate - analysis_type: "baseline" - target_columns: ["target_reg"] - - models: - - "Linear Regression" - - "Random Forest" - - metrics: - - "r2" # keys in REGRESSION_METRICS - - "neg_mean_squared_error" - - - # 6) Regression + Feature Selection - - id: "regression_feature_selection" - task: "regression" - mode: "multivariate" - analysis_type: "feature_selection" - target_columns: ["target_reg"] - - models: - - "Random Forest" - - metrics: - - "r2" - - # FS params - n_features: 5 - direction: "forward" - scoring: "r2" - - - # 7) Regression Hyperparameter Search - - id: "regression_hp_search" - task: "regression" - mode: "multivariate" - analysis_type: "hp_search" - target_columns: ["target_reg"] - - models: "all" - metrics: - - "r2" - - "neg_mean_squared_error" - - cv_kwargs: - cv_strategy: "kfold" - n_splits: 4 - - search_type: "random" - n_iter: 30 - scoring: "neg_mean_squared_error" - - - # 8) Regression + Combined FS & HP Search - - id: "regression_fs_hp_search" - task: "regression" - mode: "multivariate" - analysis_type: "hp_search_fs" - target_columns: ["target_reg"] - - models: - - "Random Forest" - - metrics: - - "r2" - - n_features: 6 - direction: "backward" - search_type: "grid" - n_iter: 10 - scoring: "r2" diff --git a/configs/venk_ml_config.yml b/configs/venk_ml_config.yml deleted file mode 100644 index ac5e598..0000000 --- a/configs/venk_ml_config.yml +++ /dev/null @@ -1,15 +0,0 @@ -ID: decoding_test -description: Decoding Multi feature every regions -data: - file: /Users/hamzaabdelhedi/Projects/data/EEG_psychostimulant_data/EEG_psychostimulants_2025-02/csv/demo_psychostim_vs_none.csv - target: target - features_groups: - groups: ['C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8', 'Fp1', 'Fp2', 'Fz', 'O1', 'O2', 'P3', 'P4', 'Pz', 'T3', 'T4', 'T5', 'T6'] - features: ["feature-foofOffset.spaces-"] -analysis: - - name: "Baseline Global" - type: "baseline" - subset: "all_features_all_groups" - models: "all" - scoring: "accuracy" -output: test diff --git a/docs/source/_ext/capability_table.py b/docs/source/_ext/capability_table.py new file mode 100644 index 0000000..12c8f73 --- /dev/null +++ b/docs/source/_ext/capability_table.py @@ -0,0 +1,201 @@ +""" +capability_table +================ + +A custom Sphinx directive that reads ``ESTIMATOR_SPECS`` from +``coco_pipe.decoding._specs`` at build time and emits a formatted RST +``list-table`` showing each estimator's capabilities. + +Usage in any .rst file:: + + .. capability-table:: + :task: classification + + .. capability-table:: + :task: regression + + .. capability-table:: + :task: all + +Options +------- +task : str, default="all" + Filter by task: ``classification``, ``regression``, or ``all``. +show-search-space : flag + If present, append a column with the default search space keys. +""" + +from __future__ import annotations + +import os +import sys +from typing import List + +from docutils import nodes +from docutils.parsers.rst import Directive, directives +from sphinx.util import logging + +logger = logging.getLogger(__name__) + +# Ensure the package root is on sys.path so we can import coco_pipe +_REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..") +) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + + +def _yes_no(value: bool) -> str: + return "✅" if value else "❌" + + +def _importance_label(tup: tuple) -> str: + mapping = { + "coefficients": "coef\\_", + "feature_importances": "feat\\_imp", + "permutation": "permutation", + "unavailable": "❌", + } + return " / ".join(mapping.get(v, v) for v in tup) + + +def _task_label(tup: tuple) -> str: + abbrev = {"classification": "clf", "regression": "reg"} + return " + ".join(abbrev.get(t, t) for t in tup) + + +def _family_label(family: str) -> str: + return family.capitalize() + + +class CapabilityTableDirective(Directive): + """ + Sphinx directive ``.. capability-table::`` + + Emits an RST list-table of all registered estimators and their key + capabilities, optionally filtered by task. + """ + + has_content = False + optional_arguments = 0 + option_spec = { + "task": directives.unchanged, + "show-search-space": directives.flag, + } + + # Column definitions: (header, width, extractor) + _COLUMNS = [ + ("Estimator", 28, lambda s: f"``{s.name}``"), + ("Family", 12, lambda s: _family_label(s.family)), + ("Task", 9, lambda s: _task_label(s.task)), + ("Proba", 6, lambda s: _yes_no(s.supports_proba)), + ("Score fn", 8, lambda s: _yes_no(s.supports_decision_function)), + ("Calibrate", 9, lambda s: _yes_no(s.supports_calibration)), + ("Feature sel", 11, lambda s: _yes_no("disabled" not in s.feature_selection)), + ("Importances", 13, lambda s: _importance_label(s.importance)), + ("Temporal", 11, lambda s: s.temporal if s.temporal != "none" else "❌"), + ( + "Dep", + 7, + lambda s: s.dependency_extra if s.dependency_extra != "core" else "—", + ), + ] + + def run(self) -> List[nodes.Node]: + try: + # Import directly from submodule to avoid triggering coco_pipe.__init__ + # which may pull in optional heavy dependencies (pydantic, etc.) + import importlib.util + import pathlib + + spec_file = ( + pathlib.Path(_REPO_ROOT) / "coco_pipe" / "decoding" / "_specs.py" + ) + spec_mod = importlib.util.spec_from_file_location("_specs", spec_file) + mod = importlib.util.module_from_spec(spec_mod) + # _specs.py depends on _constants.py — load that first + const_file = ( + pathlib.Path(_REPO_ROOT) / "coco_pipe" / "decoding" / "_constants.py" + ) + const_spec = importlib.util.spec_from_file_location( + "_constants", const_file + ) + const_mod = importlib.util.module_from_spec(const_spec) + const_spec.loader.exec_module(const_mod) + import sys as _sys + + _sys.modules["coco_pipe.decoding._constants"] = const_mod + spec_mod.loader.exec_module(mod) + ESTIMATOR_SPECS = mod.ESTIMATOR_SPECS + except Exception as exc: + error_msg = f"capability-table: could not load ESTIMATOR_SPECS: {exc}" + logger.warning(error_msg) + return [nodes.warning("", nodes.paragraph(text=error_msg))] + + task_filter = self.options.get("task", "all").strip().lower() + show_search = "show-search-space" in self.options + + specs = list(ESTIMATOR_SPECS.values()) + if task_filter != "all": + specs = [s for s in specs if task_filter in s.task] + + if not specs: + return [ + nodes.paragraph(text=f"No estimators found for task='{task_filter}'.") + ] + + columns = list(self._COLUMNS) + if show_search: + columns.append( + ( + "Search space keys", + 20, + lambda s: ", ".join(f"``{k}``" for k in s.default_search_space) + or "—", + ) + ) + + # Build the RST list-table text + col_widths = " ".join(str(c[1]) for c in columns) + header_cells = "\n".join(f" * - {c[0]}" for c in columns) + + rows_rst = [] + for spec in sorted(specs, key=lambda s: (s.family, s.name)): + cells = [] + for i, (_, _, extractor) in enumerate(columns): + prefix = " - " if i > 0 else " * - " + cells.append(f"{prefix}{extractor(spec)}") + rows_rst.append("\n".join(cells)) + + table_rst = ( + f".. list-table::\n" + f" :header-rows: 1\n" + f" :widths: {col_widths}\n" + f"\n" + f"{header_cells}\n" + "\n".join(rows_rst) + ) + + # Parse the generated RST into docutils nodes + from docutils.statemachine import ViewList + from sphinx.util.docutils import switch_source_input + + result = ViewList() + source = self.get_source_info()[0] + for lineno, line in enumerate(table_rst.splitlines()): + result.append(line, source, lineno) + + node = nodes.section() + node.document = self.state.document + with switch_source_input(self.state, result): + self.state.nested_parse(result, 0, node) + + return node.children + + +def setup(app): + app.add_directive("capability-table", CapabilityTableDirective) + return { + "version": "1.0", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/source/api_reference.md b/docs/source/api_reference.md new file mode 100644 index 0000000..6272869 --- /dev/null +++ b/docs/source/api_reference.md @@ -0,0 +1,128 @@ +# API Reference + +This page lists the stable public API entry points that should be used from +source code and examples. The modeling API is `coco_pipe.decoding`; older +modeling surfaces are not part of the supported public API. + +## Decoding + +Use `coco_pipe.decoding` for classification, regression, cross-validation, +feature selection, hyperparameter tuning, temporal decoding, and result +accessors. + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.decoding.Experiment + coco_pipe.decoding.ExperimentConfig + coco_pipe.decoding.EstimatorSpec + coco_pipe.decoding.EstimatorCapabilities + coco_pipe.decoding.SelectorCapabilities + coco_pipe.decoding.result.ExperimentResult + coco_pipe.decoding.result.ExperimentResult.to_payload + coco_pipe.decoding.result.ExperimentResult.save + coco_pipe.decoding.result.ExperimentResult.load + coco_pipe.decoding.get_estimator_cls + coco_pipe.decoding.register_estimator + coco_pipe.decoding.register_estimator_spec + coco_pipe.decoding.get_estimator_spec + coco_pipe.decoding.list_estimator_specs + coco_pipe.decoding.get_capabilities + coco_pipe.decoding.list_capabilities + coco_pipe.decoding.run_statistical_assessment + coco_pipe.decoding.binomial_accuracy_test + coco_pipe.decoding.aggregate_predictions_for_inference +``` + +### Decoding Configs + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.decoding.configs.CVConfig + coco_pipe.decoding.configs.FeatureSelectionConfig + coco_pipe.decoding.configs.TuningConfig + coco_pipe.decoding.configs.CalibrationConfig + coco_pipe.decoding.configs.StatisticalAssessmentConfig + coco_pipe.decoding.configs.ClassicalModelConfig + coco_pipe.decoding.configs.LogisticRegressionConfig + coco_pipe.decoding.configs.RandomForestClassifierConfig + coco_pipe.decoding.configs.SVCConfig + coco_pipe.decoding.configs.LinearSVCConfig + coco_pipe.decoding.configs.RidgeConfig + coco_pipe.decoding.configs.RandomForestRegressorConfig + coco_pipe.decoding.configs.SVRConfig + coco_pipe.decoding.configs.SlidingEstimatorConfig + coco_pipe.decoding.configs.GeneralizingEstimatorConfig +``` + +### Decoding Splitters, Metrics, And Registry + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.decoding.registry.get_estimator_cls + coco_pipe.decoding.registry.get_estimator_spec + coco_pipe.decoding.registry.get_capabilities + coco_pipe.decoding.registry.resolve_estimator_spec +``` + +## Dimensionality Reduction + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.dim_reduction.DimReduction + coco_pipe.dim_reduction.config.EvaluationConfig + coco_pipe.dim_reduction.BaseReducer + coco_pipe.dim_reduction.config.BaseReducerConfig +``` + +## IO + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.io.DataContainer + coco_pipe.io.load_data +``` + +## Reports + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.report.Report + coco_pipe.report.core.Report.add_decoding_diagnostics + coco_pipe.report.core.Report.add_decoding_statistical_assessment + coco_pipe.report.core.Report.add_decoding_temporal + coco_pipe.report.config.ReportConfig +``` + +## Visualization + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + coco_pipe.viz.plot_temporal_score_curve + coco_pipe.viz.plot_temporal_generalization_matrix + coco_pipe.viz.plot_temporal_statistical_assessment + coco_pipe.viz.plot_confusion_matrix + coco_pipe.viz.plot_roc_curve + coco_pipe.viz.plot_pr_curve + coco_pipe.viz.plot_calibration_curve + coco_pipe.viz.plot_fold_score_dispersion +``` + +## Full Module Index + +The generated [AutoAPI module index](autoapi/index) is still available for +lower-level internals and module exploration, but the public modeling API +should be documented and used through `coco_pipe.decoding`. diff --git a/docs/source/conf.py b/docs/source/conf.py index e1ed661..4cf469e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,6 +19,8 @@ curdir = os.path.dirname(__file__) sys.path.append(os.path.abspath(os.path.join(curdir, "..", "coco-pipe"))) +# Make the local _ext/ directory importable +sys.path.insert(0, os.path.abspath(os.path.join(curdir, "_ext"))) def copy_readme(): @@ -77,6 +79,7 @@ def setup(app): "sphinxcontrib.mermaid", "sphinx.ext.napoleon", "myst_parser", + "capability_table", ] # Allow Markdown files to be used as documentation pages diff --git a/docs/source/decoding/advanced/custom_estimators.rst b/docs/source/decoding/advanced/custom_estimators.rst new file mode 100644 index 0000000..340609c --- /dev/null +++ b/docs/source/decoding/advanced/custom_estimators.rst @@ -0,0 +1,150 @@ +.. _decoding-custom-estimators: + +============================ +Custom Estimators +============================ + +``coco_pipe.decoding`` is extensible. You can register any scikit-learn-compatible +estimator with the registry to enable capability contracts, metric compatibility +checks, and diagnostic reporting for your custom model. + +--- + +1. Protocol Requirements +========================= + +Any custom estimator used in ``coco_pipe.decoding`` must implement the +``DecoderEstimator`` protocol: + +.. code-block:: python + + from coco_pipe.decoding.interfaces import DecoderEstimator + + class MyCustomClassifier: + def fit(self, X, y=None, **fit_params): + # ... training logic + return self + + def predict(self, X): + # ... inference logic + return y_pred + + def get_params(self, deep=True): + return {} + + def set_params(self, **params): + return self + + assert isinstance(MyCustomClassifier(), DecoderEstimator) # runtime check + +For models that provide probability estimates, also implement ``predict_proba``. +For neural models with training diagnostics, implement the ``NeuralTrainable`` +protocol (see :ref:`decoding-foundation-models`). + +--- + +2. Registering an Estimator +============================= + +Register your estimator in ``ESTIMATOR_SPECS`` so it is discoverable by the +capability checking system: + +.. code-block:: python + + from coco_pipe.decoding._specs import ESTIMATOR_SPECS, EstimatorSpec + + ESTIMATOR_SPECS["MyCustomClassifier"] = EstimatorSpec( + name="MyCustomClassifier", + import_path="mypackage.models.MyCustomClassifier", + family="classical", + task=["classification"], + input_kinds=["tabular_2d"], + response_methods=["predict", "predict_proba"], + supports_feature_selection=True, + supports_calibration=True, + importance_attr="coef_", # or "feature_importances_" + ) + +2.1 EstimatorSpec Fields +-------------------------- + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Field + - Description + * - ``name`` + - Must match the registry key and the class name. + * - ``import_path`` + - Full dotted import path to the class. + * - ``family`` + - ``"classical"`` | ``"tree_ensemble"`` | ``"boosting"`` | ``"neural"`` | ``"dummy"`` + * - ``task`` + - List of ``"classification"`` and/or ``"regression"``. + * - ``input_kinds`` + - ``["tabular_2d"]`` for 2D arrays; ``["epoched_3d"]`` for 3D temporal. + * - ``response_methods`` + - Which prediction interfaces the model provides. + * - ``supports_feature_selection`` + - Whether it works inside a SelectKBest / SFS pipeline. + * - ``supports_calibration`` + - Whether CalibratedClassifierCV can wrap it. + * - ``importance_attr`` + - Attribute name for feature importances (e.g. ``"coef_"``, ``"feature_importances_"``). + * - ``default_search_space`` + - Optional dict of hyperparameter name → list of values for tuning. + +--- + +3. Using the Custom Estimator in an Experiment +=============================================== + +After registration, use the estimator name in ``ClassicalModelConfig``: + +.. code-block:: python + + from coco_pipe.decoding.configs import ExperimentConfig, ClassicalModelConfig, CVConfig + + config = ExperimentConfig( + task="classification", + models={ + "my_model": ClassicalModelConfig( + estimator="MyCustomClassifier", + params={"my_param": 1.0}, + ) + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=5), + ) + + result = Experiment(config).run(X, y) + +--- + +4. Custom Feature Importances +================================ + +If your custom model exposes feature importances differently, override the +importance extraction logic by adding a callable to ``importance_attr``: + +.. code-block:: python + + class MyCustomClassifier: + def fit(self, X, y=None): + self.importances_ = self._compute_importances(X, y) + return self + + def _compute_importances(self, X, y): + # custom importance logic + return np.ones(X.shape[1]) + + ESTIMATOR_SPECS["MyCustomClassifier"] = EstimatorSpec( + ..., + importance_attr="importances_", # will be read after .fit() + ) + +``coco_pipe.decoding._engine.extract_feature_importances`` reads +``importance_attr`` via ``getattr(fitted_model, importance_attr)`` and handles +1D arrays (feature importance vectors) and 2D arrays (class-specific weights +like ``coef_``). diff --git a/docs/source/decoding/advanced/foundation_models.rst b/docs/source/decoding/advanced/foundation_models.rst new file mode 100644 index 0000000..7bc644a --- /dev/null +++ b/docs/source/decoding/advanced/foundation_models.rst @@ -0,0 +1,202 @@ +.. _decoding-foundation-models: + +=========================== +Foundation Models (fm_hub) +=========================== + +``coco_pipe.decoding.fm_hub`` provides a unified interface to pretrained +neural network backbones for EEG/MEG decoding. Foundation models can be used +in three modes: + +1. **Frozen embedding extraction**: extract features from a fixed backbone, + then decode with a classical scikit-learn estimator. +2. **Frozen backbone + trainable head**: fine-tune only the output head. +3. **Full fine-tuning**: update all backbone parameters (LoRA, QLoRA, or full). + +All modes enter through ``Experiment.run(...)`` and are compatible with the +outer CV loop, meaning the foundation model is fit/embedded inside the +training partition of each fold. + +--- + +1. Embedding Extraction (Frozen Backbone) +========================================== + +The simplest foundation model workflow: freeze the backbone and use it as a +fixed feature extractor. A classical scikit-learn model decodes from the +extracted embeddings. + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + ExperimentConfig, CVConfig, + FoundationEmbeddingModelConfig, + FrozenBackboneDecoderConfig, + ClassicalModelConfig, + ) + + config = ExperimentConfig( + task="classification", + models={ + "labram_probe": FrozenBackboneDecoderConfig( + backbone=FoundationEmbeddingModelConfig( + provider="braindecode", + model_name="labram-pretrained", + input_kind="epoched", + pooling="mean", + cache_embeddings=True, # cache to disk for reuse across folds + ), + head=ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 1000}, + ), + ) + }, + metrics=["balanced_accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + + result = Experiment(config).run(X_epochs, y, sample_metadata=meta) + +1.1 Embedding Cache +-------------------- + +``cache_embeddings=True`` writes extracted embeddings to a fold-keyed disk +cache. On subsequent runs with the same data and backbone configuration, the +backbone forward pass is skipped and embeddings are loaded from cache. The cache +key is computed from the training split identity and backbone fingerprint. + +.. code-block:: python + + from coco_pipe.decoding import make_feature_cache_key + + key = make_feature_cache_key( + train_sample_ids=train_ids, + test_sample_ids=test_ids, + backbone_fingerprint=backbone_hash, + ) + +1.2 Supported Providers +------------------------ + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Provider + - Description + * - ``"braindecode"`` + - Pretrained models from the Braindecode library (ShallowFBCSPNet, EEGNetv4, + LaBraM, etc.). + * - ``"huggingface"`` + - Any HuggingFace Hub model compatible with the EEG tokenizer interface. + * - ``"reve"`` + - REVE pretrained EEG backbone. + * - ``"dummy"`` + - Returns random embeddings. For testing pipeline integrity only. + +--- + +2. Neural Fine-Tuning (LoRA / QLoRA) +======================================= + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + NeuralFineTuneConfig, LoRAConfig, QuantizationConfig, DeviceConfig, CheckpointConfig + ) + + config = ExperimentConfig( + task="classification", + models={ + "reve_qlora": NeuralFineTuneConfig( + provider="huggingface", + model_name="brain-bzh/reve-base", + input_kind="epoched", + train_mode="qlora", + lora=LoRAConfig(r=16, alpha=32, dropout=0.05), + quantization=QuantizationConfig(enabled=True, bits=4), + device=DeviceConfig(device="auto", precision="bf16"), + checkpoints=CheckpointConfig(save="best"), + ) + }, + metrics=["balanced_accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + +2.1 Training Modes +------------------- + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Mode + - Description + * - ``"full"`` + - Update all backbone parameters. Highest capacity; requires most memory. + * - ``"lora"`` + - Low-Rank Adaptation. Trains small rank-decomposed matrices injected into + transformer attention. Memory-efficient. + * - ``"qlora"`` + - Quantized LoRA. Backbone quantized to 4-bit for inference; LoRA adapters + trained in higher precision. Most memory-efficient option. + +2.2 LoRA Configuration +----------------------- + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Parameter + - Description + * - ``r`` + - Rank of the LoRA decomposition. Higher rank → more parameters. Default 16. + * - ``alpha`` + - Scaling factor. ``alpha / r`` scales the LoRA output. Default 32. + * - ``dropout`` + - Dropout on LoRA layers. Default 0.0. + +--- + +3. Diagnostic Artifacts +========================= + +Trainable neural models expose training diagnostics via ``NeuralTrainable`` +protocol methods: + +.. code-block:: python + + artifacts = result.get_model_artifacts() + # columns: Model, Fold, ArtifactKey, ArtifactValue + + # Per-fold training history + history = result.get_model_artifacts(artifact_type="training_history") + + # Checkpoint manifest + checkpoints = result.get_model_artifacts(artifact_type="checkpoints") + +The ``NeuralTrainable`` protocol requires: + +- ``get_training_history() → list[dict]``: loss/metric per epoch. +- ``get_checkpoint_manifest() → dict``: saved checkpoint paths and best epoch. +- ``get_model_card_info() → dict``: architecture and training summary. +- ``get_failure_diagnostics() → dict``: NaN detection, gradient norms. +- ``get_artifact_metadata() → dict``: aggregated artifact dictionary. + +--- + +4. Required Dependencies +========================== + +Foundation models require optional extras: + +- ``braindecode`` provider: ``pip install coco-pipe[braindecode]`` +- ``huggingface`` / ``qlora`` provider: ``pip install coco-pipe[hf,peft,quant]`` +- ``reve`` provider: Contact the REVE team for access. + +.. code-block:: bash + + pip install coco-pipe[hf,peft,quant] # QLoRA path + pip install coco-pipe[braindecode] # Braindecode backbone diff --git a/docs/source/decoding/advanced/reproducibility.rst b/docs/source/decoding/advanced/reproducibility.rst new file mode 100644 index 0000000..b7adb41 --- /dev/null +++ b/docs/source/decoding/advanced/reproducibility.rst @@ -0,0 +1,146 @@ +.. _decoding-reproducibility: + +============================= +Reproducibility Architecture +============================= + +``coco_pipe.decoding`` is designed so that every run with the same configuration +and data produces bit-identical results. This section documents how random seeds +are propagated, where they appear in the result schema, and how to validate +reproducibility. + +--- + +1. Seed Propagation via SeedSequence +====================================== + +Setting ``ExperimentConfig.random_state`` propagates derived, independent seeds +to every sub-component through NumPy's ``SeedSequence``: + +.. code-block:: python + + config = ExperimentConfig( + task="classification", + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=5), + random_state=42, # master seed + ) + +Internally, ``Experiment._propagate_random_state()`` derives: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Component + - Derived seed (offset from master) + * - ``cv`` + - ``master + 0`` + * - ``feature_selection`` + - ``master + 1`` + * - ``tuning`` + - ``master + 2`` + * - ``calibration`` + - ``master + 3`` + * - Per-model seeds + - Spawned from ``master + 4`` via ``SeedSequence.spawn`` + +The per-model seeds are ordered by model name (alphabetically) for determinism. + +.. note:: + + Even if you add models or change their order in the ``models`` dict, + alphabetical seed assignment ensures each model always receives the same seed + regardless of insertion order. + +--- + +2. What Is Seeded +================== + +Every stochastic component in the pipeline is seeded: + +- **CV splitters**: ``StratifiedKFold``, ``StratifiedGroupKFold``, ``KFold``. +- **Hyperparameter search**: ``RandomizedSearchCV`` uses ``tuning.random_state``. +- **Model initialization**: models with ``random_state`` parameters receive + model-specific seeds. +- **Calibration**: ``CalibratedClassifierCV`` uses ``calibration.random_state``. +- **Bootstrap CI**: ``get_bootstrap_confidence_intervals`` accepts + ``random_state``. +- **Permutation tests**: ``ChanceAssessmentConfig.random_state`` seeds the null + permutation engine. + +Not seeded (intentionally): + +- **Data loading** and **preprocessing** outside the ``Experiment.run`` call. +- **MNE meta-estimator** internal parallelism (joblib workers), which may vary + between runs if parallelism order is non-deterministic. + +--- + +3. Result Schema Provenance +============================== + +Every ``ExperimentResult`` stores reproducibility metadata in +``result.meta["hardware_provenance"]``: + +.. code-block:: python + + print(result.meta["hardware_provenance"]) + # { + # "python_version": "3.11.12", + # "sklearn_version": "1.6.1", + # "numpy_version": "1.26.4", + # "platform": "macOS-14.5", + # "n_jobs": 1, + # "timestamp": "2026-05-14T04:30:00Z", + # } + +This provenance is captured by ``get_environment_info()`` from +``coco_pipe.report.provenance`` at the time of ``Experiment.run``. + +--- + +4. Validating Reproducibility +================================ + +To verify that two runs produce identical results: + +.. code-block:: python + + import numpy as np + import pandas as pd + + # Run A + result_a = Experiment(config).run(X, y, sample_metadata=meta) + + # Run B (identical config and data) + result_b = Experiment(config).run(X, y, sample_metadata=meta) + + scores_a = result_a.get_detailed_scores() + scores_b = result_b.get_detailed_scores() + + pd.testing.assert_frame_equal( + scores_a.sort_values(["Model", "Fold", "Metric"]).reset_index(drop=True), + scores_b.sort_values(["Model", "Fold", "Metric"]).reset_index(drop=True), + ) + +--- + +5. Known Non-Determinism Sources +=================================== + +Some operations may produce slightly different results even with the same seed: + +- **Parallel outer CV** (``n_jobs > 1``): scikit-learn's parallel backends can + schedule workers in different orders between runs. For exact reproducibility, + use ``n_jobs=1``. +- **GPU operations** (for foundation models with LoRA/QLoRA): CUDA operations + are non-deterministic by default unless ``torch.use_deterministic_algorithms(True)`` + is set. +- **OS-level RNG state** leaking into ``random.random()`` or ``os.urandom()`` + calls in third-party libraries. + +For fully deterministic publication runs, set ``n_jobs=1`` and document the +exact library versions from ``result.meta["hardware_provenance"]``. diff --git a/docs/source/decoding/concepts.rst b/docs/source/decoding/concepts.rst new file mode 100644 index 0000000..56ad614 --- /dev/null +++ b/docs/source/decoding/concepts.rst @@ -0,0 +1,267 @@ +.. _decoding-concepts: + +================================== +Scientific Concepts and Principles +================================== + +This page explains the foundational principles that govern every design decision +in ``coco_pipe.decoding``. Understanding these concepts is essential for +interpreting results correctly and avoiding common pitfalls in brain decoding. + +--- + +1. Cross-Validation and Data Leakage +===================================== + +1.1 Why Outer-Only Scoring Is Insufficient +------------------------------------------- + +A decoding score is an estimate of *generalization performance* — how well a +trained model predicts labels from **unseen** brain data. The critical word is +"unseen". If any information from the test partition is visible during training +(even implicitly, through preprocessing), the score is optimistically biased. + +In practice, leakage occurs when: + +- A scaler is fit on the **whole dataset** then applied fold-locally. +- A feature selector's statistics are computed on the full feature matrix. +- A hyperparameter is tuned using a validation set that overlaps with the test set. +- A class-label encoder is fit on all samples before splitting. + +``coco_pipe.decoding`` prevents all of these by construction: every preprocessing +transformer is created inside a scikit-learn ``Pipeline`` that is fit **only on +the training partition** of each outer fold. The test partition is never touched +during training. + +1.2 The Outer Cross-Validation Loop +------------------------------------- + +The outer CV loop is controlled by ``ExperimentConfig.cv``. It defines the +*evaluation splits*. For each fold: + +1. ``X_train, y_train`` → fit scaler, feature selector, hyperparameter search, + calibration, and the final model. +2. ``X_test, y_test`` → predict and score using the fold-trained pipeline. +3. Fold scores, predictions, importances, and diagnostics are stored. + +The final score (e.g., ``get_detailed_scores()``) is the **average over outer +folds**. It is an unbiased estimate of generalization performance — provided +independence is respected (see section 2). + +1.3 Inner CV Loops +------------------- + +When hyperparameter tuning (``TuningConfig``) or Sequential Feature Selection +(``FeatureSelectionConfig(method='sfs')``) is enabled, an **inner** CV loop +operates on the training partition of each outer fold. This inner loop selects +the best model configuration without access to the test set. + +When the outer CV is group-based (e.g., ``group_kfold``), the inner CV is +**automatically made group-based** as well. Overriding this requires setting +``allow_nongroup_inner_cv=True`` and explicitly acknowledges the data-leakage +trade-off. + +.. warning:: + + Mixing group-based outer CV with non-group inner CV can cause test-set + group information to leak into model selection, inflating performance. + Always use matching group strategies for inner and outer CV when subjects + must remain exclusive to test folds. + +--- + +2. Independence and the Unit of Inference +========================================== + +2.1 Pseudoreplication in Neural Data +-------------------------------------- + +EEG and MEG experiments typically produce many epochs per subject. If a model +is trained and tested on **epochs** from the same subject, the test scores are +not independent. Each subject's neural patterns are correlated across epochs, +so the effective sample size for statistical inference is the **number of +subjects**, not the number of epochs. + +Using epochs as the unit of inference inflates degrees of freedom and produces +incorrect p-values. This is called *pseudoreplication*. + +2.2 Group-Based CV +-------------------- + +The correct solution is to ensure all epochs from a given subject belong +**exclusively** to either the training set or the test set — never both. This +is achieved with ``CVConfig(strategy="group_kfold", group_key="Subject")``. + +``coco_pipe.decoding`` accepts two equivalent ways to specify groups: + +- ``sample_metadata={"Subject": subject_ids}`` with ``cv.group_key="Subject"`` + (recommended — keeps metadata tidy and allows downstream subject-level analysis). +- ``groups=subject_ids`` (compatibility alias — binds groups directly to the + splitter). + +2.3 Subject-Level Aggregation for Statistical Tests +------------------------------------------------------ + +Even with group-based CV, the **predictions** stored after each fold are +epoch-level. Before performing a statistical test, predictions must be aggregated +to the independent unit (subjects) to restore correct degrees of freedom. + +``coco_pipe.decoding.stats.aggregate_predictions_for_inference()`` handles this: + +- ``unit_of_inference="group_mean"``: soft-vote by averaging subject class + probabilities across epochs, then hard-classify. +- ``unit_of_inference="group_majority"``: hard-vote by majority of epoch labels. +- ``unit_of_inference="sample"``: no aggregation (correct when each row is + already an independent subject). + +The statistical assessment machinery uses this aggregation automatically: + +.. code-block:: python + + from coco_pipe.decoding.configs import StatisticalAssessmentConfig, ChanceAssessmentConfig + + eval_cfg = StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=1000, + unit_of_inference="group_mean", + ), + ) + +--- + +3. Full-Pipeline Permutation Testing +====================================== + +3.1 Why "Post-Hoc" Permutations Are Biased +-------------------------------------------- + +The easiest permutation test shuffles labels and scores the **already-fitted** +model's predictions. This is fast but biased: it does not reshuffle labels +during hyperparameter search, feature selection, or calibration. If any of +these steps use the labels (which they all do), the null distribution is too +narrow. + +3.2 The Correct Null: Full-Pipeline Permutation +------------------------------------------------- + +The correct null distribution is obtained by rerunning the **complete** training +pipeline — scaler, feature selector, inner CV, hyperparameter search, calibration, +and the final model fit — on permuted labels. This is what +``ChanceAssessmentConfig(method="permutation")`` does. + +Each permutation: + +1. Shuffles ``y`` within each group (or globally for ``unit="sample"``). +2. Reruns the complete outer CV with the shuffled labels. +3. Aggregates the permuted predictions to the unit of inference. +4. Scores the aggregated permuted predictions. + +The observed score is then compared against this null distribution: + +.. math:: + + p = \frac{\#\{\text{null scores} \geq \text{observed score}\} + 1}{B + 1} + +where :math:`B` is the number of permutations. + +3.3 Multiple Comparison Correction for Temporal Data +------------------------------------------------------ + +For sliding/generalizing temporal decoders, one p-value is produced per +timepoint. Multiple testing correction is required. ``coco_pipe.decoding`` +supports: + +- ``temporal_correction="max_stat"`` (default): permutation-based Max-Stat + correction. The null at each timepoint is the *global maximum* of the + permutation distribution, yielding family-wise error rate (FWER) control. + Recommended for temporal decoding with moderate-to-high correlations. +- ``temporal_correction="fdr_bh"``: Benjamini-Hochberg FDR control. +- ``temporal_correction="fdr_by"``: Benjamini-Yekutieli FDR control (more + conservative, valid under positive dependence). +- ``temporal_correction="none"``: no correction (exploratory use only). + +.. math:: + + p_t^{\text{max\_stat}} = \frac{\#\{B_b : \max_{t'} s_b(t') \geq s(t)\} + 1}{B + 1} + +where :math:`s(t)` is the observed score at time :math:`t` and :math:`s_b(t')` +is the permuted score at any timepoint. + +--- + +4. Probability Calibration +============================ + +A classifier is *calibrated* if its predicted probability of class 1 matches +the empirical fraction of class-1 samples at that probability level. Poor +calibration does not affect accuracy but matters for: + +- Log-loss and Brier score interpretation. +- Clinical decision thresholds. +- Ensemble averaging across models. + +``coco_pipe.decoding`` supports ``sklearn.calibration.CalibratedClassifierCV`` +inside the training path. The calibration estimator uses **disjoint inner folds** +within each outer training partition, so the test set is never used for +calibration fitting. + +Enabling calibration also makes probability metrics (``log_loss``, ``brier_score``) +available for models that do not natively provide ``predict_proba`` (e.g., +``LinearSVC``). + +--- + +5. Feature Importance and Stability +====================================== + +Feature importances are extracted per fold (when the fitted model supports them) +and aggregated: + +- ``get_feature_importances(fold_level=False)``: mean importance ± std across folds. +- ``get_feature_importances(fold_level=True)``: per-fold importances in long form. +- ``get_feature_stability()``: proportion of folds in which each feature was + selected (for SFS) or had positive importance. + +.. warning:: + + Fold-averaged importance is **not** the same as importance computed on the + full dataset. Because each fold trains on a subset of subjects, the importance + estimate has higher variance than whole-dataset importance. Always report the + fold-level standard deviation alongside the mean. + +--- + +6. Temporal Decoding Concepts +================================ + +6.1 Sliding Decoding +---------------------- + +A ``SlidingEstimator`` (MNE) fits one independent model per timepoint. Each +model sees the channel-space snapshot at its timepoint across all epochs in the +training fold. The result is a score curve over time. + +- *Assumption*: The most discriminative time window is narrow relative to the + total window length. +- *Strength*: Identifies when (not just whether) neural representations are + discriminative. + +6.2 Generalizing Decoding (Temporal Generalization) +------------------------------------------------------ + +A ``GeneralizingEstimator`` (MNE) fits one model per training timepoint and +tests each model at **every** test timepoint. The result is a +``(n_train_times, n_test_times)`` matrix of scores. + +Off-diagonal entries answer: "Does the representation learned at time :math:`t_1` +generalize to predict the label at time :math:`t_2`?" A diagonal band indicates +a rapidly changing representation; an extended off-diagonal band indicates a +stable neural code. + +- *Scientific interpretation*: Off-diagonal generalization is evidence of a + sustained, format-stable neural representation. +- *Statistical note*: The generalizing matrix has ``n_train × n_test`` cells. + Temporal correction (Max-Stat) is strongly recommended to control the + family-wise error rate. diff --git a/docs/source/decoding/configs.rst b/docs/source/decoding/configs.rst new file mode 100644 index 0000000..db64abe --- /dev/null +++ b/docs/source/decoding/configs.rst @@ -0,0 +1,232 @@ +.. _decoding-configs: + +========================== +Configuration Reference +========================== + +All experiment configuration is declarative and Pydantic-validated. Every +config class uses ``extra="forbid"`` so misspelled or unsupported field names +raise a ``ValidationError`` immediately — before any training starts. + +--- + +1. ``ExperimentConfig`` +======================== + +Top-level configuration for a decoding experiment. + +.. code-block:: python + + from coco_pipe.decoding.configs import ExperimentConfig + + config = ExperimentConfig( + task="classification", # required: "classification" or "regression" + models={"lr": ...}, # required: dict of model configs + metrics=["accuracy"], # default: task-appropriate defaults + cv=CVConfig(...), # default: StratifiedKFold(5) + tuning=TuningConfig(...), # default: disabled + feature_selection=FeatureSelectionConfig(...), # default: disabled + calibration=CalibrationConfig(...), # default: disabled + evaluation=StatisticalAssessmentConfig(...), # default: disabled + grids={"lr": {"C": [0.1, 1.0]}}, # hyperparameter grids for tuning + use_scaler=True, # prepend StandardScaler to pipeline + n_jobs=1, # outer CV parallelism + verbose=False, + tag="my_experiment", # descriptive label in result metadata + random_state=42, + ) + +--- + +2. ``CVConfig`` +================ + +Controls the outer cross-validation loop. + +.. code-block:: python + + from coco_pipe.decoding.configs import CVConfig + + cv = CVConfig( + strategy="stratified_group_kfold", + n_splits=5, + group_key="Subject", # column name in sample_metadata + test_size=0.2, # for "split" strategy only + stratify=True, # for "split" + classification only + n_groups=2, # for "leave_p_out" only + random_state=42, + ) + +See :ref:`decoding-cv` for a complete strategy guide. + +--- + +3. ``ClassicalModelConfig`` +============================ + +Configures a classical scikit-learn estimator. + +.. code-block:: python + + from coco_pipe.decoding.configs import ClassicalModelConfig + + model = ClassicalModelConfig( + estimator="LogisticRegression", # key in ESTIMATOR_SPECS + params={"C": 1.0, "max_iter": 200}, + ) + +Short-form aliases are also available for common estimators: + +.. code-block:: python + + from coco_pipe.decoding.configs import LogisticRegressionConfig + + model = LogisticRegressionConfig(C=1.0, max_iter=200) + +--- + +4. ``TemporalDecoderConfig`` +============================== + +Wraps a classical base estimator for 3D temporal inputs. + +.. code-block:: python + + from coco_pipe.decoding.configs import TemporalDecoderConfig, ClassicalModelConfig + + model = TemporalDecoderConfig( + wrapper="sliding", # or "generalizing" + base=ClassicalModelConfig(estimator="LogisticRegression"), + scoring="accuracy", + n_jobs=-1, + ) + +Requires ``mne`` as an optional dependency. + +--- + +5. ``TuningConfig`` +===================== + +Controls hyperparameter search. + +.. code-block:: python + + from coco_pipe.decoding.configs import TuningConfig, CVConfig + + tuning = TuningConfig( + enabled=True, + search_type="grid", # or "random" + scoring="accuracy", + n_iter=20, # for "random" search only + n_jobs=1, + refit=True, + cv=CVConfig(strategy="stratified", n_splits=3), # inner CV + allow_nongroup_inner_cv=False, # leakage guard + random_state=42, + ) + +--- + +6. ``FeatureSelectionConfig`` +=============================== + +.. code-block:: python + + from coco_pipe.decoding.configs import FeatureSelectionConfig, CVConfig + + fs = FeatureSelectionConfig( + enabled=True, + method="k_best", # or "sfs" + n_features=20, + scoring="accuracy", # scoring criterion for SFS inner CV + cv=CVConfig(strategy="stratified", n_splits=3), # SFS inner CV + direction="forward", # for SFS: "forward" or "backward" + allow_nongroup_inner_cv=False, + ) + +--- + +7. ``CalibrationConfig`` +========================== + +Enables probability calibration inside the training path. + +.. code-block:: python + + from coco_pipe.decoding.configs import CalibrationConfig, CVConfig + + calibration = CalibrationConfig( + enabled=True, + method="sigmoid", # or "isotonic" + cv=CVConfig(strategy="stratified", n_splits=3), + allow_nongroup_inner_cv=False, + ) + +--- + +8. ``StatisticalAssessmentConfig`` +==================================== + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + StatisticalAssessmentConfig, ChanceAssessmentConfig, ConfidenceIntervalConfig + ) + + evaluation = StatisticalAssessmentConfig( + enabled=True, + random_state=42, + unit_of_inference="group_mean", # "sample", "group_mean", "group_majority", "custom" + chance=ChanceAssessmentConfig( + method="permutation", # or "binomial", "auto" + n_permutations=1000, + p0=None, # required for "binomial" + temporal_correction="max_stat", # "max_stat", "fdr_bh", "fdr_by", "none" + store_null_distribution=False, + ), + confidence_intervals=ConfidenceIntervalConfig( + alpha=0.05, + method="clopper_pearson", # or "wilson" + n_bootstraps=1000, + ), + ) + +--- + +9. Foundation Model Configs +============================== + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + FoundationEmbeddingModelConfig, + FrozenBackboneDecoderConfig, + NeuralFineTuneConfig, + LoRAConfig, + QuantizationConfig, + DeviceConfig, + CheckpointConfig, + ) + + # Frozen embedding + embed_cfg = FoundationEmbeddingModelConfig( + provider="braindecode", # "dummy", "braindecode", "huggingface", "reve" + model_name="labram-pretrained", + input_kind="epoched", # "tabular", "temporal", "epoched", "embeddings", "tokens" + pooling="mean", # "mean", "flatten", "last" + cache_embeddings=True, + normalize_embeddings=True, + ) + + # Full neural fine-tuning + ft_cfg = NeuralFineTuneConfig( + provider="huggingface", + model_name="brain-bzh/reve-base", + input_kind="epoched", + train_mode="qlora", # "full", "lora", "qlora" + lora=LoRAConfig(r=16, alpha=32), + quantization=QuantizationConfig(enabled=True, bits=4), + device=DeviceConfig(device="auto", precision="bf16"), + checkpoints=CheckpointConfig(save="best"), + ) diff --git a/docs/source/decoding/cv_strategies.rst b/docs/source/decoding/cv_strategies.rst new file mode 100644 index 0000000..2cc2489 --- /dev/null +++ b/docs/source/decoding/cv_strategies.rst @@ -0,0 +1,222 @@ +.. _decoding-cv: + +================================= +Cross-Validation Strategies Guide +================================= + +The cross-validation strategy is the most consequential choice in a decoding +experiment. It determines whether the performance estimate is statistically valid, +whether group independence is preserved, and whether the inner model-selection +loops can correctly inherit the outer splitting logic. + +--- + +1. Available Strategies +======================== + +All strategies are configured via ``CVConfig.strategy``: + +.. list-table:: + :header-rows: 1 + :widths: 30 25 20 25 + + * - Strategy + - Group-aware + - Use case + - scikit-learn equivalent + * - ``"stratified"`` + - ❌ + - Balanced class folds (classification) + - ``StratifiedKFold`` + * - ``"kfold"`` + - ❌ + - Regression, or classification without imbalance + - ``KFold`` + * - ``"group_kfold"`` + - ✅ + - K folds, subjects exclusive to test + - ``GroupKFold`` + * - ``"stratified_group_kfold"`` + - ✅ + - K folds, class-balanced, subjects exclusive + - ``StratifiedGroupKFold`` + * - ``"leave_one_group_out"`` + - ✅ + - Leave-one-subject-out (LOSO) + - ``LeaveOneGroupOut`` + * - ``"leave_p_out"`` + - ✅ + - Leave-P-subjects-out + - ``LeavePGroupsOut`` + * - ``"timeseries"`` + - ❌ + - Ordered splits for time-series data + - ``TimeSeriesSplit`` + * - ``"split"`` + - ❌ + - Single train/test holdout + - Custom ``ShuffleSplit`` + +--- + +2. Group-Based Strategies +=========================== + +2.1 When Groups Are Required +------------------------------ + +Use a group-based strategy whenever your data contains **multiple observations +per independent unit** (e.g., multiple epochs per subject). Failure to do so +means the model trains and tests on data from the **same subjects**, producing +inflated accuracy estimates. + +Provide groups via sample metadata: + +.. code-block:: python + + from coco_pipe.decoding.configs import CVConfig + + cv = CVConfig( + strategy="stratified_group_kfold", + n_splits=5, + group_key="Subject", # must match a column in sample_metadata + ) + + result = Experiment(config).run( + X, y, + sample_metadata={"Subject": subject_ids, "Session": session_ids} + ) + +2.2 LOSO (Leave-One-Subject-Out) +---------------------------------- + +LOSO leaves all epochs from one subject out of training per fold. It is the most +conservative and the most clinically-relevant evaluation strategy, but it has +as many folds as subjects, which can be computationally expensive. + +.. code-block:: python + + cv = CVConfig(strategy="leave_one_group_out", group_key="Subject") + +.. note:: + + ``leave_one_group_out`` does not accept an ``n_splits`` parameter. The number + of folds equals the number of unique subjects. + +2.3 Leave-P-Subjects-Out +-------------------------- + +Leave-P-groups-out leaves ``p`` subjects out per fold. More powerful than LOSO +when ``p > 1``, but substantially increases the number of folds. + +.. code-block:: python + + cv = CVConfig(strategy="leave_p_out", n_groups=2, group_key="Subject") + +--- + +3. Group Propagation to Inner CV +================================== + +When the outer CV is group-based, ``coco_pipe.decoding`` automatically propagates +group constraints to all inner CV loops: + +- **Hyperparameter tuning** (``TuningConfig``): uses a group-based inner CV by default. +- **Sequential Feature Selection** (``FeatureSelectionConfig(method="sfs")``): uses a + group-based inner CV by default. +- **Calibration** (``CalibrationConfig``): uses a group-based inner calibration + split by default. + +Overriding this requires explicitly setting ``allow_nongroup_inner_cv=True`` on +the relevant config object: + +.. code-block:: python + + from coco_pipe.decoding.configs import TuningConfig, CVConfig + + tuning = TuningConfig( + enabled=True, + cv=CVConfig(strategy="stratified", n_splits=3), + allow_nongroup_inner_cv=True, # explicit acknowledgement of leakage risk + ) + +--- + +4. Stratified Strategies +========================== + +Stratified strategies ensure that class proportions are approximately equal +across folds. This is important for imbalanced datasets where some folds might +contain no minority-class examples. + +- ``"stratified"`` and ``"stratified_group_kfold"`` are only valid for classification. +- For regression tasks, use ``"kfold"`` or group-based strategies. + +.. code-block:: python + + cv = CVConfig( + strategy="stratified_group_kfold", + n_splits=5, + group_key="Subject", + random_state=42, # reproducibility for stratification + ) + +--- + +5. Holdout Split +================= + +For large datasets or as a quick sanity check, a single train/test holdout +avoids the overhead of K outer folds: + +.. code-block:: python + + cv = CVConfig( + strategy="split", + test_size=0.2, + stratify=True, # stratified split for classification + random_state=42, + ) + +The ``n_splits`` field is ignored for ``"split"`` — it always produces exactly +one fold. + +--- + +6. Time Series Split +====================== + +For EEG/MEG data that is **not epoched** (e.g., continuous recordings), or for +temporal regression, use ``"timeseries"``: + +.. code-block:: python + + cv = CVConfig( + strategy="timeseries", + n_splits=5, + test_size=0.2, # optional, overrides sklearn default + ) + +``TimeSeriesSplit`` ensures that training data always comes **before** test data +in time, preventing future data from leaking into the model. + +--- + +7. Random State and Reproducibility +====================================== + +``CVConfig.random_state`` seeds the splitter. For full reproducibility, also +set ``ExperimentConfig.random_state``, which propagates derived seeds to the CV, +tuning, feature selection, and calibration configs via a ``SeedSequence``. + +.. code-block:: python + + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=5), + random_state=42, # propagated to all sub-components + ) + +See :ref:`decoding-reproducibility` for the full seed propagation architecture. diff --git a/docs/source/decoding/examples/basic_classification.rst b/docs/source/decoding/examples/basic_classification.rst new file mode 100644 index 0000000..0a2da50 --- /dev/null +++ b/docs/source/decoding/examples/basic_classification.rst @@ -0,0 +1,146 @@ +.. _decoding-example-basic: + +=========================== +Example: Basic Classification +=========================== + +This example walks through a minimal reproducible decoding experiment: binary +classification from a 2D feature matrix, with stratified group-based +cross-validation and post-hoc statistical assessment. + +**Scientific context**: Classifying task conditions (e.g., face vs. object) from +subject-level EEG power spectral density features. Each row in ``X`` is one +epoch's feature vector. Multiple epochs from the same subject share a subject ID. + +--- + +1. Prepare Data +================ + +.. code-block:: python + + import numpy as np + import pandas as pd + from sklearn.datasets import make_classification + + # Simulate: 200 epochs, 64 features, 20 subjects (10 per class) + n_subjects = 20 + n_epochs_per_subject = 10 + n_features = 64 + + rng = np.random.default_rng(42) + X = rng.standard_normal((n_subjects * n_epochs_per_subject, n_features)) + y = np.repeat([0, 1], n_subjects // 2 * n_epochs_per_subject) + subject_ids = np.repeat(np.arange(n_subjects), n_epochs_per_subject) + +--- + +2. Configure the Experiment +============================ + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ( + ClassicalModelConfig, CVConfig, + StatisticalAssessmentConfig, ChanceAssessmentConfig, + ) + + config = ExperimentConfig( + task="classification", + models={ + "logistic_regression": ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 500, "solver": "liblinear"}, + ), + "random_forest": ClassicalModelConfig( + estimator="RandomForestClassifier", + params={"n_estimators": 100}, + ), + }, + metrics=["accuracy", "balanced_accuracy", "roc_auc"], + cv=CVConfig( + strategy="stratified_group_kfold", + n_splits=5, + group_key="Subject", + ), + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=1000, + unit_of_inference="group_mean", + ), + ), + use_scaler=True, + random_state=42, + ) + +--- + +3. Run the Experiment +====================== + +.. code-block:: python + + result = Experiment(config).run( + X, + y, + sample_metadata={"Subject": subject_ids}, + observation_level="epoch", + ) + +--- + +4. Inspect Results +=================== + +.. code-block:: python + + # Per-fold scores + scores = result.get_detailed_scores() + print(scores.groupby(["Model", "Metric"])["Score"].agg(["mean", "std"])) + + # Confusion matrices + cm = result.get_confusion_matrices(normalize=True) + + # ROC curves + roc = result.get_roc_curve() + + # Calibration curves + cal = result.get_calibration_curve() + +--- + +5. Statistical Assessment +========================== + +.. code-block:: python + + assessment = result.get_statistical_assessment() + print(assessment[["Model", "Metric", "Observed", "PValue", "Significant"]]) + +--- + +6. Compare Models +================== + +.. code-block:: python + + comparison = result.compare_models_paired( + "logistic_regression", "random_forest", + metric="accuracy", + unit="Subject", + n_permutations=5000, + ) + print(comparison[["Difference", "PValue", "Significant"]]) + +--- + +7. Persist and Load +==================== + +.. code-block:: python + + path = result.save("results/basic_classification.json") + loaded = ExperimentResult.load(path) diff --git a/docs/source/decoding/examples/grouped_cv.rst b/docs/source/decoding/examples/grouped_cv.rst new file mode 100644 index 0000000..dc1d844 --- /dev/null +++ b/docs/source/decoding/examples/grouped_cv.rst @@ -0,0 +1,129 @@ +.. _decoding-example-grouped: + +======================== +Example: Grouped CV +======================== + +This example demonstrates the correct use of group-based cross-validation +for multi-session EEG data, where multiple epochs per subject and multiple +sessions per subject must remain exclusive to training or test folds. + +**Scientific context**: EEG data from 30 subjects, 2 sessions each, 20 epochs +per session. Goal: classify cognitive states while ensuring the model cannot +exploit within-subject patterns that would inflate test accuracy. + +--- + +1. Prepare Multi-Session Data +================================ + +.. code-block:: python + + import numpy as np + + n_subjects = 30 + n_sessions_per_subject = 2 + n_epochs_per_session = 20 + n_features = 64 + + rng = np.random.default_rng(42) + total_epochs = n_subjects * n_sessions_per_subject * n_epochs_per_session + + X = rng.standard_normal((total_epochs, n_features)) + y = np.tile([0, 1], total_epochs // 2) + + # Build metadata with Subject, Session, Site columns + subject_ids = np.repeat(np.arange(n_subjects), n_sessions_per_subject * n_epochs_per_session) + session_ids = np.tile( + np.repeat(np.arange(n_sessions_per_subject), n_epochs_per_session), + n_subjects + ) + site_ids = np.where(subject_ids < 15, "SiteA", "SiteB") + +--- + +2. Configure Grouped CV +========================== + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ( + ClassicalModelConfig, CVConfig, TuningConfig, + StatisticalAssessmentConfig, ChanceAssessmentConfig, + ) + + config = ExperimentConfig( + task="classification", + models={ + "lr": ClassicalModelConfig( + estimator="LogisticRegression", + params={"solver": "liblinear"}, + ) + }, + metrics=["balanced_accuracy", "roc_auc"], + cv=CVConfig( + strategy="stratified_group_kfold", + n_splits=5, + group_key="Subject", # subjects exclusive to test folds + ), + tuning=TuningConfig( + enabled=True, + scoring="balanced_accuracy", + # inner CV automatically group-based (Subject) to prevent leakage + ), + grids={"lr": {"C": [0.01, 0.1, 1.0, 10.0]}}, + use_scaler=True, + random_state=42, + ) + +--- + +3. Run with Metadata +===================== + +.. code-block:: python + + result = Experiment(config).run( + X, + y, + sample_metadata={ + "Subject": subject_ids, + "Session": session_ids, + "Site": site_ids, + }, + observation_level="epoch", + ) + +--- + +4. Site-Stratified Analysis +============================= + +.. code-block:: python + + # Bootstrap CI by site + ci_site_a = result.get_bootstrap_confidence_intervals( + metric="balanced_accuracy", unit="Subject", + n_bootstraps=2000, + ) + + # Predictions include Subject, Session, Site columns + preds = result.get_predictions() + by_site = preds.groupby("Site")["y_pred"].value_counts() + print(by_site) + +--- + +5. Hyperparameter Diagnostics +================================ + +.. code-block:: python + + # Best regularization parameter per fold + best_params = result.get_best_params() + print(best_params[["Fold", "C"]]) + + # Full grid search diagnostics + search_results = result.get_search_results() + print(search_results.sort_values("MeanTestScore", ascending=False).head()) diff --git a/docs/source/decoding/examples/model_comparison.rst b/docs/source/decoding/examples/model_comparison.rst new file mode 100644 index 0000000..63569e0 --- /dev/null +++ b/docs/source/decoding/examples/model_comparison.rst @@ -0,0 +1,104 @@ +.. _decoding-example-model-comparison: + +============================== +Example: Model Comparison +============================== + +This example demonstrates the correct workflow for comparing multiple decoding +models using paired permutation tests, with FDR correction for multiple +comparisons. + +**Scientific context**: Three candidate classifiers compete to decode cognitive +state from EEG features. We want to know: (1) which models perform above chance, +and (2) which model is significantly better than the others. + +--- + +1. Configure Multi-Model Experiment +====================================== + +.. code-block:: python + + import numpy as np + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ( + ClassicalModelConfig, CVConfig, + StatisticalAssessmentConfig, ChanceAssessmentConfig, + ) + + rng = np.random.default_rng(42) + n_subjects = 20 + X = rng.standard_normal((n_subjects * 10, 64)) + y = np.tile([0, 1], n_subjects * 5) + subject_ids = np.repeat(np.arange(n_subjects), 10) + + config = ExperimentConfig( + task="classification", + models={ + "lr": ClassicalModelConfig(estimator="LogisticRegression"), + "svm": ClassicalModelConfig(estimator="LinearSVC"), + "rf": ClassicalModelConfig(estimator="RandomForestClassifier"), + }, + metrics=["accuracy", "balanced_accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=1000, + unit_of_inference="group_mean", + ), + ), + use_scaler=True, + random_state=42, + ) + +--- + +2. Run and Assess Significance +================================= + +.. code-block:: python + + result = Experiment(config).run( + X, y, + sample_metadata={"Subject": subject_ids}, + observation_level="epoch", + ) + + # Which models are above chance? + assessment = result.get_statistical_assessment() + print(assessment[["Model", "Metric", "Observed", "PValue", "Significant"]]) + +--- + +3. Pairwise Model Comparison +============================== + +.. code-block:: python + + # Paired comparison: LR vs SVM + lr_vs_svm = result.compare_models_paired("lr", "svm", metric="accuracy", unit="Subject") + print(lr_vs_svm[["Difference", "PValue", "Significant"]]) + + # All pairwise comparisons with FDR correction + all_pairs = result.compare_models( + metric="accuracy", + unit="Subject", + correction="fdr_bh", + n_permutations=5000, + ) + print(all_pairs[["ModelA", "ModelB", "Difference", "CorrectedPValue", "Significant"]]) + +--- + +4. Score Distribution Visualization +===================================== + +.. code-block:: python + + from coco_pipe.viz import plot_fold_score_dispersion + + # Visualize fold-level score distributions across models + fig = plot_fold_score_dispersion(result, metric="accuracy") + fig.savefig("score_dispersion.png", dpi=150) diff --git a/docs/source/decoding/examples/temporal_eeg.rst b/docs/source/decoding/examples/temporal_eeg.rst new file mode 100644 index 0000000..3d86225 --- /dev/null +++ b/docs/source/decoding/examples/temporal_eeg.rst @@ -0,0 +1,139 @@ +.. _decoding-example-temporal: + +========================= +Example: Temporal EEG Decoding +========================= + +This example demonstrates sliding and generalizing temporal decoding from a +3D EEG epoch array, including time-resolved scoring, statistical inference +with Max-Stat correction, and visualization of the generalization matrix. + +**Scientific context**: Decoding face vs. object from EEG signal amplitude at +each timepoint post-stimulus. The generalizing decoder tests whether the neural +representation learned at one time generalizes to other times. + +--- + +1. Prepare Data +================ + +.. code-block:: python + + import numpy as np + import mne # required for temporal estimators + + # Simulate: 200 epochs, 64 channels, 100 timepoints + n_epochs = 200 + n_channels = 64 + n_times = 100 + sfreq = 250.0 # Hz + tmin = -0.1 # seconds pre-stimulus + + rng = np.random.default_rng(42) + X = rng.standard_normal((n_epochs, n_channels, n_times)) + y = rng.choice([0, 1], size=n_epochs) + subject_ids = np.repeat(np.arange(20), n_epochs // 20) + times = np.linspace(tmin, tmin + (n_times - 1) / sfreq, n_times) + +--- + +2. Sliding Decoding +===================== + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ( + ClassicalModelConfig, TemporalDecoderConfig, CVConfig, + StatisticalAssessmentConfig, ChanceAssessmentConfig, + ) + + sliding_config = ExperimentConfig( + task="classification", + models={ + "sliding_lr": TemporalDecoderConfig( + wrapper="sliding", + base=ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 200}, + ), + n_jobs=-1, + ) + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=500, + temporal_correction="max_stat", + unit_of_inference="group_mean", + ), + ), + random_state=42, + ) + + sliding_result = Experiment(sliding_config).run( + X, y, + sample_metadata={"Subject": subject_ids}, + time_axis=times, + observation_level="epoch", + ) + + # Score curve: one accuracy per timepoint + temporal = sliding_result.get_temporal_score_summary() + assessment = sliding_result.get_statistical_assessment() + # Significant timepoints after Max-Stat correction + sig = assessment[assessment["Significant"]] + print(f"Significant window: {sig['Time'].min():.3f}s — {sig['Time'].max():.3f}s") + +--- + +3. Generalization Matrix +========================== + +.. code-block:: python + + gen_config = ExperimentConfig( + task="classification", + models={ + "generalizing_lr": TemporalDecoderConfig( + wrapper="generalizing", + base=ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 200}, + ), + n_jobs=-1, + ) + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + random_state=42, + ) + + gen_result = Experiment(gen_config).run( + X, y, + sample_metadata={"Subject": subject_ids}, + time_axis=times, + observation_level="epoch", + ) + + # 2D score matrix: shape (n_train_times, n_test_times) + matrix_df = gen_result.get_generalization_matrix("accuracy") + +--- + +4. Visualization +================= + +.. code-block:: python + + from coco_pipe.viz import plot_temporal_score_curve, plot_temporal_generalization_matrix + + # Sliding accuracy curve with significant timepoints highlighted + fig_curve = plot_temporal_score_curve(sliding_result, metric="accuracy") + + # Generalization matrix heatmap + fig_matrix = plot_temporal_generalization_matrix(gen_result, metric="accuracy") + fig_matrix.savefig("generalization_matrix.png", dpi=150) diff --git a/docs/source/decoding/experiment.rst b/docs/source/decoding/experiment.rst new file mode 100644 index 0000000..75ac9ab --- /dev/null +++ b/docs/source/decoding/experiment.rst @@ -0,0 +1,180 @@ +.. _decoding-experiment: + +=============================== +The ``Experiment`` Orchestrator +=============================== + +``coco_pipe.decoding.Experiment`` is the main entry point for all decoding +experiments. It validates configuration, orchestrates the outer CV loop, +and returns a fully populated ``ExperimentResult``. + +--- + +1. Initialization +================== + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ClassicalModelConfig, CVConfig + + config = ExperimentConfig( + task="classification", + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=5), + ) + + exp = Experiment(config) + +At construction time, ``Experiment.__init__`` immediately: + +1. Resolves all model specs from ``ESTIMATOR_SPECS``. +2. Validates task/metric/model compatibility (raises ``ValueError`` if any combination is invalid). +3. Propagates the master ``random_state`` to all sub-configs. + +--- + +2. Running an Experiment +========================== + +.. code-block:: python + + result = exp.run( + X, + y, + groups=None, # or np.ndarray of group labels + sample_ids=None, # or array of unique sample identifiers + sample_metadata=None, # or dict/DataFrame with Subject, Session, ... + feature_names=None, # or list of feature name strings + time_axis=None, # or np.ndarray of timepoints for 3D inputs + observation_level="epoch", # or "trial", "subject", etc. + inferential_unit=None, # auto-inferred from metadata + ) + +2.1 ``X`` and ``y`` +--------------------- + +- ``X``: 2D array ``(n_samples, n_features)`` for classical models, or 3D array + ``(n_samples, n_channels, n_times)`` for temporal estimators. +- ``y``: 1D array ``(n_samples,)`` of class labels (classification) or continuous + values (regression). + +2.2 ``sample_metadata`` +------------------------- + +A dict or DataFrame with columns for each metadata variable. **Must include +``Subject`` and ``Session``** (capitalized) when the outer CV uses a group key. +Additional columns (e.g., ``Site``, ``Age``) are stored in predictions and splits +for downstream analysis. + +.. code-block:: python + + sample_metadata = { + "Subject": subject_ids, # unique subject identifiers + "Session": session_ids, # recording session identifiers + "Site": site_ids, # optional acquisition site + } + +2.3 ``observation_level`` +--------------------------- + +A string label stored in ``result.meta["observation_level"]``. It describes what +each row of ``X`` represents (``"epoch"``, ``"trial"``, ``"subject"``, etc.). +This metadata does not affect fitting but documents the result for downstream +analysis and reporting. + +--- + +3. Per-Fold Pipeline +====================== + +For each outer CV fold, ``Experiment`` executes the following sequence: + +1. **Split**: divide ``X``, ``y``, and metadata into training and test partitions. +2. **Validate fold integrity**: check for degenerate folds (empty partitions, + single-class training sets for classification). +3. **Build pipeline**: create a ``sklearn.pipeline.Pipeline`` with steps: + ``scaler → feature_selector → model``. Each step is instantiated fresh for + this fold. +4. **Wrap with tuning**: if ``TuningConfig.enabled``, wrap the pipeline in + ``GridSearchCV`` or ``RandomizedSearchCV``. +5. **Fit**: call ``pipeline.fit(X_train, y_train)`` (with groups if required). +6. **Calibrate**: if ``CalibrationConfig.enabled``, wrap in + ``CalibratedClassifierCV`` and refit with calibration folds. +7. **Score**: compute all requested metrics on ``X_test``. +8. **Extract diagnostics**: feature importances, predictions, timing, warnings. + +--- + +4. Parallel Execution +======================= + +.. code-block:: python + + config = ExperimentConfig( + ..., + n_jobs=4, # number of parallel outer CV jobs + ) + + result = Experiment(config).run(X, y) + +``n_jobs`` controls the number of parallel outer-fold evaluations via ``joblib``. +For exact reproducibility, use ``n_jobs=1`` (see :ref:`decoding-reproducibility`). + +--- + +5. Save and Load +================== + +.. code-block:: python + + # Save result to JSON + path = result.save("results/my_experiment.json") + + # Load from JSON + from coco_pipe.decoding.result import ExperimentResult + loaded = ExperimentResult.load(path) + +The result is serialized as a self-contained JSON payload (schema version +``decoding_result_v1``), including the config, metadata, per-fold outputs, +and provenance information. + +--- + +6. Configuration Reference +============================ + +See :ref:`decoding-configs` for a full listing of all configuration classes. +The most important fields on ``ExperimentConfig``: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Field + - Description + * - ``task`` + - ``"classification"`` or ``"regression"``. + * - ``models`` + - Dict mapping model names to model configs. + * - ``metrics`` + - List of metric keys (validated against the task and model capabilities). + * - ``cv`` + - ``CVConfig`` controlling the outer cross-validation loop. + * - ``tuning`` + - ``TuningConfig`` for hyperparameter search. + * - ``feature_selection`` + - ``FeatureSelectionConfig`` for filter/wrapper feature selection. + * - ``calibration`` + - ``CalibrationConfig`` for probability calibration. + * - ``evaluation`` + - ``StatisticalAssessmentConfig`` for permutation/binomial testing. + * - ``use_scaler`` + - Whether to prepend a ``StandardScaler`` to the pipeline. + * - ``n_jobs`` + - Number of parallel outer CV jobs. + * - ``random_state`` + - Master seed for reproducibility. + * - ``tag`` + - Descriptive label stored in the result metadata. diff --git a/docs/source/decoding/feature_selection.rst b/docs/source/decoding/feature_selection.rst new file mode 100644 index 0000000..22092ed --- /dev/null +++ b/docs/source/decoding/feature_selection.rst @@ -0,0 +1,174 @@ +.. _decoding-feature-selection: + +============================ +Feature Selection +============================ + +``coco_pipe.decoding`` supports two feature selection strategies that execute +**inside** each outer CV fold on the training partition only, guaranteeing +zero test-set leakage. + +--- + +1. Filter Selection (``k_best``) +================================== + +``SelectKBest`` selects the top-``k`` features based on a univariate statistical +test. It has no inner CV loop. It is fast and well-suited for high-dimensional data +(e.g., many EEG channels/frequency bins) where a quick, interpretable feature +ranking is desired. + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + ExperimentConfig, CVConfig, ClassicalModelConfig, FeatureSelectionConfig + ) + + config = ExperimentConfig( + task="classification", + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + feature_selection=FeatureSelectionConfig( + enabled=True, + method="k_best", + n_features=20, + scoring="accuracy", # optional; defaults to task-appropriate test + ), + ) + +1.1 Score Function +------------------- + +For classification, the default univariate test is ``f_classif`` (ANOVA F-value). +For regression, it is ``f_regression``. Override via ``feature_selection.scoring``. + +1.2 Accessing Feature Scores +------------------------------ + +After fitting, retrieve per-fold and per-feature scores: + +.. code-block:: python + + feature_scores = result.get_feature_scores() + # columns: FeatureName, Fold, Score, PValue + + # Mean score across folds + mean_scores = feature_scores.groupby("FeatureName")["Score"].mean().sort_values(ascending=False) + +--- + +2. Sequential Feature Selection (``sfs``) +========================================== + +``SequentialFeatureSelector`` is a wrapper-based method. It iteratively adds +(forward SFS) or removes (backward SFS) features by evaluating the model's +cross-validated performance on each candidate feature set. Because it uses the +model's predictive performance as the selection criterion, it is more powerful +than filter methods but significantly more expensive. + +.. code-block:: python + + config = ExperimentConfig( + task="classification", + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, + metrics=["balanced_accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + feature_selection=FeatureSelectionConfig( + enabled=True, + method="sfs", + n_features=10, + scoring="balanced_accuracy", # criterion for SFS inner evaluation + cv=CVConfig(strategy="stratified_group_kfold", n_splits=3, group_key="Subject"), + direction="forward", # or "backward" + ), + ) + +2.1 Inner CV for SFS +---------------------- + +SFS requires an inner CV loop to evaluate candidate feature sets. When omitted, +``coco_pipe.decoding`` derives the inner SFS CV from: + +1. ``tuning.cv`` if tuning is enabled. +2. The outer CV family (group-based if outer is group-based). + +When the outer CV is group-based, the SFS inner CV is automatically group-based. +Overriding requires ``allow_nongroup_inner_cv=True``. + +2.2 Group-Aware SFS +-------------------- + +``coco_pipe.decoding`` uses scikit-learn metadata routing to pass the +outer-fold training groups into the SFS inner CV. This requires +``scikit-learn >= 1.6``. + +2.3 SFS with Tuning +-------------------- + +SFS combined with hyperparameter tuning evaluates feature subsets inside the +tuning inner folds. ``coco_pipe.decoding`` uses a ``sklearn.pipeline.Pipeline`` +cache to avoid redundant refitting: + +.. code-block:: python + + config = ExperimentConfig( + ..., + feature_selection=FeatureSelectionConfig(enabled=True, method="sfs", n_features=10), + tuning=TuningConfig(enabled=True, scoring="accuracy"), + grids={"lr": {"C": [0.1, 1.0, 10.0]}}, + ) + +.. warning:: + + SFS + tuning is computationally intensive. Reduce the outer ``n_splits`` or + the SFS inner ``n_splits`` for development runs. + +--- + +3. Feature Stability Analysis +================================ + +For both ``k_best`` and ``sfs``, ``coco_pipe.decoding`` tracks which features +were selected in each fold. The stability score is the proportion of folds in +which a feature was selected: + +.. code-block:: python + + stability = result.get_feature_stability() + # columns: FeatureName, SelectionRate, MeanRank, StdRank + + # Most stable features + top = stability.sort_values("SelectionRate", ascending=False).head(20) + +.. note:: + + Feature stability across folds is a measure of **generalizability**, not + importance. A feature selected in all folds is a robust signal across the + sampled subjects, regardless of its average selection score. + +--- + +4. Selected Features per Fold +================================ + +.. code-block:: python + + selected = result.get_selected_features() + # columns: FeatureName, Fold, Rank + + # Features selected in every fold + universal = selected.groupby("FeatureName")["Fold"].count() + universal = universal[universal == config.cv.n_splits].index.tolist() + +--- + +5. Compatibility Notes +======================== + +- Feature selection is only valid for 2D tabular inputs (``input_kind in {"tabular_2d", "embedding_2d"}``). +- Feature selection is **incompatible** with temporal estimators (``SlidingEstimator``, ``GeneralizingEstimator``). + The registry blocks this at validation time. +- ``k_best`` does not support ranked importances beyond fold scores/p-values. + For importance-based selection, use tree ensemble importances via + ``result.get_feature_importances()``. diff --git a/docs/source/decoding/index.rst b/docs/source/decoding/index.rst new file mode 100644 index 0000000..09917e2 --- /dev/null +++ b/docs/source/decoding/index.rst @@ -0,0 +1,103 @@ +.. _decoding: + +=========================================== +Decoding Module — Scientific User Guide +=========================================== + +The ``coco_pipe.decoding`` module provides a rigorous, reproducible framework for +neural decoding — the statistical inference of cognitive, perceptual, or clinical +states from multivariate brain recordings. It is designed from first principles +around **zero data leakage**, **independence-aware inference**, and a +**declarative configuration** API. + +.. admonition:: Design Philosophy + + Every preprocessing step (scaling, feature selection, hyperparameter search, + calibration) executes **inside** each outer cross-validation fold on the + training partition only. This eliminates the most common source of inflated + decoding accuracy in neuroimaging pipelines. + +.. rubric:: Key Features + +- Outer CV with zero preprocessing leakage guaranteed by architecture. +- Registry-based estimator + metric compatibility contracts (blocked before training). +- Full-pipeline permutation testing with optional Max-Stat temporal correction. +- Sliding and generalizing temporal estimators (MNE meta-estimators) with tidy output. +- Foundation model integration (frozen backbone, fine-tune, LoRA, QLoRA). +- Group-aware cross-validation with automatic inner CV propagation. +- Comprehensive ``ExperimentResult`` API with 20+ diagnostic accessors. + +--- + +.. rubric:: Quickstart + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ClassicalModelConfig, CVConfig + + config = ExperimentConfig( + task="classification", + models={ + "lr": ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 200} + ) + }, + metrics=["accuracy", "roc_auc"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + + result = Experiment(config).run( + X, y, + sample_metadata={"Subject": subject_ids, "Session": session_ids}, + observation_level="epoch", + ) + + print(result.get_detailed_scores()) + print(result.get_predictions().head()) + +--- + +.. toctree:: + :maxdepth: 2 + :caption: Core Concepts + + concepts + configs + experiment + result + +.. toctree:: + :maxdepth: 2 + :caption: Statistical Inference + + stats + cv_strategies + model_comparison + +.. toctree:: + :maxdepth: 2 + :caption: Model Reference + + models + metrics + feature_selection + temporal_decoding + +.. toctree:: + :maxdepth: 2 + :caption: Advanced Topics + + advanced/foundation_models + advanced/custom_estimators + advanced/reproducibility + +.. toctree:: + :maxdepth: 2 + :caption: Worked Examples + + examples/basic_classification + examples/grouped_cv + examples/temporal_eeg + examples/model_comparison diff --git a/docs/source/decoding/metrics.rst b/docs/source/decoding/metrics.rst new file mode 100644 index 0000000..6560d9b --- /dev/null +++ b/docs/source/decoding/metrics.rst @@ -0,0 +1,232 @@ +.. _decoding-metrics: + +================ +Metric Registry +================ + +All metrics are registered in ``coco_pipe.decoding._metrics.METRIC_REGISTRY``. +Metric/task compatibility is enforced at config validation time — before any +model is trained — preventing silent misuse of classification metrics for +regression tasks (or vice versa). + +--- + +1. Registry API +================ + +.. code-block:: python + + from coco_pipe.decoding._metrics import ( + get_metric_spec, + get_metric_names, + get_metric_families, + get_scorer, + METRIC_REGISTRY, + ) + + # Inspect a single metric + spec = get_metric_spec("accuracy") + print(spec.name) # "accuracy" + print(spec.task) # "classification" + print(spec.family) # "label" + print(spec.response_method) # "predict" + print(spec.greater_is_better) # True + + # List all classification metrics in the "threshold_sweep" family + names = get_metric_names(task="classification", family="threshold_sweep") + + # Get a callable scorer + scorer = get_scorer("f1") # sklearn-compatible callable + +Each ``MetricSpec`` contains: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Type + - Description + * - ``name`` + - ``str`` + - Unique key in the registry. + * - ``task`` + - ``str`` + - ``"classification"`` or ``"regression"``. + * - ``scorer`` + - ``Callable`` + - ``(y_true, y_pred) → float``. + * - ``response_method`` + - ``str`` + - ``"predict"`` | ``"proba"`` | ``"score"`` | ``"proba_or_score"``. + * - ``family`` + - ``str`` + - Grouping for reporting (see below). + * - ``greater_is_better`` + - ``bool`` + - Directionality for permutation p-values and Max-Stat correction. + +--- + +2. Classification Metrics +========================== + +2.1 Label Metrics (``family="label"``) +--------------------------------------- + +Require only ``predict`` output. Work with any classifier. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Metric + - Description + * - ``accuracy`` + - Fraction of correctly classified samples. Sensitive to class imbalance. + * - ``balanced_accuracy`` + - Mean recall per class. Recommended over ``accuracy`` for imbalanced data. + * - ``zero_one_loss`` + - Fraction misclassified. ``1 - accuracy``. ``greater_is_better=False``. + * - ``hamming_loss`` + - Per-label Hamming loss (fraction of labels incorrectly predicted). + +2.2 Confusion-Derived Metrics (``family="confusion"``) +-------------------------------------------------------- + +Derived from the confusion matrix. Require only ``predict``. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Metric + - Description + * - ``f1`` + - Binary F1 score (harmonic mean of precision and recall). + * - ``f1_macro`` + - Unweighted macro-average F1 across classes. + * - ``f1_micro`` + - Global precision/recall pooled across classes. + * - ``precision`` + - Positive predictive value. TP / (TP + FP). + * - ``recall`` + - Sensitivity / true positive rate. TP / (TP + FN). + * - ``sensitivity`` + - Synonym for recall. Binary only; raises ``ValueError`` for multiclass. + * - ``specificity`` + - True negative rate. TN / (TN + FP). Binary only. + * - ``jaccard`` + - Intersection-over-union for binary labels. + * - ``matthews_corrcoef`` + - Matthews correlation coefficient. Balanced for all class distributions. + * - ``cohen_kappa`` + - Agreement corrected for chance. Range [-1, 1]. + +2.3 Threshold-Sweep Metrics (``family="threshold_sweep"``) +------------------------------------------------------------ + +Require probability or decision scores. Use ``predict_proba`` when available, +``decision_function`` as fallback for binary classifiers. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Metric + - Description + * - ``roc_auc`` + - Area under the ROC curve (binary OvR). Insensitive to class threshold. + * - ``roc_auc_ovr_weighted`` + - Macro-weighted one-vs-rest AUC for multiclass. + * - ``average_precision`` + - Area under the PR curve using sklearn's interpolated AP (binary). + * - ``pr_auc`` + - Trapezoidal AUC of the precision-recall curve. Preferred over AP when + positive fraction is small. + +2.4 Probability-Score Metrics (``family="score_probability"``) +--------------------------------------------------------------- + +Require ``predict_proba``. Enable calibration diagnostics. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Metric + - Description + * - ``log_loss`` + - Cross-entropy loss. Lower is better (``greater_is_better=False``). + * - ``brier_score`` + - Mean squared error of probability predictions. Lower is better. + +--- + +3. Regression Metrics (``family="regression"``) +================================================= + +Require only ``predict`` output. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Metric + - Description + * - ``r2`` + - Coefficient of determination. 1.0 is perfect fit; can be negative. + * - ``neg_mean_squared_error`` + - Negative MSE. Negated so higher = better for optimization consistency. + * - ``neg_mean_absolute_error`` + - Negative MAE. More robust than MSE to outliers. + * - ``neg_root_mean_squared_error`` + - Negative RMSE. Same units as the target variable. + * - ``explained_variance`` + - Proportion of variance explained. Similar to R² but not penalized for bias. + +--- + +4. Compatibility Rules +======================== + +The registry enforces three compatibility checks at ``ExperimentConfig`` +validation time: + +1. **Task mismatch**: A metric's ``task`` must match ``ExperimentConfig.task``. +2. **Proba requirement**: If ``response_method == "proba"``, the model must + declare ``predict_proba`` **or** calibration must be enabled. +3. **Score requirement**: If ``response_method == "proba_or_score"``, the model + must declare ``predict_proba`` **or** ``decision_function``. + +These checks fire before any model is trained, producing a clear ``ValueError`` +with the specific metric and model name. + +--- + +5. Custom Metrics +================== + +You can extend the registry for project-specific metrics: + +.. code-block:: python + + from coco_pipe.decoding._metrics import METRIC_REGISTRY, MetricSpec + from sklearn.metrics import top_k_accuracy_score + from functools import partial + + top2 = partial(top_k_accuracy_score, k=2, labels=[0, 1, 2]) + METRIC_REGISTRY["top2_accuracy"] = MetricSpec( + name="top2_accuracy", + task="classification", + scorer=top2, + response_method="proba", + family="label", + greater_is_better=True, + ) + +.. warning:: + + Custom metrics are added to the in-process registry only. They are not + persisted in saved ``ExperimentResult`` payloads and must be re-registered + in any new Python process that loads existing results. diff --git a/docs/source/decoding/model_comparison.rst b/docs/source/decoding/model_comparison.rst new file mode 100644 index 0000000..d60c80b --- /dev/null +++ b/docs/source/decoding/model_comparison.rst @@ -0,0 +1,181 @@ +.. _decoding-model-comparison: + +==================== +Model Comparison +==================== + +After running a decoding experiment with multiple models, ``coco_pipe.decoding`` +provides rigorous paired statistical tests to determine whether observed +performance differences are beyond chance. All comparison methods use +within-subject label swaps to control for subject-specific baseline variance. + +--- + +1. Why Paired Tests? +===================== + +Independent-sample tests compare two models assuming the samples are drawn +independently. In a within-subject decoding design, the **same subjects** appear +in both models' test folds, making the samples positively correlated. A paired +test exploits this correlation to achieve higher statistical power. + +**Paired permutation test**: randomly swap model assignments within each +independent unit (subject) and recompute the observed difference. The resulting +null distribution represents the expected difference under no true effect. + +--- + +2. Quick Paired Comparison +============================ + +For a fast paired comparison using existing outer-fold predictions: + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ClassicalModelConfig, CVConfig, SVMConfig + + config = ExperimentConfig( + task="classification", + models={ + "lr": ClassicalModelConfig(estimator="LogisticRegression"), + "svm": ClassicalModelConfig(estimator="SVC"), + }, + metrics=["accuracy", "roc_auc"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + + result = Experiment(config).run( + X, y, + sample_metadata={"Subject": subject_ids, "Session": session_ids} + ) + + paired = result.compare_models_paired( + "lr", "svm", + metric="accuracy", + unit="Subject", + n_permutations=5000, + random_state=42, + ) + + print(paired[["Metric", "ScoreA", "ScoreB", "Difference", "PValue", "Significant"]]) + +The returned DataFrame has one row per (temporal coordinate or scalar) with: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Column + - Description + * - ``ScoreA`` + - Observed score for model A. + * - ``ScoreB`` + - Observed score for model B. + * - ``Difference`` + - ``ScoreA - ScoreB``. + * - ``PValue`` + - Two-sided p-value from the sign-swap permutation distribution. + * - ``Significant`` + - Boolean: ``PValue <= 0.05``. + * - ``NUnits`` + - Number of independent units used for swapping. + * - ``NPermutations`` + - Number of permutations used. + +--- + +3. Multiple Model Comparison +============================== + +When comparing more than two models, use ``compare_models`` to compare all +pairs with optional multiple-comparison correction: + +.. code-block:: python + + comparison = result.compare_models( + metric="accuracy", + unit="Subject", + correction="fdr_bh", # or "bonferroni", "none" + n_permutations=5000, + ) + + print(comparison[["ModelA", "ModelB", "Difference", "PValue", "CorrectedPValue"]]) + +--- + +4. Full-Pipeline Paired Permutation Test +========================================== + +For rigorous inference where preprocessing, feature selection, and tuning must +be included in the null distribution, use ``run_paired_permutation_assessment``: + +.. code-block:: python + + from coco_pipe.decoding.stats import run_paired_permutation_assessment + from coco_pipe.decoding.configs import StatisticalAssessmentConfig, ChanceAssessmentConfig + + # Run two separate experiments with the same CV folds + result_a = Experiment(config_a).run(X, y, sample_metadata=meta) + result_b = Experiment(config_b).run(X, y, sample_metadata=meta) + + eval_config = StatisticalAssessmentConfig( + chance=ChanceAssessmentConfig( + n_permutations=1000, + temporal_correction="max_stat", # for temporal outputs + ), + unit_of_inference="sample", + random_state=42, + ) + + paired_df = run_paired_permutation_assessment( + result_a, result_b, "model_name", "accuracy", eval_config + ) + +.. note:: + + The two experiments must have been run with the **same outer CV configuration** + and the **same subjects**. The function aligns predictions at the ``SampleID`` + level before computing the difference. + +--- + +5. Interpreting Results +======================== + +.. rubric:: Effect Size + +The ``Difference`` column is the primary effect size. A small but significant +difference is not necessarily scientifically meaningful. Always report both the +magnitude and statistical significance. + +.. rubric:: Temporal Generalization Comparison + +For generalizing decoders, one comparison row is produced per +``(TrainTime, TestTime)`` cell. Apply temporal correction (``max_stat`` or +``fdr_bh``) to control the family-wise error rate across the matrix. + +.. rubric:: Multiple Model Pitfall + +If you run ``K`` pairwise comparisons without correction, the expected number of +false positives is ``0.05 × K``. Always apply correction when comparing more than +two models. + +--- + +6. Post-Hoc vs Full-Pipeline Comparison +========================================= + +.. list-table:: + :header-rows: 1 + :widths: 30 35 35 + + * - Method + - Speed + - Validity + * - ``compare_models_paired`` + - Fast (uses existing predictions) + - Valid if preprocessing did not use the comparison metric during fitting. + * - ``run_paired_permutation_assessment`` + - Slow (reruns full CV per permutation) + - Fully valid; recommended for publications. diff --git a/docs/source/decoding/models.rst b/docs/source/decoding/models.rst new file mode 100644 index 0000000..9115855 --- /dev/null +++ b/docs/source/decoding/models.rst @@ -0,0 +1,104 @@ +.. _decoding-models: + +================================ +Model Registry and Capabilities +================================ + +All estimators available in ``coco_pipe.decoding`` are registered in +``ESTIMATOR_SPECS`` (``coco_pipe.decoding._specs``). The registry is the +**single source of truth** for estimator class lookup, task support, input kind, +prediction interface, temporal compatibility, importance extraction, and default +hyperparameter search spaces. + +--- + +1. Registry API +================ + +.. code-block:: python + + from coco_pipe.decoding import ( + get_estimator_spec, + list_estimator_specs, + resolve_estimator_capabilities, + ) + + # Inspect a specific estimator + spec = get_estimator_spec("LogisticRegression") + print(spec.name) # "LogisticRegression" + print(spec.family) # "classical" + print(spec.task) # ["classification"] + print(spec.input_kinds) # ["tabular_2d", "embedding_2d"] + print(spec.supports_calibration) # True + print(spec.supports_feature_selection) # True + + # List all estimators + all_specs = list_estimator_specs() + + # Get capability metadata for model selection + caps = resolve_estimator_capabilities("SVC") + print(caps.has_response("predict_proba")) # False (LinearSVC) + +--- + +2. Auto-Generated Capability Table +===================================== + +The following table is generated automatically at documentation build time +from ``ESTIMATOR_SPECS`` in ``coco_pipe.decoding._specs``. It reflects the +exact state of the registry — no manual maintenance required. + +.. rubric:: Classification Estimators + +.. capability-table:: + :task: classification + +.. rubric:: Regression Estimators + +.. capability-table:: + :task: regression + +.. rubric:: All Estimators (including temporal and foundation models) + +.. capability-table:: + :task: all + :show-search-space: + + +The registry blocks incompatible combinations at configuration time: + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Blocked combination + - Error message + * - ``log_loss`` with ``LinearSVC`` (no calibration) + - ``"Metric requires probabilities, but model doesn't provide them"`` + * - ``roc_auc`` with a model that lacks both ``predict_proba`` and ``decision_function`` + - ``"Metric requires probabilities or decision scores, but model provides neither"`` + * - Feature selection on 3D temporal input + - ``"Feature selection is only valid for classical 2D tabular inputs"`` + * - Regression model with classification task + - ``"Model does not support task 'classification'"`` + +These checks prevent wasted compute time and silent errors in downstream statistics. + +--- + +6. Default Hyperparameter Search Spaces +========================================= + +Each ``EstimatorSpec`` includes a ``default_search_space`` dict for +``GridSearchCV``/``RandomizedSearchCV``. These are reasonable starting points: + +.. code-block:: python + + spec = get_estimator_spec("LogisticRegression") + print(spec.default_search_space) + # {"C": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]} + +When a ``TuningConfig.grids`` key is provided, it overrides the default. The raw +grid keys are automatically prefixed with ``clf__`` to match the pipeline step +name (e.g., ``{"C": [...]}`` becomes ``{"clf__C": [...]}``) unless the key +already contains ``__``. diff --git a/docs/source/decoding/result.rst b/docs/source/decoding/result.rst new file mode 100644 index 0000000..cc7e85a --- /dev/null +++ b/docs/source/decoding/result.rst @@ -0,0 +1,218 @@ +.. _decoding-result: + +============================== +``ExperimentResult`` API +============================== + +``ExperimentResult`` is the structured container returned by ``Experiment.run()``. +It provides 20+ accessor methods for tidy-data inspection, diagnostic reporting, +and statistical inference — all without rerunning the experiment. + +--- + +1. Structure +============= + +.. code-block:: python + + result.raw # per-model dict of fold outputs + result.meta # environment provenance, task, model names, capabilities + result.config # original ExperimentConfig + +--- + +2. Prediction Accessors +========================= + +.. code-block:: python + + # All out-of-fold predictions in tidy long form + preds = result.get_predictions() + # columns: Model, Fold, SampleIndex, SampleID, Group, y_true, y_pred + # + y_proba_0, y_proba_1, ... (if probabilities available) + # + Subject, Session, Site (from sample_metadata) + # + Time (sliding) or TrainTime, TestTime (generalizing) + +--- + +3. Score Accessors +=================== + +.. code-block:: python + + # Per-fold, per-metric scores + scores = result.get_detailed_scores() + # columns: Model, Fold, Metric, Score, Time (if temporal) + + # Fold-level split information + splits = result.get_splits(with_metadata=True) + + # Fit/predict/score timing and convergence warnings + fit_diag = result.get_fit_diagnostics() + +--- + +4. Curve Diagnostics +===================== + +.. code-block:: python + + # ROC curves (binary or one-vs-rest multiclass) + roc = result.get_roc_curve() + # columns: Model, Fold, Class, FPR, TPR, Threshold, AUC + + # Precision-recall curves + pr = result.get_pr_curve() + # columns: Model, Fold, Class, Precision, Recall, Threshold + + # Calibration (reliability) curves + cal = result.get_calibration_curve() + + # Probability quality summary (log-loss + Brier per fold) + prob_diag = result.get_probability_diagnostics() + + # Summary statistics for ROC AUC + roc_summary = result.get_roc_auc_summary() + + # Summary statistics for PR AUC + pr_summary = result.get_pr_auc_summary() + +--- + +5. Confusion Matrices +====================== + +.. code-block:: python + + # Per-fold confusion matrices in long form + cm = result.get_confusion_matrices(normalize=True) + # columns: Model, Fold, TrueLabel, PredLabel, Count + + # Pooled (over folds) confusion matrix + pooled_cm = result.get_pooled_confusion_matrix(normalize="true") + +--- + +6. Temporal Accessors +====================== + +.. code-block:: python + + # Score summary per timepoint (sliding only) + temporal = result.get_temporal_score_summary() + # columns: Model, Metric, Time, MeanScore, StdScore + + # Generalization matrix: shape (n_train_times, n_test_times) + matrix = result.get_generalization_matrix("accuracy") + # or long form: + matrix_long = result.get_generalization_matrix("accuracy", long=True) + +--- + +7. Statistical Inference +========================= + +.. code-block:: python + + # Full-pipeline or lightweight permutation/binomial assessment + assessment = result.get_statistical_assessment() + + # Lightweight (fixed-prediction, fast, biased) + assessment_fast = result.get_statistical_assessment(lightweight=True, metric="accuracy") + + # Bootstrap CI over independent units + ci = result.get_bootstrap_confidence_intervals( + metric="accuracy", + unit="Subject", + n_bootstraps=2000, + ci=0.95, + ) + + # Null distribution (if stored via store_null_distribution=True) + nulls = result.get_statistical_nulls() + +--- + +8. Model Comparison +==================== + +.. code-block:: python + + # Paired permutation test between two models (in-result) + paired = result.compare_models_paired("lr", "svm", metric="accuracy", unit="Subject") + + # All pairwise comparisons with correction + all_pairs = result.compare_models(metric="accuracy", correction="fdr_bh") + +--- + +9. Feature Importances +======================= + +.. code-block:: python + + # Mean ± std feature importance across folds + importances = result.get_feature_importances() + # columns: FeatureName, MeanImportance, StdImportance + + # Per-fold importances + fold_imp = result.get_feature_importances(fold_level=True) + + # Ranked importances (descending by mean) + ranked = result.get_feature_importances(rank=True) + +--- + +10. Feature Selection Accessors +================================= + +.. code-block:: python + + # Selected features per fold + selected = result.get_selected_features(ordered=True) + + # Feature stability: selection rate across folds + stability = result.get_feature_stability() + + # Per-fold univariate feature scores (k_best only) + scores = result.get_feature_scores(with_pvalues=True) + +--- + +11. Hyperparameter Tuning +=========================== + +.. code-block:: python + + # Best hyperparameters per fold + best = result.get_best_params() + + # Full grid search results + grid = result.get_search_results() + +--- + +12. Model Artifact Metadata +============================= + +.. code-block:: python + + # Neural model training history, checkpoints, etc. + artifacts = result.get_model_artifacts() + +--- + +13. Serialization +================== + +.. code-block:: python + + # Serialize to JSON-compatible payload + payload = result.to_payload() + + # Save to file + path = result.save("results/my_result.json") + + # Load from file + from coco_pipe.decoding.result import ExperimentResult + loaded = ExperimentResult.load("results/my_result.json") diff --git a/docs/source/decoding/stats.rst b/docs/source/decoding/stats.rst new file mode 100644 index 0000000..495a583 --- /dev/null +++ b/docs/source/decoding/stats.rst @@ -0,0 +1,273 @@ +.. _decoding-stats: + +================================== +Statistical Assessment Guide +================================== + +``coco_pipe.decoding`` cleanly separates **descriptive** CV performance from +**inferential** claims. Descriptive metrics (fold scores, confusion matrices, +curves) are always computed. Inferential statistics are opt-in and require +explicit configuration of ``StatisticalAssessmentConfig``. + +--- + +1. Two Levels of Assessment +============================= + +1.1 Descriptive Performance +----------------------------- + +Every ``ExperimentResult`` provides fold-level and summary scores without +any statistical testing: + +.. code-block:: python + + scores = result.get_detailed_scores() + print(scores[["Model", "Fold", "Metric", "Score"]]) + + # Per-model summary: mean ± std across folds + summary = scores.groupby(["Model", "Metric"])["Score"].agg(["mean", "std"]) + +This is the correct starting point for all decoding reports. Always report +fold-level variability alongside the mean. + +1.2 Finite-Sample Inferential Assessment +------------------------------------------ + +Statistical significance claims require a null distribution. ``coco_pipe.decoding`` +supports two null-generation strategies: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - Method + - How the null is generated + - When to use + * - ``"permutation"`` + - Full outer CV rerun under label permutations. + - Gold standard. Correct for any preprocessing pipeline. + * - ``"binomial"`` + - Analytical Clopper-Pearson interval on hard accuracy. + - Only valid for scalar accuracy, one prediction per independent unit. + +--- + +2. Full-Pipeline Permutation Assessment +========================================= + +.. code-block:: python + + from coco_pipe.decoding.configs import ( + ExperimentConfig, CVConfig, ClassicalModelConfig, + StatisticalAssessmentConfig, ChanceAssessmentConfig, + ) + + config = ExperimentConfig( + task="classification", + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=1000, + unit_of_inference="group_mean", + ), + ), + ) + + result = Experiment(config).run( + X, y, + sample_metadata={"Subject": subject_ids, "Session": session_ids}, + observation_level="epoch", + ) + + assessment = result.get_statistical_assessment() + +The returned DataFrame contains, per model and metric: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Column + - Description + * - ``Observed`` + - Observed score on the true labels. + * - ``PValue`` + - Empirical p-value from permutation distribution. + * - ``CorrectedPValue`` + - Multiple-comparison corrected p-value. + * - ``Significant`` + - Boolean: ``CorrectedPValue <= alpha``. + * - ``CILower``, ``CIUpper`` + - Bootstrap CI for the observed score. + * - ``NullMedian``, ``NullLower``, ``NullUpper`` + - Null distribution percentiles. + * - ``NPermutations`` + - Number of permutations used. + * - ``NEff`` + - Effective sample size (number of independent units). + * - ``Time`` / ``TrainTime`` / ``TestTime`` + - Present only for temporal outputs. + +2.1 Label Permutation Inside Groups +------------------------------------- + +When ``unit_of_inference="group_mean"`` and ``cv.group_key`` is set, labels are +permuted **across groups**, not within them. This preserves within-subject epoch +correlations in the null distribution, yielding a correctly-calibrated p-value +for the group-level null hypothesis. + +.. note:: + + Permuting only within subjects (swapping epochs within a subject) would be + the wrong null for testing whether the model performs above chance at the + **population** level. + +--- + +3. Binomial Assessment +======================== + +Binomial testing uses the Clopper-Pearson exact interval. It is valid only when: + +- Task is classification. +- Metric is plain ``accuracy``. +- Each independent unit contributes **exactly one** prediction (no aggregation needed). +- An explicit chance level ``p0`` is provided. + +.. code-block:: python + + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="binomial", + p0=0.5, # chance level for binary classification + ), + confidence_intervals=ConfidenceIntervalConfig( + method="clopper_pearson", # or "wilson" + alpha=0.05, + ), + ) + +The test statistic is: + +.. math:: + + p = 1 - F(k - 1; n, p_0) + +where :math:`F` is the binomial CDF, :math:`k` is the number of correct +predictions, and :math:`n` is the number of independent observations. + +--- + +4. Bootstrap Confidence Intervals +=================================== + +Confidence intervals for any metric can be computed independently of the +permutation test, using non-parametric bootstrap over independent units: + +.. code-block:: python + + ci = result.get_bootstrap_confidence_intervals( + metric="accuracy", + unit="Subject", # or "Session", "sample", etc. + n_bootstraps=2000, + ci=0.95, + ) + +Bootstrap CI is also automatically included in the permutation assessment output +(``CILower``, ``CIUpper`` columns). + +--- + +5. Temporal Correction Methods +================================ + +For sliding/generalizing decoders, one p-value per timepoint must be corrected +for multiple comparisons. Set ``temporal_correction`` in ``ChanceAssessmentConfig``: + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Method + - Description + * - ``"max_stat"`` + - Permutation Max-Stat (default). FWER control. Uses the global maximum + of the permutation null at each timepoint. Recommended for temporal data + with moderate-to-high positive correlation between timepoints. + * - ``"fdr_bh"`` + - Benjamini-Hochberg FDR. Controls the expected proportion of false + discoveries. More powerful than Max-Stat but weaker guarantees. + * - ``"fdr_by"`` + - Benjamini-Yekutieli FDR. Valid under arbitrary dependence. More + conservative than BH. + * - ``"none"`` + - No correction. For exploratory analysis only. + +--- + +6. Lightweight Post-Hoc Diagnostics +====================================== + +For quick exploratory inspection without rerunning training: + +.. code-block:: python + + # Lightweight label permutation over fixed predictions (fast but biased) + null = result.get_statistical_assessment(lightweight=True, metric="accuracy") + + # Direct post-hoc permutation (bypasses full retraining) + from coco_pipe.decoding.stats import assess_post_hoc_permutation + posthoc = assess_post_hoc_permutation(result.raw["lr"], metric="accuracy", n_permutations=500) + +.. warning:: + + Post-hoc permutations that shuffle labels over **fixed** predictions do not + account for preprocessing, feature selection, or hyperparameter search. They + underestimate the null and can produce overly optimistic p-values if any of + these steps used the labels. Use ``method="permutation"`` (full-pipeline) for + any claim of statistical significance in publications. + +--- + +7. Paired Model Comparison +============================ + +See :ref:`decoding-model-comparison` for a full guide. Quick reference: + +.. code-block:: python + + # Paired permutation test: does model A outperform model B? + paired = result.compare_models_paired("lr", "svm", metric="accuracy") + print(paired[["Difference", "PValue", "Significant"]]) + + # Full-pipeline paired assessment across two result objects + from coco_pipe.decoding.stats import run_paired_permutation_assessment + df = run_paired_permutation_assessment( + result_a, result_b, "lr", "accuracy", config=eval_config + ) + +--- + +8. Unit of Inference Options +============================== + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Value + - Aggregation behavior + * - ``"sample"`` + - No aggregation. Each prediction row is treated as independent. + * - ``"group_mean"`` + - Average probabilities per group, then classify. Recommended for epoch-level EEG. + * - ``"group_majority"`` + - Majority vote of hard labels per group. + * - ``"custom"`` + - Aggregate by a named column in ``sample_metadata``. diff --git a/docs/source/decoding/temporal_decoding.rst b/docs/source/decoding/temporal_decoding.rst new file mode 100644 index 0000000..d5e17ad --- /dev/null +++ b/docs/source/decoding/temporal_decoding.rst @@ -0,0 +1,227 @@ +.. _decoding-temporal: + +============================ +Temporal Decoding +============================ + +Temporal decoding applies a classifier or regressor independently at each +timepoint (or pair of timepoints) of an EEG/MEG epoch. The input array has +shape ``(n_samples, n_channels, n_times)`` — sometimes called a 3D or epochs +array. MNE meta-estimators (``SlidingEstimator``, ``GeneralizingEstimator``) +orchestrate the per-timepoint fitting inside the ``coco_pipe.decoding`` outer +CV loop. + +--- + +1. Data Format Requirements +============================= + +Temporal decoding expects 3D arrays: + +- **Axis 0**: samples (trials / epochs). +- **Axis 1**: channels or features. +- **Axis 2**: timepoints. + +.. code-block:: python + + X.shape # (n_epochs, n_channels, n_times) + y.shape # (n_epochs,) + +Pass the physical time axis (in seconds) for meaningful temporal labels in outputs: + +.. code-block:: python + + time_axis = epochs.times # NumPy array, shape (n_times,) + +When omitted, integer timepoint indices are used. + +--- + +2. Sliding Estimator +====================== + +A ``SlidingEstimator`` fits one independent model per timepoint. It is +equivalent to looping over the time axis, extracting each timepoint's +channel-space snapshot, and training a model on it. + +.. code-block:: python + + from coco_pipe.decoding import Experiment, ExperimentConfig + from coco_pipe.decoding.configs import ( + ClassicalModelConfig, TemporalDecoderConfig, CVConfig + ) + + config = ExperimentConfig( + task="classification", + models={ + "sliding_lr": TemporalDecoderConfig( + wrapper="sliding", + base=ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 200}, + ), + scoring="accuracy", + n_jobs=-1, + ) + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + + result = Experiment(config).run( + X_epochs, # shape (n_epochs, n_channels, n_times) + y, + sample_metadata={"Subject": subject_ids, "Session": session_ids}, + time_axis=epochs.times, + ) + +2.1 Outputs +----------- + +.. code-block:: python + + # Score curve: one score per timepoint per fold + scores = result.get_detailed_scores() + # columns: Model, Fold, Metric, Score, Time + + # Summary over folds: mean ± std per timepoint + temporal = result.get_temporal_score_summary() + + # Long-form predictions with Time column + preds = result.get_predictions() + +2.2 Plotting +------------ + +.. code-block:: python + + from coco_pipe.viz import plot_temporal_score_curve + + fig = plot_temporal_score_curve(result, metric="accuracy") + fig.savefig("sliding_accuracy.png") + +--- + +3. Generalizing Estimator (Temporal Generalization) +====================================================== + +A ``GeneralizingEstimator`` fits one model per training timepoint and evaluates +it at **every** test timepoint. The result is a +``(n_train_times, n_test_times)`` matrix of scores. + +.. code-block:: python + + config = ExperimentConfig( + task="classification", + models={ + "generalizing_lr": TemporalDecoderConfig( + wrapper="generalizing", + base=ClassicalModelConfig( + estimator="LogisticRegression", + params={"max_iter": 200}, + ), + n_jobs=-1, + ) + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"), + ) + + result = Experiment(config).run( + X_epochs, y, + sample_metadata={"Subject": subject_ids, "Session": session_ids}, + time_axis=epochs.times, + ) + +3.1 Outputs +----------- + +.. code-block:: python + + # Long-form predictions with TrainTime and TestTime columns + preds = result.get_predictions() + + # Score matrix: mean over folds, shape (n_train_times, n_test_times) + matrix = result.get_generalization_matrix("accuracy") + # or long form: matrix = result.get_generalization_matrix("accuracy", long=True) + +3.2 Scientific Interpretation +------------------------------ + +- **Main diagonal** (``TrainTime == TestTime``): equivalent to the sliding + decoder result. +- **Off-diagonal generalization**: a classifier trained at time :math:`t_1` + tested at :math:`t_2`. High off-diagonal scores indicate a **sustained neural + representation** whose format is preserved across the generalizing time window. +- **Asymmetric generalization**: training at late times generalizes to early + times but not vice versa — suggests temporal ordering of information flow. + +3.3 Plotting +------------ + +.. code-block:: python + + from coco_pipe.viz import plot_temporal_generalization_matrix + + fig = plot_temporal_generalization_matrix(result, metric="accuracy") + fig.savefig("generalization_matrix.png") + +--- + +4. Statistical Assessment for Temporal Outputs +================================================ + +Each timepoint (or train-time/test-time cell) produces an independent score. +Multiple comparison correction is required. + +.. code-block:: python + + from coco_pipe.decoding.configs import StatisticalAssessmentConfig, ChanceAssessmentConfig + + config = ExperimentConfig( + ..., + evaluation=StatisticalAssessmentConfig( + enabled=True, + chance=ChanceAssessmentConfig( + method="permutation", + n_permutations=1000, + temporal_correction="max_stat", # FWER control over timepoints + unit_of_inference="group_mean", + ), + ), + ) + + result = Experiment(config).run(X_epochs, y, ...) + assessment = result.get_statistical_assessment() + # One row per (Model, Metric, Time) or (Model, Metric, TrainTime, TestTime) + +The ``CorrectedPValue`` and ``Significant`` columns reflect the chosen +temporal correction. See :ref:`decoding-stats` for correction method details. + +--- + +5. Feature Importances in Temporal Models +========================================== + +Feature importances from temporal models are averaged across timepoints and +folds. Each timepoint contributes one importance vector (of length ``n_channels``): + +.. code-block:: python + + importances = result.get_feature_importances() + # columns: FeatureName, MeanImportance, StdImportance, Time + +Temporal importance patterns can reveal which channels drive decoding at each +timepoint — a form of spatiotemporal source localization. + +--- + +6. Compatibility Notes +======================= + +- Feature selection (``k_best``, ``sfs``) is **not compatible** with 3D temporal + inputs. The registry blocks this combination at validation time. +- Standard scalers are not applied to 3D inputs (they expect 2D arrays). Use + channel-wise normalization within the ``base`` estimator's pipeline if needed. +- ``n_jobs`` inside ``TemporalDecoderConfig`` controls parallelism for the + per-timepoint model fitting, separate from the outer CV ``n_jobs``. diff --git a/docs/source/index.rst b/docs/source/index.rst index f290d87..d324997 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,8 +6,9 @@ Welcome to coco-pipe's documentation! :caption: Contents: README.md + api_reference.md vision.md dim_reduction.md + decoding/index auto_examples/index.rst - autoapi/index.rst GitHub Repository diff --git a/docs/source/sg_execution_times.rst b/docs/source/sg_execution_times.rst index 6e59185..f2ade99 100644 --- a/docs/source/sg_execution_times.rst +++ b/docs/source/sg_execution_times.rst @@ -6,7 +6,7 @@ Computation times ================= -**00:38.939** total execution time for 10 files **from all galleries**: +**00:23.339** total execution time for 11 files **from all galleries**: .. container:: @@ -33,32 +33,35 @@ Computation times - Time - Mem (MB) * - :ref:`sphx_glr_auto_examples_plot_quality_metrics.py` (``../../examples/plot_quality_metrics.py``) - - 00:28.831 + - 00:12.160 - 0.0 * - :ref:`sphx_glr_auto_examples_benchmark_dim_reduction.py` (``../../examples/benchmark_dim_reduction.py``) - - 00:04.345 + - 00:03.607 - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_velocity_embedding.py` (``../../examples/plot_velocity_embedding.py``) - - 00:03.669 + * - :ref:`sphx_glr_auto_examples_descriptors_example.py` (``../../examples/descriptors_example.py``) + - 00:03.553 - 0.0 - * - :ref:`sphx_glr_auto_examples_plot_scientific_dimred.py` (``../../examples/plot_scientific_dimred.py``) - - 00:01.047 + * - :ref:`sphx_glr_auto_examples_plot_velocity_embedding.py` (``../../examples/plot_velocity_embedding.py``) + - 00:03.213 - 0.0 * - :ref:`sphx_glr_auto_examples_demo_report.py` (``../../examples/demo_report.py``) - - 00:00.395 + - 00:00.505 - 0.0 * - :ref:`sphx_glr_auto_examples_compare_phate_umap.py` (``../../examples/compare_phate_umap.py``) - - 00:00.337 + - 00:00.184 - 0.0 - * - :ref:`sphx_glr_auto_examples_demo_pipeline.py` (``../../examples/demo_pipeline.py``) - - 00:00.296 + * - :ref:`sphx_glr_auto_examples_plot_scientific_dimred.py` (``../../examples/plot_scientific_dimred.py``) + - 00:00.084 - 0.0 * - :ref:`sphx_glr_auto_examples_demo_iterative_balancing.py` (``../../examples/demo_iterative_balancing.py``) - - 00:00.011 + - 00:00.013 - 0.0 - * - :ref:`sphx_glr_auto_examples_demo_structures.py` (``../../examples/demo_structures.py``) - - 00:00.006 + * - :ref:`sphx_glr_auto_examples_demo_pipeline.py` (``../../examples/demo_pipeline.py``) + - 00:00.013 - 0.0 * - :ref:`sphx_glr_auto_examples_compare_dim_reduction.py` (``../../examples/compare_dim_reduction.py``) - - 00:00.002 + - 00:00.005 + - 0.0 + * - :ref:`sphx_glr_auto_examples_demo_structures.py` (``../../examples/demo_structures.py``) + - 00:00.003 - 0.0 diff --git a/pyproject.toml b/pyproject.toml index 24f32e9..da1e592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,10 +18,10 @@ classifiers = [ dependencies = [ "numpy>=2.0.0", "pandas", - "scikit-learn", + "scikit-learn>1.6", "matplotlib", "seaborn", - "scipy", + "scipy>=1.11.0", "pyyaml", "pydantic", "hydra-core", @@ -88,6 +88,35 @@ topology = [ "skorch", "gudhi", ] +torch = [ + "torch", +] +braindecode = [ + "braindecode", +] +hf = [ + "transformers", + "huggingface-hub", +] +peft = [ + "peft", + "accelerate", +] +quant = [ + "bitsandbytes", +] +moabb = [ + "moabb", + "mne", +] +torcheeg = [ + "torcheeg", +] +dev-gpu = [ + "pytest>=7.0", + "torch", + "accelerate", +] spatiotemporal = [ "pydmd", "meegkit", @@ -149,9 +178,6 @@ Homepage = "https://github.com/BabaSanfour/coco-pipe" Documentation = "https://cocopipe.readthedocs.io/" Repository = "https://github.com/BabaSanfour/coco-pipe.git" -[project.scripts] -run-coco-pipe = "scripts.run_ml:main" - [tool.setuptools.packages.find] where = ["."] include = ["coco_pipe*", "scripts*"] @@ -161,7 +187,7 @@ include = ["coco_pipe*", "scripts*"] [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra -q --ignore-glob=tests/test_ml_*.py" +addopts = "-ra -q" testpaths = [ "tests", ] @@ -169,6 +195,9 @@ testpaths = [ [tool.ruff] line-length = 88 target-version = "py310" +extend-exclude = [ + "coco_pipe/decoding/fm_hub/cbramod_src", +] [tool.ruff.lint] select = [ diff --git a/scripts/plot_feature_analysis.py b/scripts/plot_feature_analysis.py deleted file mode 100644 index 62ad4b2..0000000 --- a/scripts/plot_feature_analysis.py +++ /dev/null @@ -1,340 +0,0 @@ -#!/usr/bin/env python3 -""" -Plot feature-level results (analysis unit = feature) from an aggregated coco_pipe run. - -For each computed feature (one analysis per feature using all sensors as inputs), -this script: - - Plots a topomap of sensor importances (one topomap per computed feature). - - Aggregates the per-feature accuracy (or chosen metric) into a bar plot. - -Inputs ------- -- Aggregated pickle from scripts/run_ml.py (results/.pkl) -- Sensor coordinates via either: - * --coords CSV/TSV/JSON file with columns name,x,y (case-insensitive), or - * --use-mne and --montage to generate from an MNE standard montage - -Outputs -------- -- One topomap image per computed feature. -- A bar plot summarizing per-feature metric (e.g., accuracy). - -Example -------- -python scripts/plot_feature_analysis.py \ - --results results/toy_ml_config.pkl \ - --use-mne --montage standard_1020 \ - --model "Random Forest" \ - --metric accuracy \ - --out-dir results/feature_plots -""" - -import argparse -import json -import os -import re -from typing import Dict, Mapping, Optional, Sequence - -import matplotlib.pyplot as plt -import pandas as pd - -from coco_pipe.viz import plot_bar, plot_topomap - - -def load_coords(path: str) -> pd.DataFrame: - lower = path.lower() - if lower.endswith((".json", ".jsn")): - with open(path, "r") as f: - raw = json.load(f) - rows = [] - for name, val in raw.items(): - if isinstance(val, (list, tuple)) and len(val) >= 2: - x, y = float(val[0]), float(val[1]) - elif isinstance(val, dict) and "x" in val and "y" in val: - x, y = float(val["x"]), float(val["y"]) - else: - raise ValueError(f"Invalid JSON coord entry for {name}: {val}") - rows.append((name, x, y)) - return pd.DataFrame(rows, columns=["name", "x", "y"]).set_index("name") - - df = pd.read_csv(path, sep=None, engine="python") - cols = {c.lower(): c for c in df.columns} - name_col = next( - (cols[c] for c in ("name", "sensor", "channel", "id") if c in cols), None - ) - if name_col is None: - name_col = df.columns[0] - x_col = next((cols[c] for c in ("x", "xs", "xpos", "x_coord") if c in cols), None) - y_col = next((cols[c] for c in ("y", "ys", "ypos", "y_coord") if c in cols), None) - if x_col is None or y_col is None: - if df.shape[1] >= 3: - x_col = x_col or df.columns[1] - y_col = y_col or df.columns[2] - else: - raise ValueError("Coordinates file must include numeric x,y columns.") - cdf = df[[name_col, x_col, y_col]].copy() - cdf.columns = ["name", "x", "y"] - cdf = cdf.set_index("name") - cdf["x"] = pd.to_numeric(cdf["x"], errors="coerce") - cdf["y"] = pd.to_numeric(cdf["y"], errors="coerce") - return cdf.dropna() - - -def generate_coords_from_mne( - montage: str = "standard_1020", restrict_to: Optional[Sequence[str]] = None -) -> pd.DataFrame: - try: - import mne # type: ignore - except Exception as e: - raise ImportError( - "This feature requires 'mne'. Install it via 'pip install mne'." - ) from e - std_montage = mne.channels.make_standard_montage(montage) - pos = std_montage.get_positions() - ch_pos = pos.get("ch_pos", {}) - rows = [] - names = list(restrict_to) if restrict_to else list(ch_pos.keys()) - for name in names: - key = name - if key not in ch_pos: - if name.upper() in ch_pos: - key = name.upper() - elif name.capitalize() in ch_pos: - key = name.capitalize() - else: - continue - xyz = ch_pos[key] - rows.append((name, float(xyz[0]), float(xyz[1]))) - if not rows: - rows = [(n, float(v[0]), float(v[1])) for n, v in ch_pos.items()] - return pd.DataFrame(rows, columns=["name", "x", "y"]).set_index("name") - - -def pick_model( - results_per_model: Dict[str, dict], preferred: Optional[Sequence[str]] = None -) -> str: - preferred = preferred or ("Logistic Regression",) - for m in preferred: - if m in results_per_model: - return m - return next(iter(results_per_model)) - - -def sensor_from_column( - col: str, sensors: Sequence[str], sep: str, reverse: bool -) -> Optional[str]: - sensors_set = set(sensors) - if sep in col: - left, right = col.split(sep, 1) - cand = right if reverse else left - if cand in sensors_set: - return cand - up = cand.upper() - cap = cand.capitalize() - if up in sensors_set: - return up - if cap in sensors_set: - return cap - # fallback: look for a sensor token within the column name - for s in sensors: - if re.search(rf"\b{re.escape(s)}\b", col, flags=re.IGNORECASE): - return s - return None - - -def feature_from_columns(cols: Sequence[str], sep: str, reverse: bool) -> Optional[str]: - tokens = [] - for c in cols: - if sep not in c: - continue - left, right = c.split(sep, 1) - tokens.append(left if reverse else right) - if not tokens: - return None - # return most common token - return pd.Series(tokens).value_counts().idxmax() - - -def main(): - parser = argparse.ArgumentParser( - description="Plot per-feature topomaps of sensor importances and bar " - "plot of per-feature metric." - ) - parser.add_argument( - "--results", - required=True, - help="Path to aggregated results pickle (from run_ml.py)", - ) - parser.add_argument( - "--coords", required=False, help="Path to sensor coordinates (CSV/TSV/JSON)" - ) - parser.add_argument( - "--use-mne", - action="store_true", - help="Generate sensor coordinates from MNE standard montage", - ) - parser.add_argument( - "--montage", - default="standard_1020", - help="MNE montage name (default: standard_1020)", - ) - parser.add_argument( - "--model", - default=None, - help="Model name to use (default: try 'Logistic Regression' else first)", - ) - parser.add_argument( - "--metric", - default="accuracy", - help="Metric to plot per feature (default: accuracy)", - ) - parser.add_argument( - "--sep", - default="_", - help="Separator between unit and feature in column names (default: _)", - ) - parser.add_argument( - "--reverse", - action="store_true", - help="If set, interpret columns as ", - ) - parser.add_argument( - "--out-dir", - default="results/feature_analysis_plots", - help="Directory to save plots", - ) - parser.add_argument( - "--label-map", - default=None, - help="Optional JSON mapping from feature name to display label", - ) - parser.add_argument( - "--no-show", - action="store_true", - help="Do not open interactive windows; save only", - ) - - args = parser.parse_args() - - if args.no_show: - import matplotlib - - matplotlib.use("Agg") - - if not os.path.exists(args.results): - raise FileNotFoundError(args.results) - - os.makedirs(args.out_dir, exist_ok=True) - - all_results: Dict[str, Dict[str, dict]] = pd.read_pickle(args.results) - - # Attempt to infer sensor names from columns later; for MNE restriction, - # we can pass None - if args.use_mne or not args.coords: - coords_df = generate_coords_from_mne(args.montage) - else: - if not os.path.exists(args.coords): - raise FileNotFoundError(args.coords) - coords_df = load_coords(args.coords) - - sensors = coords_df.index.tolist() - - # Optional feature label mapping - feature_label_map: Mapping[str, str] = {} - if args.label_map: - with open(args.label_map, "r") as f: - feature_label_map = json.load(f) - - # Collect per-feature metric values and produce topomaps - feature_metric: Dict[str, float] = {} - - for analysis_id, results_per_model in all_results.items(): - model_name = pick_model( - results_per_model, preferred=(args.model,) if args.model else None - ) - res = results_per_model[model_name] - - # metric - metrics = res.get("metric_scores", {}) - metric_name = ( - args.metric - if args.metric in metrics - else (next(iter(metrics)) if metrics else None) - ) - if metric_name is None: - continue - - # importances - fi = res.get("feature_importances", {}) - if not fi: - continue - - col_names = list(fi.keys()) - # Derive computed feature name from columns - feat_name = feature_from_columns( - col_names, sep=args.sep, reverse=args.reverse - ) or str(analysis_id) - - # sensor importances: prefer weighted_mean, else mean - sensor_imp: Dict[str, float] = {} - for col, stats in fi.items(): - sname = sensor_from_column(col, sensors, args.sep, args.reverse) - if not sname: - continue - val = stats.get("weighted_mean") - if val is None: - val = stats.get("mean") - if val is None: - continue - sensor_imp[sname] = float(val) - - if not sensor_imp: - continue - - # Plot topomap for this computed feature - disp_name = feature_label_map.get(feat_name, feat_name) - fig, ax = plot_topomap( - sensor_imp, - coords_df[["x", "y"]], - title=f"{disp_name} – Sensor Importances ({model_name})", - cbar_label="Importance", - sensors="markers", - cmap="magma", - ) - out_path = os.path.join( - args.out_dir, f"topomap_{re.sub(r'[^A-Za-z0-9_.-]+', '_', disp_name)}.png" - ) - fig.savefig(out_path, dpi=150) - plt.close(fig) - - # Store metric - mean_val = ( - float(metrics[metric_name]["mean"]) - if isinstance(metrics[metric_name], dict) - else float(metrics[metric_name]) - ) - feature_metric[disp_name] = mean_val - - if not feature_metric: - raise RuntimeError( - "No per-feature metrics collected; check results structure and options." - ) - - # Bar plot of per-feature metric - # Sort descending for visibility - s = pd.Series(feature_metric).sort_values(ascending=False) - fig2, ax2 = plot_bar( - s, - orientation="horizontal", - title=f"Per-Feature {args.metric.capitalize()} ({args.model or 'auto'})", - xlabel=args.metric.capitalize(), - cmap="viridis", - ) - fig2.savefig(os.path.join(args.out_dir, f"features_{args.metric}_bar.png"), dpi=150) - if not args.no_show: - plt.show() - plt.close(fig2) - - -if __name__ == "__main__": - main() diff --git a/scripts/plot_lasso_importances.py b/scripts/plot_lasso_importances.py deleted file mode 100644 index ccef67a..0000000 --- a/scripts/plot_lasso_importances.py +++ /dev/null @@ -1,387 +0,0 @@ -#!/usr/bin/env python3 -""" -Plot top-N feature importances for a Lasso-regularized Logistic Regression run, -and annotate accuracy and the number of zeroed features on the plot. - -Inputs ------- -- Aggregated pickle from scripts/run_ml.py (results/.pkl) - -Behavior --------- -- Select an analysis (by --analysis-id or first found) and a model - (default 'Logistic Regression'). -- Extract per-feature importances from results['feature_importances']. -- Rank by absolute importance by default (configurable) and plot top-N as a - bar chart. -- Count zeroed features (all fold importances ~ 0 within tolerance) and include in title - along with mean accuracy. - -Example -------- -python scripts/plot_lasso_importances.py \ - --results results/toy_ml_config.pkl \ - --analysis-id classification_baseline \ - --model "Logistic Regression" \ - --metric accuracy \ - --top-n 20 \ - --abs \ - --save results/lasso_importances.png -""" - -import argparse -import json -import os -import re -from typing import Dict, Mapping, Optional, Sequence - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from coco_pipe.io import load_data -from coco_pipe.viz import plot_bar, plot_scatter2d - - -def pick_analysis( - all_results: Dict[str, Dict[str, dict]], analysis_id: Optional[str] -) -> str: - if analysis_id: - if analysis_id not in all_results: - raise KeyError(f"Analysis id '{analysis_id}' not found in results.") - return analysis_id - # fallback: first key - return next(iter(all_results)) - - -def pick_model( - results_per_model: Dict[str, dict], preferred: Optional[Sequence[str]] = None -) -> str: - preferred = preferred or ("Logistic Regression",) - for m in preferred: - if m in results_per_model: - return m - return next(iter(results_per_model)) - - -def _greekify(text: str) -> str: - rep = {"alpha": "α", "beta": "β", "gamma": "γ", "theta": "θ", "delta": "δ"} - for k, v in rep.items(): - text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE) - return text - - -def make_label_map_keep_sensor(cols: Sequence[str]) -> Dict[str, str]: - """Build compact labels per column, preserving the sensor name. - - Expected column format: 'feature-.spaces-'. - Keeps '' and abbreviates ''. - """ - abbrev = { - "BandRatiosFromAverageFooof": "BR Fooof", - "BandRatiosFromAverageSpectrum": "BR Spec", - "RelativeBandPowerFromAverageFooof": "RBP Fooof", - "RelativeBandPowerFromAverageSpectrum": "RBP Spec", - "higuchiFd": "Higuchi FD", - "katzFd": "Katz FD", - "petrosianFd": "Petrosian FD", - "hjorthComplexity": "Hjorth Complexity", - "hjorthMobility": "Hjorth Mobility", - "numZerocross": "ZeroCross", - "svdEntropy": "SVD Entropy", - "spectralEntropy": "Spectral Entropy", - "sampleEntropy": "Sample Entropy", - "permEntropy": "Perm Entropy", - "entropyMultiscale": "MSE", - "fooofExponent": "FOOOF Exp", - "foofOffset": "FOOOF Off", - "fooofOffset": "FOOOF Off", - } - - out: Dict[str, str] = {} - for raw in cols: - s = raw - if s.startswith("feature-"): - s = s[len("feature-") :] - # Extract sensor from trailing '.spaces-' - sensor = None - m_sensor = re.search(r"\.spaces-([A-Za-z0-9]+)$", s) - if m_sensor: - sensor = m_sensor.group(1) - s = s[: m_sensor.start()] # strip the sensor suffix - # Abbreviate feature part - m_pair = re.search(r"bands_pairs-\((.+)\)", s) - if m_pair: - pair = m_pair.group(1).replace("'", "").replace(" ", "").replace(",", "/") - pair = _greekify(pair) - head = s[: m_pair.start()].rstrip(".") - for k, v in abbrev.items(): - head = head.replace(k, v) - feat_label = f"{head} {pair}".strip() - else: - feat_label = s - for k, v in abbrev.items(): - feat_label = feat_label.replace(k, v) - feat_label = _greekify(feat_label) - feat_label = feat_label.replace("MeanEpochs", "") - feat_label = re.sub(r"[_.-]+$", "", feat_label).strip() - - out[raw] = f"{sensor or ''} — {feat_label}".strip(" —") - return out - - -def main(): - parser = argparse.ArgumentParser( - description="Plot L1-LR feature importances (top-N) with zeroed " - "count and accuracy." - ) - parser.add_argument( - "--results", - required=True, - help="Path to aggregated results pickle (from run_ml.py)", - ) - parser.add_argument( - "--analysis-id", - default=None, - help="Analysis id key from the aggregated results dict", - ) - parser.add_argument( - "--model", - default=None, - help="Model name (default: try 'Logistic Regression' else first)", - ) - parser.add_argument( - "--metric", - default="accuracy", - help="Metric name to show in title (default: accuracy)", - ) - parser.add_argument( - "--top-n", type=int, default=20, help="Top-N features to plot (default: 20)" - ) - parser.add_argument( - "--abs", - dest="use_abs", - action="store_true", - help="Rank by absolute importance (default)", - ) - parser.add_argument( - "--no-abs", - dest="use_abs", - action="store_false", - help="Rank by signed importance", - ) - parser.set_defaults(use_abs=True) - parser.add_argument( - "--zero-threshold", - type=float, - default=1e-12, - help="Threshold to treat coefficients as zero across folds", - ) - parser.add_argument( - "--label-map", - default=None, - help="Optional JSON mapping from feature name to display label", - ) - parser.add_argument( - "--xlabel", - default=None, - help="X-axis label (default: 'Coefficient magnitude' " - "if --abs else 'Coefficient')", - ) - parser.add_argument( - "--save", default=None, help="Path to save the figure (optional)" - ) - # For scatter plots - parser.add_argument( - "--data", - required=False, - help="Path to original dataset (CSV/TSV/Excel) for scatter plots", - ) - parser.add_argument( - "--target", required=False, help="Target column name in the dataset" - ) - parser.add_argument( - "--sep", - dest="csv_sep", - default=None, - help="CSV separator if needed (auto by extension otherwise)", - ) - parser.add_argument( - "--sheet", - dest="sheet_name", - default=None, - help="Excel sheet name if applicable", - ) - parser.add_argument( - "--scatter-dir", default=None, help="Directory to save scatter plots (optional)" - ) - parser.add_argument( - "--no-show", - action="store_true", - help="Do not open interactive windows; save only", - ) - - args = parser.parse_args() - - if args.no_show: - import matplotlib - - matplotlib.use("Agg") - - if not os.path.exists(args.results): - raise FileNotFoundError(args.results) - - all_results: Dict[str, Dict[str, dict]] = pd.read_pickle(args.results) - aid = pick_analysis(all_results, args.analysis_id) - res_per_model = all_results[aid] - model_name = pick_model( - res_per_model, preferred=(args.model,) if args.model else None - ) - res = res_per_model[model_name] - - # Accuracy (or desired metric) - metrics = res.get("metric_scores", {}) - metric_name = ( - args.metric - if args.metric in metrics - else (next(iter(metrics)) if metrics else None) - ) - acc_mean = ( - float(metrics[metric_name]["mean"]) - if metric_name and isinstance(metrics[metric_name], dict) - else None - ) - - # Feature importances - fi = res.get("feature_importances", {}) - if not fi: - raise RuntimeError( - "No feature_importances found in results for the selected model." - ) - - # Build series of importance (weighted_mean if present else mean) - values = {} - zeros = 0 - for fname, stats in fi.items(): - imp = stats.get("weighted_mean", stats.get("mean", 0.0)) - # zero detection using fold_importances - folds = np.asarray(stats.get("fold_importances", []), dtype=float) - if folds.size > 0 and np.all(np.abs(folds) <= args.zero_threshold): - zeros += 1 - values[fname] = float(imp) - - s = pd.Series(values) - rank_vals = s.abs() if args.use_abs else s - # Take top-N indices by ranking - top_idx = rank_vals.sort_values(ascending=False).head(args.top_n).index - s_top = s.loc[top_idx] - rank_top = (s_top.abs() if args.use_abs else s_top).sort_values(ascending=False) - s_top = s.loc[rank_top.index] - - # Optional label mapping (JSON) merged with auto-compact labels - auto_map = make_label_map_keep_sensor(s_top.index.tolist()) - label_map: Mapping[str, str] = dict(auto_map) - if args.label_map: - with open(args.label_map, "r") as f: - label_map.update(json.load(f)) - - xlabel = args.xlabel - if xlabel is None: - xlabel = "Coefficient magnitude" if args.use_abs else "Coefficient" - - title_parts = [f"{model_name}"] - if acc_mean is not None: - title_parts.append(f"{metric_name.capitalize()}: {acc_mean:.3f}") - title_parts.append(f"Zeroed features: {zeros}") - title = " — ".join(title_parts) - - fig, ax = plot_bar( - s_top, - labels=s_top.index.tolist(), - label_map=label_map, - top_n=None, # already trimmed - ascending=False, - orientation="vertical", - title=title, - xlabel=xlabel, - cmap="magma", - figsize=(10, 4), - ) - - if args.save: - fig.savefig(args.save, dpi=150, bbox_inches="tight") - if not args.no_show: - plt.show() - plt.close(fig) - - # Scatter plots for top importances if data is available - if args.data and args.target: - os.makedirs( - args.scatter_dir or os.path.dirname(args.save or "") or ".", exist_ok=True - ) - df = load_data( - "tabular", args.data, sheet_name=args.sheet_name, sep=args.csv_sep - ) - if not isinstance(df, pd.DataFrame): - raise RuntimeError("Expected a DataFrame from data loader.") - - if args.target not in df.columns: - raise KeyError(f"Target column '{args.target}' not found in dataset") - - # Determine top positive and negative features by signed mean coefficients - s_signed = pd.Series({k: v.get("mean", 0.0) for k, v in fi.items()}) - # Filter to columns present in df - s_signed = s_signed[[c for c in s_signed.index if c in df.columns]] - pos_feats = s_signed.sort_values(ascending=False).head(2).index.tolist() - neg_feats = s_signed.sort_values(ascending=True).head(2).index.tolist() - # Best positive and best negative - best_pos = pos_feats[0] if pos_feats else None - best_neg = neg_feats[0] if neg_feats else None - - y = df[args.target] - - # Helper to plot a pair if valid - def _scatter_pair(fx: Optional[str], fy: Optional[str], name: str): - if not fx or not fy: - return - if fx not in df.columns or fy not in df.columns: - return - out_path = None - if args.scatter_dir: - out_path = os.path.join(args.scatter_dir, f"scatter_{name}.png") - title = ( - f"{model_name} – {name} (top features)\n" - f"{metric_name.capitalize()}: {acc_mean:.3f}" - if acc_mean is not None - else f"{model_name} – {name} (top features)" - ) - fig_s, ax_s = plot_scatter2d( - df[fx].values, - df[fy].values, - labels=y.values, - title=title, - xlabel=fx, - ylabel=fy, - save=out_path, - ) - if not args.no_show and not args.save: - plt.show() - plt.close(fig_s) - - # (i) top two positive - if len(pos_feats) >= 2: - _scatter_pair(pos_feats[0], pos_feats[1], "top2_positive") - # (ii) top two negative - if len(neg_feats) >= 2: - _scatter_pair(neg_feats[0], neg_feats[1], "top2_negative") - # (iii) top positive and top negative - if best_pos and best_neg: - _scatter_pair(best_pos, best_neg, "top_pos_neg") - else: - # If not provided, inform how to enable scatter - if not args.no_show: - print("Skipping scatter plots: provide --data and --target to enable.") - - -if __name__ == "__main__": - main() diff --git a/scripts/plot_sensor_analysis.py b/scripts/plot_sensor_analysis.py deleted file mode 100644 index 3ce68de..0000000 --- a/scripts/plot_sensor_analysis.py +++ /dev/null @@ -1,366 +0,0 @@ -#!/usr/bin/env python3 -""" -Plot sensor-level results from an aggregated coco_pipe run. - -This script expects that you ran scripts/run_ml.py and produced an aggregated -pickle at results/.pkl containing a dict: - { analysis_id -> { model_name -> result_dict } } - -Assumptions for this visualization: - - The analysis unit is sensor (1 model per sensor using all features). - - Each analysis_id encodes the sensor name (e.g., "..._Fz_...", "sensor-F3", etc.). - - You provide a coordinates file with 2D positions for each sensor. - -Outputs: - - Topomap of sensor accuracies (or chosen metric) across sensors. - - Bar plot of feature importances (with error bars if available) for the best sensor. - -Usage example: - python scripts/plot_sensor_analysis.py \ - --results results/toy_ml_config.pkl \ - --coords coords/sensor_coords.csv \ - --model "Logistic Regression" \ - --metric accuracy \ - --save-topo results/topomap.png \ - --save-bar results/best_sensor_features.png -""" - -import argparse -import json -import os -from typing import Dict, Optional, Sequence - -import matplotlib.pyplot as plt -import pandas as pd - -from coco_pipe.viz import plot_bar, plot_topomap - - -def load_coords(path: str) -> pd.DataFrame: - """Load sensor coordinates from CSV/TSV/TXT/JSON into a DataFrame with index=name - and columns ['x','y'].""" - lower = path.lower() - if lower.endswith((".json", ".jsn")): - with open(path, "r") as f: - data = json.load(f) - # Expect {name: [x,y]} or {name: {"x": x, "y": y}} - rows = [] - for name, val in data.items(): - if isinstance(val, (list, tuple)) and len(val) >= 2: - x, y = float(val[0]), float(val[1]) - elif isinstance(val, dict) and "x" in val and "y" in val: - x, y = float(val["x"]), float(val["y"]) - else: - raise ValueError(f"Invalid JSON coord entry for {name}: {val}") - rows.append((name, x, y)) - df = pd.DataFrame(rows, columns=["name", "x", "y"]).set_index("name") - return df - - # CSV/TSV/TXT: detect separator - df = pd.read_csv(path, sep=None, engine="python") - cols = {c.lower(): c for c in df.columns} - - # Find name column - name_col = None - for cand in ("name", "sensor", "channel", "id"): - if cand in cols: - name_col = cols[cand] - break - if name_col is None: - # if exactly three columns, assume first is name - if df.shape[1] >= 3: - name_col = df.columns[0] - else: - raise ValueError( - "Coordinates file must include a sensor name column (e.g., 'name')." - ) - - # Find x/y columns - x_col = None - y_col = None - for cand in ("x", "xs", "xpos", "x_coord"): - if cand in cols: - x_col = cols[cand] - break - for cand in ("y", "ys", "ypos", "y_coord"): - if cand in cols: - y_col = cols[cand] - break - if x_col is None or y_col is None: - # Try second/third columns - if df.shape[1] >= 3: - x_col = x_col or df.columns[1] - y_col = y_col or df.columns[2] - else: - raise ValueError("Coordinates file must include numeric x,y columns.") - - cdf = df[[name_col, x_col, y_col]].copy() - cdf.columns = ["name", "x", "y"] - cdf = cdf.set_index("name") - # coerce to float - cdf["x"] = pd.to_numeric(cdf["x"], errors="coerce") - cdf["y"] = pd.to_numeric(cdf["y"], errors="coerce") - return cdf.dropna() - - -def generate_coords_from_mne( - montage: str = "standard_1020", restrict_to: Optional[Sequence[str]] = None -) -> pd.DataFrame: - """Generate 2D sensor coordinates from an MNE template montage. - - Parameters - ---------- - montage : str - Name of the standard montage to use (e.g., 'standard_1020', - 'standard_1005', 'biosemi64'). - restrict_to : list of str, optional - If provided, only return coordinates for these channel names. - - Returns - ------- - DataFrame - Index = sensor names, columns ['x','y'] with 2D coordinates derived - from montage. - - Notes - ----- - - This function requires the optional dependency 'mne'. If not installed, - an ImportError is raised with guidance. - - We use the x,y components from the montage's 3D positions as a top-view - projection. - """ - try: - import mne # type: ignore - except Exception as e: - raise ImportError( - "This feature requires 'mne'. Install it via 'pip install mne'." - ) from e - - std_montage = mne.channels.make_standard_montage(montage) - pos = std_montage.get_positions() - ch_pos = pos.get("ch_pos", {}) - rows = [] - if restrict_to is None: - names = list(ch_pos.keys()) - else: - names = list(restrict_to) - for name in names: - if name not in ch_pos: - # Try relaxed matching (upper/capitalize) - alt = None - if name.upper() in ch_pos: - alt = name.upper() - elif name.capitalize() in ch_pos: - alt = name.capitalize() - if alt is None: - continue - xyz = ch_pos[alt] - rows.append((name, float(xyz[0]), float(xyz[1]))) - else: - xyz = ch_pos[name] - rows.append((name, float(xyz[0]), float(xyz[1]))) - if not rows: - # Fall back to all channels in montage - rows = [(n, float(v[0]), float(v[1])) for n, v in ch_pos.items()] - df = pd.DataFrame(rows, columns=["name", "x", "y"]).set_index("name") - return df - - -def pick_model( - results_per_model: Dict[str, dict], preferred: Optional[Sequence[str]] = None -) -> str: - preferred = preferred or ("Logistic Regression",) - for m in preferred: - if m in results_per_model: - return m - return next(iter(results_per_model)) - - -def analysis_to_sensor(analysis_id: str, known_sensors: Sequence[str]) -> Optional[str]: - toks = analysis_id.replace("-", "_").replace(" ", "_").split("_") - known = set(known_sensors) - for t in toks: - if t in known: - return t - if t.upper() in known: - return t.upper() - c = t.capitalize() - if c in known: - return c - return None - - -def main(): - parser = argparse.ArgumentParser( - description="Plot sensor accuracies topomap and best sensor feature " - "importances." - ) - parser.add_argument( - "--results", - required=True, - help="Path to aggregated results pickle (from run_ml.py)", - ) - parser.add_argument( - "--coords", required=False, help="Path to sensor coordinates (CSV/TSV/JSON)" - ) - parser.add_argument( - "--use-mne", - action="store_true", - help="Generate sensor coordinates from MNE standard montage", - ) - parser.add_argument( - "--montage", - default="standard_1020", - help="MNE montage name (default: standard_1020)", - ) - parser.add_argument( - "--model", - default=None, - help="Model name to use (default: try 'Logistic Regression' else first)", - ) - parser.add_argument( - "--metric", - default="accuracy", - help="Metric to plot for sensors (default: accuracy)", - ) - parser.add_argument( - "--top-n", type=int, default=20, help="Top-N features in bar plot (default: 20)" - ) - parser.add_argument( - "--save-topo", default=None, help="Path to save topomap image (optional)" - ) - parser.add_argument( - "--save-bar", default=None, help="Path to save barplot image (optional)" - ) - parser.add_argument( - "--no-show", - action="store_true", - help="Do not open interactive windows; save only", - ) - - args = parser.parse_args() - - # Non-interactive backend if --no-show and saving - if args.no_show: - import matplotlib - - matplotlib.use("Agg") - - if not os.path.exists(args.results): - raise FileNotFoundError(args.results) - all_results: Dict[str, Dict[str, dict]] = pd.read_pickle(args.results) - - # Determine sensor names from results (by parsing IDs) to optionally restrict - # MNE coords. Quick heuristic: collect all tokens from analysis ids that - # look like EEG names (letters+digits+optional z) - candidate_tokens = set() - for aid in all_results.keys(): - toks = aid.replace("-", "_").replace(" ", "_").split("_") - for t in toks: - if ( - len(t) <= 5 - and any(c.isalpha() for c in t) - and any(c.isdigit() for c in t) - ): - candidate_tokens.add(t) - - if args.use_mne or not args.coords: - coords_df = generate_coords_from_mne( - args.montage, - restrict_to=sorted(candidate_tokens) if candidate_tokens else None, - ) - else: - if not os.path.exists(args.coords): - raise FileNotFoundError(args.coords) - coords_df = load_coords(args.coords) - - sensor_names = coords_df.index.tolist() - - sensor_acc: Dict[str, float] = {} - model_name_used: Optional[str] = None - metric_name = args.metric - - for analysis_id, results_per_model in all_results.items(): - sensor = analysis_to_sensor(analysis_id, sensor_names) - if not sensor: - continue - model_name = pick_model( - results_per_model, preferred=(args.model,) if args.model else None - ) - metrics = results_per_model[model_name].get("metric_scores", {}) - if metric_name not in metrics: - # fallback to first metric if requested not found - if metrics: - metric_name = next(iter(metrics.keys())) - else: - continue - mean_val = ( - float(metrics[metric_name]["mean"]) - if isinstance(metrics[metric_name], dict) - else float(metrics[metric_name]) - ) - sensor_acc[sensor] = mean_val - model_name_used = model_name - - if not sensor_acc: - raise RuntimeError( - "No sensor accuracies found. Ensure analysis IDs include sensor names " - "and coords match." - ) - - # Topomap - fig1, ax1 = plot_topomap( - sensor_acc, - coords_df[["x", "y"]], - title=f"Sensor {metric_name.capitalize()} ({model_name_used})", - cbar_label=metric_name.capitalize(), - sensors="markers", - cmap="viridis", - symmetric=False, - ) - if args.save_topo: - fig1.savefig(args.save_topo, dpi=150) - if not args.no_show: - plt.show() - plt.close(fig1) - - # Best sensor feature importances - best_sensor = max(sensor_acc, key=sensor_acc.get) - # find the analysis id for this sensor - best_analysis_id = next( - a for a in all_results if analysis_to_sensor(a, sensor_names) == best_sensor - ) - best_model_results = all_results[best_analysis_id][model_name_used] - fi_dict = best_model_results.get("feature_importances", {}) - - if not fi_dict: - print( - f"No feature_importances available for {best_sensor} / {model_name_used}." - ) - return - - imp_mean = pd.Series( - {k: v.get("mean", 0.0) for k, v in fi_dict.items()} - ).sort_values(ascending=False) - imp_std = pd.Series({k: v.get("std", 0.0) for k, v in fi_dict.items()}).reindex( - imp_mean.index - ) - - fig2, ax2 = plot_bar( - imp_mean, - errors=imp_std, - orientation="horizontal", - top_n=args.top_n, - title=f"{best_sensor} Feature Importance ({model_name_used})", - xlabel="Importance", - cmap="magma", - ) - if args.save_bar: - fig2.savefig(args.save_bar, dpi=150) - if not args.no_show: - plt.show() - plt.close(fig2) - - -if __name__ == "__main__": - main() diff --git a/scripts/run_ml.py b/scripts/run_ml.py deleted file mode 100644 index c496099..0000000 --- a/scripts/run_ml.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import logging -import os -from copy import deepcopy - -import pandas as pd -import yaml - -from coco_pipe.io import TabularDataset -from coco_pipe.ml.pipeline import MLPipeline - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def run_analysis(X, y, analysis_cfg): - """Run a single analysis with the given config, passing through the new `mode`.""" - # scikit-learn pipelines expect numpy arrays - X_arr = X.values if hasattr(X, "values") else X - y_arr = y.values if hasattr(y, "values") else y - # ADD blabla blaq - # Build the MLPipeline config dict - pipeline_config = { - "task": analysis_cfg.get("task"), - "analysis_type": analysis_cfg.get("analysis_type"), - "models": analysis_cfg.get("models"), - "metrics": analysis_cfg.get("metrics"), - "cv_strategy": analysis_cfg.get("cv_kwargs", {}).get("cv_strategy"), - "n_splits": analysis_cfg.get("cv_kwargs", {}).get("n_splits"), - "cv_kwargs": analysis_cfg.get("cv_kwargs"), - "n_features": analysis_cfg.get("n_features"), - "direction": analysis_cfg.get("direction"), - "search_type": analysis_cfg.get("search_type"), - "n_iter": analysis_cfg.get("n_iter"), - "scoring": analysis_cfg.get("scoring"), - "n_jobs": analysis_cfg.get("n_jobs"), - "save_intermediate": analysis_cfg.get("save_intermediate"), - "results_dir": analysis_cfg.get("results_dir"), - "results_file": analysis_cfg.get("results_file"), - # **NEW** univariate vs. multivariate mode - "mode": analysis_cfg.get("mode"), - } - - # strip out any None so pipeline defaults apply - pipeline_config = {k: v for k, v in pipeline_config.items() if v is not None} - - logger.info( - f"Launching {pipeline_config['task']} pipeline " - f"({pipeline_config.get('mode', 'multivariate')}) – " - f"{pipeline_config['analysis_type']} on " - f"{X_arr.shape[0]}×{X_arr.shape[1]} data" - ) - - pipeline = MLPipeline(X=X_arr, y=y_arr, config=pipeline_config) - results = pipeline.run() - - logger.info(f"Analysis {analysis_cfg['id']} completed") - return results - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--config", "-c", required=True, help="YAML file with defaults+analyses" - ) - args = parser.parse_args() - - # 0) Load config & data - # 0) Load data - cfg = yaml.safe_load(open(args.config)) - # Using TabularDataset directly to get full container - # Assuming config provides necessary kwargs for reshaping if applicable - data_path = cfg["data_path"] - load_kwargs = cfg.get("loader_kwargs", {}) - - # Check if we should reshape based on config hint or defaults - # For now, simplistic loading unless config specifies columns_to_dims - ds = TabularDataset(data_path, **load_kwargs) - full_container = ds.load() - - all_results = {} - defaults = cfg.get("defaults", {}) - - for analysis in cfg["analyses"]: - # merge defaults + specific - analysis_cfg = deepcopy(defaults) - analysis_cfg.update(analysis) - - # 1) Select relevant data using DataContainer - # Map generic 'select_features' config to DataContainer.select() - # "spatial_units" -> usually 'channel' dimension - # "feature_names" -> 'feature' dimension - # "covariates" -> handled at load time or aux selection - - selection_query = {} - - if "spatial_units" in analysis_cfg and analysis_cfg["spatial_units"] != "all": - # Assuming 'channel' dim exists if reshaping happened - # If strictly flat 2D but logical spatial units exist, - # user should have loaded with columns_to_dims=['channel', 'feature'] - if "channel" in full_container.dims: - selection_query["channel"] = analysis_cfg["spatial_units"] - - if "feature_names" in analysis_cfg and analysis_cfg["feature_names"] != "all": - if "feature" in full_container.dims: - selection_query["feature"] = analysis_cfg["feature_names"] - - # Row filters (e.g. subjects) - if "row_filter" in analysis_cfg: - # This requires parsing row_filter dicts to container.select queries - # Simplified: assume simple kwargs for now - # row_filter: {'group': 1} - pass - - if "target_columns" in analysis_cfg: - # If target was not set at load time, it might be in X - # But DataContainer should ideally handle y. - # If y IS loaded, we effectively just use it. - # If we need to SWAP target (rare), we'd need to manipulate container. - pass - - # Apply Selection - # If no query, we get the whole container - sub_container = ( - full_container.select(**selection_query) - if selection_query - else full_container - ) - - X = sub_container.X - y = sub_container.y - - logger.info( - f"Analysis {analysis['id']} selected shape {X.shape}, " - f"target available: {y is not None}" - ) - - # 1.5) Concatenate covariates if requested and separated? - # If covariates are in coords, we might need to add them back to X for some - # ML pipelines? - # Or MLPipeline handles coords? For now assume standard X, y. - - # 2) Run - # ensure results_dir/file come from global defaults if not overwritten - analysis_cfg["results_dir"] = cfg.get( - "results_dir", analysis_cfg.get("results_dir") - ) - analysis_cfg["results_file"] = cfg.get( - "results_file", analysis_cfg.get("results_file") - ) - - results = run_analysis(sub_container.X, sub_container.y, analysis_cfg) - all_results[analysis["id"]] = results - - # 3) Save all results - out_path = os.path.join(cfg["results_dir"], f"{cfg['global_experiment_id']}.pkl") - logger.info(f"Saving aggregated results to {out_path}") - pd.to_pickle(all_results, out_path) - - -if __name__ == "__main__": - main() diff --git a/tests/test_decoding_cache.py b/tests/test_decoding_cache.py new file mode 100644 index 0000000..176330b --- /dev/null +++ b/tests/test_decoding_cache.py @@ -0,0 +1,112 @@ +from coco_pipe.decoding._cache import make_feature_cache_key + + +def test_make_feature_cache_key_is_stable(): + train = ["s1", "s2"] + test = ["s3"] + prep = "p1" + backbone = "b1" + meta = {"task": "classify"} + + key1 = make_feature_cache_key(train, test, prep, backbone, meta) + key2 = make_feature_cache_key(train, test, prep, backbone, meta) + + assert key1 == key2 + assert isinstance(key1, str) + assert len(key1) == 64 # SHA256 length + + +def test_make_feature_cache_key_is_order_insensitive_by_default(): + train1 = ["s1", "s2"] + train2 = ["s2", "s1"] + test = ["s3"] + prep = "p1" + backbone = "b1" + + key1 = make_feature_cache_key(train1, test, prep, backbone) + key2 = make_feature_cache_key(train2, test, prep, backbone) + + # Default is now order-insensitive + assert key1 == key2 + + +def test_make_feature_cache_key_can_be_order_sensitive(): + train1 = ["s1", "s2"] + train2 = ["s2", "s1"] + test = ["s3"] + prep = "p1" + backbone = "b1" + + key1 = make_feature_cache_key(train1, test, prep, backbone, sort_ids=False) + key2 = make_feature_cache_key(train2, test, prep, backbone, sort_ids=False) + + # With sort_ids=False, order matters + assert key1 != key2 + + +def test_make_feature_cache_key_is_sensitive_to_fingerprints(): + train = ["s1"] + test = ["s2"] + + key_base = make_feature_cache_key(train, test, "p1", "b1") + key_diff_prep = make_feature_cache_key(train, test, "p2", "b1") + key_diff_back = make_feature_cache_key(train, test, "p1", "b2") + + assert key_base != key_diff_prep + assert key_base != key_diff_back + + +def test_make_feature_cache_key_is_sensitive_to_samples(): + prep = "p1" + backbone = "b1" + + key1 = make_feature_cache_key(["s1"], ["s2"], prep, backbone) + key2 = make_feature_cache_key(["s1", "s2"], ["s3"], prep, backbone) + key3 = make_feature_cache_key(["s1"], ["s3"], prep, backbone) + + assert key1 != key2 + assert key1 != key3 + assert key2 != key3 + + +def test_make_feature_cache_key_handles_extra_metadata(): + train = ["s1"] + test = ["s2"] + prep = "p1" + backbone = "b1" + + key_no_meta = make_feature_cache_key(train, test, prep, backbone) + key_meta1 = make_feature_cache_key(train, test, prep, backbone, {"time": 0}) + key_meta2 = make_feature_cache_key(train, test, prep, backbone, {"time": 1}) + + assert key_no_meta != key_meta1 + assert key_meta1 != key_meta2 + + +def test_make_feature_cache_key_converts_ids_to_strings(): + # Mixing types should work because they are converted to str + train = [1, 2.0, "3"] + test = [4] + + key1 = make_feature_cache_key(train, test, "p", "b") + key2 = make_feature_cache_key(["1", "2.0", "3"], ["4"], "p", "b") + + assert key1 == key2 + + +def test_make_feature_cache_key_extra_metadata_paths(): + """Verify both None and dict paths for coverage.""" + train, test = ["s1"], ["s2"] + p, b = "p", "b" + + # Path 1: None (results in empty dict) + key_none = make_feature_cache_key(train, test, p, b, extra_metadata=None) + + # Path 2: Explicit empty dict + key_empty = make_feature_cache_key(train, test, p, b, extra_metadata={}) + + # Path 3: Non-empty dict + key_data = make_feature_cache_key(train, test, p, b, extra_metadata={"a": 1}) + + assert key_none == key_empty + assert key_none != key_data diff --git a/tests/test_decoding_configs.py b/tests/test_decoding_configs.py new file mode 100644 index 0000000..451ac42 --- /dev/null +++ b/tests/test_decoding_configs.py @@ -0,0 +1,307 @@ +import pytest +from pydantic import ValidationError + +from coco_pipe.decoding import Experiment +from coco_pipe.decoding.configs import ( + AdaBoostClassifierConfig, + AdaBoostRegressorConfig, + ARDRegressionConfig, + BayesianRidgeConfig, + CalibrationConfig, + ConfidenceIntervalConfig, + CVConfig, + DecisionTreeRegressorConfig, + DummyClassifierConfig, + DummyRegressorConfig, + ElasticNetConfig, + ExperimentConfig, + ExtraTreesRegressorConfig, + FeatureSelectionConfig, + GaussianNBConfig, + GradientBoostingClassifierConfig, + GradientBoostingRegressorConfig, + HistGradientBoostingRegressorConfig, + KNeighborsClassifierConfig, + KNeighborsRegressorConfig, + LassoConfig, + LDAConfig, + LinearRegressionConfig, + LinearSVCConfig, + LogisticRegressionConfig, + MLPClassifierConfig, + MLPRegressorConfig, + RandomForestClassifierConfig, + RandomForestRegressorConfig, + RidgeConfig, + SGDClassifierConfig, + SGDRegressorConfig, + StatisticalAssessmentConfig, + SVCConfig, + SVRConfig, + TuningConfig, +) +from coco_pipe.decoding.registry import ( + EstimatorSpec, + get_estimator_cls, + register_estimator, + register_estimator_spec, +) + +ACTIVE_SKLEARN_CONFIGS = [ + LogisticRegressionConfig, + RandomForestClassifierConfig, + SVCConfig, + LinearSVCConfig, + KNeighborsClassifierConfig, + GradientBoostingClassifierConfig, + SGDClassifierConfig, + MLPClassifierConfig, + GaussianNBConfig, + LDAConfig, + AdaBoostClassifierConfig, + DummyClassifierConfig, + LinearRegressionConfig, + RidgeConfig, + LassoConfig, + ElasticNetConfig, + RandomForestRegressorConfig, + SVRConfig, + GradientBoostingRegressorConfig, + SGDRegressorConfig, + MLPRegressorConfig, + DummyRegressorConfig, + DecisionTreeRegressorConfig, + KNeighborsRegressorConfig, + ExtraTreesRegressorConfig, + HistGradientBoostingRegressorConfig, + AdaBoostRegressorConfig, + BayesianRidgeConfig, + ARDRegressionConfig, +] + + +def _experiment_for_instantiation(): + return Experiment( + ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + metrics=["accuracy"], + n_jobs=1, + verbose=False, + ) + ) + + +def test_experiment_config_task_consistency(): + # Valid classification + cfg = ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + metrics=["accuracy", "f1"], + ) + assert cfg.task == "classification" + + # Invalid classification (regression metric) + with pytest.raises(ValidationError, match="Metric 'r2' is for regression"): + ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + metrics=["r2"], + ) + + # Invalid regression (stratified CV) + with pytest.raises( + ValidationError, match="CV strategy 'stratified' is not valid for regression" + ): + ExperimentConfig( + task="regression", + models={"ridge": {"kind": "classical", "method": "Ridge"}}, + cv=CVConfig(strategy="stratified"), + metrics=["r2"], + ) + + +def test_model_discriminator(): + # Test that 'kind' and 'method' correctly resolve the subclass + data = { + "lr": { + "kind": "classical", + "method": "ClassicalModel", + "estimator": "LogisticRegression", + "params": {"C": 0.1}, + } + } + cfg = ExperimentConfig(task="classification", models=data) + assert cfg.models["lr"].kind == "classical" + # ClassicalModelConfig has .params + assert cfg.models["lr"].params["C"] == 0.1 + + +def test_scientific_defaults(): + cfg = ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + ) + assert cfg.cv.n_splits == 5 + assert cfg.evaluation.chance.n_permutations == 1000 + assert cfg.evaluation.confidence_intervals.alpha == 0.05 + + +def test_field_constraints(): + # Negative splits + with pytest.raises(ValidationError): + CVConfig(n_splits=0) + + # Invalid alpha + with pytest.raises(ValidationError): + ConfidenceIntervalConfig(alpha=1.5) + + # Feature selection invalid features + with pytest.raises(ValidationError): + FeatureSelectionConfig( + enabled=True, method="sfs", n_features=0, cv=CVConfig(strategy="stratified") + ) + + # Tuning iterations + with pytest.raises(ValidationError): + TuningConfig(enabled=True, n_iter=0, cv=CVConfig(strategy="stratified")) + + +def test_tuning_metric_consistency(): + # Invalid tuning metric (task mismatch) + with pytest.raises(ValidationError, match="Tuning metric 'r2' is for regression"): + ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + tuning={"enabled": True, "scoring": "r2", "cv": {"strategy": "stratified"}}, + ) + + +def test_statistical_assessment_nesting(): + cfg = StatisticalAssessmentConfig(enabled=True) + assert cfg.chance.n_permutations == 1000 + assert cfg.confidence_intervals.alpha == 0.05 + + # Test custom unit column + cfg_unit = StatisticalAssessmentConfig( + unit_of_inference="custom", custom_unit_column="session" + ) + assert cfg_unit.custom_unit_column == "session" + + +def test_every_active_sklearn_config_method_resolves(): + for config_cls in ACTIVE_SKLEARN_CONFIGS: + config = config_cls() + assert get_estimator_cls(config.method) is not None + + +def test_every_active_sklearn_default_config_instantiates(): + experiment = _experiment_for_instantiation() + for config_cls in ACTIVE_SKLEARN_CONFIGS: + config = config_cls() + estimator = experiment._instantiate_model(config.method, config) + assert estimator is not None + + +def test_experiment_config_forbids_extra_fields(): + with pytest.raises(ValidationError): + ExperimentConfig( + task="classification", + models={"lr": {"kind": "classical", "method": "LogisticRegression"}}, + unexpected=True, + ) + + +def test_removed_deprecated_config_fields_are_rejected(): + with pytest.raises(ValidationError): + LogisticRegressionConfig(multiclass="ovr") + + with pytest.raises(ValidationError): + AdaBoostClassifierConfig(algorithm="SAMME") + + with pytest.raises(ValidationError): + BayesianRidgeConfig(n_iter=10) + + with pytest.raises(ValidationError): + ARDRegressionConfig(n_iter=10) + + +def test_modern_iteration_fields_are_exposed(): + bayes = BayesianRidgeConfig(max_iter=12) + ard = ARDRegressionConfig(max_iter=13) + assert bayes.model_dump()["max_iter"] == 12 + assert ard.model_dump()["max_iter"] == 13 + assert "n_iter" not in bayes.model_dump() + assert "n_iter" not in ard.model_dump() + + +def test_sgd_penalty_accepts_none_not_null_string(): + assert SGDClassifierConfig(penalty=None).penalty is None + with pytest.raises(ValidationError): + SGDClassifierConfig(penalty="null") + + +def test_invalid_constructor_params_are_not_silently_dropped(): + @register_estimator("StrictFakeEstimator") + class StrictFakeEstimator: + def __init__(self, known=1): + self.known = known + + # Register spec for it + spec = EstimatorSpec( + name="StrictFakeEstimator", + import_path="fake", + family="linear", + task=("classification",), + ) + register_estimator_spec(spec) + + class FakeConfig: + method = "StrictFakeEstimator" + kind = "classical" + + def model_dump(self, exclude=None): + data = {"method": self.method, "kind": self.kind, "unknown": 2} + for key in exclude or set(): + data.pop(key, None) + return data + + experiment = _experiment_for_instantiation() + with pytest.raises(ValueError, match="Failed to instantiate model 'fake'"): + experiment._instantiate_model("fake", FakeConfig()) + + +def test_experiment_validation_hardening(): + # calibration for regression + with pytest.raises( + (ValueError, ValidationError), + match="calibration is only available for classification", + ): + ExperimentConfig( + task="regression", + models={"lr": LogisticRegressionConfig()}, + metrics=["neg_mean_squared_error"], + cv=CVConfig(strategy="kfold"), + calibration=CalibrationConfig(enabled=True), + ) + + # stratified for regression + with pytest.raises((ValueError, ValidationError), match="not valid for regression"): + ExperimentConfig( + task="regression", + models={"lr": LogisticRegressionConfig()}, + metrics=["neg_mean_squared_error"], + cv=CVConfig(strategy="stratified"), + ) + + # FS metric mismatch + with pytest.raises( + (ValueError, ValidationError), match="is not defined|incompatible with task" + ): + ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + feature_selection=FeatureSelectionConfig( + enabled=True, scoring="neg_mean_squared_error" + ), + ) diff --git a/tests/test_decoding_diagnostics.py b/tests/test_decoding_diagnostics.py new file mode 100644 index 0000000..3f517b8 --- /dev/null +++ b/tests/test_decoding_diagnostics.py @@ -0,0 +1,273 @@ +import numpy as np +import pandas as pd +import pytest + +from coco_pipe.decoding._diagnostics import ( + confusion_matrix_frame, + curve_score_groups, + optional_values, + paired_unit_indices, + prediction_rows, + proba_matrix, + row_value, + scalar_prediction_frame, + score_frame, + score_rows, + time_value, + unit_indices, +) + + +def test_time_value(): + assert time_value(0, [10, 20]) == 10 + assert time_value(5, [10, 20]) == 5 + assert time_value(0, None) == 0 + + +def test_score_rows(): + assert len(score_rows("m", 0, "a", 0.5)) == 1 + assert len(score_rows("m", 0, "a", [0.5, 0.6])) == 2 + assert len(score_rows("m", 0, "a", [[0.5, 0.6], [0.7, 0.8]])) == 4 + assert len(score_rows("m", 0, "a", np.zeros((2, 2, 2)))) == 1 + + +def test_prediction_rows_all_paths(): + # 1. Standard Binary Proba (1D) + p_bin = {"y_true": [0, 1], "y_pred": [0, 1], "y_proba": np.array([0.2, 0.8])} + r_bin = prediction_rows("m", 0, p_bin) + assert r_bin[0]["y_proba_0"] == 0.8 + assert r_bin[0]["y_proba_1"] == 0.2 + + # 2. Standard Multiclass Proba (2D) + p_multi = { + "y_true": [0, 1], + "y_pred": [0, 1], + "y_proba": np.array([[0.8, 0.2], [0.1, 0.9]]), + } + r_multi = prediction_rows("m", 0, p_multi) + assert r_multi[0]["y_proba_0"] == 0.8 + + # 3. Sliding Proba (3D) + p_sl = { + "y_true": [0, 1], + "y_pred": np.zeros((2, 2)), + "y_proba": np.zeros((2, 2, 2)), + } + r_sl = prediction_rows("m", 0, p_sl) + assert "y_proba_0" in r_sl[0] + + # 4. Generalizing Proba (4D) + p_gen = { + "y_true": [0, 1], + "y_pred": np.zeros((2, 2, 2)), + "y_proba": np.zeros((2, 2, 2, 2)), + } + r_gen = prediction_rows("m", 0, p_gen) + assert "y_proba_0" in r_gen[0] + + # 5. Standard Binary Score (1D) + p_s1 = {"y_true": [0, 1], "y_pred": [0, 1], "y_score": np.array([0.5, 0.8])} + r_s1 = prediction_rows("m", 0, p_s1) + assert r_s1[0]["y_score"] == 0.5 + + # 6. Standard Multiclass Score (2D) + p_sm = { + "y_true": [0, 1], + "y_pred": [0, 1], + "y_score": np.array([[1.0, -1.0], [-1.0, 1.0]]), + } + r_sm = prediction_rows("m", 0, p_sm) + assert r_sm[0]["y_score_0"] == 1.0 + + # 7. Metadata and Groups + p_meta = { + "y_true": [0], + "y_pred": [0], + "group": [1], + "sample_metadata": {"m": [10]}, + } + r_meta = prediction_rows("m", 0, p_meta) + assert r_meta[0]["m"] == 10 + assert r_meta[0]["Group"] == 1 + + +def test_row_value_ndarray(): + assert row_value(np.array([np.array([1])], dtype=object), 0) == [1] + + +def test_optional_values(): + assert optional_values(None, 2).tolist() == [None, None] + assert optional_values([1, 2], 2).tolist() == [1, 2] + + +def test_proba_matrix(): + assert proba_matrix(pd.DataFrame({"y_proba_0": [0.8]}), 2) is None + assert ( + proba_matrix( + pd.DataFrame({"y_proba_0": [0.8, np.nan], "y_proba_1": [0.2, 0.8]}), 2 + ) + is None + ) + assert ( + proba_matrix(pd.DataFrame({"y_proba_0": [0.8], "y_proba_1": [0.2]}), 2) + is not None + ) + + +def test_unit_indices(): + df = pd.DataFrame( + { + "SampleID": [1, 2], + "Group": [1, 2], + "Subject": [1, 2], + "Session": [1, 2], + "Site": [1, 2], + } + ) + for u in ["sample", "epoch", "group", "subject", "session", "site"]: + assert len(unit_indices(df, u)) == 2 + with pytest.raises(ValueError): + unit_indices(df, "unknown") + df_err = pd.DataFrame({"SampleID": [1], "Site": [np.nan]}) + with pytest.raises(ValueError): + unit_indices(df_err, "site") + + +def test_paired_unit_indices(): + df = pd.DataFrame( + { + "SampleID": [1, 1], + "Group_A": [1, 2], + "Group_B": [1, 2], + "Subject_A": [1, 2], + "Subject_B": [1, 2], + } + ) + assert len(paired_unit_indices(df, "group")) == 2 + assert len(paired_unit_indices(df, "subject")) == 2 + + +def test_score_frame_all(): + df_l = pd.DataFrame({"y_true": [0, 1], "y_pred": [0, 1]}) + assert score_frame(df_l, "accuracy") == 1.0 + df_p = pd.DataFrame( + {"y_true": [0, 1], "y_proba_0": [0.8, 0.2], "y_proba_1": [0.2, 0.8]} + ) + assert score_frame(df_p, "roc_auc") == 1.0 + assert score_frame(df_p, "brier_score") < 1.0 + df_s = pd.DataFrame({"y_true": [0, 1], "y_score": [0.5, 0.8]}) + assert score_frame(df_s, "roc_auc") == 1.0 + df_m = pd.DataFrame( + { + "y_true": [0, 1, 2], + "y_proba_0": [0.8, 0.1, 0.1], + "y_proba_1": [0.1, 0.8, 0.1], + "y_proba_2": [0.1, 0.1, 0.8], + } + ) + assert score_frame(df_m, "roc_auc") == 1.0 + assert score_frame(df_m, "average_precision") == 1.0 + with pytest.raises(ValueError): + score_frame(df_l, "roc_auc") + + +def test_scalar_prediction_frame(): + df = pd.DataFrame({"Time": [0.1, np.nan]}) + assert len(scalar_prediction_frame(df)) == 1 + assert scalar_prediction_frame(pd.DataFrame()).empty + + +def test_confusion_matrix_frame(): + df = pd.DataFrame( + {"Model": ["m", "m"], "Fold": [0, 0], "y_true": [0, 1], "y_pred": [0, 1]} + ) + assert not confusion_matrix_frame(df, [0, 1]).empty + assert confusion_matrix_frame(pd.DataFrame(columns=["Model", "Fold"]), [0, 1]).empty + assert "Model" in confusion_matrix_frame(df, [0, 1], group_cols=["Model"]).columns + + +def test_curve_score_groups(): + df = pd.DataFrame( + { + "Model": ["m", "m"], + "Fold": [0, 0], + "y_true": [0, 1], + "y_proba_0": [0.5, 0.5], + "y_proba_1": [0.5, 0.5], + } + ) + assert len(list(curve_score_groups(df))) == 1 + assert len(list(curve_score_groups(df, model="m2"))) == 0 + df_s0 = pd.DataFrame( + {"Model": ["m", "m"], "Fold": [0, 0], "y_true": [0, 1], "y_score": [0.5, 0.8]} + ) + assert len(list(curve_score_groups(df_s0, pos_label=0))) == 1 + df_m = pd.DataFrame( + { + "Model": ["m"] * 3, + "Fold": [0] * 3, + "y_true": [0, 1, 2], + "y_proba_0": [0.8, 0.1, 0.1], + "y_proba_1": [0.1, 0.8, 0.1], + "y_proba_2": [0.1, 0.1, 0.8], + } + ) + assert len(list(curve_score_groups(df_m))) == 3 + + +def test_paired_unit_indices_exhaustive(): + df = pd.DataFrame( + { + "SampleID": [1], + "Group_A": [1], + "Subject_A": [1], + "Session_A": [1], + "Site_A": [1], + "Group_B": [1], + "Subject_B": [1], + "Session_B": [1], + "Site_B": [1], + } + ) + for u in ["sample", "epoch", "group", "subject", "session", "site"]: + assert len(paired_unit_indices(df, u)) == 1 + + +def test_score_frame_error_paths(): + df = pd.DataFrame({"y_true": [0, 1], "y_pred": [0, 1]}) + with pytest.raises(ValueError, match="cannot be scored"): + score_frame(df, "roc_auc") + + # Brier score multiclass error + df_m = pd.DataFrame( + { + "y_true": [0, 1, 2], + "y_proba_0": [0.3] * 3, + "y_proba_1": [0.3] * 3, + "y_proba_2": [0.4] * 3, + } + ) + with pytest.raises(ValueError, match="binary classification only"): + score_frame(df_m, "brier_score") + + +def test_prediction_rows_temporal_multiclass(): + # Sliding Multiclass + p_sl = { + "y_true": [0], + "y_pred": np.zeros((1, 2)), + "y_proba": np.zeros((1, 2, 3)), # (samples, times, classes) + } + r_sl = prediction_rows("m", 0, p_sl) + assert len(r_sl) == 2 + assert all(f"y_proba_{c}" in r_sl[0] for c in range(3)) + + # Generalizing Multiclass + p_gen = { + "y_true": [0], + "y_pred": np.zeros((1, 2, 2)), + "y_proba": np.zeros((1, 2, 2, 3)), # (samples, tr, te, classes) + } + r_gen = prediction_rows("m", 0, p_gen) + assert len(r_gen) == 4 + assert all(f"y_proba_{c}" in r_gen[0] for c in range(3)) diff --git a/tests/test_decoding_engine.py b/tests/test_decoding_engine.py new file mode 100644 index 0000000..cf17d93 --- /dev/null +++ b/tests/test_decoding_engine.py @@ -0,0 +1,324 @@ +from types import SimpleNamespace + +import numpy as np +from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.datasets import make_classification +from sklearn.pipeline import Pipeline + +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding._engine import ( + compact_search_results, + compute_metric_safe, + extract_feature_importances, + extract_metadata, + fit_and_score_fold, + fit_estimator, + metadata_slice, + warning_records_to_dict, +) +from coco_pipe.decoding.configs import ( + CalibrationConfig, + CVConfig, + LinearSVCConfig, +) +from coco_pipe.decoding.interfaces import NeuralTrainable + +# --- Mock Objects --- + + +class MockConfig: + def __init__(self, **kwargs): + self.enabled = kwargs.get("enabled", False) + self.method = kwargs.get("method", "none") + self.cv = kwargs.get("cv", SimpleNamespace(strategy="group_kfold", n_splits=2)) + self.n_splits = kwargs.get("n_splits", 2) + + +class MockEstimator(BaseEstimator, ClassifierMixin): + _estimator_type = "classifier" + + def __init__(self, **kwargs): + self._estimator_type = "classifier" + for k, v in kwargs.items(): + setattr(self, k, v) + # Defaults for coverage + if not hasattr(self, "coef_"): + self.coef_ = np.zeros(2) + if not hasattr(self, "feature_importances_"): + self.feature_importances_ = np.zeros(2) + if not hasattr(self, "best_estimator_"): + self.best_estimator_ = self + if not hasattr(self, "best_params_"): + self.best_params_ = {} + if not hasattr(self, "best_score_"): + self.best_score_ = 0.9 + if not hasattr(self, "best_index_"): + self.best_index_ = 0 + if not hasattr(self, "cv_results_"): + self.cv_results_ = { + "params": [{}], + "mean_test_score": [0.9], + "rank_test_score": [1], + "std_test_score": [0.1], + } + if not hasattr(self, "classes_"): + self.classes_ = np.array([0, 1]) + + def fit(self, X, y=None, **kwargs): + self.fit_kwargs = kwargs + return self + + def predict(self, X): + return getattr(self, "y_pred_val", np.zeros(len(X))) + + def predict_proba(self, X): + return getattr(self, "y_proba_val", np.zeros((len(X), 2))) + + def decision_function(self, X): + return getattr(self, "y_score_val", np.zeros(len(X))) + + def get_support(self): + return getattr(self, "support_", np.array([True, True])) + + +# --- Tests --- + + +def test_diagnostics_basics(): + meta = {"a": np.array([10, 20, 30])} + assert metadata_slice(meta, np.array([0, 2])) == {"a": [10, 30]} + assert metadata_slice(None, [0]) is None + record = SimpleNamespace(category=UserWarning, message="test") + assert len(warning_records_to_dict("fit", [record])) == 1 + + +def test_importance_extraction_comprehensive(): + spec = SimpleNamespace(importance=("coefficients",), is_sparse_capable=True) + clf = MockEstimator(coef_=np.array([1.0, 2.0])) + assert np.allclose(extract_feature_importances(clf, spec), [1.0, 2.0]) + + # Missing attribute (now safe thanks to hardening) + assert extract_feature_importances(BaseEstimator(), spec) is None + + # Pipeline + FS + fs = MockEstimator(support_=np.array([True, False])) + pipe = Pipeline([("fs", fs), ("clf", MockEstimator(coef_=np.array([5.0])))]) + assert np.allclose( + extract_feature_importances(pipe, spec, fs_enabled=True), [5.0, 0.0] + ) + + +def test_compute_metric_safe_variants(): + def scorer(yt, yp, **kw): + return yp.mean() + + y_true = np.array([0, 1]) + assert np.isnan(compute_metric_safe(scorer, y_true, None, False)) + + # Sliding + y_sl = np.zeros((2, 5)) + assert compute_metric_safe(scorer, y_true, y_sl, False).shape == (5,) + + # Generalizing + y_gen = np.zeros((2, 3, 3)) + assert compute_metric_safe(scorer, y_true, y_gen, False).shape == (3, 3) + + +def test_extract_metadata_exhaustive(): + # Search enabled + search = MockEstimator() + meta = extract_metadata(search, None, MockConfig(), search_enabled=True) + assert "search_results" in meta + + # FS with ranking + fs = MockEstimator(ranking_=np.array([1, 2])) + pipe = Pipeline([("fs", fs), ("clf", MockEstimator())]) + meta_fs = extract_metadata(pipe, None, MockConfig(enabled=True, method="sfs")) + assert "selection_order" in meta_fs + + +def test_fit_and_score_fold_response_logic(): + spec = SimpleNamespace( + supports_proba=True, + supports_decision_function=True, + importance=("coefficients",), + supports_groups=True, + grouped_metadata="none", + is_sparse_capable=False, + family="linear", + ) + X, y = np.zeros((4, 2)), np.array([0, 0, 1, 1]) + ids = np.array(["a", "b", "c", "d"]) + + import coco_pipe.decoding._engine as engine + from coco_pipe.decoding._metrics import MetricSpec + + old_get = engine.get_metric_spec + + try: + # Use positional arguments for MetricSpec to be safe + # MetricSpec(name, task, scorer, response_method) + engine.get_metric_spec = lambda m: MetricSpec( + m, "classification", lambda yt, yp: yp.mean(), "predict" + ) + res1 = fit_and_score_fold( + MockEstimator(), + X, + y, + None, + ids, + None, + train_idx=np.array([0, 2]), + test_idx=np.array([1, 3]), + metrics=["m1"], + feature_selection_config=MockConfig(), + calibration_config=MockConfig(), + spec=spec, + ) + assert "m1" in res1["scores"] + + # Proba missing path + engine.get_metric_spec = lambda m: MetricSpec( + m, "classification", lambda yt, yp: yp.mean(), "proba_or_score" + ) + res2 = fit_and_score_fold( + MockEstimator(y_score_val=np.zeros(2)), + X, + y, + None, + ids, + None, + train_idx=np.array([0, 2]), + test_idx=np.array([1, 3]), + metrics=["m2"], + feature_selection_config=MockConfig(), + calibration_config=MockConfig(), + spec=SimpleNamespace(**{**spec.__dict__, "supports_proba": False}), + ) + assert "y_score" in res2["preds"] + finally: + engine.get_metric_spec = old_get + + +def test_fit_estimator_complex(): + from sklearn.calibration import CalibratedClassifierCV + from sklearn.linear_model import LogisticRegression + + X, y = make_classification( + n_samples=20, + n_features=2, + n_informative=2, + n_redundant=0, + n_repeated=0, + random_state=42, + ) + groups = np.repeat(np.arange(10), 2) + fit_estimator( + CalibratedClassifierCV(LogisticRegression(), cv=2), + X, + y, + groups, + MockConfig(), + MockConfig( + enabled=True, cv=SimpleNamespace(strategy="group_kfold", n_splits=2) + ), + ) + + +def test_calibration_integration(): + config = ExperimentConfig( + task="classification", + models={"svm": LinearSVCConfig(max_iter=500, kind="classical")}, + metrics=["log_loss"], + cv=CVConfig(strategy="stratified", n_splits=2), + calibration=CalibrationConfig( + enabled=True, + method="sigmoid", + cv=CVConfig(strategy="stratified", n_splits=2), + ), + n_jobs=1, + verbose=False, + ) + estimator = Experiment(config)._prepare_estimator("svm", config.models["svm"]) + assert estimator.__class__.__name__ == "CalibratedClassifierCV" + + +def test_fit_estimator_calibration_group_cv(): + from sklearn.calibration import CalibratedClassifierCV + from sklearn.linear_model import LogisticRegression + from sklearn.model_selection import GroupKFold + + X, y = make_classification( + n_samples=10, + n_features=2, + n_informative=2, + n_redundant=0, + n_repeated=0, + random_state=42, + ) + groups = np.repeat([0, 1], 5) + + cal_cfg = SimpleNamespace(cv=SimpleNamespace(strategy="group_kfold", n_splits=2)) + cal = CalibratedClassifierCV(LogisticRegression(), cv=GroupKFold(n_splits=2)) + + fit_estimator(cal, X, y, groups, MockConfig(), cal_cfg) + + from coco_pipe.decoding._engine import _CVWithGroups + + assert isinstance(cal.cv, _CVWithGroups) + + +def test_importance_extraction_calibration_averaging(): + from sklearn.calibration import CalibratedClassifierCV + from sklearn.linear_model import LogisticRegression + + X, y = make_classification( + n_samples=10, + n_features=2, + n_informative=2, + n_redundant=0, + n_repeated=0, + random_state=42, + ) + cal = CalibratedClassifierCV(LogisticRegression(), cv=2).fit(X, y) + + spec = SimpleNamespace(importance=("coefficients",), is_sparse_capable=True) + + # We need to ensure calibrated_classifiers_ is present + assert hasattr(cal, "calibrated_classifiers_") + + # This should call the recursive path in extract_feature_importances + imp = extract_feature_importances(cal, spec, calibration_enabled=True) + assert imp.shape == (2,) + + +def test_compute_metric_safe_2d_generalizing(): + def scorer(yt, yp, **kw): + return np.mean((yt - yp) ** 2) + + y_true = np.array([0, 1]) + y_gen = np.zeros((2, 2, 2)) # (n_samples, n_tr, n_te) + y_gen[1, :, :] = 1.0 # Perfect predictions for y_true=1 + + score = compute_metric_safe(scorer, y_true, y_gen, False, name="mse") + assert score.shape == (2, 2) + assert np.all(score == 0.0) + + +def test_extract_metadata_neural(): + class MockNeural(MockEstimator, NeuralTrainable): + def get_artifact_metadata(self): + return {"weight_norm": 1.0} + + def get_train_stage(self): + return "final" + + est = MockNeural() + meta = extract_metadata(est, None, MockConfig()) + assert meta["artifacts"] == {"weight_norm": 1.0} + + +def test_compact_search_results_missing_keys(): + est = SimpleNamespace(cv_results_={"params": [{"C": 1}]}) + res = compact_search_results(est) + assert res == [{"candidate": 0, "params": {"C": 1}}] diff --git a/tests/test_decoding_experiment.py b/tests/test_decoding_experiment.py new file mode 100644 index 0000000..efcbd5e --- /dev/null +++ b/tests/test_decoding_experiment.py @@ -0,0 +1,698 @@ +import warnings +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression + +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding.configs import ( + AdaBoostClassifierConfig, + AdaBoostRegressorConfig, + ARDRegressionConfig, + BayesianRidgeConfig, + CalibrationConfig, + ClassicalModelConfig, + CVConfig, + DecisionTreeRegressorConfig, + DummyClassifierConfig, + DummyRegressorConfig, + ElasticNetConfig, + ExtraTreesRegressorConfig, + FeatureSelectionConfig, + FoundationEmbeddingModelConfig, + GaussianNBConfig, + GradientBoostingClassifierConfig, + GradientBoostingRegressorConfig, + HistGradientBoostingClassifierConfig, + HistGradientBoostingRegressorConfig, + KNeighborsClassifierConfig, + KNeighborsRegressorConfig, + LassoConfig, + LDAConfig, + LinearRegressionConfig, + LogisticRegressionConfig, + MLPClassifierConfig, + MLPRegressorConfig, + RandomForestClassifierConfig, + RandomForestRegressorConfig, + RidgeConfig, + SGDClassifierConfig, + SGDRegressorConfig, + StatisticalAssessmentConfig, + SVCConfig, + SVRConfig, + TemporalDecoderConfig, + TuningConfig, +) +from coco_pipe.decoding.registry import ( + get_estimator_spec, +) + +# ============================================================================= +# FIXTURES & DATA GENERATORS +# ============================================================================= + + +def _classification_data(n_samples=40, n_features=6): + rng = np.random.default_rng(10) + y = np.tile([0, 1], n_samples // 2) + X = rng.normal(size=(len(y), n_features)) + X[:, 0] += y * 2.0 + X[:, 1] -= y * 1.0 + return X, y + + +def _regression_data(n_samples=28, n_features=5): + rng = np.random.default_rng(11) + X = rng.normal(size=(n_samples, n_features)) + y = X[:, 0] * 1.5 - X[:, 1] * 0.5 + rng.normal(scale=0.05, size=X.shape[0]) + return X, y + + +def _temporal_data(n_samples=12, n_channels=3, n_times=4): + rng = np.random.default_rng(42) + y = np.array([0, 1] * (n_samples // 2)) + X = rng.normal(scale=0.1, size=(n_samples, n_channels, n_times)) + X[y == 1, 0, :] += 1.0 + X[y == 0, 0, :] -= 1.0 + return X, y + + +# ============================================================================= +# SMOKE TEST CONFIGURATIONS +# ============================================================================= + +CLASSIFIER_SMOKE_CONFIGS = { + "DummyClassifier": DummyClassifierConfig(strategy="prior"), + "LogisticRegression": LogisticRegressionConfig(solver="liblinear", max_iter=200), + "LinearSVC": SVCConfig(kernel="linear", max_iter=500), # LinearSVC alias + "LinearDiscriminantAnalysis": LDAConfig(), + "RandomForestClassifier": RandomForestClassifierConfig( + n_estimators=5, max_depth=3, n_jobs=1 + ), + "SVC": SVCConfig(kernel="linear", probability=True, max_iter=500), + "KNeighborsClassifier": KNeighborsClassifierConfig(n_neighbors=3), + "GradientBoostingClassifier": GradientBoostingClassifierConfig(n_estimators=5), + "SGDClassifier": SGDClassifierConfig(max_iter=500), + "MLPClassifier": MLPClassifierConfig(hidden_layer_sizes=(4,), max_iter=50), + "GaussianNB": GaussianNBConfig(), + "AdaBoostClassifier": AdaBoostClassifierConfig(n_estimators=5), + "ExtraTreesClassifier": RandomForestClassifierConfig( + n_estimators=5, kind="classical" + ), + "HistGradientBoostingClassifier": HistGradientBoostingClassifierConfig(max_iter=5), +} + +REGRESSOR_SMOKE_CONFIGS = { + "DummyRegressor": DummyRegressorConfig(strategy="mean"), + "Ridge": RidgeConfig(), + "RandomForestRegressor": RandomForestRegressorConfig( + n_estimators=5, max_depth=3, n_jobs=1 + ), + "LinearRegression": LinearRegressionConfig(), + "Lasso": LassoConfig(max_iter=500), + "ElasticNet": ElasticNetConfig(max_iter=500), + "SVR": SVRConfig(kernel="linear"), + "GradientBoostingRegressor": GradientBoostingRegressorConfig(n_estimators=5), + "SGDRegressor": SGDRegressorConfig(max_iter=500), + "MLPRegressor": MLPRegressorConfig(hidden_layer_sizes=(4,), max_iter=50), + "DecisionTreeRegressor": DecisionTreeRegressorConfig(max_depth=3), + "KNeighborsRegressor": KNeighborsRegressorConfig(n_neighbors=3), + "ExtraTreesRegressor": ExtraTreesRegressorConfig( + n_estimators=5, max_depth=3, n_jobs=1 + ), + "HistGradientBoostingRegressor": HistGradientBoostingRegressorConfig( + max_iter=5, min_samples_leaf=2 + ), + "AdaBoostRegressor": AdaBoostRegressorConfig(n_estimators=5), + "BayesianRidge": BayesianRidgeConfig(), + "ARDRegression": ARDRegressionConfig(), +} + +SMOKE_CONFIGS = {**CLASSIFIER_SMOKE_CONFIGS, **REGRESSOR_SMOKE_CONFIGS} + + +@pytest.mark.parametrize("method", sorted(SMOKE_CONFIGS)) +def test_registered_estimator_survives_fit_predict_and_declared_responses(method): + config = SMOKE_CONFIGS[method] + spec = get_estimator_spec(method) + is_clf = "classification" in spec.task + X, y = _classification_data() if is_clf else _regression_data() + X_test = X[:5] + + exp = Experiment( + ExperimentConfig( + task="classification" if is_clf else "regression", + models={method: config}, + metrics=["accuracy" if is_clf else "r2"], + cv=CVConfig(strategy="stratified" if is_clf else "kfold", n_splits=2), + verbose=False, + ) + ) + estimator = exp._prepare_estimator(method, config) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + estimator.fit(X, y) + y_pred = estimator.predict(X_test) + assert y_pred.shape[0] == X_test.shape[0] + if spec.supports_proba: + assert estimator.predict_proba(X_test).shape[0] == X_test.shape[0] + if spec.supports_decision_function: + assert estimator.decision_function(X_test).shape[0] == X_test.shape[0] + + +# ============================================================================= +# SCIENTIFIC VALIDITY & LEAKAGE GUARDS +# ============================================================================= + + +def test_grouped_outer_cv_respects_boundaries(): + X, y = _classification_data(n_samples=20) + groups = np.repeat(np.arange(5), 4) # 5 subjects, 4 trials each + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + cv=CVConfig(strategy="group_kfold", n_splits=5), + verbose=False, + n_jobs=1, + ) + result = Experiment(config).run(X, y, groups=groups) + assert result.raw["lr"]["status"] == "success" + # Verify no subject overlap across any fold + for split_meta in result.raw["lr"]["splits"]: + train_idx = split_meta["train_idx"] + test_idx = split_meta["test_idx"] + assert set(groups[train_idx]).isdisjoint(set(groups[test_idx])) + + +def test_group_leakage_guard_raises_error(): + X, y = _classification_data() + groups = np.repeat([0, 1, 2, 3], len(y) // 4) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + cv=CVConfig(strategy="group_kfold", n_splits=2), + tuning=TuningConfig( + enabled=True, cv=CVConfig(strategy="stratified", n_splits=2) + ), + grids={"lr": {"C": [0.1, 1.0]}}, + ) + with pytest.raises( + ValueError, + match=( + "Outer CV strategy is group-based, but tuning.cv strategy " + "'stratified' is not" + ), + ): + Experiment(config).run(X, y, groups=groups) + + +# ============================================================================= +# TEMPORAL DECODING (Comprehensive) +# ============================================================================= + + +def test_sliding_and_generalizing_estimators_full_workflow(): + pytest.importorskip("mne") + X, y = _temporal_data() + times = np.array([-0.1, 0.0, 0.1, 0.2]) + # Use real instances to avoid Pydantic issues + base_cfg = ClassicalModelConfig(estimator="LogisticRegression") + config = ExperimentConfig( + task="classification", + models={ + "sliding": TemporalDecoderConfig(wrapper="sliding", base=base_cfg), + "generalizing": TemporalDecoderConfig( + wrapper="generalizing", base=base_cfg + ), + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + ) + result = Experiment(config).run(X, y, time_axis=times) + assert "Time" in result.get_predictions().columns + # Specify model name to get the matrix (4, 4), otherwise it returns long format + assert result.get_generalization_matrix( + "generalizing", metric="accuracy" + ).shape == (4, 4) + + +# ============================================================================= +# FEATURE SELECTION & TUNING +# ============================================================================= + + +def test_sfs_with_tuning_and_groups_routing(): + X, y = _classification_data(n_samples=24) + groups = np.repeat(np.arange(6), 4) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + cv=CVConfig(strategy="group_kfold", n_splits=2), + feature_selection=FeatureSelectionConfig( + enabled=True, + method="sfs", + n_features=2, + cv=CVConfig(strategy="group_kfold", n_splits=2), + ), + tuning=TuningConfig( + enabled=True, n_iter=2, cv=CVConfig(strategy="group_kfold", n_splits=2) + ), + grids={"lr": {"clf__C": [0.1, 1.0]}}, + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y, groups=groups) + assert result.raw["lr"]["status"] == "success" + assert len(result.get_selected_features()["FeatureName"].unique()) == 6 + + +# ============================================================================= +# EDGE CASES & COVERAGE GAPS +# ============================================================================= + + +def test_experiment_config_validation_errors(): + # 1. Metric mismatch + cfg = ExperimentConfig.model_construct( + task="regression", + models={"dummy": DummyRegressorConfig()}, + metrics=["accuracy"], + cv=CVConfig(strategy="kfold"), + tuning=TuningConfig(), + feature_selection=FeatureSelectionConfig(), + calibration=CalibrationConfig(), + evaluation=StatisticalAssessmentConfig(), + ) + with pytest.raises( + ValueError, match="is for classification but experiment task is regression" + ): + Experiment(cfg) + + # 2. Calibration for regression + cfg = ExperimentConfig.model_construct( + task="regression", + models={"dummy": DummyRegressorConfig()}, + calibration=CalibrationConfig(enabled=True), + metrics=["r2"], + cv=CVConfig(strategy="kfold"), + tuning=TuningConfig(), + feature_selection=FeatureSelectionConfig(), + evaluation=StatisticalAssessmentConfig(), + ) + with pytest.raises( + ValueError, match="calibration is only available for classification" + ): + Experiment(cfg) + + # 3. Stratified regression + cfg = ExperimentConfig.model_construct( + task="regression", + models={"dummy": DummyRegressorConfig()}, + cv=CVConfig(strategy="stratified"), + metrics=["r2"], + tuning=TuningConfig(), + feature_selection=FeatureSelectionConfig(), + calibration=CalibrationConfig(), + evaluation=StatisticalAssessmentConfig(), + ) + with pytest.raises(ValueError, match="invalid for regression"): + Experiment(cfg) + + +def test_resolve_metadata_and_groups_mismatch(): + X, y = _classification_data(n_samples=10) + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + # 1. Length mismatch + with pytest.raises(ValueError, match="sample_metadata length mismatch"): + exp._resolve_metadata_and_groups(10, pd.DataFrame({"a": [1]}), None) + + # 2. Missing Subject/Session (with capitalized error message check) + with pytest.raises(ValueError, match="must include Subject and Session"): + exp._resolve_metadata_and_groups(10, pd.DataFrame({"a": range(10)}), None) + + # 3. Missing group_key (with Subject/Session provided) + with pytest.raises(ValueError, match="group_key 'missing' not found"): + exp.config.cv.group_key = "missing" + exp._resolve_metadata_and_groups( + 10, pd.DataFrame({"Subject": range(10), "Session": range(10)}), None + ) + + +def test_feature_names_alignment(): + X, y = _classification_data(n_samples=10, n_features=2) + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + with pytest.raises(ValueError, match="feature_names length mismatch"): + exp._resolve_feature_names(X, ["only_one"]) + + +def test_degenerate_fold_integrity(): + from coco_pipe.decoding.experiment import _validate_fold_integrity + + with pytest.raises(ValueError, match="Empty fold"): + _validate_fold_integrity(np.array([]), np.array([1]), ("classification",)) + with pytest.raises(ValueError, match="Degenerate Test Fold"): + _validate_fold_integrity( + np.array([0, 1]), np.array([1, 1]), ("classification",) + ) + + +def test_random_state_propagation_none(): + config = ExperimentConfig( + task="classification", + random_state=None, + models={"lr": LogisticRegressionConfig()}, + ) + assert config.cv.random_state == 42 + + +def test_instantiate_foundation_model_mock(): + # Use a dictionary to bypass Pydantic literal restrictions for the mock + mock_model = { + "kind": "foundation_embedding", + "provider": "reve", + "model_name": "dummy", + "checkpoint": None, + } + + config = ExperimentConfig.model_construct( + task="classification", + models={"reve": mock_model}, + metrics=["accuracy"], + cv=CVConfig(), + tuning=TuningConfig(), + feature_selection=FeatureSelectionConfig(), + calibration=CalibrationConfig(), + evaluation=StatisticalAssessmentConfig(), + verbose=False, + ) + exp = Experiment(config) + + pytest.importorskip("torch") + from coco_pipe.decoding.fm_hub import REVEModel + + with patch( + "coco_pipe.decoding.experiment.Experiment._instantiate_model" + ) as mock_inst: + mock_inst.return_value = MagicMock(spec=REVEModel) + # Should NOT raise spec error anymore + est = exp._prepare_estimator("reve", mock_model) + assert est is not None + + +def test_wrap_with_tuning_grid_invalid_key(): + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + tuning=TuningConfig(enabled=True, cv=CVConfig(strategy="stratified")), + grids={"lr": {"invalid_key": [1, 2]}}, + ) + exp = Experiment(config) + with pytest.raises(ValueError, match="Invalid tuning keys"): + exp._prepare_estimator("lr", config.models["lr"]) + + +def test_build_result_meta_time_axis_mismatch(): + X, y = _classification_data(n_samples=10, n_features=5) + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + X3 = np.zeros((10, 2, 5)) + with pytest.raises(ValueError, match="time_axis length mismatch"): + exp.run(X3, y, time_axis=np.arange(10)) + + +def test_importance_aggregation_shape_mismatch_recovery(): + X, y = _classification_data() + Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + valid_imps = [np.array([1, 2]), np.array([1, 2, 3])] + assert not all(imp.shape == valid_imps[0].shape for imp in valid_imps) + + +def test_observation_level_validation(): + X, y = _classification_data() + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + with pytest.raises( + ValueError, match="observation_level must be 'sample' or 'epoch'" + ): + exp.run(X, y, observation_level="invalid") + + +# ============================================================================= +# REPRODUCIBILITY & GROUPED CV SCIENTIFIC VALIDITY +# ============================================================================= + + +def test_grouped_cv_requires_at_least_two_groups(): + X, y = _classification_data(n_samples=10) + groups = np.zeros(10) # All same group + # Force strategy to be in GROUP_CV_STRATEGIES + cv_cfg = CVConfig() + cv_cfg.__dict__["strategy"] = "group_kfold" + + config = ExperimentConfig.model_construct( + task="classification", + models={"lr": LogisticRegressionConfig()}, + cv=cv_cfg, + metrics=["accuracy"], + tuning=TuningConfig(), + feature_selection=FeatureSelectionConfig(), + calibration=CalibrationConfig(), + evaluation=StatisticalAssessmentConfig(), + verbose=False, + ) + # The guard should raise BEFORE sklearn. + with pytest.raises( + ValueError, match="Grouped CV requires at least 2 unique groups" + ): + Experiment(config).run(X, y, groups=groups) + + +def test_inject_seed_recursion_depth(): + from types import SimpleNamespace + + cfg = SimpleNamespace(base=SimpleNamespace(random_state=None)) + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + exp._inject_seed(cfg, 123) + assert cfg.base.random_state == 123 + + +# ============================================================================= +# PIPELINE COMBINATIONS (Tuning, SFS, Calibration) +# ============================================================================= + + +def test_tuning_only_workflow(): + X, y = _classification_data(n_samples=20) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + tuning=TuningConfig( + enabled=True, n_iter=2, cv=CVConfig(strategy="stratified", n_splits=2) + ), + grids={"lr": {"clf__C": [0.1, 1.0]}}, + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y) + assert result.raw["lr"]["status"] == "success" + + +def test_calibration_only_workflow(): + X, y = _classification_data(n_samples=20) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + calibration=CalibrationConfig( + enabled=True, + method="sigmoid", + cv=CVConfig(strategy="stratified", n_splits=2), + ), + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y) + assert result.raw["lr"]["status"] == "success" + + +def test_sfs_tuning_and_calibration_combined(): + X, y = _classification_data(n_samples=30) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + feature_selection=FeatureSelectionConfig( + enabled=True, + method="sfs", + n_features=2, + cv=CVConfig(strategy="stratified", n_splits=2), + ), + tuning=TuningConfig( + enabled=True, n_iter=2, cv=CVConfig(strategy="stratified", n_splits=2) + ), + calibration=CalibrationConfig( + enabled=True, cv=CVConfig(strategy="stratified", n_splits=2) + ), + grids={"lr": {"clf__C": [0.1, 1.0]}}, + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y) + assert result.raw["lr"]["status"] == "success" + # Ensure probabilities are still produced after SFS and Tuning + assert "y_proba_1" in result.get_predictions().columns + + +def test_tuning_with_calibration(): + X, y = _classification_data(n_samples=20) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + tuning=TuningConfig( + enabled=True, n_iter=2, cv=CVConfig(strategy="stratified", n_splits=2) + ), + calibration=CalibrationConfig( + enabled=True, cv=CVConfig(strategy="stratified", n_splits=2) + ), + grids={"lr": {"clf__C": [0.1, 1.0]}}, + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y) + assert result.raw["lr"]["status"] == "success" + + +def test_sfs_with_calibration(): + X, y = _classification_data(n_samples=20) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + feature_selection=FeatureSelectionConfig( + enabled=True, + method="sfs", + n_features=2, + cv=CVConfig(strategy="stratified", n_splits=2), + ), + calibration=CalibrationConfig( + enabled=True, cv=CVConfig(strategy="stratified", n_splits=2) + ), + cv=CVConfig(strategy="stratified", n_splits=2), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y) + assert result.raw["lr"]["status"] == "success" + + +def test_instantiate_temporal_model_explicit(): + pytest.importorskip("mne") + base_cfg = ClassicalModelConfig(estimator="LogisticRegression", params={"C": 1.0}) + config = TemporalDecoderConfig(wrapper="sliding", base=base_cfg) + exp = Experiment(ExperimentConfig(task="classification", models={"sl": config})) + est = exp._instantiate_model("sl", config) + assert est.__class__.__name__ == "SlidingEstimator" + + +def test_create_fs_step_sfs_path(): + exp = Experiment( + ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + feature_selection=FeatureSelectionConfig( + enabled=True, + method="sfs", + n_features=2, + cv=CVConfig(strategy="kfold", n_splits=2), + ), + ) + ) + step = exp._create_fs_step(LogisticRegression()) + assert step[0] == "fs" + assert step[1].__class__.__name__ == "GroupedSequentialFeatureSelector" + + +def test_build_result_meta_time_axis(): + exp = Experiment( + ExperimentConfig( + task="classification", models={"lr": LogisticRegressionConfig()} + ) + ) + # Mock some internal state needed by _build_result_meta + exp._sample_metadata = pd.DataFrame({"Subject": range(10)}) + exp._observation_level = "sample" + exp._inferential_unit = "subject" + + meta = exp._build_result_meta(np.zeros((10, 5)), t_axis=np.array([0, 1])) + assert meta["time_axis"] == [0, 1] + + +def test_capability_payload_with_fs_enabled(): + exp = Experiment( + ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + feature_selection=FeatureSelectionConfig(enabled=True, method="k_best"), + ) + ) + payload = exp._capability_payload() + assert "k_best" in payload["feature_selectors"] + + +def test_instantiate_foundation_model_fm_hub(): + # Use valid config object to avoid pydantic issues + fm_config = FoundationEmbeddingModelConfig( + kind="foundation_embedding", provider="reve", model_name="dummy" + ) + exp = Experiment( + ExperimentConfig.model_construct( + task="classification", + models={"fm": fm_config}, + metrics=["accuracy"], + cv=CVConfig(), + ) + ) + with patch("coco_pipe.decoding.fm_hub.build_foundation_model") as mock_build: + exp._instantiate_model("fm", fm_config) + mock_build.assert_called_once() + + +def test_experiment_calibration_run(): + X, y = make_classification(n_samples=40, random_state=42) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig()}, + calibration=CalibrationConfig(enabled=True, cv=CVConfig(n_splits=2)), + cv=CVConfig(n_splits=2), + ) + res = Experiment(config).run(X, y) + assert "lr" in res.raw diff --git a/tests/test_decoding_fm_hub.py b/tests/test_decoding_fm_hub.py new file mode 100644 index 0000000..43ea29d --- /dev/null +++ b/tests/test_decoding_fm_hub.py @@ -0,0 +1,98 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from coco_pipe.decoding.fm_hub._factory import build_foundation_model +from coco_pipe.decoding.fm_hub.base import BaseFoundationModel +from coco_pipe.decoding.fm_hub.cbramod import CBraModModel, CBraModModule +from coco_pipe.decoding.fm_hub.reve import REVEModel + + +def test_fm_hub_base_hardening(): + class MockFM(BaseFoundationModel): + def get_module_cls(self): + return MagicMock() + + def _get_net_params(self): + return {} + + fm = MockFM("test-model", train_mode="frozen") + assert fm.model_name == "test-model" + + with pytest.raises(ValueError, match="train_mode must be one of"): + MockFM("test-model", train_mode="invalid") + + # Fit/Predict/Transform runtime errors + with pytest.raises(RuntimeError, match="Model must be fitted before transform"): + fm.transform(np.zeros((2, 10))) + with pytest.raises(RuntimeError, match="Model must be fitted before predict"): + fm.predict(np.zeros((2, 10))) + + +def test_fm_hub_factory_hardening(): + class MockConfig: + provider = "reve" + model_name = "reve-large" + sfreq = 250.0 + + # Mock REVEModel import to avoid torch/transformers issues + with patch("importlib.import_module") as mock_import: + mock_module = MagicMock() + mock_import.return_value = mock_module + mock_module.REVEModel = MagicMock() + + build_foundation_model(MockConfig()) + assert mock_module.REVEModel.called + + # Unknown provider + class BadConfig: + provider = "unknown" + + with pytest.raises(ValueError, match="Unknown foundation model provider"): + build_foundation_model(BadConfig()) + + +def test_fm_hub_reve_hardening(): + # We test REVEModel without actually loading it + fm = REVEModel(model_name="reve-test", sfreq=100.0) + assert fm.provider == "reve" + info = fm.get_embedding_info() + assert info.n_embeddings == 1024 + assert info.model_name == "reve-test" + + +def test_fm_hub_cbramod_hardening(): + # Test CBraModModel without actually loading it + fm = CBraModModel(model_name="cbramod-test", sfreq=100.0) + assert fm.provider == "cbramod" + info = fm.get_embedding_info() + assert info.n_embeddings == 200 + assert info.model_name == "cbramod-test" + + +@patch("coco_pipe.decoding.fm_hub.cbramod.hf_hub_download") +@patch("coco_pipe.decoding.fm_hub.cbramod.torch.load") +def test_fm_hub_cbramod_module_forward(mock_load, mock_download): + import torch + + # Mock huggingface load + mock_download.return_value = "fake/path/model.bin" + mock_load.return_value = {} + + # Initialize the module + module = CBraModModule(model_name="cbramod-test", patch_size=200, output_dim=2) + module.eval() + + # Fake input: batch=4, channels=16, time=400 (which is 2 patches of size 200) + # CBraMod default expects 200 out_dim, etc. + X = torch.zeros((4, 16, 400)) + + with torch.no_grad(): + # Test feature extraction + out_emb = module.forward(X, return_embeddings=True) + assert out_emb.shape == (4, 200) + + # Test classification logits + out_logits = module.forward(X, return_embeddings=False) + assert out_logits.shape == (4, 2) diff --git a/tests/test_decoding_interfaces.py b/tests/test_decoding_interfaces.py new file mode 100644 index 0000000..584a4c9 --- /dev/null +++ b/tests/test_decoding_interfaces.py @@ -0,0 +1,96 @@ +from coco_pipe.decoding.interfaces import ( + DecoderEstimator, + EmbeddingExtractor, + NeuralTrainable, + StagedTrainable, +) + + +def test_decoder_estimator_protocol(): + class ValidDecoder: + def fit(self, X, y=None, **kwargs): + return self + + def predict(self, X): + return X + + def get_params(self, deep=True): + return {} + + def set_params(self, **params): + return self + + class InvalidDecoder: + def fit(self, X, y=None): + return self + + # Missing predict + def get_params(self, deep=True): + return {} + + def set_params(self, **params): + return self + + assert isinstance(ValidDecoder(), DecoderEstimator) + assert not isinstance(InvalidDecoder(), DecoderEstimator) + + +def test_embedding_extractor_protocol(): + class ValidExtractor: + def transform(self, X): + return X + + def get_embedding_info(self): + return {} + + class InvalidExtractor: + def transform(self, X): + return X + + # Missing get_embedding_info + + assert isinstance(ValidExtractor(), EmbeddingExtractor) + assert not isinstance(InvalidExtractor(), EmbeddingExtractor) + + +def test_neural_trainable_protocol(): + class ValidNeural: + def get_training_history(self): + return [] + + def get_checkpoint_manifest(self): + return {} + + def get_model_card_info(self): + return {} + + def get_failure_diagnostics(self): + return {} + + def get_artifact_metadata(self): + return {} + + class PartialNeural: + def get_training_history(self): + return [] + + # Missing others + + assert isinstance(ValidNeural(), NeuralTrainable) + assert not isinstance(PartialNeural(), NeuralTrainable) + + +def test_staged_trainable_protocol(): + class ValidStaged: + def set_train_stage(self, stage: str): + return self + + def get_train_stage(self): + return "pretrain" + + class InvalidStaged: + def something_else(self): + pass + + assert isinstance(ValidStaged(), StagedTrainable) + assert not isinstance(InvalidStaged(), StagedTrainable) diff --git a/tests/test_decoding_metrics.py b/tests/test_decoding_metrics.py new file mode 100644 index 0000000..2600e09 --- /dev/null +++ b/tests/test_decoding_metrics.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest +from sklearn.metrics import average_precision_score + +from coco_pipe.decoding._metrics import ( + METRIC_REGISTRY, + _pr_auc_score, + _sensitivity_score, + _specificity_score, + get_metric_families, + get_metric_names, + get_metric_spec, + get_scorer, +) + + +def test_all_metrics_runable(): + """Verify every registered metric can be called with appropriate data.""" + y_true_cls = np.array([0, 1, 0, 1]) + y_pred_cls = np.array([0, 1, 1, 1]) + y_proba_cls = np.array([0.1, 0.9, 0.4, 0.6]) + + y_true_reg = np.array([1.0, 2.0, 3.0]) + y_pred_reg = np.array([1.1, 1.9, 3.2]) + + for name, spec in METRIC_REGISTRY.items(): + if spec.task == "classification": + if spec.response_method == "predict": + val = spec.scorer(y_true_cls, y_pred_cls) + else: + val = spec.scorer(y_true_cls, y_proba_cls) + else: + val = spec.scorer(y_true_reg, y_pred_reg) + + assert isinstance(val, (float, np.float64, np.float32)) + + +def test_pr_auc_calculation(): + """Verify PR-AUC uses trapezoidal integration and differs from AP if imbalanced.""" + # Highly imbalanced + y_true = np.array([0, 0, 0, 0, 1]) + y_score = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + + average_precision_score(y_true, y_score) + pr_auc = _pr_auc_score(y_true, y_score) + + # In some versions of sklearn, AP and trapezoidal PR-AUC can differ + # depending on how thresholds are handled. + assert isinstance(pr_auc, float) + assert 0 <= pr_auc <= 1 + + +def test_sensitivity_guards(): + """Verify sensitivity enforces binary data and robust zero-division.""" + # Multiclass should fail + with pytest.raises(ValueError, match="binary classification"): + _sensitivity_score(np.array([0, 1, 2]), np.array([0, 1, 0])) + + # Binary should pass + val = _sensitivity_score(np.array([0, 1]), np.array([0, 1])) + assert val == 1.0 + + # Zero division (no positive samples) + val = _sensitivity_score(np.array([0, 0]), np.array([0, 0])) + assert val == 0.0 + + +def test_metric_registry_accessors(): + """Verify registry lookups.""" + spec = get_metric_spec("accuracy") + assert spec.name == "accuracy" + + scorer = get_scorer("f1") + assert callable(scorer) + + with pytest.raises(ValueError, match="Unknown metric"): + get_metric_spec("non_existent") + + +def test_get_metric_names_filters(): + """Verify task and family filters in name retrieval.""" + names = get_metric_names(task="classification", family="confusion") + assert "precision" in names + assert "recall" in names + assert "neg_mean_squared_error" not in names + + +def test_get_metric_families_filter(): + """Verify task filter in family retrieval.""" + families = get_metric_families(task="regression") + assert "regression" in families + assert "threshold_sweep" not in families + + +def test_specificity_standalone(): + """Verify specificity calculation (TN / (TN + FP)).""" + y_true = np.array([0, 0, 1, 1]) + y_pred = np.array([0, 1, 1, 1]) + # Specificity = TN / (TN + FP) = 1 / (1 + 1) = 0.5 + assert _specificity_score(y_true, y_pred) == 0.5 diff --git a/tests/test_decoding_registry.py b/tests/test_decoding_registry.py new file mode 100644 index 0000000..f219da3 --- /dev/null +++ b/tests/test_decoding_registry.py @@ -0,0 +1,256 @@ +import re + +import pytest + +from coco_pipe.decoding._specs import canonical_estimator_name +from coco_pipe.decoding.registry import ( + EstimatorSpec, + get_capabilities, + get_estimator_cls, + get_estimator_spec, + get_selector_capabilities, + list_capabilities, + list_estimator_specs, + register_estimator, + register_estimator_spec, + resolve_estimator_capabilities, + resolve_estimator_spec, +) + + +def test_manual_registration(): + @register_estimator("TestModel") + class TestModel: + pass + + cls = get_estimator_cls("TestModel") + assert cls is TestModel + + +def test_registration_overwrite_warning(): + @register_estimator("WarningModel") + class Model1: + pass + + with pytest.warns(UserWarning, match="Overwriting existing estimator registry"): + + @register_estimator("WarningModel") + class Model2: + pass + + +def test_get_estimator_cls_not_found(): + # Direct string check on the exception message + with pytest.raises( + pytest.importorskip("coco_pipe.decoding.registry").EstimatorNotFoundError + ) as excinfo: + get_estimator_cls("LogisticRegresion") # Typo + err_msg = str(excinfo.value) + assert "Did you mean:" in err_msg + assert "LogisticRegression" in err_msg + + +def test_lazy_load_from_spec(): + # LogisticRegression should be loadable from spec even if not in registry dict yet + cls = get_estimator_cls("LogisticRegression") + assert cls.__name__ == "LogisticRegression" + + +def test_capabilities_methods(): + caps = get_capabilities("LogisticRegression") + assert caps.method == "LogisticRegression" + assert "classification" in caps.tasks + assert not caps.supports_task("regression") + assert caps.has_response("predict") + assert caps.to_dict()["method"] == "LogisticRegression" + + # Test canonical lookup + spec = get_estimator_spec("LogisticRegression") + assert spec.to_dict()["name"] == "LogisticRegression" + + # Test list_capabilities + all_caps = list_capabilities() + assert "LogisticRegression" in all_caps + + +def test_selector_capabilities(): + caps = get_selector_capabilities("k_best") + assert caps.method == "k_best" + assert caps.to_dict()["method"] == "k_best" + + with pytest.raises( + ValueError, match="No decoding capabilities registered for selector" + ): + get_selector_capabilities("invalid_selector") + + +def test_spec_lookup(): + spec = get_estimator_spec("LogisticRegression") + assert spec.name == "LogisticRegression" + assert spec.import_path == "sklearn.linear_model" + + all_specs = list_estimator_specs() + assert "LogisticRegression" in all_specs + + # Test missing spec + with pytest.raises(ValueError, match="No decoding estimator spec registered"): + get_estimator_spec("InvalidModel") + + +def test_register_new_spec(): + new_spec = EstimatorSpec( + name="NewModel", + import_path="sklearn.dummy:DummyClassifier", + family="dummy", + task=("classification",), + ) + register_estimator_spec(new_spec) + cls = get_estimator_cls("NewModel") + assert cls.__name__ == "DummyClassifier" + + +def test_invalid_import_path(): + new_spec = EstimatorSpec( + name="InvalidPath", + import_path="nonexistent.module", + family="linear", + task=("regression",), + ) + register_estimator_spec(new_spec) + with pytest.raises(ImportError): + get_estimator_cls("InvalidPath") + + +def test_missing_class_in_module(): + new_spec = EstimatorSpec( + name="MissingClass", + import_path="sklearn.linear_model:NonExistentClass", + family="linear", + task=("regression",), + ) + register_estimator_spec(new_spec) + with pytest.raises(Exception): # EstimatorNotFoundError + get_estimator_cls("MissingClass") + + +def test_resolve_estimator_spec(): + from types import SimpleNamespace + + # Classical + cfg = SimpleNamespace( + kind="classical", estimator="logistic_regression", method="LogisticRegression" + ) + spec = resolve_estimator_spec(cfg) + assert spec.name == "LogisticRegression" + + # Foundation (reve) + cfg_f = SimpleNamespace(kind="reve", method="REVEModel") + spec_f = resolve_estimator_spec(cfg_f) + assert spec_f.name == "REVEModel" + + # Temporal + cfg_t = SimpleNamespace( + kind="temporal", + wrapper="sliding", + base=cfg, + method="SlidingEstimator", + base_estimator=cfg, + ) + spec_t = resolve_estimator_spec(cfg_t) + assert spec_t.name == "SlidingEstimator" + + # Canonical + assert canonical_estimator_name("lda") == "LinearDiscriminantAnalysis" + assert canonical_estimator_name("unknown") == "unknown" + + +def test_resolve_estimator_spec_variants(): + from types import SimpleNamespace + + # SVC with probability=False + cfg_svc = SimpleNamespace( + method="SVC", kind="classical", estimator="SVC", probability=False + ) + spec_svc = resolve_estimator_spec(cfg_svc) + assert not spec_svc.supports_proba + + # SGD with log_loss + cfg_sgd = SimpleNamespace( + method="SGDClassifier", + kind="classical", + estimator="SGDClassifier", + loss="log_loss", + ) + spec_sgd = resolve_estimator_spec(cfg_sgd) + assert spec_sgd.supports_proba + + # resolve_capabilities + caps = resolve_estimator_capabilities(cfg_svc) + assert caps.method == "SVC" + + +def test_resolve_estimator_spec_temporal_foundation_fixups(): + # 1. Temporal with dict config + temporal_cfg = { + "kind": "temporal", + "wrapper": "sliding", + "base": {"kind": "classical", "method": "LogisticRegression"}, + } + spec = resolve_estimator_spec(temporal_cfg) + assert spec.name == "SlidingEstimator" + assert spec.supports_proba is True + + # 2. Foundation with dict config + foundation_cfg = {"kind": "foundation_embedding", "provider": "reve"} + spec = resolve_estimator_spec(foundation_cfg) + assert spec.name == "REVEModel" + + +def test_resolve_estimator_spec_runtime_fixups(): + # SVC probability=False (using dict) + # Note: registry.py must be updated to use _get_val for this to pass with dicts + from types import SimpleNamespace + + svc_cfg = SimpleNamespace(kind="classical", method="SVC", probability=False) + spec = resolve_estimator_spec(svc_cfg) + assert spec.supports_proba is False + assert spec.supports_decision_function is True + + # SGDClassifier loss="log_loss" + sgd_cfg = SimpleNamespace(kind="classical", method="SGDClassifier", loss="log_loss") + spec = resolve_estimator_spec(sgd_cfg) + assert spec.supports_proba is True + + +def test_get_selector_capabilities_error(): + with pytest.raises( + ValueError, match="No decoding capabilities registered for selector" + ): + get_selector_capabilities("non_existent_selector") + + +def test_register_spec_overwrite_warning(): + spec = EstimatorSpec( + name="OverwriteModel", import_path="fake", family="linear", task=("regression",) + ) + register_estimator_spec(spec) + pass + + +def test_get_estimator_cls_import_error(): + spec = EstimatorSpec( + name="ImportErrorModel", + import_path="non.existent.module:Class", + family="linear", + task=("regression",), + ) + register_estimator_spec(spec) + with pytest.raises(ImportError, match="Could not load estimator"): + get_estimator_cls("ImportErrorModel") + + +def test_get_estimator_cls_not_found_with_matches(): + # Use re.escape to avoid invalid escape sequence warning + expected = re.escape("Did you mean: ['LogisticRegression']") + with pytest.raises(Exception, match=expected): + get_estimator_cls("LogisticRegres") diff --git a/tests/test_decoding_results.py b/tests/test_decoding_results.py new file mode 100644 index 0000000..dbeb5a2 --- /dev/null +++ b/tests/test_decoding_results.py @@ -0,0 +1,909 @@ +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_classification + +from coco_pipe.decoding import Experiment, ExperimentResult +from coco_pipe.decoding._cache import make_feature_cache_key +from coco_pipe.decoding._constants import RESULT_SCHEMA_VERSION +from coco_pipe.decoding.configs import ( + CVConfig, + ExperimentConfig, + LogisticRegressionConfig, +) + + +def _classification_data(): + X = np.array( + [ + [-2.0, -1.0], + [-1.5, -0.8], + [-1.0, -0.6], + [-0.8, -0.4], + [0.8, 0.4], + [1.0, 0.6], + [1.5, 0.8], + [2.0, 1.0], + ] + ) + y = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + return X, y + + +def _config(output_dir=None): + return ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig(max_iter=200)}, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=2, shuffle=True, random_state=0), + output_dir=output_dir, + tag="result_schema_test", + n_jobs=1, + verbose=False, + ) + + +def test_run_result_payload_stores_config_provenance_sample_ids_and_groups(): + X, y = _classification_data() + sample_ids = np.array([f"sample_{idx}" for idx in range(len(y))]) + groups = np.array(["g0", "g0", "g1", "g1", "g2", "g2", "g3", "g3"]) + sample_metadata = { + "subject": ["s0", "s0", "s1", "s1", "s2", "s2", "s3", "s3"], + "session": ["visit1"] * len(y), + "site": ["site1"] * len(y), + } + + result = Experiment(_config()).run( + X, + y, + groups=groups, + sample_ids=sample_ids, + sample_metadata=sample_metadata, + observation_level="epoch", + feature_names=["left", "right"], + ) + + payload = result.to_payload() + assert payload["schema_version"] == RESULT_SCHEMA_VERSION + assert payload["config"]["tag"] == "result_schema_test" + assert payload["meta"]["task"] == "classification" + assert payload["meta"]["n_samples"] == len(y) + assert payload["meta"]["n_features"] == X.shape[1] + assert payload["meta"]["observation_level"] == "epoch" + assert payload["meta"]["inferential_unit"] == "subject" + assert payload["meta"]["sample_metadata_columns"] == [ + "Subject", + "Session", + "Site", + ] + assert "versions" in payload["meta"] + + predictions = result.get_predictions() + assert { + "SampleIndex", + "SampleID", + "Group", + "Subject", + "Session", + "Site", + }.issubset(predictions.columns) + assert set(predictions["SampleID"]) == set(sample_ids) + assert set(predictions["Group"]) == set(groups) + assert set(predictions["Subject"]) == {"s0", "s1", "s2", "s3"} + assert set(predictions["Session"]) == {"visit1"} + + splits = result.get_splits() + assert set(splits["Set"]) == {"train", "test"} + assert set(splits["SampleID"]) == set(sample_ids) + assert set(splits["Group"]) == set(groups) + assert {"Subject", "Session", "Site"}.issubset(splits.columns) + + ci = result.get_bootstrap_confidence_intervals( + n_bootstraps=10, + random_state=0, + ) + assert set(ci["Unit"]) == {"subject"} + + +def test_duplicate_sample_ids_are_rejected(): + X, y = _classification_data() + + try: + Experiment(_config()).run(X, y, sample_ids=["s0"] * len(y)) + except ValueError as exc: + assert "sample_ids must be unique" in str(exc) + else: + raise AssertionError("Expected duplicate sample_ids to fail.") + + +def test_observation_level_is_explicitly_limited(): + X, y = _classification_data() + + try: + Experiment(_config()).run(X, y, observation_level="trial") + except ValueError as exc: + assert "observation_level must be 'sample' or 'epoch'" in str(exc) + else: + raise AssertionError("Expected invalid observation_level to fail.") + + +def test_sample_metadata_requires_subject_and_session(): + X, y = _classification_data() + + try: + Experiment(_config()).run( + X, + y, + sample_metadata={"subject": [f"s{idx}" for idx in range(len(y))]}, + ) + except ValueError as exc: + assert "sample_metadata must include Subject and Session" in str(exc) + else: + raise AssertionError("Expected incomplete sample_metadata to fail.") + + +def test_explicit_inferential_unit_overrides_epoch_default(): + X, y = _classification_data() + metadata = { + "subject": [f"s{idx // 2}" for idx in range(len(y))], + "session": ["visit1"] * len(y), + } + + result = Experiment(_config()).run( + X, + y, + sample_metadata=metadata, + observation_level="epoch", + inferential_unit="epoch", + ) + + assert result.meta["inferential_unit"] == "epoch" + ci = result.get_bootstrap_confidence_intervals( + n_bootstraps=10, + random_state=0, + ) + assert set(ci["Unit"]) == {"epoch"} + + +def test_save_load_roundtrip_preserves_decoding_payload(tmp_path): + X, y = _classification_data() + exp = Experiment(_config(output_dir=tmp_path)) + result = exp.run(X, y, sample_ids=[f"s{idx}" for idx in range(len(y))]) + + path = result.save() + loaded = ExperimentResult.load(path) + + assert loaded.schema_version == result.schema_version + assert loaded.config["tag"] == result.config["tag"] + assert loaded.meta["n_samples"] == result.meta["n_samples"] + assert loaded.raw.keys() == result.raw.keys() + assert loaded.to_payload()["schema_version"] == RESULT_SCHEMA_VERSION + + +def test_get_predictions_expands_temporal_arrays(): + raw = { + "sliding": { + "metrics": {}, + "predictions": [ + { + "sample_index": np.array([0, 1]), + "sample_id": np.array(["s0", "s1"]), + "group": np.array(["g0", "g1"]), + "y_true": np.array([0, 1]), + "y_pred": np.array([[0, 1], [1, 1]]), + "y_proba": np.array( + [ + [[0.8, 0.4], [0.2, 0.6]], + [[0.3, 0.2], [0.7, 0.8]], + ] + ), + } + ], + }, + "generalizing": { + "metrics": {}, + "predictions": [ + { + "sample_index": np.array([0, 1]), + "sample_id": np.array(["s0", "s1"]), + "group": None, + "y_true": np.array([0, 1]), + "y_pred": np.array( + [ + [[0, 1], [1, 0]], + [[1, 1], [0, 1]], + ] + ), + "y_proba": np.ones((2, 2, 2, 2)) * 0.5, + } + ], + }, + } + + predictions = ExperimentResult(raw).get_predictions() + + sliding = predictions[predictions["Model"] == "sliding"] + assert len(sliding) == 4 + assert set(sliding["Time"]) == {0, 1} + assert {"y_proba_0", "y_proba_1"}.issubset(sliding.columns) + + generalizing = predictions[predictions["Model"] == "generalizing"] + assert len(generalizing) == 8 + assert set(generalizing["TrainTime"]) == {0, 1} + assert set(generalizing["TestTime"]) == {0, 1} + assert generalizing["Group"].isna().all() + + +def test_get_detailed_scores_expands_temporal_scores(): + result = ExperimentResult( + { + "model": { + "metrics": { + "accuracy": { + "mean": 0.5, + "std": 0.1, + "folds": [ + 0.75, + np.array([0.1, 0.2]), + np.array([[0.1, 0.2], [0.3, 0.4]]), + ], + } + } + } + } + ) + + scores = result.get_detailed_scores() + + assert len(scores[scores["Fold"] == 0]) == 1 + assert set(scores[scores["Fold"] == 1]["Time"]) == {0, 1} + matrix_scores = scores[scores["Fold"] == 2] + assert len(matrix_scores) == 4 + assert set(matrix_scores["TrainTime"]) == {0, 1} + assert set(matrix_scores["TestTime"]) == {0, 1} + + +def test_get_feature_importances_returns_named_aggregate_and_fold_tables(): + result = ExperimentResult( + { + "rf": { + "metrics": {}, + "importances": { + "mean": np.array([0.25, 0.75]), + "std": np.array([0.05, 0.10]), + "raw": np.array([[0.2, 0.8], [0.3, 0.7]]), + "feature_names": ["alpha", "beta"], + }, + "metadata": [{}], + } + } + ) + + aggregate = result.get_feature_importances() + assert aggregate.columns.tolist() == [ + "Model", + "Feature", + "FeatureName", + "Mean", + "Std", + "Rank", + ] + assert aggregate["FeatureName"].tolist() == ["alpha", "beta"] + assert aggregate["Mean"].tolist() == [0.25, 0.75] + + fold_level = result.get_feature_importances(fold_level=True) + assert fold_level.columns.tolist() == [ + "Model", + "Fold", + "Feature", + "FeatureName", + "Importance", + "Rank", + ] + assert len(fold_level) == 4 + assert set(fold_level["Fold"]) == {0, 1} + + +def test_feature_cache_key_tracks_split_preprocessing_and_backbone_identity(): + base = make_feature_cache_key( + train_sample_ids=["s0", "s1"], + test_sample_ids=["s2"], + preprocessing_fingerprint="prep-a", + backbone_fingerprint="backbone-a", + ) + + assert base == make_feature_cache_key( + train_sample_ids=["s0", "s1"], + test_sample_ids=["s2"], + preprocessing_fingerprint="prep-a", + backbone_fingerprint="backbone-a", + ) + assert base != make_feature_cache_key( + train_sample_ids=["s0"], + test_sample_ids=["s1", "s2"], + preprocessing_fingerprint="prep-a", + backbone_fingerprint="backbone-a", + ) + assert base != make_feature_cache_key( + train_sample_ids=["s0", "s1"], + test_sample_ids=["s2"], + preprocessing_fingerprint="prep-b", + backbone_fingerprint="backbone-a", + ) + assert base != make_feature_cache_key( + train_sample_ids=["s0", "s1"], + test_sample_ids=["s2"], + preprocessing_fingerprint="prep-a", + backbone_fingerprint="backbone-b", + ) + + +def _diagnostic_result(): + X, y = make_classification( + n_samples=40, + n_features=5, + n_informative=3, + n_redundant=0, + random_state=7, + ) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig(max_iter=200, kind="classical")}, + metrics=["accuracy", "roc_auc", "brier_score"], + cv=CVConfig(strategy="stratified", n_splits=2, shuffle=True, random_state=7), + n_jobs=1, + verbose=False, + ) + return Experiment(config).run(X, y) + + +def test_fit_diagnostics_are_recorded_per_fold(): + result = _diagnostic_result() + diagnostics = result.get_fit_diagnostics() + assert len(diagnostics) >= 2 + assert len(pd.unique(diagnostics["Fold"])) == 2 + assert {"FitTime", "PredictTime", "ScoreTime", "TotalTime"}.issubset( + diagnostics.columns + ) + assert (diagnostics["FitTime"] >= 0).all() + + +def test_fit_diagnostics_expands_warning_records(): + result = ExperimentResult( + { + "model": { + "metrics": {}, + "predictions": [], + "diagnostics": [ + { + "fit_time": 0.1, + "predict_time": 0.2, + "score_time": 0.3, + "total_time": 0.6, + "warnings": [ + { + "stage": "fit", + "category": "ConvergenceWarning", + "message": "did not converge", + } + ], + } + ], + } + } + ) + diagnostics = result.get_fit_diagnostics() + assert diagnostics.loc[0, "Stage"] == "fit" + assert diagnostics.loc[0, "WarningCategory"] == "ConvergenceWarning" + assert "did not converge" in diagnostics.loc[0, "WarningMessage"] + + +def test_confusion_roc_pr_and_calibration_accessors(): + result = _diagnostic_result() + confusion = result.get_confusion_matrices() + counts = result.get_confusion_counts() + pooled = result.get_pooled_confusion_matrix() + roc = result.get_roc_curve() + pr = result.get_pr_curve() + calibration = result.get_calibration_curve(n_bins=3) + proba = result.get_probability_diagnostics() + + assert {"TrueLabel", "PredictedLabel", "Value"}.issubset(confusion.columns) + assert {"TrueLabel", "PredictedLabel", "Value"}.issubset(counts.columns) + assert {"TrueLabel", "PredictedLabel", "Value"}.issubset(pooled.columns) + assert {"FPR", "TPR", "Threshold"}.issubset(roc.columns) + assert {"Precision", "Recall", "Threshold"}.issubset(pr.columns) + assert { + "MeanPredictedProbability", + "FractionPositive", + }.issubset(calibration.columns) + assert {"Metric", "Class", "Value"}.issubset(proba.columns) + assert not confusion.empty + assert not counts.empty + assert not pooled.empty + assert not roc.empty + assert not pr.empty + assert not calibration.empty + assert {"log_loss", "brier_score_macro"}.issubset(set(proba["Metric"])) + + +def test_multiclass_curves_use_one_vs_rest_rows(): + raw = { + "m": { + "metrics": {}, + "predictions": [ + { + "sample_index": np.arange(6), + "sample_id": np.arange(6), + "group": None, + "y_true": np.array([0, 1, 2, 0, 1, 2]), + "y_pred": np.array([0, 1, 2, 1, 1, 0]), + "y_proba": np.array( + [ + [0.8, 0.1, 0.1], + [0.1, 0.7, 0.2], + [0.1, 0.2, 0.7], + [0.3, 0.5, 0.2], + [0.2, 0.6, 0.2], + [0.4, 0.2, 0.4], + ] + ), + } + ], + } + } + result = ExperimentResult(raw) + assert set(result.get_roc_curve()["Class"]) == {0, 1, 2} + assert set(result.get_pr_curve()["Class"]) == {0, 1, 2} + assert set(result.get_calibration_curve(n_bins=2)["Class"]) == {0, 1, 2} + assert not result.get_probability_diagnostics().empty + + +def test_permutation_bootstrap_and_paired_comparison_helpers(): + X, y = make_classification( + n_samples=36, + n_features=5, + n_informative=3, + n_redundant=0, + random_state=9, + ) + groups = np.repeat(np.arange(12), 3) + config = ExperimentConfig( + task="classification", + models={ + "lr": LogisticRegressionConfig(max_iter=200, kind="classical"), + "dummy": { + "kind": "classical", + "method": "DummyClassifier", + "strategy": "prior", + }, + }, + metrics=["accuracy"], + cv=CVConfig(strategy="stratified", n_splits=2, shuffle=True, random_state=9), + n_jobs=1, + verbose=False, + ) + result = Experiment(config).run(X, y, groups=groups) + + null = result.get_statistical_assessment( + lightweight=True, n_permutations=20, random_state=1 + ) + ci = result.get_bootstrap_confidence_intervals( + n_bootstraps=20, + unit="group", + random_state=1, + ) + paired = result.compare_models_paired( + "lr", + "dummy", + n_permutations=20, + unit="group", + random_state=1, + ) + + assert {"Observed", "NullLower", "NullUpper", "PValue"}.issubset(null.columns) + assert {"Estimate", "CILower", "CIUpper", "NUnits"}.issubset(ci.columns) + assert paired.loc[0, "ModelA"] == "lr" + assert paired.loc[0, "ModelB"] == "dummy" + assert 0 <= paired.loc[0, "PValue"] <= 1 + + +def test_make_serializable_all_types(): + from coco_pipe.decoding.result import make_serializable + + data = { + "arr": np.array([1, 2, 3]), + "int": np.int64(42), + "float": np.float32(3.14), + "bool": np.bool_(True), + "nested": { + "list": [np.int32(1), np.float64(2.0)], + "tuple": (np.int16(3),), + }, + } + + serialized = make_serializable(data) + + assert isinstance(serialized["arr"], list) + assert isinstance(serialized["int"], int) + assert isinstance(serialized["float"], float) + assert isinstance(serialized["bool"], bool) + assert isinstance(serialized["nested"]["list"][0], int) + assert isinstance(serialized["nested"]["list"][1], float) + assert isinstance(serialized["nested"]["tuple"][0], int) + # Check JSON compatibility + import json + + json.dumps(serialized) + + +def test_save_load_json_roundtrip(tmp_path): + X, y = _classification_data() + result = Experiment(_config(output_dir=tmp_path)).run(X, y) + + json_path = tmp_path / "result.json" + saved_path = result.save(json_path) + + assert saved_path == json_path + assert saved_path.exists() + + loaded = ExperimentResult.load(json_path) + assert loaded.config["tag"] == result.config["tag"] + assert loaded.raw.keys() == result.raw.keys() + + # Verify it was indeed JSON + with open(json_path, "r") as f: + import json + + data = json.load(f) + assert data["schema_version"] == RESULT_SCHEMA_VERSION + + +def test_get_predictions_with_model_filter(): + raw = { + "m1": { + "metrics": {}, + "predictions": [ + {"sample_index": [0], "sample_id": ["s0"], "y_true": [0], "y_pred": [0]} + ], + }, + "m2": { + "metrics": {}, + "predictions": [ + {"sample_index": [1], "sample_id": ["s1"], "y_true": [1], "y_pred": [1]} + ], + }, + } + result = ExperimentResult(raw) + + preds_m1 = result.get_predictions(model="m1") + assert (preds_m1["Model"] == "m1").all() + assert len(preds_m1) == 1 + + preds_all = result.get_predictions() + assert set(preds_all["Model"]) == {"m1", "m2"} + + +def test_get_temporal_score_summary_1d_2d(): + raw = { + "m1": { + "metrics": { + "acc": { + "folds": [np.array([0.5, 0.6]), np.array([0.7, 0.8])], + } + }, + "statistical_assessment": [ + {"Metric": "acc", "Time": 0.0, "PValue": 0.01, "Significant": True}, + {"Metric": "acc", "Time": 1.0, "PValue": 0.05, "Significant": False}, + ], + }, + "m2": { + "metrics": { + "acc": { + "folds": [np.array([[0.5, 0.6], [0.7, 0.8]])], + } + } + }, + } + result = ExperimentResult(raw, time_axis=[0.0, 1.0]) + + summary = result.get_temporal_score_summary() + + # 1D check + m1_sum = summary[summary["Model"] == "m1"] + assert len(m1_sum) == 2 + assert m1_sum.iloc[0]["Mean"] == 0.6 # (0.5 + 0.7) / 2 + assert m1_sum.iloc[0]["PValue"] == 0.01 + assert m1_sum.iloc[0]["Significant"] + + # 2D check + m2_sum = summary[summary["Model"] == "m2"] + assert len(m2_sum) == 4 + assert set(m2_sum["TrainTime"]) == {0.0, 1.0} + assert set(m2_sum["TestTime"]) == {0.0, 1.0} + + +def test_get_splits_with_metadata_flattening(): + raw = { + "m": { + "splits": [ + { + "train_idx": [0], + "train_sample_id": ["s0"], + "train_metadata": {"sub": ["sub1"], "site": ["site1"]}, + "test_idx": [1], + "test_sample_id": ["s1"], + "test_metadata": {"sub": ["sub2"], "site": ["site1"]}, + } + ] + } + } + result = ExperimentResult(raw) + splits = result.get_splits() + + assert "sub" in splits.columns + assert "site" in splits.columns + assert splits.loc[splits["SampleID"] == "s0", "sub"].iloc[0] == "sub1" + assert splits.loc[splits["SampleID"] == "s1", "sub"].iloc[0] == "sub2" + + +def test_summary_with_statistical_rows(): + raw = { + "m": { + "metrics": {"acc": {"mean": 0.8, "std": 0.05}}, + "statistical_assessment": [ + { + "Metric": "acc", + "Time": None, + "TrainTime": None, + "PValue": 0.001, + "Significant": True, + } + ], + } + } + result = ExperimentResult(raw) + summ = result.summary() + + assert summ.loc["m", "acc_mean"] == 0.8 + assert summ.loc["m", "acc_p_val"] == 0.001 + assert summ.loc["m", "acc_sig"] == "*" + + +def test_get_roc_pr_auc_summaries_multiclass(): + # Setup multiclass probas + y_true = np.array([0, 0, 1, 1, 2, 2]) + # Model 1: perfect + y_proba = np.array( + [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]] + ) + + raw = { + "m": { + "predictions": [ + { + "sample_index": np.arange(6), + "sample_id": np.arange(6), + "y_true": y_true, + "y_pred": y_true, + "y_proba": y_proba, + } + ] + } + } + result = ExperimentResult(raw) + + roc_auc = result.get_roc_auc_summary() + pr_auc = result.get_pr_auc_summary() + + assert roc_auc.loc[0, "MacroROCAUC"] == 1.0 + assert pr_auc.loc[0, "MacroPRAUC"] == 1.0 + + +def test_get_generalization_matrix_formats(): + raw = { + "m": {"metrics": {"accuracy": {"folds": [np.array([[0.5, 0.6], [0.7, 0.8]])]}}} + } + result = ExperimentResult(raw, time_axis=["t1", "t2"]) + + # Wide format (model specified) + wide = result.get_generalization_matrix(model="m") + assert wide.shape == (2, 2) + assert wide.index.name == "TrainTime" + assert list(wide.index) == ["t1", "t2"] + + # Long format (no model specified) + long = result.get_generalization_matrix() + assert len(long) == 4 + assert set(long["TrainTime"]) == {"t1", "t2"} + assert "Value" in long.columns + + +def test_get_best_params_and_search_results(): + raw = { + "m": { + "metadata": [ + { + "best_params": {"C": 1.0, "penalty": "l2"}, + "search_results": [ + { + "candidate": 0, + "rank_test_score": 1, + "mean_test_score": 0.8, + "params": {"C": 1.0}, + }, + { + "candidate": 1, + "rank_test_score": 2, + "mean_test_score": 0.7, + "params": {"C": 0.1}, + }, + ], + } + ] + } + } + result = ExperimentResult(raw) + + best = result.get_best_params() + assert len(best) == 2 + assert set(best["Param"]) == {"C", "penalty"} + + search = result.get_search_results() + assert len(search) == 2 + assert search.loc[0, "Rank"] == 1 + + +def test_get_selected_features_with_order_and_stability(): + raw = { + "m": { + "metadata": [ + { + "selected_features": [True, False, True], + "selection_order": [2, 0], # feature 2 first, then feature 0 + "feature_names": ["f1", "f2", "f3"], + }, + { + "selected_features": [True, False, False], + "feature_names": ["f1", "f2", "f3"], + }, + ] + } + } + result = ExperimentResult(raw) + + sel = result.get_selected_features() + assert len(sel) == 6 + # Check order for first fold + fold0 = sel[sel["Fold"] == 0] + assert fold0.loc[fold0["FeatureName"] == "f3", "Order"].iloc[0] == 1 + assert fold0.loc[fold0["FeatureName"] == "f1", "Order"].iloc[0] == 2 + + stability = result.get_feature_stability() + assert ( + stability.loc[stability["FeatureName"] == "f1", "SelectionFrequency"].iloc[0] + == 1.0 + ) + assert ( + stability.loc[stability["FeatureName"] == "f3", "SelectionFrequency"].iloc[0] + == 0.5 + ) + + +def test_result_error_paths(tmp_path): + # load non-existent + with pytest.raises(FileNotFoundError): + ExperimentResult.load(tmp_path / "non_existent.pkl") + + # summary empty + res = ExperimentResult({}) + assert res.summary().empty + + # get_detailed_scores with error in raw + res_err = ExperimentResult({"m1": {"error": "failed"}}) + assert res_err.get_detailed_scores().empty + assert res_err.get_temporal_score_summary().empty + assert res_err.get_predictions().empty + assert res_err.get_splits().empty + assert res_err.get_fit_diagnostics().empty + + # empty curves + assert res.get_roc_curve().empty + assert res.get_pr_curve().empty + assert res.get_roc_auc_summary().empty + assert res.get_pr_auc_summary().empty + + +def test_result_save_default_path(): + res = ExperimentResult({}, config={"output_dir": "."}) + path = res.save() + assert path.exists() + path.unlink() + + +def test_compare_models_all_pairs(): + raw = { + "m1": { + "metrics": {"accuracy": {"mean": 0.8, "std": 0.1, "folds": [0.8, 0.8]}}, + "predictions": [ + {"sample_index": [0], "sample_id": ["s0"], "y_true": [0], "y_pred": [0]} + ], + }, + "m2": { + "metrics": {"accuracy": {"mean": 0.7, "std": 0.1, "folds": [0.7, 0.7]}}, + "predictions": [ + {"sample_index": [0], "sample_id": ["s0"], "y_true": [0], "y_pred": [1]} + ], + }, + "m3": { + "metrics": {"accuracy": {"mean": 0.6, "std": 0.1, "folds": [0.6, 0.6]}}, + "predictions": [ + {"sample_index": [0], "sample_id": ["s0"], "y_true": [0], "y_pred": [1]} + ], + }, + } + res = ExperimentResult(raw) + comp = res.compare_models(metric="accuracy", n_permutations=5) + assert not comp.empty + + +def test_get_feature_scores_with_pvalues(): + raw = { + "m": { + "metadata": [ + { + "feature_scores": [10.0, 5.0], + "feature_pvalues": [0.001, 0.05], + "feature_names": ["a", "b"], + "feature_selection_method": "f_classif", + } + ] + } + } + result = ExperimentResult(raw) + scores = result.get_feature_scores() + assert scores.loc[0, "Score"] == 10.0 + assert scores.loc[0, "PValue"] == 0.001 + + +def test_get_statistical_nulls_and_model_artifacts(): + raw = { + "m": { + "statistical_nulls": {"acc": [0.4, 0.5, 0.6]}, + "metadata": [{"artifacts": {"coef": [1, 2], "intercept": 0.5}}], + } + } + result = ExperimentResult(raw) + + nulls = result.get_statistical_nulls() + assert "m" in nulls + assert "acc" in nulls["m"] + + artifacts = result.get_model_artifacts() + assert len(artifacts) == 2 + assert set(artifacts["Key"]) == {"coef", "intercept"} + + +def test_generalization_matrix_formatting(): + # Mock result with 2D temporal scores + raw = { + "tg_model": { + "status": "success", + "metrics": { + "accuracy": {"folds": [np.ones((2, 2)) * 0.8, np.ones((2, 2)) * 0.9]} + }, + } + } + result = ExperimentResult(raw, config={}, meta={"time_axis": [0.1, 0.2]}) + + # 1. Long format + long_df = result.get_generalization_matrix() + assert len(long_df) == 4 + assert set(long_df.columns) == {"Model", "Metric", "TrainTime", "TestTime", "Value"} + + # 2. Wide format (matrix) for specific model + wide_df = result.get_generalization_matrix(model="tg_model") + assert wide_df.shape == (2, 2) + assert wide_df.index.tolist() == [0.1, 0.2] + assert wide_df.columns.tolist() == [0.1, 0.2] + assert np.allclose(wide_df.values, 0.85) diff --git a/tests/test_decoding_specs.py b/tests/test_decoding_specs.py new file mode 100644 index 0000000..9e9ffbd --- /dev/null +++ b/tests/test_decoding_specs.py @@ -0,0 +1,72 @@ +from coco_pipe.decoding._specs import ( + ESTIMATOR_SPECS, + SELECTOR_CAPABILITIES, + EstimatorCapabilities, + EstimatorSpec, + canonical_estimator_name, +) + + +def test_estimator_capabilities_methods(): + caps = EstimatorCapabilities( + method="test", + tasks=("classification",), + prediction_interfaces=("predict", "predict_proba"), + ) + # Check supports_task + assert caps.supports_task("classification") is True + assert caps.supports_task("regression") is False + + # Check has_response + assert caps.has_response("predict") is True + assert caps.has_response("predict_proba") is True + assert caps.has_response("decision_function") is False + + # Check to_dict + d = caps.to_dict() + assert d["method"] == "test" + assert d["tasks"] == ("classification",) + + +def test_estimator_spec_to_capabilities_variants(): + # 1. Temporal Sliding + spec_sliding = ESTIMATOR_SPECS["SlidingEstimator"] + caps_sliding = spec_sliding.to_capabilities() + assert caps_sliding.temporal == "sliding" + assert caps_sliding.input_ranks == ("3d_temporal",) + + # 2. Temporal Generalizing + spec_gen = ESTIMATOR_SPECS["GeneralizingEstimator"] + caps_gen = spec_gen.to_capabilities() + assert caps_gen.temporal == "generalizing" + assert caps_gen.input_ranks == ("3d_temporal",) + + # 3. Tokens Rank (Simulated if not in registry) + spec_tokens = EstimatorSpec( + name="TokenModel", + import_path="test:TokenModel", + family="neural", + task=("classification",), + input_kinds=("tokens",), + ) + caps_tokens = spec_tokens.to_capabilities() + assert caps_tokens.input_ranks == ("tokens",) + + +def test_canonical_name_mapping(): + # Test common aliases + assert canonical_estimator_name("lda") == "LinearDiscriminantAnalysis" + assert canonical_estimator_name("logistic_regression") == "LogisticRegression" + assert canonical_estimator_name("ridge") == "Ridge" + + # Test unknown passthrough + assert canonical_estimator_name("CustomModel") == "CustomModel" + assert canonical_estimator_name("RandomStuff") == "RandomStuff" + + +def test_selector_capabilities_to_dict(): + cap = SELECTOR_CAPABILITIES["sfs"] + d = cap.to_dict() + assert d["method"] == "sfs" + assert "input_ranks" in d + assert "support" in d diff --git a/tests/test_decoding_splitters.py b/tests/test_decoding_splitters.py new file mode 100644 index 0000000..027874c --- /dev/null +++ b/tests/test_decoding_splitters.py @@ -0,0 +1,254 @@ +import numpy as np +import pandas as pd +import pytest +from pydantic import ValidationError +from sklearn.model_selection import ( + GroupKFold, + KFold, + LeaveOneGroupOut, + StratifiedGroupKFold, + StratifiedKFold, +) + +from coco_pipe.decoding._splitters import ( + SimpleSplit, + _CVWithGroups, + cv_uses_groups, + get_cv_splitter, +) +from coco_pipe.decoding.configs import CVConfig + + +def test_simple_split_basic(): + X = np.zeros((100, 10)) + y = np.zeros(100) + ss = SimpleSplit(test_size=0.2, shuffle=True, random_state=42) + + splits = list(ss.split(X, y)) + assert len(splits) == 1 + train_idx, test_idx = splits[0] + assert len(train_idx) == 80 + assert len(test_idx) == 20 + assert ss.get_n_splits() == 1 + assert ss._get_tags()["non_deterministic"] is True + assert "SimpleSplit" in repr(ss) + + +def test_simple_split_invalid_size(): + with pytest.raises(ValueError, match="test_size must be between 0 and 1"): + SimpleSplit(test_size=1.5) + + +def test_simple_split_stratify(): + X = np.zeros((10, 1)) + y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + + # Stratify = True (uses y) + ss = SimpleSplit(test_size=0.2, stratify=True) + train_idx, test_idx = next(ss.split(X, y)) + assert y[test_idx].sum() == 1 # 20% of 5 is 1 + + # Stratify = False + ss_none = SimpleSplit(test_size=0.2, stratify=False) + next(ss_none.split(X, y)) + + # Stratify = array + ss_arr = SimpleSplit(test_size=0.2, stratify=y) + next(ss_arr.split(X, y)) + + +def test_cv_with_groups_wrapper(): + base_cv = KFold(n_splits=5) + groups = np.repeat(np.arange(10), 10) + X = np.zeros((100, 10)) + + wrapped = _CVWithGroups(base_cv, groups) + assert wrapped.get_n_splits(X) == 5 + assert len(list(wrapped.split(X))) == 5 + assert "non_deterministic" in wrapped._get_tags() + assert wrapped.get_params()["groups"] is groups + assert "_CVWithGroups" in repr(wrapped) + + +def test_cv_with_groups_rejects_unsliced_nested_groups(): + base_cv = KFold(n_splits=2) + groups = np.repeat(np.arange(5), 2) + wrapped = _CVWithGroups(base_cv, groups) + + with pytest.raises(ValueError, match="Bound groups length does not match X"): + list(wrapped.split(np.zeros((6, 2)))) + + +def test_cv_uses_groups_detects_runtime_splitters(): + wrapped = _CVWithGroups(KFold(n_splits=2), np.arange(6)) + + assert cv_uses_groups(GroupKFold(n_splits=2)) + assert cv_uses_groups(wrapped) + assert not cv_uses_groups(KFold(n_splits=2)) + + +def test_get_cv_splitter_factory(): + # KFold + cfg = CVConfig(strategy="kfold", n_splits=5, shuffle=True, random_state=42) + splitter = get_cv_splitter(cfg) + assert isinstance(splitter, KFold) + assert splitter.n_splits == 5 + + # Stratified + cfg_s = CVConfig(strategy="stratified", n_splits=3) + splitter_s = get_cv_splitter(cfg_s, task="classification") + assert isinstance(splitter_s, StratifiedKFold) + + # GroupKFold + cfg_g = CVConfig(strategy="group_kfold", n_splits=5) + groups = np.repeat(np.arange(5), 20) + splitter_g = get_cv_splitter(cfg_g, groups=groups) + assert isinstance(splitter_g, _CVWithGroups) + assert isinstance(splitter_g.cv, GroupKFold) + + +def test_get_cv_splitter_errors(): + cfg_g = CVConfig(strategy="group_kfold") + + # Missing groups + with pytest.raises(ValueError, match="requires groups"): + get_cv_splitter(cfg_g, groups=None) + + # Stratified on regression + cfg_s = CVConfig(strategy="stratified") + with pytest.raises(ValueError, match="not supported for regression"): + get_cv_splitter(cfg_s, task="regression") + + # Unknown strategy (Pydantic catches this at config level) + with pytest.raises(ValidationError): + CVConfig(strategy="invalid") + + # Bypassing Pydantic to hit the ValueError in the factory + from types import SimpleNamespace + + fake_cfg = SimpleNamespace( + strategy="invalid", n_splits=5, shuffle=True, random_state=42 + ) + with pytest.raises(ValueError, match="Unknown CV strategy"): + get_cv_splitter(fake_cfg) + + +def test_get_cv_splitter_all_strategies(): + strategies = [ + ("stratified_group_kfold", 5), + ("leave_p_out", 2), + ("leave_one_group_out", 1), + ("group_shuffle_split", 5), + ("timeseries", 5), + ("split", 1), + ] + groups = np.repeat(np.arange(10), 10) + for strat, n in strategies: + cfg = CVConfig(strategy=strat, n_splits=n) + splitter = get_cv_splitter(cfg, groups=groups, task="classification") + assert splitter is not None + + +def test_get_cv_splitter_unshuffled(): + # Timeseries or KFold without shuffle + cfg = CVConfig(strategy="kfold", shuffle=False) + splitter = get_cv_splitter(cfg) + assert splitter.random_state is None + + cfg_t = CVConfig(strategy="timeseries") + splitter_t = get_cv_splitter(cfg_t) + assert splitter_t.n_splits == 5 + + +def test_simple_split_shuffle_false(): + """Verify shuffle=False ignores random_state.""" + X = np.zeros((10, 2)) + y = np.zeros(10) + ss = SimpleSplit(shuffle=False, random_state=42) + train, test = next(ss.split(X, y)) + assert np.all(test == [8, 9]) # Deterministic split from the end + + +def test_cv_with_groups_pandas_alignment(): + """Verify group binding aligns to pandas index.""" + cv = KFold(n_splits=2) + df = pd.DataFrame({"a": [1, 2]}, index=[2, 3]) + gp_ser = pd.Series([10, 20, 30, 40], index=[0, 1, 2, 3]) + wrapper_pd = _CVWithGroups(cv, gp_ser) + aligned = wrapper_pd._get_effective_groups(df) + assert np.all(aligned.values == [30, 40]) + + # Matching explicit groups + explicit = np.array([30, 40]) + assert np.all(wrapper_pd._get_effective_groups(df, groups=explicit) == explicit) + + +def test_get_cv_splitter_all_strategies_extended(): + """Verify factory handles all supported strategies.""" + # 1. Timeseries already tested in test_get_cv_splitter_unshuffled + + # 2. Leave P Out + cfg = CVConfig(strategy="leave_p_out", n_splits=2) + s = get_cv_splitter(cfg, groups=[1, 2, 3]) + assert s.cv.n_groups == 2 + + # 3. Group Shuffle Split + cfg = CVConfig(strategy="group_shuffle_split", n_splits=5, test_size=0.1) + s = get_cv_splitter(cfg, groups=[1, 2, 3, 4, 5]) + assert s.cv.n_splits == 5 + + # 4. Stratified Group KFold + cfg = CVConfig(strategy="stratified_group_kfold", n_splits=3) + s = get_cv_splitter(cfg, groups=[1, 2, 3]) + assert isinstance(s.cv, StratifiedGroupKFold) + + # 5. Leave One Group Out + cfg = CVConfig(strategy="leave_one_group_out") + s = get_cv_splitter(cfg, groups=[1, 2, 3]) + assert isinstance(s.cv, LeaveOneGroupOut) + + # 6. Split (SimpleSplit) + cfg = CVConfig(strategy="split", test_size=0.3) + s = get_cv_splitter(cfg) + assert isinstance(s, SimpleSplit) + assert s.test_size == 0.3 + + # 7. Regression + Stratified -> ValueError + cfg = CVConfig(strategy="stratified") + with pytest.raises(ValueError, match="not supported for regression"): + get_cv_splitter(cfg, task="regression") + + # 8. require_groups=False + cfg = CVConfig(strategy="group_kfold") + s = get_cv_splitter(cfg, groups=None, require_groups=False) + assert s.n_splits == 5 + + # 9. Warning for groups in non-group strategy (covers 322-323) + cfg = CVConfig(strategy="kfold") + with pytest.warns(UserWarning, match="not group-aware"): + get_cv_splitter(cfg, groups=[1, 2, 3]) + + +def test_cv_with_groups_more_coverage(): + """Hit the remaining edge cases in _CVWithGroups.""" + cv = KFold(n_splits=2) + groups = np.array([1, 1, 2, 2]) + wrapper = _CVWithGroups(cv, groups) + + # X is None (covers 75 and 104) + assert np.all(wrapper._get_effective_groups(None) == groups) + assert wrapper.get_n_splits(X=None, groups=[1, 1, 2, 2]) == 2 + + # Matching explicit groups (covers return on line 72 or similar) + # Actually, let's re-verify line numbers. + # In my local view, line 68 was the ValueError. + + # Explicit groups that match length (covers 72) + explicit = np.array([3, 3, 4, 4]) + assert np.all( + wrapper._get_effective_groups(np.zeros(4), groups=explicit) == explicit + ) + + # Mismatched explicit groups (covers 68-71) + with pytest.raises(ValueError, match="Explicit groups length does not match X"): + wrapper._get_effective_groups(np.zeros(4), groups=[1, 2]) diff --git a/tests/test_decoding_stats.py b/tests/test_decoding_stats.py new file mode 100644 index 0000000..9aa1fb1 --- /dev/null +++ b/tests/test_decoding_stats.py @@ -0,0 +1,318 @@ +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_classification + +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding.configs import ( + ChanceAssessmentConfig, + CVConfig, + DummyClassifierConfig, + LogisticRegressionConfig, + StatisticalAssessmentConfig, +) +from coco_pipe.decoding.stats import ( + _accuracy_ci, + _coord_dict, + _correct_p_values, + _empirical_p_values, + _run_binomial_assessment, + aggregate_predictions_for_inference, + apply_multiple_comparison_correction, + assess_paired_comparison, + assess_post_hoc_permutation, + binomial_accuracy_test, + run_paired_permutation_assessment, +) + +# --- Unit Tests for Core Stats Functions --- + + +def test_aggregate_predictions_regression_mean(): + df = pd.DataFrame( + { + "Subject": ["S1", "S1", "S2"], + "y_true": [10.0, 10.0, 20.0], + "y_pred": [12.0, 8.0, 22.0], + "SampleID": [0, 1, 2], + } + ) + res = aggregate_predictions_for_inference( + df, + metric="mse", + task="regression", + unit_of_inference="Subject", + custom_aggregation="mean", + ) + assert len(res) == 2 + assert res[res["InferentialUnitID"] == "S1"]["y_pred"].iloc[0] == 10.0 + + +def test_aggregate_predictions_custom_unit(): + df = pd.DataFrame( + { + "Session": ["A", "A", "B"], + "y_true": [1, 1, 0], + "y_pred": [1, 0, 0], + "y_proba_0": [0.2, 0.8, 0.9], + "y_proba_1": [0.8, 0.2, 0.1], + "SampleID": [0, 1, 2], + } + ) + res = aggregate_predictions_for_inference( + df, + metric="accuracy", + task="classification", + unit_of_inference="custom", + custom_unit_column="Session", + custom_aggregation="mean", + ) + assert len(res) == 2 + assert "InferentialUnitID" in res.columns + assert res[res["InferentialUnitID"] == "A"]["y_pred"].iloc[0] == 0 + + +def test_aggregate_predictions_empty(): + df = pd.DataFrame() + res = aggregate_predictions_for_inference(df, metric="accuracy") + assert res.empty + + +def test_binomial_accuracy_test_clopper(): + y_true = [1, 1, 1, 0] + y_pred = [1, 1, 0, 0] # 3/4 correct + res = binomial_accuracy_test(y_true, y_pred, p0=0.5, ci_method="clopper_pearson") + assert res["observed"] == 0.75 + assert "ci_lower" in res + assert "ci_upper" in res + + +def test_binomial_accuracy_test_errors(): + with pytest.raises(ValueError, match="requires an explicit p0"): + binomial_accuracy_test([1], [1], p0=None) + with pytest.raises(ValueError, match="zero predictions"): + binomial_accuracy_test([], [], p0=0.5) + + +def test_empirical_p_values_two_sided(): + observed = np.array([0.8, 0.2]) + null = np.array([[0.5, 0.5], [0.6, 0.4], [0.7, 0.3]]) + p_vals = _empirical_p_values(observed, null, greater_is_better=True, two_sided=True) + assert len(p_vals) == 2 + assert np.all(p_vals <= 1.0) + + +def test_correct_p_values_all_methods(): + observed = np.array([0.9, 0.1]) + null = np.array([[0.5, 0.5], [0.8, 0.8]]) + p_vals_raw = np.array([0.1, 0.1]) + + # Bonferroni + p_bonf = _correct_p_values( + observed, null, p_vals_raw, method="bonferroni", greater_is_better=True + ) + assert p_bonf[0] == 0.2 + + # Max-Stat + p_max = _correct_p_values( + observed, null, p_vals_raw, method="max_stat", greater_is_better=True + ) + assert p_max[0] == pytest.approx(1 / 3) + + # FDR + p_fdr = _correct_p_values( + observed, null, p_vals_raw, method="fdr_bh", greater_is_better=True + ) + assert len(p_fdr) == 2 + + +def test_accuracy_ci_boundary(): + # k=0 + low, high = _accuracy_ci(np.array([0]), np.array([10]), 0.05, "clopper_pearson") + assert low[0] == 0.0 + # k=n + low, high = _accuracy_ci(np.array([10]), np.array([10]), 0.05, "clopper_pearson") + assert high[0] == 1.0 + + +def test_coord_dict_temporal(): + res = _coord_dict((10.0, 20.0), ["TrainTime", "TestTime"]) + assert res["TrainTime"] == 10.0 + assert res["TestTime"] == 20.0 + assert res["Time"] is None + + +def test_apply_multiple_comparison_correction_short(): + df = pd.DataFrame({"PValue": [0.01]}) + res = apply_multiple_comparison_correction(df) + assert "Significant" in res.columns + assert res["Significant"].iloc[0] + + +# --- Integration Tests for Statistical Assessment --- + + +def test_run_binomial_assessment_temporal(): + predictions = pd.DataFrame( + { + "SampleID": [0, 1, 0, 1], + "Time": [0.0, 0.0, 1.0, 1.0], + "y_true": [1, 0, 1, 0], + "y_pred": [1, 0, 0, 1], # Time 0: 100%, Time 1: 0% + } + ) + # Use proper ChanceAssessmentConfig + chance_cfg = ChanceAssessmentConfig(p0=0.5) + config = StatisticalAssessmentConfig(chance=chance_cfg) + rows = _run_binomial_assessment( + model="LR", + metric="accuracy", + predictions=predictions, + task="classification", + config=config, + unit="sample", + ) + assert len(rows) == 2 # One per timepoint + assert rows[0]["Time"] == 0.0 + assert rows[1]["Time"] == 1.0 + assert rows[0]["Observed"] == 1.0 + assert rows[1]["Observed"] == 0.0 + + +def test_assess_paired_comparison_temporal(): + df = pd.DataFrame( + { + "SampleID": [0, 1, 0, 1], + "Time": [0.0, 0.0, 1.0, 1.0], + "y_true": [1, 0, 1, 0], + "y_pred_A": [1, 0, 1, 0], + "y_pred_B": [0, 1, 0, 1], + } + ) + res = assess_paired_comparison( + df, metric="accuracy", unit="sample", n_permutations=10 + ) + assert len(res) == 2 # One per timepoint + assert "Difference" in res.columns + assert res.iloc[0]["Difference"] == 1.0 + + +def test_correct_p_values_less_is_better(): + observed = np.array([0.1, 0.2]) # less-is-better + null = np.array([[0.5, 0.5], [0.1, 0.1]]) + p_vals_raw = np.array([0.1, 0.1]) + + # Max-Stat with greater_is_better=False + p_max = _correct_p_values( + observed, null, p_vals_raw, method="max_stat", greater_is_better=False + ) + assert p_max[0] == pytest.approx(2 / 3) + + +def test_assess_post_hoc_permutation_with_groups(): + res = { + "predictions": [ + { + "y_true": np.array([0, 0, 1, 1]), + "y_pred": np.array([0, 1, 1, 0]), + "group": np.array([0, 0, 1, 1]), + "sample_index": np.array([0, 1, 2, 3]), + } + ] + } + df = assess_post_hoc_permutation( + res, metric="accuracy", unit="group", n_permutations=10 + ) + assert "Observed" in df.columns + assert df["Observed"].iloc[0] == 0.5 + + +def test_run_paired_permutation_assessment_full(): + from unittest.mock import MagicMock + + res_a = MagicMock() + res_b = MagicMock() + + df_a = pd.DataFrame( + { + "Model": ["m"], + "Fold": [0], + "SampleID": [0], + "Group": [0], + "y_true": [1], + "y_pred": [1], + "InferentialUnitID": [0], + } + ) + res_a.get_predictions.return_value = df_a + res_b.get_predictions.return_value = df_a.copy() + + config = StatisticalAssessmentConfig( + chance=ChanceAssessmentConfig(n_permutations=10), unit_of_inference="sample" + ) + res = run_paired_permutation_assessment(res_a, res_b, "m", "accuracy", config) + assert not res.empty + assert res.iloc[0]["Observed"] == 0.0 + + +def test_aggregate_predictions_error_paths(): + df = pd.DataFrame({"y_true": [0, 1], "y_pred": [0, 0]}) + + # Missing unit column + with pytest.raises(ValueError, match="Inference unit 'Subject' not found"): + aggregate_predictions_for_inference(df, "accuracy", unit_of_inference="Subject") + + # Majority aggregation for regression + df_reg = pd.DataFrame( + {"y_true": [1.0, 2.0], "y_pred": [1.1, 1.9], "Subject": ["S1", "S1"]} + ) + with pytest.raises( + ValueError, match="majority aggregation is only valid for classification" + ): + aggregate_predictions_for_inference( + df_reg, + "neg_mean_squared_error", + task="regression", + unit_of_inference="Subject", + custom_aggregation="majority", + ) + + # Mean aggregation without probas + df_class = pd.DataFrame( + {"y_true": [0, 1], "y_pred": [0, 0], "Subject": ["S1", "S1"]} + ) + with pytest.raises( + ValueError, match="mean aggregation requires probability columns" + ): + aggregate_predictions_for_inference( + df_class, + "accuracy", + task="classification", + unit_of_inference="Subject", + custom_aggregation="mean", + ) + + +def test_binomial_test_hardening(): + # k_alpha adjustment path (line 236) + res = binomial_accuracy_test([0] * 100, [0] * 100, p0=0.5, alpha=0.05) + assert "chance_threshold" in res + + +def test_run_paired_permutation_assessment_hardening(): + # Create two results to compare + X, y = make_classification(n_samples=20, random_state=42) + config = ExperimentConfig( + task="classification", + models={ + "lr": LogisticRegressionConfig(), + "dummy": DummyClassifierConfig(), + }, + cv=CVConfig(n_splits=2), + ) + res = Experiment(config).run(X, y) + + # Note: run_paired_permutation_assessment expects ExperimentResult objects + # But it is often called via result.compare_models + paired = res.compare_models(n_permutations=5, random_state=42) + assert not paired.empty diff --git a/tests/test_ml_base_pipeline.py b/tests/test_ml_base_pipeline.py deleted file mode 100644 index 5ac3af7..0000000 --- a/tests/test_ml_base_pipeline.py +++ /dev/null @@ -1,724 +0,0 @@ -from copy import deepcopy - -import numpy as np -import pandas as pd -import pytest -from sklearn.dummy import DummyClassifier, DummyRegressor -from sklearn.ensemble import RandomForestClassifier -from sklearn.linear_model import LinearRegression, LogisticRegression -from sklearn.metrics import accuracy_score, mean_squared_error - -from coco_pipe.ml.base import BasePipeline -from coco_pipe.ml.config import CLASSIFICATION_METRICS, REGRESSION_METRICS - - -class DummyPipeline(BasePipeline): - """Concrete subclass for testing BasePipeline.""" - - pass - - -def test_validate_input_errors(): - with pytest.raises(ValueError): - DummyPipeline( - X="not array", - y=np.zeros(3), - metric_funcs=CLASSIFICATION_METRICS, - model_configs={}, - default_metrics=["accuracy"], - ) - with pytest.raises(ValueError): - DummyPipeline( - X=np.zeros((3, 2)), - y=np.zeros(4), - metric_funcs=CLASSIFICATION_METRICS, - model_configs={}, - default_metrics=["accuracy"], - ) - with pytest.raises(ValueError): - DummyPipeline( - X=np.zeros((3, 2)), - y=np.zeros(3), - metric_funcs=CLASSIFICATION_METRICS, - model_configs={}, - default_metrics=["accuracy"], - groups=np.zeros(10), - ) - - -def test_select_columns(): - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - selected = DummyPipeline._select_columns(df, [True, False, True]) - expected = pd.DataFrame({"a": [1, 2, 3], "c": [7, 8, 9]}) - pd.testing.assert_frame_equal(selected, expected) - - # Test with numpy array - arr = np.array([[1, 2], [3, 4], [5, 6]]) - selected_arr = DummyPipeline._select_columns(arr, [0]) - expected_arr = np.array([[1], [3], [5]]) - np.testing.assert_array_equal(selected_arr, expected_arr) - - -def test_validate_metrics_error(): - X = np.zeros((5, 2)) - y = np.zeros(5) - with pytest.raises(ValueError): - DummyPipeline( - X, - y, - metric_funcs={"acc": lambda a, b: 0}, - model_configs={}, - default_metrics=["bad"], - ) - - -def test_feature_names_and_importances(): - X = np.random.randn(20, 3) - y = np.concatenate([np.zeros(10), np.ones(10)]) - np.random.shuffle(y) - - lr = LogisticRegression(solver="saga").fit(X, y) - imp = DummyPipeline._extract_feature_importances(lr) - assert isinstance(imp, np.ndarray) and imp.shape == (3,) - - dc = DummyClassifier().fit(X, y) - assert DummyPipeline._extract_feature_importances(dc) is None - - df = pd.DataFrame(X, columns=["a", "b", "c"]) - assert DummyPipeline._get_feature_names(df) == ["a", "b", "c"] - arr_names = DummyPipeline._get_feature_names(X) - assert arr_names == ["feature_0", "feature_1", "feature_2"] - - -def test_cross_val_and_baseline_evaluation_classification(): - X = np.arange(40).reshape(-1, 2) - y = np.array([0] * 10 + [1] * 10) - model_configs = { - "dummy": {"estimator": DummyClassifier(strategy="most_frequent"), "params": {}} - } - - def acc(y_t, y_p): - return float(np.mean(y_t == y_p)) - - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": acc}, - model_configs=model_configs, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "stratified", - "n_splits": 5, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - - # cross_val - cv_res = pipe.cross_val(DummyClassifier(strategy="most_frequent"), X, y) - for key in [ - "cv_fold_predictions", - "cv_fold_scores", - "cv_fold_estimators", - "cv_fold_importances", - ]: - assert key in cv_res - assert np.all(cv_res["cv_fold_scores"]["accuracy"] == pytest.approx(0.5)) - - # baseline_evaluation - eval_res = pipe.baseline_evaluation("dummy") - assert eval_res["model_name"] == "dummy" - # compare fold scores - assert np.all( - eval_res["metric_scores"]["accuracy"]["fold_scores"] - == cv_res["cv_fold_scores"]["accuracy"] - ) - # predictions key - preds = eval_res["predictions"] - assert set(preds.keys()) >= {"y_true", "y_pred", "fold_preds"} - - -def test_cross_val_requires_groups_for_group_kfold(): - X = np.zeros((6, 1)) - y = np.zeros(6) - pipe = DummyPipeline( - X, - y, - metric_funcs=CLASSIFICATION_METRICS, - model_configs={"dummy": {"estimator": DummyClassifier(), "params": {}}}, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "group_kfold", - "n_splits": 2, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - with pytest.raises(ValueError): - pipe.cross_val(DummyClassifier(), X, y) - - -def test_baseline_evaluation_errors_and_params(): - X = np.vstack([np.random.randn(5, 2), np.random.randn(5, 2) + 2]) - y = np.array([0] * 5 + [1] * 5) - pipe = DummyPipeline( - X, - y, - metric_funcs=CLASSIFICATION_METRICS, - model_configs={"clf": {"estimator": LogisticRegression(), "params": {"C": 1}}}, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "stratified", - "n_splits": 2, - "shuffle": True, - "random_state": 0, - }, - ) - with pytest.raises(KeyError): - pipe.baseline_evaluation("bad_model") - - out = pipe.baseline_evaluation("clf") - # params comes back empty dict (no default_params in config) - assert out["params"] == {} - - -def test_baseline_evaluation_regression(): - X = np.arange(30).reshape(-1, 3) - y = np.arange(10) - model_configs = { - "dummy": {"estimator": DummyRegressor(), "params": {"constant": 0.2}} - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"mse": mean_squared_error}, - model_configs=model_configs, - default_metrics=["mse"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 3, - "shuffle": True, - "random_state": 0, - }, - ) - out = pipe.baseline_evaluation("dummy") - assert np.all(out["metric_scores"]["mse"]["fold_scores"] >= 0) - - -def test_feature_selection_regression(): - X = np.arange(30).reshape(-1, 3) - y = X[:, 0] * 2.0 + 1.0 - model_configs = {"lr": {"estimator": LinearRegression(), "params": {}}} - pipe = DummyPipeline( - X, - y, - metric_funcs={"mse": mean_squared_error}, - model_configs=model_configs, - default_metrics=["mse"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 3, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - out = pipe.feature_selection("lr", n_features=1, direction="forward", scoring="mse") - assert isinstance(out["selected_per_fold"], dict) - # feature_importances keys - fi = out["feature_importances"] - assert "feature_0" in fi - keys = set(fi["feature_0"].keys()) - assert {"mean", "std", "weighted_mean", "weighted_std", "fold_importances"} <= keys - # weighted mean consistency - freq = out["feature_frequency"]["feature_0"] - assert fi["feature_0"]["weighted_mean"] == pytest.approx( - fi["feature_0"]["mean"] * freq - ) - # fold count - spp = out["selected_per_fold"] - assert len(spp) == pipe.cv_kwargs["n_splits"] - best = out["best_fold"] - # best fold has metric 'mse' - assert "mse" in best and "features" in best and "fold" in best - - -def test_feature_selection_backward(): - X = np.arange(20).reshape(-1, 2) - y = X[:, 0] * 3.0 - 2.0 - model_configs = {"lr": {"estimator": LinearRegression(), "params": {}}} - pipe = DummyPipeline( - X, - y, - metric_funcs={"mse": mean_squared_error}, - model_configs=model_configs, - default_metrics=["mse"], - cv_kwargs={"cv_strategy": "kfold", "n_splits": 4, "shuffle": False}, - n_jobs=1, - ) - out_bw = pipe.feature_selection( - "lr", n_features=1, direction="backward", scoring="mse" - ) - # backward, single feature remains or both if equal score - assert isinstance(out_bw["selected_features"], set) - assert out_bw["selected_features"] - - -def test_feature_selection_missing_model_error(): - X = np.random.randn(10, 4) - y = np.random.randn(10) - with pytest.raises(KeyError): - DummyPipeline( - X, - y, - metric_funcs=REGRESSION_METRICS, - model_configs={"a": {"estimator": DummyRegressor(), "params": {}}}, - default_metrics=["mse"], - cv_kwargs={"cv_strategy": "kfold", "n_splits": 2}, - ).feature_selection("b") - - -def _make_pipeline(): - X = np.vstack([np.zeros((5, 2)), np.ones((5, 2))]) - y = np.array([0] * 5 + [1] * 5) - model_configs = { - "dummy": { - "estimator": LogisticRegression(solver="saga"), - "params": {"C": [0.1, 1.0]}, - } - } - return ( - DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 2, - "shuffle": True, - "random_state": 42, - }, - n_jobs=1, - ), - X, - y, - ) - - -def test_build_search_estimator_grid(): - pipe, X, y = _make_pipeline() - grid_est, metric = pipe._build_search_estimator( - "dummy", "grid", None, n_iter=10, scoring="accuracy" - ) - from sklearn.model_selection import GridSearchCV - - assert isinstance(grid_est, GridSearchCV) - assert metric == "accuracy" - assert grid_est.param_grid == pipe.model_configs["dummy"].param_grid - assert grid_est.cv.get_n_splits(X, y) == pipe.cv_kwargs["n_splits"] - - -def test_build_search_estimator_random(): - pipe, X, y = _make_pipeline() - rand_est, metric = pipe._build_search_estimator( - "dummy", "random", None, n_iter=5, scoring="accuracy" - ) - from sklearn.model_selection import RandomizedSearchCV - - assert isinstance(rand_est, RandomizedSearchCV) - assert metric == "accuracy" - assert rand_est.n_iter == 5 - - -def test_extract_hp_search_params(): - pipe, X, y = _make_pipeline() - - class FakeEst: - def __init__(self, best_params): - self.best_params_ = best_params - - cv_estimators = [FakeEst({"C": 0.1}), FakeEst({"C": 1.0})] - bp_per_fold, bp, freq = pipe._extract_hp_search_params(cv_estimators) - assert isinstance(bp_per_fold, dict) and len(bp_per_fold) == 2 - assert isinstance(bp, dict) and "C" in bp - assert isinstance(freq, dict) and "C" in freq - - -def test_hp_search_grid(): - pipe, X, y = _make_pipeline() - res = pipe.hp_search("dummy", search_type="grid") - for key in [ - "model_name", - "hp search parameters", - "best_params", - "param_frequency", - "best_fold", - "predictions", - "metric_scores", - "folds_estimators", - ]: - # 'outer_results' is renamed to 'best_params_per_fold' - alt = "best_params_per_fold" if key == "outer_results" else key - assert alt in res - assert res["hp search parameters"]["search type"] == "grid" - assert res["best_params"]["C"] in pipe.model_configs["dummy"].param_grid["C"] - assert pytest.approx(sum(res["param_frequency"]["C"].values()), rel=1e-6) == 1.0 - assert len(res["best_params_per_fold"]) == pipe.cv_kwargs["n_splits"] - assert 0 <= res["best_fold"]["fold"] < pipe.cv_kwargs["n_splits"] - - -def test_hp_search_random_and_invalid_grid(): - pipe, X, y = _make_pipeline() - # invalid grid update should fail - with pytest.raises(ValueError): - pipe.update_model_params( - "dummy", - {"C": 0.5}, - update_estimator=False, - update_config=True, - param_type="hp_search", - ) - # valid random search after correcting grid - pipe.model_configs["dummy"].param_grid = {"C": [0.1, 1.0], "max_iter": [50, 100]} - res = pipe.hp_search("dummy", search_type="random", n_iter=4) - assert res["hp search parameters"]["search type"] == "random" - assert set(res["param_frequency"].keys()) == set( - pipe.model_configs["dummy"].param_grid.keys() - ) - - -def test_build_combined_fs_hp_pipeline(): - X = np.random.randn(10, 3) - y = np.random.randint(0, 2, 10) - model_configs = { - "rf": { - "estimator": RandomForestClassifier(random_state=0), - "params": {"n_estimators": [5, 10]}, - } - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 2, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - search_est, feat_names, metric = pipe._build_combined_fs_hp_pipeline( - "rf", "grid", None, 2, "forward", 3, "accuracy" - ) - from sklearn.model_selection import GridSearchCV - - assert isinstance(search_est, GridSearchCV) - assert isinstance(feat_names, np.ndarray) and feat_names.shape == (3,) - assert metric == "accuracy" - - -def test_hp_search_fs_end_to_end(): - X = np.vstack([np.random.randn(10, 2) - 2, np.random.randn(10, 2) + 2]) - y = np.array([0] * 10 + [1] * 10) - idx = np.random.RandomState(42).permutation(len(y)) - X, y = X[idx], y[idx] - - model_configs = { - "lr": { - "estimator": LogisticRegression(solver="saga"), - "params": {"C": [0.1, 1.0, 10.0], "penalty": ["l2", "l1"]}, - } - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 2, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - res = pipe.hp_search_fs( - "lr", - search_type="grid", - n_features=1, - direction="forward", - n_iter=1, - scoring="accuracy", - ) - expected = { - "model_name", - "metric_scores", - "selected_features", - "feature_frequency", - "feature_importances", - "best_params", - "selected_per_fold", - "best_params_per_fold", - "best_fold", - "folds_estimators", - "hp search and fs parameters", - } - assert expected.issubset(res.keys()) - assert len(res["selected_features"]) == 1 - - -def test_execute_method(): - X = np.vstack([np.random.randn(10, 3) - 2, np.random.randn(10, 3) + 2]) - y = np.array([0] * 10 + [1] * 10) - - model_configs = { - "lr": { - "estimator": LogisticRegression(solver="saga"), - "default_params": {"C": 0.1, "penalty": "l2"}, - "hp_search_params": {"C": [0.1, 1.0], "penalty": ["l2", "l1"]}, - } - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - cv_kwargs={ - "cv_strategy": "kfold", - "n_splits": 2, - "shuffle": True, - "random_state": 0, - }, - n_jobs=1, - ) - - # baseline - r1 = pipe.execute(type="baseline", model_name="lr") - assert "metric_scores" in r1 and "accuracy" in r1["metric_scores"] - # feature_selection - r2 = pipe.execute( - type="feature_selection", model_name="lr", n_features=2, direction="forward" - ) - assert "selected_features" in r2 and "feature_frequency" in r2 - # hp_search - r3 = pipe.execute(type="hp_search", model_name="lr", search_type="grid") - assert "best_params" in r3 and "param_frequency" in r3 - # hp_search_fs - r4 = pipe.execute( - type="hp_search_fs", model_name="lr", search_type="grid", n_features=2 - ) - assert "selected_features" in r4 and "best_params" in r4 - - with pytest.raises(ValueError): - pipe.execute(type="invalid", model_name="lr") - with pytest.raises(TypeError): - pipe.execute(type="baseline") - with pytest.raises(KeyError): - pipe.execute(type="baseline", model_name="nope") - - -def test_get_model_params_single_model(): - X = np.random.randn(20, 4) - y = np.random.randint(0, 2, 20) - model_configs = { - "lr": { - "estimator": LogisticRegression(), - "default_params": {"C": 1.0, "random_state": 42}, - "hp_search_params": {"C": [0.1, 1.0, 10.0], "penalty": ["l1", "l2"]}, - } - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - n_jobs=1, - ) - params = pipe.get_model_params("lr") - assert params["estimator_type"] == "LogisticRegression" - assert params["init_params"] == {"C": 1.0, "random_state": 42} - assert params["param_grid"] == {"C": [0.1, 1.0, 10.0], "penalty": ["l1", "l2"]} - - -def test_get_model_params_all_models_and_update(): - X = np.random.randn(20, 4) - y = np.random.randint(0, 2, 20) - model_configs = { - "lr": { - "estimator": LogisticRegression(), - "default_params": {"random_state": 42}, - "hp_search_params": {"C": [0.1, 1.0, 10.0]}, - } - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - n_jobs=1, - ) - - pipe.update_model_params("lr", {"C": 5.0}, param_type="default") - p1 = pipe.get_model_params("lr") - assert p1["init_params"]["C"] == 5.0 - - pipe.update_model_params("lr", {"penalty": ["l1", "l2"]}, param_type="hp_search") - p2 = pipe.get_model_params("lr") - assert p2["param_grid"]["penalty"] == ["l1", "l2"] - - -def test_get_model_params_nonexistent_model(): - X = np.random.randn(20, 4) - y = np.random.randint(0, 2, 20) - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs={"lr": {"estimator": LogisticRegression(), "params": {}}}, - default_metrics=["accuracy"], - n_jobs=1, - ) - with pytest.raises(KeyError): - pipe.get_model_params("nonexistent") - - -def test_update_model_params_invalid_grid(): - X = np.random.randn(10, 3) - y = np.random.randint(0, 2, 10) - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs={"d": {"estimator": LogisticRegression(), "params": {}}}, - default_metrics=["accuracy"], - n_jobs=1, - ) - with pytest.raises(ValueError): - pipe.update_model_params( - "d", - {"C": 1.0}, - update_estimator=False, - update_config=True, - param_type="hp_search", - ) - - -def test_reset_and_list_models(): - X = np.random.randn(10, 4) - y = np.random.randint(0, 2, 10) - model_configs = { - "lr": { - "estimator": LogisticRegression(C=1.0, solver="saga"), - "params": {"C": [1, 2]}, - }, - "rf": {"estimator": RandomForestClassifier(n_estimators=5), "params": {}}, - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"accuracy": accuracy_score}, - model_configs=model_configs, - default_metrics=["accuracy"], - n_jobs=1, - ) - - # list_models - names = pipe.list_models() - assert set(names.keys()) == {"lr", "rf"} - assert names["lr"] == "LogisticRegression" - - # reset_model_params error - with pytest.raises(KeyError): - pipe.reset_model_params("nope") - - # update & reset roundtrip - orig = deepcopy(pipe.get_model_params("lr")) - pipe.update_model_params("lr", {"C": 3.0}, param_type="default") - assert pipe.get_model_params("lr")["init_params"]["C"] == 3.0 - pipe.reset_model_params("lr") - after = pipe.get_model_params("lr") - assert after["init_params"] == orig["init_params"] - assert after["param_grid"] == orig["param_grid"] - - -def test_cross_val_scaler_injection(): - import numpy as np - from sklearn.dummy import DummyClassifier - from sklearn.pipeline import Pipeline - from sklearn.preprocessing import StandardScaler - - # simple binary classification data - X = np.array([[0.0], [1.0], [2.0], [3.0]]) - y = np.array([0, 0, 1, 1]) - - def acc(y_true, y_pred): - return float((y_true == y_pred).mean()) - - # pipeline with use_scaler=True - model_configs = { - "d": {"estimator": DummyClassifier(strategy="most_frequent"), "params": {}} - } - pipe = DummyPipeline( - X, - y, - metric_funcs={"acc": acc}, - model_configs=model_configs, - use_scaler=True, - default_metrics=["acc"], - cv_kwargs={"cv_strategy": "kfold", "n_splits": 2, "shuffle": False}, - n_jobs=1, - ) - - cv_res = pipe.cross_val(DummyClassifier(), X, y) - # each fold estimator should be a Pipeline with a StandardScaler step - for est in cv_res["cv_fold_estimators"]: - assert isinstance(est, Pipeline) - assert "scaler" in est.named_steps - assert isinstance(est.named_steps["scaler"], StandardScaler) - - -def test_baseline_evaluation_scaler(): - import numpy as np - from sklearn.dummy import DummyClassifier - from sklearn.pipeline import Pipeline - from sklearn.preprocessing import StandardScaler - - # same data, test baseline_evaluation path - X = np.array([[0.0], [1.0], [2.0], [3.0]]) - y = np.array([0, 0, 1, 1]) - - def acc(y_true, y_pred): - return float((y_true == y_pred).mean()) - - model_configs = { - "d": {"estimator": DummyClassifier(strategy="most_frequent"), "params": {}} - } - - pipe = DummyPipeline( - X, - y, - metric_funcs={"acc": acc}, - model_configs=model_configs, - use_scaler=True, - default_metrics=["acc"], - cv_kwargs={"cv_strategy": "kfold", "n_splits": 2, "shuffle": False}, - n_jobs=1, - ) - - res = pipe.baseline_evaluation("d") - # and the returned fold estimators must also include the scaler - for est in res["folds_estimators"]: - assert isinstance(est, Pipeline) - assert "scaler" in est.named_steps - assert isinstance(est.named_steps["scaler"], StandardScaler) diff --git a/tests/test_ml_classification.py b/tests/test_ml_classification.py deleted file mode 100644 index cc24679..0000000 --- a/tests/test_ml_classification.py +++ /dev/null @@ -1,327 +0,0 @@ -import numpy as np -import pytest -from sklearn.datasets import make_classification, make_multilabel_classification -from sklearn.metrics import average_precision_score, roc_auc_score - -from coco_pipe.ml.classification import ( - BinaryClassificationPipeline, - ClassificationPipeline, - MultiClassClassificationPipeline, - MultiOutputClassificationPipeline, -) -from coco_pipe.ml.config import ( - BINARY_MODELS, - DEFAULT_CV, - MULTICLASS_MODELS, - MULTIOUTPUT_CLASS_MODELS, -) - -# ───────────────────────────────────────────────────────────────────────────── -# small datasets -# ───────────────────────────────────────────────────────────────────────────── -X_binary = np.vstack([np.zeros((5, 2)), np.ones((5, 2))]) -y_binary = np.array([0] * 5 + [1] * 5) - -X_multi = np.arange(30).reshape(10, 3) -y_multi = np.array([0, 1, 2, 1, 0, 2, 1, 0, 2, 1]) - -X_multiout, y_multiout = make_multilabel_classification( - n_samples=20, - n_features=4, - n_classes=3, - n_labels=2, - random_state=0, - allow_unlabeled=False, -) - - -@pytest.fixture(autouse=True) -def tmp_working_dir(tmp_path, monkeypatch): - """Use a clean working dir to avoid polluting the repo.""" - monkeypatch.chdir(tmp_path) - yield - - -# ───────────────────────────────────────────────────────────────────────────── -# ClassificationPipeline wrapper -# ───────────────────────────────────────────────────────────────────────────── -@pytest.mark.parametrize( - "analysis_type, model_list, metrics, expected_task", - [ - ("baseline", [list(BINARY_MODELS.keys())[0]], ["accuracy"], "binary"), - ("baseline", [list(MULTICLASS_MODELS.keys())[0]], ["accuracy"], "multiclass"), - ( - "baseline", - [list(MULTIOUTPUT_CLASS_MODELS.keys())[0]], - ["subset_accuracy"], - "multioutput", - ), - ], -) -def test_pipeline_detect_and_run_baseline( - analysis_type, model_list, metrics, expected_task, monkeypatch -): - if expected_task == "binary": - X, y = X_binary, y_binary - elif expected_task == "multiclass": - X, y = X_multi, y_multi - else: - X, y = X_multiout, y_multiout - - # capture save() calls - saved = [] - - def fake_save(self, name, res): - saved.append(name) - - monkeypatch.setattr(ClassificationPipeline, "save", fake_save) - - pipe = ClassificationPipeline( - X=X, - y=y, - analysis_type=analysis_type, - models=model_list, - metrics=metrics, - random_state=0, - cv_strategy="kfold", - n_jobs=1, - save_intermediate=True, - results_file="testres", - ) - results = pipe.run() - - # top‐level shape - assert isinstance(results, dict) - assert set(results.keys()) == set(model_list) - - # the wrapper picked the correct sub‐pipeline - cls_name = type(pipe.pipeline).__name__.lower() - assert expected_task in cls_name - - # each result now has 'predictions' + 'metric_scores' - for res in results.values(): - assert "predictions" in res - assert "metric_scores" in res - - # save() called once per model + final - assert len(saved) == len(model_list) + 1 - assert any(n.startswith("testres") for n in saved) - - -def test_invalid_analysis_type_raises(): - with pytest.raises(ValueError): - ClassificationPipeline(X=X_binary, y=y_binary, analysis_type="foo").run() - - -# ───────────────────────────────────────────────────────────────────────────── -# BinaryClassificationPipeline -# ───────────────────────────────────────────────────────────────────────────── -def test_binary_aggregate_metrics_correctness(): - fold_preds = [ - { - "y_true": np.array([0, 1, 0]), - "y_pred": np.array([0, 1, 1]), - "y_proba": np.array([[0.8, 0.2], [0.3, 0.7], [0.4, 0.6]]), - }, - { - "y_true": np.array([1, 0]), - "y_pred": np.array([1, 0]), - "y_proba": np.array([[0.2, 0.8], [0.6, 0.4]]), - }, - ] - # compute per-fold scores - aucs = [roc_auc_score(fp["y_true"], fp["y_proba"][:, 1]) for fp in fold_preds] - aps = [ - average_precision_score(fp["y_true"], fp["y_proba"][:, 1]) for fp in fold_preds - ] - fold_scores = {"roc_auc": np.array(aucs), "average_precision": np.array(aps)} - # no importances in this dummy test - fold_imps = {} - - pipe = BinaryClassificationPipeline( - X=np.zeros((5, 2)), - y=np.array([0, 1, 0, 1, 0]), - models="all", - metrics=["roc_auc", "average_precision"], - cv_kwargs={**DEFAULT_CV, "n_splits": 2, "cv_strategy": "kfold"}, - n_jobs=1, - ) - predictions, metrics, feature_importances = pipe._aggregate( - fold_preds, fold_scores, fold_imps - ) - - expected_auc = np.mean(aucs) - expected_ap = np.mean(aps) - assert "roc_auc" in metrics - assert pytest.approx(expected_auc, rel=1e-6) == metrics["roc_auc"]["mean"] - assert pytest.approx(expected_ap, rel=1e-6) == metrics["average_precision"]["mean"] - - # predictions concatenated - assert predictions["y_true"].shape[0] == 5 - assert "y_proba" in predictions - - -def test_baseline_all_models_run_binary(): - X, y = make_classification( - n_samples=100, n_features=5, n_informative=2, n_classes=2, random_state=0 - ) - metrics = ["accuracy", "roc_auc", "average_precision"] - for name in BINARY_MODELS: - pipe = BinaryClassificationPipeline( - X=X, - y=y, - models=[name], - metrics=metrics, - random_state=0, - n_jobs=1, - cv_kwargs={**DEFAULT_CV, "n_splits": 3, "shuffle": True}, - ) - out = pipe.baseline_evaluation(name) - for key in ( - "predictions", - "metric_scores", - "feature_importances", - "model_name", - "params", - "folds_estimators", - ): - assert key in out - # check y_true length and score bounds - assert out["predictions"]["y_true"].shape[0] == len(y) - for m in metrics: - mv = out["metric_scores"][m]["mean"] - assert 0.0 <= mv <= 1.0 - - -# ───────────────────────────────────────────────────────────────────────────── -# MultiClassClassificationPipeline -# ───────────────────────────────────────────────────────────────────────────── -def test_multiclass_aggregate_and_per_class(): - fold_preds = [ - { - "y_true": np.array([0, 1, 2]), - "y_pred": np.array([0, 2, 1]), - "y_proba": np.array([[0.7, 0.2, 0.1], [0.1, 0.1, 0.8], [0.2, 0.6, 0.2]]), - }, - { - "y_true": np.array([1, 2]), - "y_pred": np.array([1, 2]), - "y_proba": np.array([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05]]), - }, - ] - # compute fold scores - accs = [(fp["y_true"] == fp["y_pred"]).mean() for fp in fold_preds] - # aucs = [roc_auc_score(fp["y_true"], fp["y_proba"], multi_class='ovo') for fp - # in fold_preds] - fold_scores = { - "accuracy": np.array(accs), - # "roc_auc": np.array(aucs) - } - pipe = MultiClassClassificationPipeline( - X=np.zeros((5, 4)), - y=np.array([0, 1, 2, 1, 0]), - models="all", - metrics=["accuracy"], # ,"roc_auc"], - per_class=True, - cv_kwargs={**DEFAULT_CV, "n_splits": 2}, - n_jobs=1, - ) - predictions, metrics, feature_importances = pipe._aggregate( - fold_preds, fold_scores, {} - ) - - assert pytest.approx(np.mean(accs), rel=1e-6) == metrics["accuracy"]["mean"] - # assert "roc_auc" in agg["metrics"] - pcm = metrics["per_class_metrics"] - assert set(pcm.keys()) == set(pipe.classes_) - for stats in pcm.values(): - assert all(k in stats for k in ("precision", "recall", "f1")) - - -def test_baseline_all_multiclass_models(): - X, y = make_classification( - n_samples=150, n_features=6, n_informative=3, n_classes=3, random_state=1 - ) - metrics = ["accuracy", "roc_auc"] - for name in MULTICLASS_MODELS: - pipe = MultiClassClassificationPipeline( - X=X, - y=y, - models=[name], - metrics=metrics, - per_class=False, - random_state=0, - n_jobs=1, - cv_kwargs={**DEFAULT_CV, "n_splits": 4}, - ) - out = pipe.baseline_evaluation(name) - assert "predictions" in out - assert "metric_scores" in out - assert "y_true" in out["predictions"] - assert out["predictions"]["y_true"].shape # [0] == len(y) - for m in metrics: - assert 0.0 <= out["metric_scores"][m]["mean"] <= 1.0 - - -# ───────────────────────────────────────────────────────────────────────────── -# MultiOutputClassificationPipeline -# ───────────────────────────────────────────────────────────────────────────── -def test_multioutput_aggregate_per_output(): - fold_preds = [ - { - "y_true": np.array([[1, 0], [0, 1], [1, 1]]), - "y_pred": np.array([[1, 0], [1, 0], [1, 1]]), - }, - {"y_true": np.array([[0, 1], [1, 0]]), "y_pred": np.array([[0, 1], [0, 0]])}, - ] - # subset_accuracy only uses y_true and y_pred - # so fold_scores can be dummy - subset_scores = [ - (fp["y_true"] == fp["y_pred"]).all(axis=1).mean() for fp in fold_preds - ] - fold_scores = {"subset_accuracy": np.array(subset_scores)} - pipe = MultiOutputClassificationPipeline( - X=np.zeros((5, 2)), - y=np.zeros((5, 2)), - models=list(MULTIOUTPUT_CLASS_MODELS.keys()), - metrics=["subset_accuracy"], - cv_kwargs={**DEFAULT_CV, "n_splits": 2, "cv_strategy": "kfold"}, - n_jobs=1, - ) - predictions, metrics, feature_importances = pipe._aggregate( - fold_preds, fold_scores, {} - ) - - for m in pipe.metrics: - assert m in metrics - # pom = metrics["per_output_metrics"] - # assert set(pom.keys()) == {0,1} - # for stats in pom.values(): - # assert all(k in stats for k in ("precision","recall","f1")) - - -def test_baseline_all_models_multioutput(): - X, y = make_multilabel_classification( - n_samples=80, - n_features=5, - n_classes=3, - n_labels=2, - random_state=0, - allow_unlabeled=False, - ) - metrics = ["subset_accuracy", "hamming_loss"] - for name in MULTIOUTPUT_CLASS_MODELS: - pipe = MultiOutputClassificationPipeline( - X=X, - y=y, - models=[name], - metrics=metrics, - random_state=42, - n_jobs=1, - cv_kwargs={**DEFAULT_CV, "n_splits": 3, "cv_strategy": "kfold"}, - ) - out = pipe.baseline_evaluation(name) - assert "predictions" in out - # assert "y_true" in out["predictions"] - # for m in metrics: - # assert 0.0 <= out["metric_scores"][m]["mean"] <= 1.0 diff --git a/tests/test_ml_pipeline.py b/tests/test_ml_pipeline.py deleted file mode 100644 index 6c04cec..0000000 --- a/tests/test_ml_pipeline.py +++ /dev/null @@ -1,204 +0,0 @@ -import numpy as np -import pytest - -from coco_pipe.ml.classification import ClassificationPipeline -from coco_pipe.ml.pipeline import MLPipeline -from coco_pipe.ml.regression import RegressionPipeline - - -# Fixtures for dummy data -def make_dummy_multi_output(n_samples=10, n_targets=3): - X = np.arange(n_samples * 2).reshape(n_samples, 2) - # create multi-output y with simple relationships - y = np.vstack([X[:, 0] * (i + 1) for i in range(n_targets)]).T - return X, y - - -@pytest.fixture -def multi_output_data(): - return make_dummy_multi_output(n_samples=8, n_targets=3) - - -@pytest.fixture -def single_output_data(): - X = np.arange(20).reshape(10, 2) - # single-output y as 1D array - y = X[:, 0] * 2.0 + 1.0 - return X, y - - -# Test invalid task now raises when run() is invoked -def test_invalid_task_raises(single_output_data): - X, y = single_output_data - cfg = {"task": "unknown"} - with pytest.raises(ValueError, match="Invalid task"): - MLPipeline(X, y, None, cfg) - - -# Test invalid mode now raises when run() is invoked -def test_invalid_mode_raises(single_output_data): - X, y = single_output_data - cfg = {"task": "regression", "mode": "invalid"} - with pytest.raises(ValueError, match="Invalid mode"): - MLPipeline(X, y, None, cfg) - - -# Test multivariate mode calls pipeline once returning dict -def test_multivariate_mode(single_output_data, monkeypatch): - X, y = single_output_data - cfg = {"task": "regression", "mode": "multivariate"} - - def fake_run(self): - return {"y_shape": self.y.shape} - - monkeypatch.setattr(RegressionPipeline, "run", fake_run) - - mlp = MLPipeline(X, y, None, cfg) - out = mlp.run() - assert isinstance(out, dict) - assert out["y_shape"] == y.shape - - -# Test univariate mode runs per feature for multi-output data -def test_univariate_mode_runs_per_feature(monkeypatch, multi_output_data): - X, y = multi_output_data - cfg = {"task": "regression", "mode": "univariate"} - - def fake_run(self): - return {"y_shape": self.y.shape} - - monkeypatch.setattr(RegressionPipeline, "run", fake_run) - - mlp = MLPipeline(X, y, None, cfg) - out = mlp.run() - assert isinstance(out, dict) - expected_keys = set(range(X.shape[1])) - assert set(out.keys()) == expected_keys - for res in out.values(): - assert res["y_shape"] == y.shape - - -# Univariate mode with single-output still loops over features -def test_univariate_mode_single_output(monkeypatch, single_output_data): - X, y = single_output_data - cfg = {"task": "regression", "mode": "univariate"} - - def fake_run(self): - return {"y_shape": self.y.shape} - - monkeypatch.setattr(RegressionPipeline, "run", fake_run) - - mlp = MLPipeline(X, y, None, cfg) - out = mlp.run() - assert isinstance(out, dict) - expected_keys = set(range(X.shape[1])) - assert set(out.keys()) == expected_keys - for res in out.values(): - assert res["y_shape"] == y.shape - - -# Feature-selection not allowed in univariate -def test_univariate_feature_selection_error(multi_output_data): - X, y = multi_output_data - cfg = { - "task": "regression", - "mode": "univariate", - "analysis_type": "feature_selection", - } - mlp = MLPipeline(X, y, None, cfg) - with pytest.raises( - ValueError, match="Cannot perform feature_selection in univariate mode" - ): - mlp.run() - - -# Classification analogs -def test_classification_modes(monkeypatch): - X = np.random.rand(6, 3) - y = np.array( - [ - [0, 1, 0], - [1, 0, 1], - [0, 1, 1], - [1, 1, 0], - [0, 0, 1], - [1, 0, 0], - ] - ) - - def fake_run(self): - return {"y_shape": self.y.shape} - - monkeypatch.setattr(ClassificationPipeline, "run", fake_run) - - # Multivariate - cfg_mv = {"task": "classification", "mode": "multivariate"} - mlp_mv = MLPipeline(X, y, None, cfg_mv) - out_mv = mlp_mv.run() - assert out_mv["y_shape"] == y.shape - - # Univariate - cfg_uv = {"task": "classification", "mode": "univariate"} - mlp_uv = MLPipeline(X, y, None, cfg_uv) - out_uv = mlp_uv.run() - expected_keys = set(range(X.shape[1])) - assert set(out_uv.keys()) == expected_keys - for res in out_uv.values(): - assert res["y_shape"] == y.shape - - -# Classification FS+HP search not allowed in univariate -def test_classification_univariate_fs_error(): - X, y = np.zeros((5, 2)), np.zeros((5, 2)) - cfg = { - "task": "classification", - "mode": "univariate", - "analysis_type": "hp_search_fs", - } - mlp = MLPipeline(X, y, None, cfg) - with pytest.raises( - ValueError, match="Cannot perform hp_search_fs in univariate mode" - ): - mlp.run() - - -# def test_update_model_config(monkeypatch): -# import numpy as np -# from sklearn.linear_model import LogisticRegression - -# # create simple binary data -# X = np.random.rand(12, 4) -# y = np.random.choice([0, 1], size=(12,)) - -# # custom model_configs: LR with default C=2.5 and a small grid -# custom_models = { -# 'Logistic Regression': { -# 'default_params': {'C': 2.5}, -# 'params': {'C': [0.5, 2.5]} -# } -# } -# cfg = { -# 'task': 'classification', -# 'mode': 'multivariate', -# 'model_configs': custom_models -# } - -# # stub out the actual run to just return get_model_params for 'lr' -# def fake_run(self): -# return 0 -# monkeypatch.setattr(ClassificationPipeline, 'run', fake_run) - -# mlp = MLPipeline(X, y, None, cfg) -# out = mlp.run() - -# from coco_pipe.ml.base import ModelConfig - -# # ensure the pipeline saw our custom settings via ModelConfig -# mc = mlp.pipeline.pipeline.model_configs.get('Logistic Regression') -# assert isinstance(mc, ModelConfig) -# # estimator should be a LogisticRegression instance with C=2.5 -# from sklearn.linear_model import LogisticRegression -# assert isinstance(mc.estimator, LogisticRegression) -# # assert mc.init_params['C'] == 2.5 -# # hyperparameter grid lives in param_grid -# assert mc.param_grid['C'] == [0.5, 2.5] diff --git a/tests/test_ml_regression.py b/tests/test_ml_regression.py deleted file mode 100644 index 23360e6..0000000 --- a/tests/test_ml_regression.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np -import pytest -from sklearn.datasets import make_regression - -from coco_pipe.ml.config import DEFAULT_CV, MULTIOUTPUT_REG_MODELS, REGRESSION_MODELS -from coco_pipe.ml.regression import ( - MultiOutputRegressionPipeline, - RegressionPipeline, - SingleOutputRegressionPipeline, -) - -# ───────────────────────────────────────────────────────────────────────────── -# Helper small datasets -# ───────────────────────────────────────────────────────────────────────────── -X_single = np.arange(40).reshape(20, 2) # enough samples for 2-fold CV -y_single = X_single[:, 0] * 2.0 + 1.0 - -X_multi, y_multi = make_regression( - n_samples=50, # enough samples for 2-fold CV - n_features=4, - n_targets=3, - noise=0.1, - random_state=0, -) - - -@pytest.fixture(autouse=True) -def tmp_working_dir(tmp_path, monkeypatch): - """Use a clean working dir to avoid polluting the repo.""" - monkeypatch.chdir(tmp_path) - yield - - -# ───────────────────────────────────────────────────────────────────────────── -# RegressionPipeline wrapper -# ───────────────────────────────────────────────────────────────────────────── -@pytest.mark.parametrize( - "analysis_type, model_list, metrics, expected_task", - [ - ("baseline", ["Linear Regression"], ["r2"], "singleoutput"), - ("baseline", ["Random Forest"], ["r2"], "singleoutput"), - ("baseline", ["Linear Regression"], ["mean_r2", "neg_mean_mse"], "multioutput"), - ], -) -def test_pipeline_detect_and_run_baseline( - analysis_type, model_list, metrics, expected_task, monkeypatch -): - if expected_task == "singleoutput": - X, y = X_single, y_single - else: - X, y = X_multi, y_multi - - saved = [] - monkeypatch.setattr( - RegressionPipeline, "save", lambda self, name, res: saved.append(name) - ) - - pipe = RegressionPipeline( - X=X, - y=y, - analysis_type=analysis_type, - models=model_list, - metrics=metrics, - random_state=0, - cv_strategy="kfold", - n_splits=2, - n_jobs=1, - save_intermediate=True, - results_file="testres", - ) - results = pipe.run() - - # basic structure - assert isinstance(results, dict) - assert set(results.keys()) == set(model_list) - cls_name = type(pipe.pipeline).__name__.lower() - assert expected_task in cls_name - - # each result has predictions and metric_scores - for out in results.values(): - assert "predictions" in out - assert "metric_scores" in out - - # save called once per model + final - assert len(saved) == len(model_list) + 1 - assert any(n.startswith("testres") for n in saved) - - -def test_regression_pipeline_invalid_type(): - with pytest.raises(ValueError): - RegressionPipeline(X=X_single, y=y_single, analysis_type="foo").run() - - -# ───────────────────────────────────────────────────────────────────────────── -# SingleOutputRegressionPipeline -# ───────────────────────────────────────────────────────────────────────────── -def test_single_output_metrics_correctness(): - # two folds with simple preds - fold_preds = [ - {"y_true": np.array([1.0, 2.0]), "y_pred": np.array([1.0, 2.0])}, - {"y_true": np.array([3.0, 4.0]), "y_pred": np.array([2.5, 4.5])}, - ] - # define fold-level scores manually - fold_scores = { - "r2": np.array([1.0, 0.5]), # perfect then half explained - "mse": np.array([0.0, 0.25]), # zero then (0.5^2 + 0.5^2)/2 = 0.25 - } - fold_importances = {} - - pipe = SingleOutputRegressionPipeline( - X=np.zeros((4, 2)), - y=np.zeros(4), - models="all", - metrics=["r2", "mse"], - cv_kwargs={**DEFAULT_CV, "n_splits": 2, "cv_strategy": "kfold"}, - n_jobs=1, - ) - # BasePipeline._aggregate returns tuple - predictions, metrics, feature_importances = pipe._aggregate( - fold_preds, fold_scores, fold_importances - ) - - # check concatenated predictions - assert np.array_equal( - predictions["y_true"], np.concatenate([fp["y_true"] for fp in fold_preds]) - ) - assert "y_pred" in predictions - - # metric means - assert pytest.approx(np.mean(fold_scores["r2"])) == metrics["r2"]["mean"] - assert pytest.approx(np.mean(fold_scores["mse"])) == metrics["mse"]["mean"] - - # feature_importances empty - assert feature_importances == {} - - -def test_baseline_all_models_run_single(): - X, y = make_regression( - n_samples=100, n_features=10, n_informative=3, noise=0.1, random_state=0 - ) - metrics = ["r2", "mse", "mae"] - seen = 0 - for name in REGRESSION_MODELS: - pipe = SingleOutputRegressionPipeline( - X=X, - y=y, - models=[name], - metrics=metrics, - random_state=0, - n_jobs=1, - cv_kwargs={**DEFAULT_CV, "n_splits": 5, "cv_strategy": "kfold"}, - ) - out = pipe.baseline_evaluation(name) - # essential keys - for key in ("model_name", "params", "predictions", "metric_scores"): - assert key in out - # prediction length matches - assert out["predictions"]["y_true"].shape[0] == y.shape[0] - # each metric present - for m in metrics: - assert m in out["metric_scores"] - seen += 1 - assert seen == len(REGRESSION_MODELS) - - -def test_target_validation_error_multioutput(): - X = np.zeros((10, 5)) - y = np.zeros(10) - with pytest.raises(ValueError, match="Target must be 2D array"): - MultiOutputRegressionPipeline(X=X, y=y) - - -def test_baseline_all_models_multioutput(): - X, y = make_regression( - n_samples=100, n_features=6, n_targets=3, noise=0.1, random_state=0 - ) - metrics = ["mean_r2", "neg_mean_mse", "neg_mean_mae"] - seen = 0 - for name in MULTIOUTPUT_REG_MODELS: - pipe = MultiOutputRegressionPipeline( - X=X, - y=y, - models=[name], - metrics=metrics, - random_state=42, - n_jobs=1, - cv_kwargs={**DEFAULT_CV, "n_splits": 4, "cv_strategy": "kfold"}, - ) - out = pipe.baseline_evaluation(name) - # essential keys - for key in ("model_name", "params", "predictions", "metric_scores"): - assert key in out - # shape preserved - assert out["predictions"]["y_true"].shape == y.shape - # each metric present - for m in metrics: - assert m in out["metric_scores"] - seen += 1 - assert seen == len(MULTIOUTPUT_REG_MODELS) diff --git a/tests/test_ml_utils.py b/tests/test_ml_utils.py deleted file mode 100644 index 15fe866..0000000 --- a/tests/test_ml_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -import numpy as np -import pytest -from sklearn.model_selection import ( - GroupKFold, - KFold, - LeaveOneGroupOut, - LeavePGroupsOut, - StratifiedKFold, -) - -from coco_pipe.ml.config import DEFAULT_CV -from coco_pipe.ml.utils import SimpleSplit, get_cv_splitter - - -def test_stratified_no_shuffle_behavior(): - # shuffle=False should force random_state=None - splitter = get_cv_splitter( - "stratified", n_splits=3, shuffle=False, random_state=123 - ) - assert isinstance(splitter, StratifiedKFold) - assert splitter.n_splits == 3 - assert splitter.shuffle is False - assert splitter.random_state is None - - -def test_stratified_with_shuffle_preserves_seed(): - splitter = get_cv_splitter("stratified", n_splits=4, shuffle=True, random_state=99) - assert isinstance(splitter, StratifiedKFold) - assert splitter.n_splits == 4 - assert splitter.shuffle is True - assert splitter.random_state == 99 - - -def test_kfold_defaults_and_override(): - # default KFold - kf = get_cv_splitter("kfold") - assert isinstance(kf, KFold) - assert kf.n_splits == DEFAULT_CV["n_splits"] - assert kf.shuffle == DEFAULT_CV["shuffle"] - # seed only when shuffle=True - if DEFAULT_CV["shuffle"]: - assert kf.random_state == DEFAULT_CV["random_state"] - else: - assert kf.random_state is None - - # override shuffle=False - kf2 = get_cv_splitter("kfold", n_splits=5, shuffle=False, random_state=42) - assert isinstance(kf2, KFold) - assert kf2.n_splits == 5 - assert kf2.shuffle is False - assert kf2.random_state is None - - -def test_group_kfold_n_splits(): - gk = get_cv_splitter("group_kfold", n_splits=7) - assert isinstance(gk, GroupKFold) - assert gk.n_splits == 7 - - -def test_leave_p_out_requires_n_groups_and_behavior(): - with pytest.raises(ValueError): - get_cv_splitter("leave_p_out") - lp = get_cv_splitter("leave_p_out", n_groups=2) - assert isinstance(lp, LeavePGroupsOut) - # scikit-learn leaves n_groups stored privately, but repr shows the parameter - assert "n_groups=2" in repr(lp) - - -def test_leave_one_out_behavior(): - loo = get_cv_splitter("leave_one_out") - assert isinstance(loo, LeaveOneGroupOut) - - -def test_simple_split_default_vs_custom(): - # default test_size=0.2 - sp = get_cv_splitter("split") - assert isinstance(sp, SimpleSplit) - idx = np.arange(50) - train_idx, test_idx = next(sp.split(idx, y=None, groups=None)) - assert len(test_idx) == int(0.2 * 50) - assert len(train_idx) == 50 - len(test_idx) - - # custom test_size - sp2 = get_cv_splitter("split", test_size=0.3, shuffle=False, random_state=123) - assert isinstance(sp2, SimpleSplit) - idx2 = np.arange(100) - # no stratify by default - train2, test2 = next(sp2.split(idx2)) - assert len(test2) == 30 - assert len(train2) == 70 - - -def test_unknown_strategy_raises(): - with pytest.raises(ValueError): - get_cv_splitter("this_does_not_exist") diff --git a/tests/test_report_core.py b/tests/test_report_core.py index 4f0d553..41da504 100644 --- a/tests/test_report_core.py +++ b/tests/test_report_core.py @@ -6,6 +6,7 @@ import pytest from coco_pipe.report.core import ( + ContainerElement, HtmlElement, ImageElement, InteractiveTableElement, @@ -18,6 +19,7 @@ _metrics_summary_table, _trajectory_times, ) +from coco_pipe.report.quality import CheckResult @pytest.fixture @@ -330,3 +332,65 @@ def test_fluent_interface_structure(): rep = Report("Fluency") rep.add_element("Start").add_section(Section("Middle")).add_markdown("End") assert len(rep.children) == 3 + + +def test_report_elements_hardening(tmp_path): + # ImageElement + img_data = b"fake-image-data" + elem = ImageElement(img_data) + assert "data:image/png;base64" in elem.render() + + p = tmp_path / "test.png" + p.write_bytes(img_data) + elem_p = ImageElement(p) + assert "data:image/png;base64" in elem_p.render() + + with pytest.raises(ValueError, match="Unsupported image source type"): + ImageElement(123)._encode_image() + + # PlotlyElement binary decoding + class MockFig: + def to_json(self): + return '{"data": [{"y": {"dtype": "f8", "bdata": "AAAAAAAAAAA="}}]}' + + def to_dict(self): + return {"data": []} + + elem_plotly = PlotlyElement(MockFig()) + registry = {} + elem_plotly.collect_payload(registry) + assert elem_plotly.registry_id in registry + + # TableElement normalization + assert isinstance(TableElement._to_frame({"a": 1, "b": 2}), pd.DataFrame) + + # MetricsTableElement directions + df = pd.DataFrame({"m": ["a", "b"], "score": [0.8, 0.9], "error": [0.1, 0.05]}) + elem_metrics = MetricsTableElement( + df, higher_is_better=["score"], highlight_cols=["score", "error"] + ) + assert elem_metrics.best_vals["score"] == 0.9 + + elem_metrics_low = MetricsTableElement(df, higher_is_better=False) + assert elem_metrics_low.best_vals["score"] == 0.8 + + # ContainerElement markdown fallback + cont = ContainerElement() + cont.add_markdown("# Title") + assert "Title" in cont.render() + + # Section status upgrades + sec = Section("Test") + sec.add_finding(CheckResult("c1", "WARN", "w", 4)) + assert sec.status == "WARN" + sec.add_finding(CheckResult("c2", "FAIL", "f", 9)) + assert sec.status == "FAIL" + sec.add_finding(CheckResult("c3", "WARN", "w2", 4)) + assert sec.status == "FAIL" + + # Report config coercion + rep = Report(title="T", config={"some_param": 1}) + assert rep.title == "T" + # In Pydantic 2, extra fields are allowed but might be on the object + # if extra='allow' + assert getattr(rep.config, "some_param") == 1 diff --git a/tests/test_report_decoding.py b/tests/test_report_decoding.py new file mode 100644 index 0000000..22de6d4 --- /dev/null +++ b/tests/test_report_decoding.py @@ -0,0 +1,37 @@ +from sklearn.datasets import make_classification + +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding.configs import CVConfig, LogisticRegressionConfig +from coco_pipe.report.core import Report + + +def _diagnostic_result(): + X, y = make_classification( + n_samples=40, + n_features=5, + n_informative=3, + n_redundant=0, + random_state=7, + ) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig(max_iter=200, kind="classical")}, + metrics=["accuracy", "roc_auc", "brier_score"], + cv=CVConfig(strategy="stratified", n_splits=2, shuffle=True, random_state=7), + n_jobs=1, + verbose=False, + ) + return Experiment(config).run(X, y) + + +def test_decoding_diagnostics_report_section_renders(): + result = _diagnostic_result() + report = Report("Diagnostics") + + report.add_decoding_diagnostics(result) + html = report.render() + + assert "Decoding Diagnostics" in html + assert "Inference Context" in html + assert "Fold Scores" in html + assert "Fit Diagnostics" in html diff --git a/tests/test_report_provenance.py b/tests/test_report_provenance.py index 3a9a7ec..f2e9360 100644 --- a/tests/test_report_provenance.py +++ b/tests/test_report_provenance.py @@ -33,35 +33,31 @@ def test_get_package_version(): assert v_fake == "Unknown" -def test_experiment_provenance_metadata_integration(tmp_path): +def test_experiment_provenance_metadata_integration(): """Verify that decoded results contain the dynamic version.""" - import joblib + import numpy as np + import pandas as pd from coco_pipe.decoding.configs import ( + ClassicalModelConfig, CVConfig, ExperimentConfig, - LogisticRegressionConfig, ) - from coco_pipe.decoding.core import Experiment + from coco_pipe.decoding.experiment import Experiment config = ExperimentConfig( task="classification", - models={"lr": LogisticRegressionConfig(method="LogisticRegression")}, + models={"lr": ClassicalModelConfig(estimator="LogisticRegression")}, metrics=["accuracy"], cv=CVConfig(strategy="kfold", n_splits=2), - output_dir=str(tmp_path), tag="test_meta", ) exp = Experiment(config) - exp.results = {"dummy": "data"} + exp._observation_level = "epoch" + exp._inferential_unit = "sample" + exp._sample_metadata = pd.DataFrame() - save_path = exp.save_results() - assert save_path.exists() - - payload = joblib.load(save_path) - assert "meta" in payload - assert "coco_pipe_version" in payload["meta"] - - version = payload["meta"]["coco_pipe_version"] - assert isinstance(version, str) + meta = exp._build_result_meta(np.zeros((10, 5)), None) + assert "coco_pipe_version" in meta + assert isinstance(meta["coco_pipe_version"], str) diff --git a/tests/test_viz_decoding.py b/tests/test_viz_decoding.py new file mode 100644 index 0000000..184a0a4 --- /dev/null +++ b/tests/test_viz_decoding.py @@ -0,0 +1,148 @@ +import matplotlib.pyplot as plt +import pandas as pd +from sklearn.datasets import make_classification + +from coco_pipe.decoding import Experiment, ExperimentConfig +from coco_pipe.decoding.configs import CVConfig, LogisticRegressionConfig +from coco_pipe.viz.decoding import ( + plot_calibration_curve, + plot_confusion_matrix, + plot_fold_score_dispersion, + plot_pr_curve, + plot_roc_curve, + plot_statistical_null_distribution, + plot_temporal_generalization_matrix, + plot_temporal_score_curve, + plot_temporal_statistical_assessment, + plot_training_history, +) + + +def _diagnostic_result(): + X, y = make_classification( + n_samples=40, + n_features=5, + n_informative=3, + n_redundant=0, + random_state=7, + ) + config = ExperimentConfig( + task="classification", + models={"lr": LogisticRegressionConfig(max_iter=200, kind="classical")}, + metrics=["accuracy", "roc_auc", "brier_score"], + cv=CVConfig(strategy="stratified", n_splits=2, shuffle=True, random_state=7), + n_jobs=1, + verbose=False, + ) + return Experiment(config).run(X, y) + + +def test_diagnostic_plots_return_matplotlib_figures(): + result = _diagnostic_result() + + figures = [ + plot_confusion_matrix(result), + plot_roc_curve(result), + plot_pr_curve(result), + plot_calibration_curve(result), + plot_fold_score_dispersion(result), + ] + + for fig in figures: + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_viz_temporal_plots(): + # Mock data for temporal plots + summary_data = { + "Model": ["LR", "LR"], + "Metric": ["accuracy", "accuracy"], + "Time": [0.0, 0.1], + "Mean": [0.6, 0.7], + "Std": [0.05, 0.05], + "Significant": [True, False], + } + summary_df = pd.DataFrame(summary_data) + + # plot_temporal_score_curve + fig = plot_temporal_score_curve(summary_df) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + # Non-numeric time + summary_non_numeric = summary_df.copy() + summary_non_numeric["Time"] = ["T1", "T2"] + fig = plot_temporal_score_curve(summary_non_numeric) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + # plot_temporal_generalization_matrix + tg_data = { + "Model": ["LR"] * 4, + "Metric": ["accuracy"] * 4, + "TrainTime": [0, 0, 1, 1], + "TestTime": [0, 1, 0, 1], + "Mean": [0.5, 0.6, 0.7, 0.8], + } + tg_df = pd.DataFrame(tg_data) + fig = plot_temporal_generalization_matrix(tg_df) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + # plot_temporal_statistical_assessment + stats_data = { + "Model": ["LR", "LR"], + "Metric": ["accuracy", "accuracy"], + "Time": [0.0, 0.1], + "Observed": [0.6, 0.7], + "NullLower": [0.4, 0.4], + "NullUpper": [0.6, 0.6], + "Significant": [True, False], + } + stats_df = pd.DataFrame(stats_data) + fig = plot_temporal_statistical_assessment(stats_df) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + # plot_statistical_null_distribution + fig = plot_statistical_null_distribution(stats_df) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_viz_curves_mean_only(): + # Mock ROC/PR curve data with multiple folds + curve_data = { + "Model": ["LR"] * 4, + "Fold": [0, 0, 1, 1], + "FPR": [0, 1, 0, 1], + "TPR": [0, 1, 0, 1], + "Recall": [0, 1, 0, 1], + "Precision": [1, 0, 1, 0], + "Class": [0, 0, 0, 0], + } + curve_df = pd.DataFrame(curve_data) + + fig = plot_roc_curve(curve_df, mean_only=True) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + fig = plot_pr_curve(curve_df, mean_only=True) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_plot_training_history_hardening(): + # Mock history artifacts + history = [ + {"epoch": 1, "loss": 0.5, "val_loss": 0.6}, + {"epoch": 2, "loss": 0.4, "val_loss": 0.5}, + ] + artifacts_df = pd.DataFrame( + [{"Model": "NN", "Key": "history", "ArtifactType": "history", "Value": history}] + ) + + fig = plot_training_history(artifacts_df) + assert isinstance(fig, plt.Figure) + plt.close(fig)