Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions coco_pipe/decoding/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def fit_and_score_fold(
feature_names: Optional[list[str]] = None,
search_enabled: bool = False,
force_serial: bool = False,
sample_weight: Optional[np.ndarray] = None,
) -> Dict[str, Any]:
"""
Execute a single Cross-Validation fold: Fit, Predict, and Score.
Expand Down Expand Up @@ -135,6 +136,9 @@ def fit_and_score_fold(
Original names of the features, used for importance labeling.
force_serial : bool, default=False
If True, forces the internal estimator fit to be serial.
sample_weight : np.ndarray, optional
Per-sample weights for the full dataset. Only the training-fold
slice is forwarded to the classifier; test samples are never weighted.

Returns
-------
Expand All @@ -145,6 +149,7 @@ def fit_and_score_fold(
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

sw_train = sample_weight[train_idx] if sample_weight is not None else None
groups_train = groups[train_idx] if groups is not None else None
test_groups = groups[test_idx] if groups is not None else None
_needs_group_routing = (
Expand Down Expand Up @@ -173,6 +178,7 @@ def fit_and_score_fold(
feature_selection_config=feature_selection_config,
calibration_config=calibration_config,
tuning_config=tuning_config,
sample_weight=sw_train,
)
fit_time = time.perf_counter() - fit_start
captured_warnings.extend(warning_records_to_dict("fit", warning_records))
Expand Down Expand Up @@ -309,6 +315,7 @@ def fit_estimator(
feature_selection_config: Any,
calibration_config: Any,
tuning_config: Any = None,
sample_weight: Optional[np.ndarray] = None,
) -> None:
"""
Fit an estimator with intelligent metadata and group routing.
Expand All @@ -335,6 +342,10 @@ def fit_estimator(
Probability calibration settings.
tuning_config : Any
Hyperparameter tuning settings.
sample_weight : np.ndarray, optional
Per-sample weights for the training fold. Forwarded as
``clf__sample_weight`` when the pipeline's ``clf`` step exposes
a ``sample_weight`` parameter in its ``fit`` signature.
"""
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
Expand Down Expand Up @@ -375,6 +386,12 @@ def fit_estimator(
scaler_step = pipeline.named_steps["scaler"]
if "groups" in inspect.signature(scaler_step.fit).parameters:
fit_params["scaler__groups"] = groups_train

if sample_weight is not None and isinstance(pipeline, Pipeline) and "clf" in pipeline.named_steps:
clf_step = pipeline.named_steps["clf"]
if "sample_weight" in inspect.signature(clf_step.fit).parameters:
fit_params["clf__sample_weight"] = sample_weight

estimator.fit(X_train, y_train, **fit_params)


Expand Down
10 changes: 10 additions & 0 deletions coco_pipe/decoding/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def run(
observation_level: str = "sample",
inferential_unit: Optional[str] = None,
time_axis: Optional[Sequence[Any]] = None,
sample_weight: Optional[np.ndarray] = None,
) -> ExperimentResult:
"""
Execute the complete decoding experiment pipeline.
Expand Down Expand Up @@ -462,6 +463,12 @@ def run(
raise ValueError("X is empty.")
if len(y) != len(X):
raise ValueError("Length mismatch between X and y.")
if sample_weight is not None:
sample_weight = np.asarray(sample_weight, dtype=float)
if len(sample_weight) != len(X):
raise ValueError(
f"sample_weight length {len(sample_weight)} != X length {len(X)}."
)

# 1. Scientific Guard: Double-Normalization Warning
if self.config.use_scaler and X.ndim == 2:
Expand Down Expand Up @@ -569,6 +576,7 @@ def run(
n_jobs=model_n_jobs,
spec=spec,
model_name=name,
sample_weight=sample_weight,
)
except Exception as e:
logger.error(f"Failed model '{name}': {e}", exc_info=True)
Expand Down Expand Up @@ -619,6 +627,7 @@ def _cross_validate(
n_jobs: int = 1,
spec: Optional[Any] = None,
model_name: Optional[str] = None,
sample_weight: Optional[np.ndarray] = None,
) -> Dict[str, Any]:
"""Perform parallel cross-validation for a single estimator."""
cv = get_cv_splitter(self.config.cv, groups=groups, y=y)
Expand Down Expand Up @@ -659,6 +668,7 @@ def _cross_validate(
and model_name in self.config.grids
),
force_serial=(n_jobs == 1),
sample_weight=sample_weight,
)
for train_idx, test_idx in splits
)
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,8 @@ quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

[dependency-groups]
dev = [
"pytest>=9.0.3",
]
139 changes: 139 additions & 0 deletions tests/test_decoding_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,142 @@ def test_compact_search_results_missing_keys():
est = SimpleNamespace(cv_results_={"params": [{"C": 1}]})
res = compact_search_results(est)
assert res == [{"candidate": 0, "params": {"C": 1}}]


# --- sample_weight tests ---


class _WeightCapturingClassifier(BaseEstimator, ClassifierMixin):
"""Records the sample_weight passed to fit(); predict always returns zeros."""

_estimator_type = "classifier"
classes_ = np.array([0, 1])

def fit(self, X, y, sample_weight=None):
self.recorded_weight_ = sample_weight
self.classes_ = np.array([0, 1])
return self

def predict(self, X):
return np.zeros(len(X), dtype=int)

def predict_proba(self, X):
return np.column_stack([np.ones(len(X)), np.zeros(len(X))])


def _make_spec(**overrides):
base = dict(
supports_proba=False,
supports_decision_function=False,
importance=("unavailable",),
supports_groups=False,
grouped_metadata="none",
is_sparse_capable=False,
family="linear",
)
base.update(overrides)
return SimpleNamespace(**base)


def test_fit_estimator_routes_sample_weight():
"""fit_estimator forwards sample_weight to clf step inside a Pipeline."""
from sklearn.linear_model import LogisticRegression

clf = _WeightCapturingClassifier()
pipe = Pipeline([("clf", clf)])
sw = np.array([1.0, 2.0, 3.0])

fit_estimator(pipe, np.zeros((3, 2)), np.array([0, 1, 0]), None,
MockConfig(), MockConfig(), sample_weight=sw)

assert np.allclose(clf.recorded_weight_, sw)


def test_fit_estimator_no_sample_weight_when_none():
"""fit_estimator passes None → clf.recorded_weight_ is None."""
clf = _WeightCapturingClassifier()
pipe = Pipeline([("clf", clf)])

fit_estimator(pipe, np.zeros((3, 2)), np.array([0, 1, 0]), None,
MockConfig(), MockConfig(), sample_weight=None)

assert clf.recorded_weight_ is None


def test_fit_estimator_skips_unsupported_clf():
"""fit_estimator does NOT crash when clf.fit lacks sample_weight param."""

class NoWeightClf(BaseEstimator, ClassifierMixin):
_estimator_type = "classifier"
classes_ = np.array([0, 1])

def fit(self, X, y):
return self

def predict(self, X):
return np.zeros(len(X), dtype=int)

pipe = Pipeline([("clf", NoWeightClf())])
sw = np.array([1.0, 2.0, 3.0])
# should not raise
fit_estimator(pipe, np.zeros((3, 2)), np.array([0, 1, 0]), None,
MockConfig(), MockConfig(), sample_weight=sw)


def test_fit_and_score_fold_sample_weight_train_only():
"""Only the training-fold slice of sample_weight reaches the classifier."""
import coco_pipe.decoding._engine as engine
from coco_pipe.decoding._metrics import MetricSpec

clf = _WeightCapturingClassifier()
pipe = Pipeline([("clf", clf)])
X = np.zeros((6, 2))
y = np.array([0, 0, 0, 1, 1, 1])
ids = np.arange(6).astype(str)
sw = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
train_idx = np.array([0, 1, 3, 4])
test_idx = np.array([2, 5])

old_get = engine.get_metric_spec
try:
engine.get_metric_spec = lambda m: MetricSpec(
m, "classification", lambda yt, yp: float(yp.mean()), "predict"
)
fit_and_score_fold(
pipe, X, y, None, ids, None,
train_idx=train_idx, test_idx=test_idx,
metrics=["acc"],
feature_selection_config=MockConfig(),
calibration_config=MockConfig(),
spec=_make_spec(),
sample_weight=sw,
)
finally:
engine.get_metric_spec = old_get

# Only train-fold weights should have been forwarded
expected = sw[train_idx]
assert np.allclose(clf.recorded_weight_, expected), (
f"Expected {expected}, got {clf.recorded_weight_}"
)


def test_experiment_run_rejects_length_mismatch():
"""Experiment.run raises ValueError when sample_weight length != len(X)."""
import pytest
from coco_pipe.decoding import Experiment, ExperimentConfig
from coco_pipe.decoding.configs import CVConfig, LogisticRegressionConfig

config = ExperimentConfig(
task="classification",
models={"lr": LogisticRegressionConfig()},
cv=CVConfig(strategy="stratified", n_splits=2),
n_jobs=1,
verbose=False,
)
X = np.zeros((10, 2))
y = np.zeros(10, dtype=int)
y[5:] = 1

with pytest.raises(ValueError, match="sample_weight length"):
Experiment(config).run(X, y, sample_weight=np.ones(5))
Loading