Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/pycorpdiff/semantic/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ class SenseDriftResult:
The margin-density flag threshold --- null-calibrated (a high
percentile of the label-shuffle null) when ``n_permutations > 0``,
otherwise the in-sample control chart.
threshold_method
How ``threshold`` was derived: ``"permutation_null"`` when
``n_permutations > 0``, else ``"control_chart"``. **The two methods
can flag different periods**, so ``change_type`` and ``drift_terms``
depend on it --- enabling permutations is not merely "add a p-value"
(see the ``n_permutations`` note on :func:`sense_drift`).
p_value
Permutation p-value for the overall drift (real max margin density
vs the label-shuffle null max); ``None`` unless
Expand All @@ -222,6 +228,7 @@ class SenseDriftResult:
reference: list[Any]
k: int
threshold: float
threshold_method: str
p_value: float | None
embedding_meta: dict[str, Any]
_records: pd.DataFrame = field(repr=False)
Expand Down Expand Up @@ -536,6 +543,12 @@ def sense_drift(
to a reference fitted on themselves; the shuffle null removes that
bias. Costs one model re-fit per permutation. ``0`` (default) uses
the fast in-sample chart, fine for exploration.

**This switches the *thresholding method*, not just the p-value.**
The null-calibrated threshold can flag a *different* set of periods
than the control chart, so ``change_type`` and ``drift_terms`` may
change too. Inspect :attr:`SenseDriftResult.threshold_method` to see
which regime produced the flags.
null_pctile
Percentile of the label-shuffle null margin-density (and JSD)
distribution used as the flag threshold when ``n_permutations > 0``.
Expand Down Expand Up @@ -655,6 +668,7 @@ def sense_drift(
float(np.mean(jsd_ref)) + k_sigma * float(np.std(jsd_ref, ddof=1))
if len(jsd_ref) >= 2 else np.inf)
threshold = md_threshold
threshold_method = "permutation_null" if n_permutations > 0 else "control_chart"

table = pd.DataFrame({
"period": periods,
Expand Down Expand Up @@ -719,6 +733,7 @@ def sense_drift(
reference=ref_labels_set,
k=k,
threshold=threshold,
threshold_method=threshold_method,
p_value=p_value,
embedding_meta=dict(embedding_meta or {}),
_records=recs,
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_semantic_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,37 @@ def test_bad_novelty_raises():
df, X = _stable()
with pytest.raises(ValueError, match="mahalanobis.*cosine"):
pcd.sense_drift(df, X, time_col="year", reference=REF, k=3, novelty="bogus")


def test_threshold_method_is_surfaced():
"""``n_permutations`` switches the thresholding regime, and the result
labels which regime produced the flags (so the change is not silent)."""
df, X = _broadening()
canon = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3)
perm = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3,
n_permutations=50)
assert canon.threshold_method == "control_chart"
assert canon.p_value is None
assert perm.threshold_method == "permutation_null"
assert perm.p_value is not None


def test_permutation_switches_threshold_not_just_pvalue():
"""The two regimes use *different* thresholds (the documented behavior):
enabling permutations is not merely 'add a p-value'."""
df, X = _broadening()
canon = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3)
perm = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3,
n_permutations=50)
assert canon.threshold != perm.threshold


def test_permutation_mode_is_deterministic():
"""Both regimes are reproducible run-to-run under a fixed random_state."""
df, X = _broadening()
a = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3, n_permutations=50)
b = pcd.sense_drift(df, X, time_col="year", reference=REF, k=3, n_permutations=50)
pd.testing.assert_frame_equal(a.table, b.table)
assert a.threshold == b.threshold
assert a.change_type == b.change_type
assert a.p_value == b.p_value
Loading