Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions meridian/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,13 @@
BIC = 'bic'
EBIC = 'ebic'

# Posterior downsampling constants.
POSTERIOR_IS_DOWNSAMPLED = 'posterior_is_downsampled'
POSTERIOR_DOWNSAMPLE_METHOD = 'posterior_downsample_method'
POSTERIOR_DOWNSAMPLE_SAMPLING_RATE = 'posterior_downsample_sampling_rate'
# Posterior thinning constants.
POSTERIOR_IS_THINNED = 'posterior_is_thinned'
POSTERIOR_THINNING_METHOD = 'posterior_thinning_method'
POSTERIOR_THINNING_SAMPLING_RATE = 'posterior_thinning_sampling_rate'
POSTERIOR_ORIGINAL_CHAIN_COUNT = 'posterior_original_chain_count'
POSTERIOR_ORIGINAL_DRAW_COUNT = 'posterior_original_draw_count'
POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN = (
'posterior_selected_draw_count_per_chain'
)
POSTERIOR_DOWNSAMPLE_SEED = 'posterior_downsample_seed'
POSTERIOR_THINNING_SEED = 'posterior_thinning_seed'
123 changes: 75 additions & 48 deletions meridian/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@


@enum.unique
class DownsampleMethod(enum.Enum):
"""Posterior draw downsampling methods."""
class ThinningMethod(enum.Enum):
"""Posterior draw thinning methods."""

SYSTEMATIC = "systematic"

Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
self._computation_precision = backend.computation_precision().name
self._eda_spec = eda_spec
self._health_summary = health_summary
self._full_posterior = None
self._full_datasets = None

self._validate_injected_inference_data()
self._validate_injected_health_summary()
Expand Down Expand Up @@ -1219,29 +1219,25 @@ def sample_posterior_and_review(
)
self.review()

def downsample_posterior(
def thin_posterior(
self,
sampling_rate: float | None = None,
n_draws: int | None = None,
method: DownsampleMethod = DownsampleMethod.SYSTEMATIC,
method: ThinningMethod = ThinningMethod.SYSTEMATIC,
seed: int | Sequence[int] | None = None,
preserve_original: bool = True,
) -> xr.Dataset:
"""Downsamples `inference_data.posterior` while preserving chains.
"""Thins all groups in `inference_data` with `chain` and `draw` dimensions.

This method replaces `self.inference_data.posterior` with a chain-preserving
subset of posterior draws. For example, a posterior with shape
`chain=10, draw=1000` and `sampling_rate=0.1` becomes
`chain=10, draw=100`.
Thinning is a process of sub-sampling MCMC draws to reduce autocorrelation,
reduce file size, and speed up downstream analysis. This method supports
systematic thinning, which selects draws at regular intervals.

The main use case is accelerating posterior workflows such as budget
optimization while continuing to use Meridian's existing APIs unchanged.
Outputs produced after downsampling are approximate with respect to the full
posterior.
To maintain the consistency of the `inference_data` container, this method
applies the same draw selection to all groups that contain both `chain` and
`draw` dimensions (e.g., `posterior`, `sample_stats`, `log_likelihood`).

Systematic sampling selects posterior samples from each MCMC chain at a
fixed, regular interval (such as every 10th sample) starting from a randomly
chosen point. This is done to minimize the auto-correlation of the sample.
Exactly one of `sampling_rate` or `n_draws` must be provided.

Args:
sampling_rate: Fraction of draws to keep per chain. Must be in `(0, 1]`.
Expand All @@ -1250,33 +1246,34 @@ def downsample_posterior(
original_n_draws]`. Exactly one of `sampling_rate` or `n_draws` must be
provided.
method: Draw selection method. Currently only
`DownsampleMethod.SYSTEMATIC` is supported.
`ThinningMethod.SYSTEMATIC` is supported.
seed: Optional random seed for reproducible draw selection. This is used
only for selecting posterior draw indices.
preserve_original: If `True`, stores a copy of the full posterior on this
model so `restore_full_posterior()` can restore it.
preserve_original: If `True`, stores a copy of the full datasets on this
model so `restore_full_posterior()` can restore them.

Returns:
The downsampled posterior `xarray.Dataset`.
The thinned posterior `xarray.Dataset`.

Raises:
NotFittedModelError: If the model does not have posterior samples.
ValueError: If arguments are invalid.
ValueError: If arguments are invalid or if a group has already been
thinned.
"""
if not hasattr(self.inference_data, constants.POSTERIOR):
raise NotFittedModelError(
"sample_posterior() must be called before downsample_posterior()."
"sample_posterior() must be called before thin_posterior()."
)
if (sampling_rate is None) == (n_draws is None):
raise ValueError(
"Exactly one of `sampling_rate` or `n_draws` must be provided."
)
if method is not DownsampleMethod.SYSTEMATIC:
raise ValueError(f"Unsupported posterior downsample method: {method}.")
if method is not ThinningMethod.SYSTEMATIC:
raise ValueError(f"Unsupported posterior thinning method: {method}.")

posterior = self.inference_data.posterior
if posterior.attrs.get(constants.POSTERIOR_IS_DOWNSAMPLED):
raise ValueError("Posterior has already been downsampled.")
if posterior.attrs.get(constants.POSTERIOR_IS_THINNED):
raise ValueError("Posterior has already been thinned.")
if (
constants.CHAIN not in posterior.sizes
or constants.DRAW not in posterior.sizes
Expand All @@ -1298,8 +1295,26 @@ def downsample_posterior(
if n_selected_draws == n_original_draws:
return posterior

if preserve_original and self._full_posterior is None:
self._full_posterior = posterior.copy(deep=True)
# Check all groups for pre-existing thinning state to fail early
for group in self.inference_data.groups():
if "prior" in group:
continue
dataset = getattr(self.inference_data, group)
if (
constants.CHAIN in dataset.sizes
and constants.DRAW in dataset.sizes
and dataset.attrs.get(constants.POSTERIOR_IS_THINNED)
):
raise ValueError(f"Group {group} has already been thinned.")

if preserve_original and self._full_datasets is None:
self._full_datasets = {}
for group in self.inference_data.groups():
if "prior" in group:
continue
dataset = getattr(self.inference_data, group)
if constants.CHAIN in dataset.sizes and constants.DRAW in dataset.sizes:
self._full_datasets[group] = dataset.copy(deep=True)

rng = np.random.default_rng(seed)
selected_draw_indices = np.stack([
Expand All @@ -1315,41 +1330,53 @@ def downsample_posterior(
selected_draw_indices,
dims=(constants.CHAIN, constants.DRAW),
)
downsampled_posterior = posterior.isel(
{constants.DRAW: draw_indexer}
).assign_coords({constants.DRAW: np.arange(n_selected_draws)})
attrs = dict(posterior.attrs)
attrs.update({
constants.POSTERIOR_IS_DOWNSAMPLED: True,
constants.POSTERIOR_DOWNSAMPLE_METHOD: method.value,
constants.POSTERIOR_DOWNSAMPLE_SAMPLING_RATE: (

common_attrs = {
constants.POSTERIOR_IS_THINNED: True,
constants.POSTERIOR_THINNING_METHOD: method.value,
constants.POSTERIOR_THINNING_SAMPLING_RATE: (
float(sampling_rate)
if sampling_rate is not None
else n_selected_draws / n_original_draws
),
constants.POSTERIOR_ORIGINAL_CHAIN_COUNT: n_chains,
constants.POSTERIOR_ORIGINAL_DRAW_COUNT: n_original_draws,
constants.POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN: n_selected_draws,
})
}
if seed is not None:
attrs[constants.POSTERIOR_DOWNSAMPLE_SEED] = (
common_attrs[constants.POSTERIOR_THINNING_SEED] = (
list(seed)
if isinstance(seed, Sequence) and not isinstance(seed, (str, bytes))
else int(seed)
)
downsampled_posterior.attrs = attrs
self.inference_data.posterior = downsampled_posterior
return downsampled_posterior

for group in self.inference_data.groups():
if "prior" in group:
continue
dataset = getattr(self.inference_data, group)
if constants.CHAIN in dataset.sizes and constants.DRAW in dataset.sizes:
thinned_dataset = dataset.isel(
{constants.DRAW: draw_indexer}
).assign_coords({constants.DRAW: np.arange(n_selected_draws)})

attrs = dict(dataset.attrs)
attrs.update(common_attrs)
thinned_dataset.attrs = attrs

setattr(self.inference_data, group, thinned_dataset)

return self.inference_data.posterior

def restore_full_posterior(self) -> xr.Dataset:
"""Restores the full posterior saved by `downsample_posterior()`."""
if self._full_posterior is None:
"""Restores the full datasets saved by `thin_posterior()`."""
if self._full_datasets is None:
raise ValueError(
"No preserved full posterior is available. Call "
"downsample_posterior(..., preserve_original=True) first."
"No preserved full datasets are available. Call "
"thin_posterior(..., preserve_original=True) first."
)
self.inference_data.posterior = self._full_posterior
self._full_posterior = None
for group, full_dataset in self._full_datasets.items():
setattr(self.inference_data, group, full_dataset)
self._full_datasets = None
return self.inference_data.posterior


Expand Down
Loading
Loading