From 4e22a37e92b4a7bbda0f41f408a1c4c0b2f9c1ea Mon Sep 17 00:00:00 2001 From: Lukasz Mazurek Date: Fri, 12 Jun 2026 06:59:11 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 931129969 --- meridian/constants.py | 6 ++++++ meridian/model/prior_distribution.py | 24 +++++++++++++++++----- meridian/model/prior_distribution_test.py | 25 ++++++++++++++++++++++- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/meridian/constants.py b/meridian/constants.py index 05b79b046..eb37f210b 100644 --- a/meridian/constants.py +++ b/meridian/constants.py @@ -825,3 +825,9 @@ 'posterior_selected_draw_count_per_chain' ) POSTERIOR_THINNING_SEED = 'posterior_thinning_seed' + + +# Calibration constants. +IS_CALIBRATED = 'is_calibrated' +CALIBRATION_OUTPUTS = 'calibration_outputs' + diff --git a/meridian/model/prior_distribution.py b/meridian/model/prior_distribution.py index 47e0b11d4..84e589074 100644 --- a/meridian/model/prior_distribution.py +++ b/meridian/model/prior_distribution.py @@ -29,7 +29,6 @@ from meridian import constants import numpy as np - __all__ = [ 'IndependentMultivariateDistribution', 'PriorDistribution', @@ -92,7 +91,7 @@ class PriorDistribution: | `contribution_m` | `n_media_channels` | | `contribution_rf` | `n_rf_channels` | | `contribution_om` | `n_organic_media_channels` | - | `contribution_orf` | `n_organic_f_channels` | + | `contribution_orf` | `n_organic_rf_channels` | | `contribution_n` | `n_non_media_channels` | (σ) `n_geos` if `unique_sigma_for_each_geo`, otherwise this is `1` @@ -1206,9 +1205,23 @@ def distributions_are_equal( del a_params[constants.DISTRIBUTION] del b_params[constants.DISTRIBUTION] + if 'distributions' in a_params and 'distributions' in b_params: + a_dists = a_params['distributions'] + b_dists = b_params['distributions'] + if len(a_dists) != len(b_dists): + return False + for a_d, b_d in zip(a_dists, b_dists): + if not distributions_are_equal(a_d, b_d): + return False + del a_params['distributions'] + del b_params['distributions'] + if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params: return False + if 'distributions' in a_params or 'distributions' in b_params: + return False + if a_params.keys() != b_params.keys(): return False @@ -1286,13 +1299,14 @@ def lognormal_dist_from_range( mass_percent = np.asarray(mass_percent) if not ((0.0 < low).all() and (low < high).all()): # pytype: disable=attribute-error - raise ValueError("'low' and 'high' values must be non-negative and satisfy " - "high > low.") + raise ValueError( + "'low' and 'high' values must be non-negative and satisfy high > low." + ) if not ((0.0 < mass_percent).all() and (mass_percent < 1.0).all()): # pytype: disable=attribute-error raise ValueError( "'mass_percent' values must be between 0 and 1, exclusive." - ) + ) normal = backend.tfd.Normal(0, 1) mass_lower = 0.5 - (mass_percent / 2) diff --git a/meridian/model/prior_distribution_test.py b/meridian/model/prior_distribution_test.py index 32c6ab08c..cae8ec92f 100644 --- a/meridian/model/prior_distribution_test.py +++ b/meridian/model/prior_distribution_test.py @@ -24,7 +24,6 @@ from meridian.model import prior_distribution import numpy as np - _N_GEOS = 10 _N_GEOS_NATIONAL = 1 _N_MEDIA_CHANNELS = 6 @@ -1529,6 +1528,30 @@ def test_get_total_media_contribution_prior(self): ), expected_result=True, ), + dict( + testcase_name='same_independent_multivariate_distributions', + get_a=lambda: prior_distribution.IndependentMultivariateDistribution([ + backend.tfd.LogNormal(0.7, 0.4), + backend.tfd.Normal(0.0, 1.0), + ]), + get_b=lambda: prior_distribution.IndependentMultivariateDistribution([ + backend.tfd.LogNormal(0.7, 0.4), + backend.tfd.Normal(0.0, 1.0), + ]), + expected_result=True, + ), + dict( + testcase_name='different_independent_multivariate_distributions', + get_a=lambda: prior_distribution.IndependentMultivariateDistribution([ + backend.tfd.LogNormal(0.7, 0.4), + backend.tfd.Normal(0.0, 1.0), + ]), + get_b=lambda: prior_distribution.IndependentMultivariateDistribution([ + backend.tfd.LogNormal(0.7, 0.4), + backend.tfd.Normal(0.0, 2.0), + ]), + expected_result=False, + ), dict( testcase_name='different_outer_complex_distributions', get_a=lambda: backend.tfd.BatchBroadcast(