From d2d54a33645eb2690e2ba6ef5438e86d80dbc11b Mon Sep 17 00:00:00 2001 From: The Meridian Authors Date: Mon, 1 Jun 2026 11:51:14 -0700 Subject: [PATCH] Rename downsample_posterior to thin_posterior and generalize thinning to all matching MCMC groups in InferenceData to maintain container-wide consistency, and update serialization tests. PiperOrigin-RevId: 924831568 --- meridian/constants.py | 10 +- meridian/model/model.py | 123 +++++++++++-------- meridian/model/model_test.py | 119 ++++++++++++------ meridian/schema/serde/meridian_serde_test.py | 69 +++++++++++ 4 files changed, 229 insertions(+), 92 deletions(-) diff --git a/meridian/constants.py b/meridian/constants.py index d57f29a33..05b79b046 100644 --- a/meridian/constants.py +++ b/meridian/constants.py @@ -815,13 +815,13 @@ BIC = 'bic' EBIC = 'ebic' -# Posterior downsampling constants. -POSTERIOR_IS_DOWNSAMPLED = 'posterior_is_downsampled' -POSTERIOR_DOWNSAMPLE_METHOD = 'posterior_downsample_method' -POSTERIOR_DOWNSAMPLE_SAMPLING_RATE = 'posterior_downsample_sampling_rate' +# Posterior thinning constants. +POSTERIOR_IS_THINNED = 'posterior_is_thinned' +POSTERIOR_THINNING_METHOD = 'posterior_thinning_method' +POSTERIOR_THINNING_SAMPLING_RATE = 'posterior_thinning_sampling_rate' POSTERIOR_ORIGINAL_CHAIN_COUNT = 'posterior_original_chain_count' POSTERIOR_ORIGINAL_DRAW_COUNT = 'posterior_original_draw_count' POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN = ( 'posterior_selected_draw_count_per_chain' ) -POSTERIOR_DOWNSAMPLE_SEED = 'posterior_downsample_seed' +POSTERIOR_THINNING_SEED = 'posterior_thinning_seed' diff --git a/meridian/model/model.py b/meridian/model/model.py index b2ab5cb0d..b48f836a8 100644 --- a/meridian/model/model.py +++ b/meridian/model/model.py @@ -60,8 +60,8 @@ @enum.unique -class DownsampleMethod(enum.Enum): - """Posterior draw downsampling methods.""" +class ThinningMethod(enum.Enum): + """Posterior draw thinning methods.""" SYSTEMATIC = "systematic" @@ -217,7 +217,7 @@ def __init__( self._computation_precision = backend.computation_precision().name self._eda_spec = eda_spec self._health_summary = health_summary - self._full_posterior = None + self._full_datasets = None self._validate_injected_inference_data() self._validate_injected_health_summary() @@ -1219,29 +1219,25 @@ def sample_posterior_and_review( ) self.review() - def downsample_posterior( + def thin_posterior( self, sampling_rate: float | None = None, n_draws: int | None = None, - method: DownsampleMethod = DownsampleMethod.SYSTEMATIC, + method: ThinningMethod = ThinningMethod.SYSTEMATIC, seed: int | Sequence[int] | None = None, preserve_original: bool = True, ) -> xr.Dataset: - """Downsamples `inference_data.posterior` while preserving chains. + """Thins all groups in `inference_data` with `chain` and `draw` dimensions. - This method replaces `self.inference_data.posterior` with a chain-preserving - subset of posterior draws. For example, a posterior with shape - `chain=10, draw=1000` and `sampling_rate=0.1` becomes - `chain=10, draw=100`. + Thinning is a process of sub-sampling MCMC draws to reduce autocorrelation, + reduce file size, and speed up downstream analysis. This method supports + systematic thinning, which selects draws at regular intervals. - The main use case is accelerating posterior workflows such as budget - optimization while continuing to use Meridian's existing APIs unchanged. - Outputs produced after downsampling are approximate with respect to the full - posterior. + To maintain the consistency of the `inference_data` container, this method + applies the same draw selection to all groups that contain both `chain` and + `draw` dimensions (e.g., `posterior`, `sample_stats`, `log_likelihood`). - Systematic sampling selects posterior samples from each MCMC chain at a - fixed, regular interval (such as every 10th sample) starting from a randomly - chosen point. This is done to minimize the auto-correlation of the sample. + Exactly one of `sampling_rate` or `n_draws` must be provided. Args: sampling_rate: Fraction of draws to keep per chain. Must be in `(0, 1]`. @@ -1250,33 +1246,34 @@ def downsample_posterior( original_n_draws]`. Exactly one of `sampling_rate` or `n_draws` must be provided. method: Draw selection method. Currently only - `DownsampleMethod.SYSTEMATIC` is supported. + `ThinningMethod.SYSTEMATIC` is supported. seed: Optional random seed for reproducible draw selection. This is used only for selecting posterior draw indices. - preserve_original: If `True`, stores a copy of the full posterior on this - model so `restore_full_posterior()` can restore it. + preserve_original: If `True`, stores a copy of the full datasets on this + model so `restore_full_posterior()` can restore them. Returns: - The downsampled posterior `xarray.Dataset`. + The thinned posterior `xarray.Dataset`. Raises: NotFittedModelError: If the model does not have posterior samples. - ValueError: If arguments are invalid. + ValueError: If arguments are invalid or if a group has already been + thinned. """ if not hasattr(self.inference_data, constants.POSTERIOR): raise NotFittedModelError( - "sample_posterior() must be called before downsample_posterior()." + "sample_posterior() must be called before thin_posterior()." ) if (sampling_rate is None) == (n_draws is None): raise ValueError( "Exactly one of `sampling_rate` or `n_draws` must be provided." ) - if method is not DownsampleMethod.SYSTEMATIC: - raise ValueError(f"Unsupported posterior downsample method: {method}.") + if method is not ThinningMethod.SYSTEMATIC: + raise ValueError(f"Unsupported posterior thinning method: {method}.") posterior = self.inference_data.posterior - if posterior.attrs.get(constants.POSTERIOR_IS_DOWNSAMPLED): - raise ValueError("Posterior has already been downsampled.") + if posterior.attrs.get(constants.POSTERIOR_IS_THINNED): + raise ValueError("Posterior has already been thinned.") if ( constants.CHAIN not in posterior.sizes or constants.DRAW not in posterior.sizes @@ -1298,8 +1295,26 @@ def downsample_posterior( if n_selected_draws == n_original_draws: return posterior - if preserve_original and self._full_posterior is None: - self._full_posterior = posterior.copy(deep=True) + # Check all groups for pre-existing thinning state to fail early + for group in self.inference_data.groups(): + if "prior" in group: + continue + dataset = getattr(self.inference_data, group) + if ( + constants.CHAIN in dataset.sizes + and constants.DRAW in dataset.sizes + and dataset.attrs.get(constants.POSTERIOR_IS_THINNED) + ): + raise ValueError(f"Group {group} has already been thinned.") + + if preserve_original and self._full_datasets is None: + self._full_datasets = {} + for group in self.inference_data.groups(): + if "prior" in group: + continue + dataset = getattr(self.inference_data, group) + if constants.CHAIN in dataset.sizes and constants.DRAW in dataset.sizes: + self._full_datasets[group] = dataset.copy(deep=True) rng = np.random.default_rng(seed) selected_draw_indices = np.stack([ @@ -1315,14 +1330,11 @@ def downsample_posterior( selected_draw_indices, dims=(constants.CHAIN, constants.DRAW), ) - downsampled_posterior = posterior.isel( - {constants.DRAW: draw_indexer} - ).assign_coords({constants.DRAW: np.arange(n_selected_draws)}) - attrs = dict(posterior.attrs) - attrs.update({ - constants.POSTERIOR_IS_DOWNSAMPLED: True, - constants.POSTERIOR_DOWNSAMPLE_METHOD: method.value, - constants.POSTERIOR_DOWNSAMPLE_SAMPLING_RATE: ( + + common_attrs = { + constants.POSTERIOR_IS_THINNED: True, + constants.POSTERIOR_THINNING_METHOD: method.value, + constants.POSTERIOR_THINNING_SAMPLING_RATE: ( float(sampling_rate) if sampling_rate is not None else n_selected_draws / n_original_draws @@ -1330,26 +1342,41 @@ def downsample_posterior( constants.POSTERIOR_ORIGINAL_CHAIN_COUNT: n_chains, constants.POSTERIOR_ORIGINAL_DRAW_COUNT: n_original_draws, constants.POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN: n_selected_draws, - }) + } if seed is not None: - attrs[constants.POSTERIOR_DOWNSAMPLE_SEED] = ( + common_attrs[constants.POSTERIOR_THINNING_SEED] = ( list(seed) if isinstance(seed, Sequence) and not isinstance(seed, (str, bytes)) else int(seed) ) - downsampled_posterior.attrs = attrs - self.inference_data.posterior = downsampled_posterior - return downsampled_posterior + + for group in self.inference_data.groups(): + if "prior" in group: + continue + dataset = getattr(self.inference_data, group) + if constants.CHAIN in dataset.sizes and constants.DRAW in dataset.sizes: + thinned_dataset = dataset.isel( + {constants.DRAW: draw_indexer} + ).assign_coords({constants.DRAW: np.arange(n_selected_draws)}) + + attrs = dict(dataset.attrs) + attrs.update(common_attrs) + thinned_dataset.attrs = attrs + + setattr(self.inference_data, group, thinned_dataset) + + return self.inference_data.posterior def restore_full_posterior(self) -> xr.Dataset: - """Restores the full posterior saved by `downsample_posterior()`.""" - if self._full_posterior is None: + """Restores the full datasets saved by `thin_posterior()`.""" + if self._full_datasets is None: raise ValueError( - "No preserved full posterior is available. Call " - "downsample_posterior(..., preserve_original=True) first." + "No preserved full datasets are available. Call " + "thin_posterior(..., preserve_original=True) first." ) - self.inference_data.posterior = self._full_posterior - self._full_posterior = None + for group, full_dataset in self._full_datasets.items(): + setattr(self.inference_data, group, full_dataset) + self._full_datasets = None return self.inference_data.posterior diff --git a/meridian/model/model_test.py b/meridian/model/model_test.py index 20e63df85..6d7c9cc5e 100644 --- a/meridian/model/model_test.py +++ b/meridian/model/model_test.py @@ -721,10 +721,10 @@ def test_sample_posterior_and_review_method(self): def _meridian_with_posterior(self, posterior: xr.Dataset) -> model.Meridian: meridian = model.Meridian(input_data=self.input_data_with_media_only) - meridian.inference_data.posterior = posterior + meridian.inference_data.add_groups({"posterior": posterior}) return meridian - def test_downsample_posterior_preserves_chains(self): + def test_thin_posterior_preserves_chains(self): values = np.arange(3 * 10 * 2).reshape((3, 10, 2)) meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -744,26 +744,26 @@ def test_downsample_posterior_preserves_chains(self): }, )) - downsampled = meridian.downsample_posterior(n_draws=4, seed=7) + thinned = meridian.thin_posterior(n_draws=4, seed=7) - self.assertEqual(downsampled.sizes[constants.CHAIN], 3) - self.assertEqual(downsampled.sizes[constants.DRAW], 4) - self.assertEqual(downsampled.sizes["channel"], 2) + self.assertEqual(thinned.sizes[constants.CHAIN], 3) + self.assertEqual(thinned.sizes[constants.DRAW], 4) + self.assertEqual(thinned.sizes["channel"], 2) self.assertEqual( - downsampled.attrs["posterior_selected_draw_count_per_chain"], 4 + thinned.attrs["posterior_selected_draw_count_per_chain"], 4 ) - self.assertTrue(downsampled.attrs["posterior_is_downsampled"]) + self.assertTrue(thinned.attrs["posterior_is_thinned"]) self.assertEqual( - downsampled.attrs["posterior_downsample_method"], "systematic" + thinned.attrs["posterior_thinning_method"], "systematic" ) for chain in range(3): with self.subTest(chain=chain): - selected_draws = downsampled["draw_id"].sel(chain=chain).values + selected_draws = thinned["draw_id"].sel(chain=chain).values self.assertLen(set(selected_draws.tolist()), 4) self.assertTrue(np.all(np.diff(selected_draws) >= 1)) self.assertTrue(np.all(selected_draws >= 0)) self.assertTrue(np.all(selected_draws < 10)) - selected_values = downsampled["param"].sel(chain=chain).values[:, 0] + selected_values = thinned["param"].sel(chain=chain).values[:, 0] self.assertTrue(np.all(selected_values >= chain * 20)) self.assertTrue(np.all(selected_values < (chain + 1) * 20)) @@ -773,18 +773,18 @@ def test_downsample_posterior_preserves_chains(self): self.assertEqual(restored.sizes[constants.DRAW], 10) np.testing.assert_array_equal(restored["param"].values, values) - def test_downsample_posterior_accepts_downsample_method_enum(self): + def test_thin_posterior_accepts_thinning_method_enum(self): meridian = self._meridian_with_posterior(_simple_posterior()) - downsampled = meridian.downsample_posterior( - n_draws=4, method=model.DownsampleMethod.SYSTEMATIC, seed=7 + thinned = meridian.thin_posterior( + n_draws=4, method=model.ThinningMethod.SYSTEMATIC, seed=7 ) self.assertEqual( - downsampled.attrs["posterior_downsample_method"], "systematic" + thinned.attrs["posterior_thinning_method"], "systematic" ) - def test_downsample_posterior_supports_non_leading_draw_dimension(self): + def test_thin_posterior_supports_non_leading_draw_dimension(self): values = np.arange(2 * 3 * 10).reshape((2, 3, 10)) meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -800,14 +800,14 @@ def test_downsample_posterior_supports_non_leading_draw_dimension(self): }, )) - downsampled = meridian.downsample_posterior(n_draws=4, seed=7) + thinned = meridian.thin_posterior(n_draws=4, seed=7) self.assertEqual( - downsampled["param"].dims, ("channel", constants.CHAIN, constants.DRAW) + thinned["param"].dims, ("channel", constants.CHAIN, constants.DRAW) ) - self.assertEqual(downsampled.sizes[constants.CHAIN], 3) - self.assertEqual(downsampled.sizes[constants.DRAW], 4) - self.assertEqual(downsampled.sizes["channel"], 2) + self.assertEqual(thinned.sizes[constants.CHAIN], 3) + self.assertEqual(thinned.sizes[constants.DRAW], 4) + self.assertEqual(thinned.sizes["channel"], 2) @parameterized.named_parameters( dict( @@ -833,7 +833,7 @@ def test_systematic_draw_indices_returns_exact_count( self.assertTrue(np.all(selected >= 0)) self.assertTrue(np.all(selected < n_original_draws)) - def test_downsample_posterior_seed_reproducible(self): + def test_thin_posterior_seed_reproducible(self): values = np.arange(2 * 30).reshape((2, 30)) first_meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -851,63 +851,104 @@ def test_downsample_posterior_seed_reproducible(self): first_meridian.inference_data.posterior.copy(deep=True) ) - first = first_meridian.downsample_posterior(n_draws=5, seed=7) - second = second_meridian.downsample_posterior(n_draws=5, seed=7) + first = first_meridian.thin_posterior(n_draws=5, seed=7) + second = second_meridian.thin_posterior(n_draws=5, seed=7) - self.assertEqual(first.attrs["posterior_downsample_seed"], 7) + self.assertEqual(first.attrs["posterior_thinning_seed"], 7) np.testing.assert_array_equal(first["param"].values, second["param"].values) - def test_downsample_posterior_requires_posterior(self): + def test_thin_posterior_requires_posterior(self): meridian = model.Meridian(input_data=self.input_data_with_media_only) with self.assertRaises(model.NotFittedModelError): - meridian.downsample_posterior(sampling_rate=0.1) + meridian.thin_posterior(sampling_rate=0.1) @parameterized.named_parameters( dict(testcase_name="missing", kwargs={}), dict(testcase_name="both", kwargs={"sampling_rate": 0.1, "n_draws": 2}), ) - def test_downsample_posterior_requires_exactly_one_draw_argument( + def test_thin_posterior_requires_exactly_one_draw_argument( self, kwargs ): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "Exactly one"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) @parameterized.named_parameters( dict(testcase_name="zero", kwargs={"n_draws": 0}), dict(testcase_name="too_many", kwargs={"n_draws": 11}), ) - def test_downsample_posterior_rejects_invalid_n_draws(self, kwargs): + def test_thin_posterior_rejects_invalid_n_draws(self, kwargs): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "`n_draws`"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) @parameterized.named_parameters( dict(testcase_name="zero", kwargs={"sampling_rate": 0}), dict(testcase_name="too_large", kwargs={"sampling_rate": 1.1}), ) - def test_downsample_posterior_rejects_invalid_sampling_rate(self, kwargs): + def test_thin_posterior_rejects_invalid_sampling_rate(self, kwargs): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "`sampling_rate`"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) - def test_downsample_posterior_rejects_invalid_method(self): + def test_thin_posterior_rejects_invalid_method(self): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "Unsupported"): - meridian.downsample_posterior(n_draws=4, method=mock.MagicMock()) + meridian.thin_posterior(n_draws=4, method=mock.MagicMock()) - def test_downsample_posterior_rejects_downsampling_twice(self): + def test_thin_posterior_rejects_thinning_twice(self): meridian = self._meridian_with_posterior(_simple_posterior()) - meridian.downsample_posterior(n_draws=4, seed=7) + meridian.thin_posterior(n_draws=4, seed=7) - with self.assertRaisesRegex(ValueError, "already been downsampled"): - meridian.downsample_posterior(n_draws=2) + with self.assertRaisesRegex(ValueError, "already been thinned"): + meridian.thin_posterior(n_draws=2) + + def test_thin_posterior_thins_all_matching_groups(self): + posterior = _simple_posterior(n_chains=3, n_draws=10) + sample_stats = xr.Dataset( + data_vars={ + "diverging": ( + (constants.CHAIN, constants.DRAW), + np.zeros((3, 10), dtype=bool), + ), + }, + coords={ + constants.CHAIN: np.arange(3), + constants.DRAW: np.arange(10), + }, + ) + meridian = self._meridian_with_posterior(posterior) + meridian.inference_data.add_groups({"sample_stats": sample_stats}) + + thinned = meridian.thin_posterior(n_draws=4, seed=7) + + self.assertEqual(thinned.sizes[constants.DRAW], 4) + self.assertTrue(thinned.attrs[constants.POSTERIOR_IS_THINNED]) + + thinned_sample_stats = meridian.inference_data.sample_stats + self.assertEqual(thinned_sample_stats.sizes[constants.DRAW], 4) + self.assertTrue(thinned_sample_stats.attrs[constants.POSTERIOR_IS_THINNED]) + self.assertEqual( + thinned_sample_stats.attrs[constants.POSTERIOR_THINNING_SEED], 7 + ) + + restored_posterior = meridian.restore_full_posterior() + self.assertEqual(restored_posterior.sizes[constants.DRAW], 10) + self.assertNotIn( + constants.POSTERIOR_IS_THINNED, restored_posterior.attrs + ) + + restored_sample_stats = meridian.inference_data.sample_stats + self.assertEqual(restored_sample_stats.sizes[constants.DRAW], 10) + self.assertNotIn( + constants.POSTERIOR_IS_THINNED, restored_sample_stats.attrs + ) class ModelPersistenceTest( diff --git a/meridian/schema/serde/meridian_serde_test.py b/meridian/schema/serde/meridian_serde_test.py index 5dee81530..14d61fdff 100644 --- a/meridian/schema/serde/meridian_serde_test.py +++ b/meridian/schema/serde/meridian_serde_test.py @@ -874,6 +874,75 @@ def test_serialize_deserialize_round_trip( deserialized_model.inference_data.posterior, ) + def test_serialize_deserialize_thinned_model(self): + model_spec = spec.ModelSpec(knots=49) + sample_stats = xr.Dataset( + data_vars={ + 'diverging': ( + (constants.CHAIN, constants.DRAW), + np.zeros( + (_POSTERIOR_DATASET_CHAINS, _POSTERIOR_DATASET_DRAWS), + dtype=bool, + ), + ), + }, + coords={ + constants.CHAIN: np.arange(_POSTERIOR_DATASET_CHAINS), + constants.DRAW: np.arange(_POSTERIOR_DATASET_DRAWS), + }, + ) + inf_data = az.InferenceData( + prior=_PRIOR_DATASET, + posterior=_POSTERIOR_DATASET, + sample_stats=sample_stats, + ) + with mock.patch.object( + context.ModelContext, '_validate_geo_invariants', autospec=True + ): + original_model = model.Meridian( + input_data=_INPUT_DATA, + model_spec=model_spec, + inference_data=inf_data, + ) + original_model.thin_posterior(n_draws=4, seed=7) + serialized_model = self.serde.serialize(original_model, 'test_model') + deserialized_model = self.serde.deserialize(serialized_model) + + self.assertIsInstance(deserialized_model, model.Meridian) + self.assertTrue( + deserialized_model.inference_data.posterior.attrs[ + constants.POSTERIOR_IS_THINNED + ] + ) + self.assertEqual( + deserialized_model.inference_data.posterior.sizes[constants.DRAW], 4 + ) + self.assertEqual( + deserialized_model.inference_data.posterior.attrs[ + constants.POSTERIOR_THINNING_SEED + ], + 7, + ) + xrt.assert_allclose( + original_model.inference_data.posterior, + deserialized_model.inference_data.posterior, + ) + self.assertTrue( + deserialized_model.inference_data.sample_stats.attrs[ + constants.POSTERIOR_IS_THINNED + ] + ) + self.assertEqual( + deserialized_model.inference_data.sample_stats.sizes[ + constants.DRAW + ], + 4, + ) + xrt.assert_allclose( + original_model.inference_data.sample_stats, + deserialized_model.inference_data.sample_stats, + ) + def test_save_load_meridian_binpb(self): # The create_tempdir() method below internally uses command line flag # (--test_tmpdir) and such flags are not marked as parsed by default