Skip to content

Latest commit

 

History

History
204 lines (170 loc) · 5.75 KB

File metadata and controls

204 lines (170 loc) · 5.75 KB
title Cumulative Means with Lets-Plot
marimo-version 0.4.6
width full
author Gemini
description A notebook to visualize the distribution of cumulative means and compare them to the normal distribution.
pyproject requires-python = ">=3.8" dependencies = [ "marimo", "lets-plot==4.7.0", "numpy==2.3.1", "scipy==1.16.0", "polars==1.31.0", ]

Distribution of Cumulative Means

This notebook generates a set of random variables and visualizes the distribution of their cumulative means. According to the Central Limit Theorem, the distribution of sample means should approach a normal distribution as the sample size grows.

We will perform the following steps:

  1. Generate a matrix of random numbers.
  2. Calculate the cumulative means for each row.
  3. Plot histograms of these cumulative means.
  4. Overlay the Probability Density Function (PDF) of a normal distribution for comparison.

Configuration

Use the sliders to control the dimensions of the random data matrix.

form = mo.md(
    """
    - **Number of samples (rows):** {n_rows}
    - **Number of variables (columns):** {n_cols}
    """
).batch(
    n_rows=mo.ui.slider(1_000, 20_000, step=1_000, value=10_000, label=" "),
    n_cols=mo.ui.slider(2, 10, value=6, label=" ")
).form()

form

Visualization

Finally, we use lets-plot to create the visualization. We use facet_wrap to create a separate plot for each column. Each plot contains a histogram of the cumulative means and a line representing the corresponding normal distribution PDF.

LetsPlot.setup_html()

# Define the plot object
_plot = mo.center(
    ggplot()
    + geom_histogram(
        data=cum_means_df,
        mapping=aes(x="value", y="..density.."),
        bins=100,
        size=0.25,
        fill="#4682B4",
        alpha=0.7,
    )
    + geom_line(
        data=pdf_curves_df,
        mapping=aes(x="value", y="pdf"),
        color="red",
        size=2,
        alpha=0.5,
    )
    + facet_wrap("column_id", ncol=3)
    + ggtitle("Distribution of Cumulative Means vs. Normal PDF")
    + xlab("Cumulative Mean Value")
    + ylab("Density")
    + flavor_darcula()
    + ggsize(1024, 600)
)

# Define the message object
_message = mo.md("### Configure and submit the form to generate the visualization.")

# Use a final expression to determine the output
mo.callout(_message if cum_means_df is None or pdf_curves_df is None else _plot)

Data Processing

The following cells perform the data generation and processing steps.

pdf_curves_df = generate_pdf_curves(cum_means_df) if cum_means_df is not None else None
pdf_curves_df
cum_means_df = calculate_cumulative_means(random_data) if random_data is not None else None
cum_means_df
random_data = (
    np.random.rand(form.value["n_rows"], form.value["n_cols"])
    if form.value is not None
    else None
)

Definitions & Imports

This section contains all imports, configuration, and function definitions.

import marimo as mo
import numpy as np
import polars as pl
from lets_plot import (
    LetsPlot,
    aes,
    facet_wrap,
    geom_histogram,
    geom_line,
    ggplot,
    ggtitle,
    xlab,
    ylab,
    flavor_darcula,
    ggsize,
)
from scipy.stats import norm
from typing import Tuple
def calculate_cumulative_means(data: np.ndarray) -> pl.DataFrame:
    """
    Calculates the cumulative mean for each row and returns a long-format DataFrame.

    Args:
        data: A 2D NumPy array.

    Returns:
        A Polars DataFrame with columns 'column_id' and 'value'.
    """
    if data is None:
        return None
    # Calculate cumulative sum across rows (axis=1)
    cumulative_sum = data.cumsum(axis=1)
    # Divisors will be 1, 2, 3, ... for each row
    divisors = np.arange(1, data.shape[1] + 1)

    # Calculate cumulative mean
    cumulative_mean_array = cumulative_sum / divisors

    # Convert to a long-format Polars DataFrame
    df = pl.from_numpy(
        cumulative_mean_array, schema=[f"col_{i}" for i in range(data.shape[1])]
    )
    # Add a row count column to use as the id for unpivoting
    df = df.with_columns(pl.Series("row_count", np.arange(data.shape[0])))
    return df.unpivot(index="row_count", variable_name="column_id", value_name="value")
def generate_pdf_curves(df: pl.DataFrame) -> pl.DataFrame:
    """
    Generates data for normal distribution PDF curves based on the input data.

    Args:
        df: A long-format DataFrame with 'column_id' and 'value'.

    Returns:
        A DataFrame with 'column_id', 'value', and 'pdf' columns for plotting.
    """
    if df is None or df.is_empty():
        return None

    stats = df.group_by("column_id").agg(
        pl.mean("value").alias("mean"), pl.std("value").alias("std")
    )

    curves = []
    for row in stats.iter_rows(named=True):
        column_id, mean, std = row["column_id"], row["mean"], row["std"]
        
        if std is None or std == 0:
            continue

        # Filter the DataFrame to get the data for the current group
        group_df = df.filter(pl.col("column_id") == column_id)

        # Generate x-values based on the range of the specific group
        x_vals = np.linspace(
            group_df["value"].min(), group_df["value"].max(), 200
        )
        pdf = norm.pdf(x_vals, mean, std)
        # Use 'value' as the column name for the x-axis to match the histogram data
        curves.append(
            pl.DataFrame({"column_id": column_id, "value": x_vals, "pdf": pdf})
        )

    return pl.concat(curves) if curves else None