Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions meridian/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,3 +825,9 @@
'posterior_selected_draw_count_per_chain'
)
POSTERIOR_THINNING_SEED = 'posterior_thinning_seed'


# Calibration constants.
IS_CALIBRATED = 'is_calibrated'
CALIBRATION_OUTPUTS = 'calibration_outputs'

24 changes: 19 additions & 5 deletions meridian/model/prior_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from meridian import constants
import numpy as np


__all__ = [
'IndependentMultivariateDistribution',
'PriorDistribution',
Expand Down Expand Up @@ -92,7 +91,7 @@ class PriorDistribution:
| `contribution_m` | `n_media_channels` |
| `contribution_rf` | `n_rf_channels` |
| `contribution_om` | `n_organic_media_channels` |
| `contribution_orf` | `n_organic_f_channels` |
| `contribution_orf` | `n_organic_rf_channels` |
| `contribution_n` | `n_non_media_channels` |

(σ) `n_geos` if `unique_sigma_for_each_geo`, otherwise this is `1`
Expand Down Expand Up @@ -1206,9 +1205,23 @@ def distributions_are_equal(
del a_params[constants.DISTRIBUTION]
del b_params[constants.DISTRIBUTION]

if 'distributions' in a_params and 'distributions' in b_params:
a_dists = a_params['distributions']
b_dists = b_params['distributions']
if len(a_dists) != len(b_dists):
return False
for a_d, b_d in zip(a_dists, b_dists):
if not distributions_are_equal(a_d, b_d):
return False
del a_params['distributions']
del b_params['distributions']

if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
return False

if 'distributions' in a_params or 'distributions' in b_params:
return False

if a_params.keys() != b_params.keys():
return False

Expand Down Expand Up @@ -1286,13 +1299,14 @@ def lognormal_dist_from_range(
mass_percent = np.asarray(mass_percent)

if not ((0.0 < low).all() and (low < high).all()): # pytype: disable=attribute-error
raise ValueError("'low' and 'high' values must be non-negative and satisfy "
"high > low.")
raise ValueError(
"'low' and 'high' values must be non-negative and satisfy high > low."
)

if not ((0.0 < mass_percent).all() and (mass_percent < 1.0).all()): # pytype: disable=attribute-error
raise ValueError(
"'mass_percent' values must be between 0 and 1, exclusive."
)
)

normal = backend.tfd.Normal(0, 1)
mass_lower = 0.5 - (mass_percent / 2)
Expand Down
25 changes: 24 additions & 1 deletion meridian/model/prior_distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from meridian.model import prior_distribution
import numpy as np


_N_GEOS = 10
_N_GEOS_NATIONAL = 1
_N_MEDIA_CHANNELS = 6
Expand Down Expand Up @@ -1529,6 +1528,30 @@ def test_get_total_media_contribution_prior(self):
),
expected_result=True,
),
dict(
testcase_name='same_independent_multivariate_distributions',
get_a=lambda: prior_distribution.IndependentMultivariateDistribution([
backend.tfd.LogNormal(0.7, 0.4),
backend.tfd.Normal(0.0, 1.0),
]),
get_b=lambda: prior_distribution.IndependentMultivariateDistribution([
backend.tfd.LogNormal(0.7, 0.4),
backend.tfd.Normal(0.0, 1.0),
]),
expected_result=True,
),
dict(
testcase_name='different_independent_multivariate_distributions',
get_a=lambda: prior_distribution.IndependentMultivariateDistribution([
backend.tfd.LogNormal(0.7, 0.4),
backend.tfd.Normal(0.0, 1.0),
]),
get_b=lambda: prior_distribution.IndependentMultivariateDistribution([
backend.tfd.LogNormal(0.7, 0.4),
backend.tfd.Normal(0.0, 2.0),
]),
expected_result=False,
),
dict(
testcase_name='different_outer_complex_distributions',
get_a=lambda: backend.tfd.BatchBroadcast(
Expand Down
Loading