Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name: documentation

on: [pull_request, workflow_dispatch]
on:
release:
types: [published] # Runs only on official releases
workflow_dispatch:

permissions:
contents: write
Expand Down
10 changes: 6 additions & 4 deletions src/pysimmmulator/geos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

class Geos:
"""Provides randomized generation of population subsets"""
def __init__(self, total_population: int, random_seed: Optional[int] = None) -> None:
def __init__(self, total_population: int, random_seed: Optional[int] = None, rng: Optional[np.random.Generator] = None) -> None:
self.total_population = total_population
self.rng = self._create_random_factory(seed=random_seed)
self.rng = rng if rng is not None else self._create_random_factory(seed=random_seed)

def __call__(self,
geo_specs: Optional[dict] = None,
Expand All @@ -18,7 +18,7 @@ def __call__(self,
geo_specs (Optional[dict]): Geography names coupled with a dict of parameters for the normal distribution of that geos population
(ie {"California":{"loc": 3.0, "scale": 0.5}}). 'loc' in this case is the multiplicative bias relative to an
equal apportionment of the total population.
universal_scale (Optional[flaot]): Scale parameter to be used universally for all geographies. Increased value means increased
universal_scale (Optional[float]): Scale parameter to be used universally for all geographies. Increased value means increased
spread in the distribution of all geos
count (int): in the absense of specified geographies, this is the number of geos to be created using the `create_random_geos` function.
Returns:
Expand Down Expand Up @@ -116,6 +116,7 @@ def distribute_to_geos(
mmm_input: 'pd.DataFrame',
geo_details: dict,
random_seed: Optional[int] = None,
rng: Optional[np.random.Generator] = None,
dist_spec: tuple[float, float] = (0.0, 0.25),
media_cost_spec: tuple[float, float] = (0.0, 0.069),
perf_spec: tuple[float, float] = (0.0, 0.069)
Expand All @@ -126,6 +127,7 @@ def distribute_to_geos(
mmm_input (pd.DataFrame): simulated MMM data that was generated as part of a prior process
geo_details (dict): formulated dict or output of the `geos` creation call (ie `geos(count=50)`)
random_seed (int): random seed for rng--if needed
rng (np.random.Generator): optional random number generator
dist_spec (tuple[float, float]): Parameters to control the normal distribution function for populations of the geographies
media_cost_spec (tuple[float, float]): Parameters to control the normal distribution function for allocation of spend across geographies
perf_spec (tuple[float, float]): Parameters to control the normal distribution function for allocation of performance across geographies
Expand All @@ -137,7 +139,7 @@ def distribute_to_geos(

geo_dataframes = []
total_population: int = sum(geo_details.values())
rng = np.random.default_rng(seed=random_seed)
rng = rng if rng is not None else np.random.default_rng(seed=random_seed)
media_cols = [w for w in mmm_input.columns if "impressions" in w or "clicks" in w]
for geo_name, geo_pop in geo_details.items():
pop_pct = geo_pop / total_population
Expand Down
3 changes: 2 additions & 1 deletion src/pysimmmulator/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,13 @@ def simulate_geos(self, mmm_df: pd.DataFrame, params: GeoParameters) -> pd.DataF
params (GeoParameters): Parameters for geographic distribution.
Returns:
pd.DataFrame: MMM DataFrame with geographic distribution"""
geos = Geos(total_population=params.total_population, random_seed=None)
geos = Geos(total_population=params.total_population, random_seed=None, rng=self.rng)
geo_details = geos(geo_specs=params.geo_specs, universal_scale=params.universal_scale, count=params.count)

mmm_df = distribute_to_geos(
mmm_input=mmm_df,
geo_details=geo_details,
rng=self.rng,
dist_spec=params.dist_spec,
media_cost_spec=params.media_cost_spec,
perf_spec=params.perf_spec
Expand Down
5 changes: 4 additions & 1 deletion src/pysimmmulator/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def __init__(self,
channel_name: str,
true_roi: float,
random_seed: int = None,
rng: Optional[np.random.Generator] = None,
bias: float = DEFAULT_STUDY_BIAS,
stdev: float = DEFAULT_STUDY_SCALE) -> None:
self.channel_name = channel_name
self._true_roi = true_roi
self.rng = self._create_random_factory(seed=random_seed)
self.rng = rng if rng is not None else self._create_random_factory(seed=random_seed)
self._bias = bias
self._stdev = stdev

Expand Down Expand Up @@ -94,12 +95,14 @@ def __init__(self,
channel_rois: dict,
channel_distributions: dict[str, dict] = dict(),
random_seed: int = None,
rng: Optional[np.random.Generator] = None,
bias: float = DEFAULT_STUDY_BIAS,
stdev: float = DEFAULT_STUDY_SCALE) -> None:
self._study_hold = {
k: Study(channel_name=k,
true_roi=v,
random_seed=random_seed,
rng=rng,
bias=channel_distributions.get(k, {}).get("bias", bias),
stdev=channel_distributions.get(k, {}).get("stdev", stdev))
for k, v in channel_rois.items()
Expand Down
19 changes: 19 additions & 0 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,22 @@ def test_multisim_get_data_coverage():
ms = Multisim()
ms.data = "test_data"
assert ms.get_data == "test_data"

def test_reproducibility():
with open("examples/example_config.yaml", "r") as f:
config = yaml.safe_load(f)

config["geo_params"] = {
"total_population": 1000000,
"count": 5
}

seed = 42
sim1 = Simulate(random_seed=seed)
df1, roi1 = sim1.run_with_config(config)

sim2 = Simulate(random_seed=seed)
df2, roi2 = sim2.run_with_config(config)

pd.testing.assert_frame_equal(df1, df2)
assert roi1 == roi2
Loading