From ca447a42a38b2bebe29fe95c378fdd8af7a510d8 Mon Sep 17 00:00:00 2001 From: Vivek Beniwal Date: Mon, 1 Jun 2026 11:52:00 -0700 Subject: [PATCH] Rename downsample_posterior posterior to thin_posterior. PiperOrigin-RevId: 924831982 --- meridian/constants.py | 14 +++---- meridian/model/model.py | 55 +++++++++++++------------- meridian/model/model_test.py | 76 ++++++++++++++++++------------------ 3 files changed, 73 insertions(+), 72 deletions(-) diff --git a/meridian/constants.py b/meridian/constants.py index d57f29a33..9dd93d728 100644 --- a/meridian/constants.py +++ b/meridian/constants.py @@ -815,13 +815,13 @@ BIC = 'bic' EBIC = 'ebic' -# Posterior downsampling constants. -POSTERIOR_IS_DOWNSAMPLED = 'posterior_is_downsampled' -POSTERIOR_DOWNSAMPLE_METHOD = 'posterior_downsample_method' -POSTERIOR_DOWNSAMPLE_SAMPLING_RATE = 'posterior_downsample_sampling_rate' +# Posterior thinning constants. +POSTERIOR_IS_THINNED = 'posterior_is_thinned' +POSTERIOR_THINNING_METHOD = 'posterior_thinning_method' +POSTERIOR_THINNING_RATE = 'posterior_thinning_rate' POSTERIOR_ORIGINAL_CHAIN_COUNT = 'posterior_original_chain_count' POSTERIOR_ORIGINAL_DRAW_COUNT = 'posterior_original_draw_count' -POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN = ( - 'posterior_selected_draw_count_per_chain' +POSTERIOR_THINNED_DRAW_COUNT_PER_CHAIN = ( + 'posterior_thinned_draw_count_per_chain' ) -POSTERIOR_DOWNSAMPLE_SEED = 'posterior_downsample_seed' +POSTERIOR_THINNING_SEED = 'posterior_thinning_seed' diff --git a/meridian/model/model.py b/meridian/model/model.py index b2ab5cb0d..6ce2df3ab 100644 --- a/meridian/model/model.py +++ b/meridian/model/model.py @@ -60,8 +60,8 @@ @enum.unique -class DownsampleMethod(enum.Enum): - """Posterior draw downsampling methods.""" +class ThinningMethod(enum.Enum): + """Posterior draw thinning methods.""" SYSTEMATIC = "systematic" @@ -1219,15 +1219,15 @@ def sample_posterior_and_review( ) self.review() - def downsample_posterior( + def thin_posterior( self, sampling_rate: float | None = None, n_draws: int | None = None, - method: DownsampleMethod = DownsampleMethod.SYSTEMATIC, + method: ThinningMethod = ThinningMethod.SYSTEMATIC, seed: int | Sequence[int] | None = None, preserve_original: bool = True, ) -> xr.Dataset: - """Downsamples `inference_data.posterior` while preserving chains. + """Thins `inference_data.posterior` while preserving chains. This method replaces `self.inference_data.posterior` with a chain-preserving subset of posterior draws. For example, a posterior with shape @@ -1236,12 +1236,13 @@ def downsample_posterior( The main use case is accelerating posterior workflows such as budget optimization while continuing to use Meridian's existing APIs unchanged. - Outputs produced after downsampling are approximate with respect to the full + Outputs produced after thinning are approximate with respect to the full posterior. - Systematic sampling selects posterior samples from each MCMC chain at a - fixed, regular interval (such as every 10th sample) starting from a randomly - chosen point. This is done to minimize the auto-correlation of the sample. + Systematic MCMC posterior thinning selects posterior samples from each + MCMC chain at a fixed, regular interval (such as every 10th sample) starting + from a randomly chosen point. This is done to minimize the auto-correlation + of the sample. Args: sampling_rate: Fraction of draws to keep per chain. Must be in `(0, 1]`. @@ -1250,14 +1251,14 @@ def downsample_posterior( original_n_draws]`. Exactly one of `sampling_rate` or `n_draws` must be provided. method: Draw selection method. Currently only - `DownsampleMethod.SYSTEMATIC` is supported. + `ThinningMethod.SYSTEMATIC` is supported. seed: Optional random seed for reproducible draw selection. This is used only for selecting posterior draw indices. preserve_original: If `True`, stores a copy of the full posterior on this model so `restore_full_posterior()` can restore it. Returns: - The downsampled posterior `xarray.Dataset`. + The thinned posterior `xarray.Dataset`. Raises: NotFittedModelError: If the model does not have posterior samples. @@ -1265,18 +1266,18 @@ def downsample_posterior( """ if not hasattr(self.inference_data, constants.POSTERIOR): raise NotFittedModelError( - "sample_posterior() must be called before downsample_posterior()." + "sample_posterior() must be called before thin_posterior()." ) if (sampling_rate is None) == (n_draws is None): raise ValueError( "Exactly one of `sampling_rate` or `n_draws` must be provided." ) - if method is not DownsampleMethod.SYSTEMATIC: - raise ValueError(f"Unsupported posterior downsample method: {method}.") + if method is not ThinningMethod.SYSTEMATIC: + raise ValueError(f"Unsupported posterior thinning method: {method}.") posterior = self.inference_data.posterior - if posterior.attrs.get(constants.POSTERIOR_IS_DOWNSAMPLED): - raise ValueError("Posterior has already been downsampled.") + if posterior.attrs.get(constants.POSTERIOR_IS_THINNED): + raise ValueError("Posterior has already been thinned.") if ( constants.CHAIN not in posterior.sizes or constants.DRAW not in posterior.sizes @@ -1315,38 +1316,38 @@ def downsample_posterior( selected_draw_indices, dims=(constants.CHAIN, constants.DRAW), ) - downsampled_posterior = posterior.isel( + thinned_posterior = posterior.isel( {constants.DRAW: draw_indexer} ).assign_coords({constants.DRAW: np.arange(n_selected_draws)}) attrs = dict(posterior.attrs) attrs.update({ - constants.POSTERIOR_IS_DOWNSAMPLED: True, - constants.POSTERIOR_DOWNSAMPLE_METHOD: method.value, - constants.POSTERIOR_DOWNSAMPLE_SAMPLING_RATE: ( + constants.POSTERIOR_IS_THINNED: True, + constants.POSTERIOR_THINNING_METHOD: method.value, + constants.POSTERIOR_THINNING_RATE: ( float(sampling_rate) if sampling_rate is not None else n_selected_draws / n_original_draws ), constants.POSTERIOR_ORIGINAL_CHAIN_COUNT: n_chains, constants.POSTERIOR_ORIGINAL_DRAW_COUNT: n_original_draws, - constants.POSTERIOR_SELECTED_DRAW_COUNT_PER_CHAIN: n_selected_draws, + constants.POSTERIOR_THINNED_DRAW_COUNT_PER_CHAIN: n_selected_draws, }) if seed is not None: - attrs[constants.POSTERIOR_DOWNSAMPLE_SEED] = ( + attrs[constants.POSTERIOR_THINNING_SEED] = ( list(seed) if isinstance(seed, Sequence) and not isinstance(seed, (str, bytes)) else int(seed) ) - downsampled_posterior.attrs = attrs - self.inference_data.posterior = downsampled_posterior - return downsampled_posterior + thinned_posterior.attrs = attrs + self.inference_data.posterior = thinned_posterior + return thinned_posterior def restore_full_posterior(self) -> xr.Dataset: - """Restores the full posterior saved by `downsample_posterior()`.""" + """Restores the full posterior saved by `thin_posterior()`.""" if self._full_posterior is None: raise ValueError( "No preserved full posterior is available. Call " - "downsample_posterior(..., preserve_original=True) first." + "thin_posterior(..., preserve_original=True) first." ) self.inference_data.posterior = self._full_posterior self._full_posterior = None diff --git a/meridian/model/model_test.py b/meridian/model/model_test.py index 20e63df85..84226a003 100644 --- a/meridian/model/model_test.py +++ b/meridian/model/model_test.py @@ -724,7 +724,7 @@ def _meridian_with_posterior(self, posterior: xr.Dataset) -> model.Meridian: meridian.inference_data.posterior = posterior return meridian - def test_downsample_posterior_preserves_chains(self): + def test_thin_posterior_preserves_chains(self): values = np.arange(3 * 10 * 2).reshape((3, 10, 2)) meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -744,26 +744,26 @@ def test_downsample_posterior_preserves_chains(self): }, )) - downsampled = meridian.downsample_posterior(n_draws=4, seed=7) + thinned = meridian.thin_posterior(n_draws=4, seed=7) - self.assertEqual(downsampled.sizes[constants.CHAIN], 3) - self.assertEqual(downsampled.sizes[constants.DRAW], 4) - self.assertEqual(downsampled.sizes["channel"], 2) + self.assertEqual(thinned.sizes[constants.CHAIN], 3) + self.assertEqual(thinned.sizes[constants.DRAW], 4) + self.assertEqual(thinned.sizes["channel"], 2) self.assertEqual( - downsampled.attrs["posterior_selected_draw_count_per_chain"], 4 + thinned.attrs["posterior_thinned_draw_count_per_chain"], 4 ) - self.assertTrue(downsampled.attrs["posterior_is_downsampled"]) + self.assertTrue(thinned.attrs["posterior_is_thinned"]) self.assertEqual( - downsampled.attrs["posterior_downsample_method"], "systematic" + thinned.attrs["posterior_thinning_method"], "systematic" ) for chain in range(3): with self.subTest(chain=chain): - selected_draws = downsampled["draw_id"].sel(chain=chain).values + selected_draws = thinned["draw_id"].sel(chain=chain).values self.assertLen(set(selected_draws.tolist()), 4) self.assertTrue(np.all(np.diff(selected_draws) >= 1)) self.assertTrue(np.all(selected_draws >= 0)) self.assertTrue(np.all(selected_draws < 10)) - selected_values = downsampled["param"].sel(chain=chain).values[:, 0] + selected_values = thinned["param"].sel(chain=chain).values[:, 0] self.assertTrue(np.all(selected_values >= chain * 20)) self.assertTrue(np.all(selected_values < (chain + 1) * 20)) @@ -773,18 +773,18 @@ def test_downsample_posterior_preserves_chains(self): self.assertEqual(restored.sizes[constants.DRAW], 10) np.testing.assert_array_equal(restored["param"].values, values) - def test_downsample_posterior_accepts_downsample_method_enum(self): + def test_thin_posterior_accepts_thinning_method_enum(self): meridian = self._meridian_with_posterior(_simple_posterior()) - downsampled = meridian.downsample_posterior( - n_draws=4, method=model.DownsampleMethod.SYSTEMATIC, seed=7 + thinned = meridian.thin_posterior( + n_draws=4, method=model.ThinningMethod.SYSTEMATIC, seed=7 ) self.assertEqual( - downsampled.attrs["posterior_downsample_method"], "systematic" + thinned.attrs["posterior_thinning_method"], "systematic" ) - def test_downsample_posterior_supports_non_leading_draw_dimension(self): + def test_thin_posterior_supports_non_leading_draw_dimension(self): values = np.arange(2 * 3 * 10).reshape((2, 3, 10)) meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -800,14 +800,14 @@ def test_downsample_posterior_supports_non_leading_draw_dimension(self): }, )) - downsampled = meridian.downsample_posterior(n_draws=4, seed=7) + thinned = meridian.thin_posterior(n_draws=4, seed=7) self.assertEqual( - downsampled["param"].dims, ("channel", constants.CHAIN, constants.DRAW) + thinned["param"].dims, ("channel", constants.CHAIN, constants.DRAW) ) - self.assertEqual(downsampled.sizes[constants.CHAIN], 3) - self.assertEqual(downsampled.sizes[constants.DRAW], 4) - self.assertEqual(downsampled.sizes["channel"], 2) + self.assertEqual(thinned.sizes[constants.CHAIN], 3) + self.assertEqual(thinned.sizes[constants.DRAW], 4) + self.assertEqual(thinned.sizes["channel"], 2) @parameterized.named_parameters( dict( @@ -833,7 +833,7 @@ def test_systematic_draw_indices_returns_exact_count( self.assertTrue(np.all(selected >= 0)) self.assertTrue(np.all(selected < n_original_draws)) - def test_downsample_posterior_seed_reproducible(self): + def test_thin_posterior_seed_reproducible(self): values = np.arange(2 * 30).reshape((2, 30)) first_meridian = self._meridian_with_posterior(xr.Dataset( data_vars={ @@ -851,63 +851,63 @@ def test_downsample_posterior_seed_reproducible(self): first_meridian.inference_data.posterior.copy(deep=True) ) - first = first_meridian.downsample_posterior(n_draws=5, seed=7) - second = second_meridian.downsample_posterior(n_draws=5, seed=7) + first = first_meridian.thin_posterior(n_draws=5, seed=7) + second = second_meridian.thin_posterior(n_draws=5, seed=7) - self.assertEqual(first.attrs["posterior_downsample_seed"], 7) + self.assertEqual(first.attrs["posterior_thinning_seed"], 7) np.testing.assert_array_equal(first["param"].values, second["param"].values) - def test_downsample_posterior_requires_posterior(self): + def test_thin_posterior_requires_posterior(self): meridian = model.Meridian(input_data=self.input_data_with_media_only) with self.assertRaises(model.NotFittedModelError): - meridian.downsample_posterior(sampling_rate=0.1) + meridian.thin_posterior(sampling_rate=0.1) @parameterized.named_parameters( dict(testcase_name="missing", kwargs={}), dict(testcase_name="both", kwargs={"sampling_rate": 0.1, "n_draws": 2}), ) - def test_downsample_posterior_requires_exactly_one_draw_argument( + def test_thin_posterior_requires_exactly_one_draw_argument( self, kwargs ): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "Exactly one"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) @parameterized.named_parameters( dict(testcase_name="zero", kwargs={"n_draws": 0}), dict(testcase_name="too_many", kwargs={"n_draws": 11}), ) - def test_downsample_posterior_rejects_invalid_n_draws(self, kwargs): + def test_thin_posterior_rejects_invalid_n_draws(self, kwargs): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "`n_draws`"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) @parameterized.named_parameters( dict(testcase_name="zero", kwargs={"sampling_rate": 0}), dict(testcase_name="too_large", kwargs={"sampling_rate": 1.1}), ) - def test_downsample_posterior_rejects_invalid_sampling_rate(self, kwargs): + def test_thin_posterior_rejects_invalid_sampling_rate(self, kwargs): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "`sampling_rate`"): - meridian.downsample_posterior(**kwargs) + meridian.thin_posterior(**kwargs) - def test_downsample_posterior_rejects_invalid_method(self): + def test_thin_posterior_rejects_invalid_method(self): meridian = self._meridian_with_posterior(_simple_posterior()) with self.assertRaisesRegex(ValueError, "Unsupported"): - meridian.downsample_posterior(n_draws=4, method=mock.MagicMock()) + meridian.thin_posterior(n_draws=4, method=mock.MagicMock()) - def test_downsample_posterior_rejects_downsampling_twice(self): + def test_thin_posterior_rejects_thinning_twice(self): meridian = self._meridian_with_posterior(_simple_posterior()) - meridian.downsample_posterior(n_draws=4, seed=7) + meridian.thin_posterior(n_draws=4, seed=7) - with self.assertRaisesRegex(ValueError, "already been downsampled"): - meridian.downsample_posterior(n_draws=2) + with self.assertRaisesRegex(ValueError, "already been thinned"): + meridian.thin_posterior(n_draws=2) class ModelPersistenceTest(