| title | Cumulative Means with Lets-Plot |
|---|---|
| marimo-version | 0.4.6 |
| width | full |
| author | Gemini |
| description | A notebook to visualize the distribution of cumulative means and compare them to the normal distribution. |
| pyproject | requires-python = ">=3.8" dependencies = [ "marimo", "lets-plot==4.7.0", "numpy==2.3.1", "scipy==1.16.0", "polars==1.31.0", ] |
This notebook generates a set of random variables and visualizes the distribution of their cumulative means. According to the Central Limit Theorem, the distribution of sample means should approach a normal distribution as the sample size grows.
We will perform the following steps:
- Generate a matrix of random numbers.
- Calculate the cumulative means for each row.
- Plot histograms of these cumulative means.
- Overlay the Probability Density Function (PDF) of a normal distribution for comparison.
Use the sliders to control the dimensions of the random data matrix.
form = mo.md(
"""
- **Number of samples (rows):** {n_rows}
- **Number of variables (columns):** {n_cols}
"""
).batch(
n_rows=mo.ui.slider(1_000, 20_000, step=1_000, value=10_000, label=" "),
n_cols=mo.ui.slider(2, 10, value=6, label=" ")
).form()
formFinally, we use lets-plot to create the visualization. We use facet_wrap to create a separate plot for each column. Each plot contains a histogram of the cumulative means and a line representing the corresponding normal distribution PDF.
LetsPlot.setup_html()
# Define the plot object
_plot = mo.center(
ggplot()
+ geom_histogram(
data=cum_means_df,
mapping=aes(x="value", y="..density.."),
bins=100,
size=0.25,
fill="#4682B4",
alpha=0.7,
)
+ geom_line(
data=pdf_curves_df,
mapping=aes(x="value", y="pdf"),
color="red",
size=2,
alpha=0.5,
)
+ facet_wrap("column_id", ncol=3)
+ ggtitle("Distribution of Cumulative Means vs. Normal PDF")
+ xlab("Cumulative Mean Value")
+ ylab("Density")
+ flavor_darcula()
+ ggsize(1024, 600)
)
# Define the message object
_message = mo.md("### Configure and submit the form to generate the visualization.")
# Use a final expression to determine the output
mo.callout(_message if cum_means_df is None or pdf_curves_df is None else _plot)The following cells perform the data generation and processing steps.
pdf_curves_df = generate_pdf_curves(cum_means_df) if cum_means_df is not None else None
pdf_curves_dfcum_means_df = calculate_cumulative_means(random_data) if random_data is not None else None
cum_means_dfrandom_data = (
np.random.rand(form.value["n_rows"], form.value["n_cols"])
if form.value is not None
else None
)This section contains all imports, configuration, and function definitions.
import marimo as mo
import numpy as np
import polars as pl
from lets_plot import (
LetsPlot,
aes,
facet_wrap,
geom_histogram,
geom_line,
ggplot,
ggtitle,
xlab,
ylab,
flavor_darcula,
ggsize,
)
from scipy.stats import norm
from typing import Tupledef calculate_cumulative_means(data: np.ndarray) -> pl.DataFrame:
"""
Calculates the cumulative mean for each row and returns a long-format DataFrame.
Args:
data: A 2D NumPy array.
Returns:
A Polars DataFrame with columns 'column_id' and 'value'.
"""
if data is None:
return None
# Calculate cumulative sum across rows (axis=1)
cumulative_sum = data.cumsum(axis=1)
# Divisors will be 1, 2, 3, ... for each row
divisors = np.arange(1, data.shape[1] + 1)
# Calculate cumulative mean
cumulative_mean_array = cumulative_sum / divisors
# Convert to a long-format Polars DataFrame
df = pl.from_numpy(
cumulative_mean_array, schema=[f"col_{i}" for i in range(data.shape[1])]
)
# Add a row count column to use as the id for unpivoting
df = df.with_columns(pl.Series("row_count", np.arange(data.shape[0])))
return df.unpivot(index="row_count", variable_name="column_id", value_name="value")def generate_pdf_curves(df: pl.DataFrame) -> pl.DataFrame:
"""
Generates data for normal distribution PDF curves based on the input data.
Args:
df: A long-format DataFrame with 'column_id' and 'value'.
Returns:
A DataFrame with 'column_id', 'value', and 'pdf' columns for plotting.
"""
if df is None or df.is_empty():
return None
stats = df.group_by("column_id").agg(
pl.mean("value").alias("mean"), pl.std("value").alias("std")
)
curves = []
for row in stats.iter_rows(named=True):
column_id, mean, std = row["column_id"], row["mean"], row["std"]
if std is None or std == 0:
continue
# Filter the DataFrame to get the data for the current group
group_df = df.filter(pl.col("column_id") == column_id)
# Generate x-values based on the range of the specific group
x_vals = np.linspace(
group_df["value"].min(), group_df["value"].max(), 200
)
pdf = norm.pdf(x_vals, mean, std)
# Use 'value' as the column name for the x-axis to match the histogram data
curves.append(
pl.DataFrame({"column_id": column_id, "value": x_vals, "pdf": pdf})
)
return pl.concat(curves) if curves else None