Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 69 additions & 58 deletions diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1117,7 +1117,7 @@ def fit(

# Compute overall ATT (simple aggregation)
overall_att, overall_se = self._aggregate_simple(
group_time_effects, influence_func_info, df, unit
group_time_effects, influence_func_info, df, unit, precomputed
)
overall_t = overall_att / overall_se if overall_se > 0 else 0.0
overall_p = compute_p_value(overall_t)
Expand Down Expand Up @@ -1471,6 +1471,7 @@ def _aggregate_simple(
influence_func_info: Dict,
df: pd.DataFrame,
unit: str,
precomputed: Optional[PrecomputedData] = None,
) -> Tuple[float, float]:
"""
Compute simple weighted average of ATT(g,t).
Expand Down Expand Up @@ -1508,7 +1509,7 @@ def _aggregate_simple(
# Compute SE using influence function aggregation with wif adjustment
overall_se = self._compute_aggregated_se_with_wif(
gt_pairs, weights_norm, effects, groups_for_gt,
influence_func_info, df, unit
influence_func_info, df, unit, precomputed
)

return overall_att, overall_se
Expand Down Expand Up @@ -1582,6 +1583,7 @@ def _compute_aggregated_se_with_wif(
influence_func_info: Dict,
df: pd.DataFrame,
unit: str,
precomputed: Optional[PrecomputedData] = None,
) -> float:
"""
Compute SE with weight influence function (wif) adjustment.
Expand All @@ -1605,22 +1607,23 @@ def _compute_aggregated_se_with_wif(
return 0.0

# Build unit index mapping
all_units = set()
all_units_set: Set[Any] = set()
for (g, t) in gt_pairs:
if (g, t) in influence_func_info:
info = influence_func_info[(g, t)]
all_units.update(info['treated_units'])
all_units.update(info['control_units'])
all_units_set.update(info['treated_units'])
all_units_set.update(info['control_units'])

if not all_units:
if not all_units_set:
return 0.0

all_units = sorted(all_units)
all_units = sorted(all_units_set)
n_units = len(all_units)
unit_to_idx = {u: i for i, u in enumerate(all_units)}

# Get unique groups and their information
unique_groups = sorted(set(groups_for_gt))
unique_groups_set = set(unique_groups)
group_to_idx = {g: i for i, g in enumerate(unique_groups)}

# Compute group-level probabilities matching R's formula:
Expand All @@ -1639,6 +1642,10 @@ def _compute_aggregated_se_with_wif(
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
sum_pg_keepers = np.sum(pg_keepers)

# Guard against zero weights (no keepers = no variance)
if sum_pg_keepers == 0:
return 0.0

# Standard aggregated influence (without wif)
psi_standard = np.zeros(n_units)

Expand All @@ -1649,62 +1656,66 @@ def _compute_aggregated_se_with_wif(
info = influence_func_info[(g, t)]
w = weights[j]

for i, uid in enumerate(info['treated_units']):
idx = unit_to_idx[uid]
psi_standard[idx] += w * info['treated_inf'][i]

for i, uid in enumerate(info['control_units']):
idx = unit_to_idx[uid]
psi_standard[idx] += w * info['control_inf'][i]

# Build unit-group membership indicator
unit_groups = {}
for uid in all_units:
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
if unit_first_treat in unique_groups:
unit_groups[uid] = unit_first_treat
else:
unit_groups[uid] = None # Never-treated or other

# Compute wif using R's exact formula (iterate over keepers, not groups)
# R's wif function:
# Vectorized influence function aggregation for treated units
treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']])
if len(treated_indices) > 0:
np.add.at(psi_standard, treated_indices, w * info['treated_inf'])

# Vectorized influence function aggregation for control units
control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']])
if len(control_indices) > 0:
np.add.at(psi_standard, control_indices, w * info['control_inf'])

# Build unit-group array using precomputed data if available
# This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
if precomputed is not None:
# Use precomputed cohort mapping
precomputed_units = precomputed['all_units']
precomputed_cohorts = precomputed['unit_cohorts']
precomputed_unit_to_idx = precomputed['unit_to_idx']

# Build unit_groups_array for the units in this SE computation
# A value of -1 indicates never-treated or other (not in unique_groups)
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
for i, uid in enumerate(all_units):
if uid in precomputed_unit_to_idx:
cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
if cohort in unique_groups_set:
unit_groups_array[i] = cohort
else:
# Fallback: build from DataFrame (slow path for backward compatibility)
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
for i, uid in enumerate(all_units):
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
if unit_first_treat in unique_groups_set:
unit_groups_array[i] = unit_first_treat

# Vectorized WIF computation
# R's wif formula:
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
# wif[i,k] = if1[i,k] - if2[i,k]
#
# Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])
# wif_contrib[i] = sum_k(wif[i,k] * att[k])

n_keepers = len(gt_pairs)
wif_contrib = np.zeros(n_units)
# Build indicator matrix: (n_units, n_keepers)
# indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
groups_for_gt_array = np.array(groups_for_gt)
indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)

# Pre-compute indicator_sum for each unit
# Vectorized indicator_sum: sum over keepers
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
indicator_sum = np.zeros(n_units)
for j, g in enumerate(groups_for_gt):
pg_k = pg_keepers[j]
for uid in all_units:
i = unit_to_idx[uid]
unit_g = unit_groups[uid]
indicator = 1.0 if unit_g == g else 0.0
indicator_sum[i] += (indicator - pg_k)

# Compute wif contribution for each keeper
for j, (g, t) in enumerate(gt_pairs):
pg_k = pg_keepers[j]
att_k = effects[j]

for uid in all_units:
i = unit_to_idx[uid]
unit_g = unit_groups[uid]
indicator = 1.0 if unit_g == g else 0.0

# R's formula for wif
if1_ik = (indicator - pg_k) / sum_pg_keepers
if2_ik = indicator_sum[i] * pg_k / (sum_pg_keepers ** 2)
wif_ik = if1_ik - if2_ik

# Add contribution: wif[i,k] * att[k]
wif_contrib[i] += wif_ik * att_k
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)

# Vectorized wif matrix computation
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
wif_matrix = if1_matrix - if2_matrix

# Single matrix-vector multiply for all contributions
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
wif_contrib = wif_matrix @ effects

# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
psi_wif = wif_contrib / n_units
Expand Down
108 changes: 60 additions & 48 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Summary Table
- **PASS**
* - CallawaySantAnna
- < 1e-10
- < 1%
- 0.0%
- Yes
- **PASS**
* - SyntheticDiD
Expand Down Expand Up @@ -171,27 +171,28 @@ Callaway-Sant'Anna Results
- 2.519
- < 1e-10
* - SE
- 0.062
- 0.062
- 0.063
- 2.3%
- 0.063
- 0.063
- 0.0%
* - Time (s)
- 0.005
- 0.005
- 0.071
- **14x faster**
- 0.007 ± 0.000
- 0.007 ± 0.000
- 0.070 ± 0.001
- **10x faster**

**Validation**: PASS - Both point estimates and standard errors match R closely.
**Validation**: PASS - Both point estimates and standard errors match R exactly.

**Key findings from investigation:**

1. **Individual ATT(g,t) effects match perfectly** (~1e-11 difference)
2. **Never-treated coding**: R's ``did`` package requires ``first_treat=Inf``
for never-treated units. diff-diff accepts ``first_treat=0``. The benchmark
converts 0 to Inf for R compatibility.
3. **Standard errors**: As of v1.5.0, analytical SEs use influence function
aggregation (matching R's approach), resulting in < 3% SE difference across
all scales. Both analytical and bootstrap inference now match R closely.
3. **Standard errors**: As of v2.0.2, analytical SEs match R's ``did`` package
exactly (0.0% difference). The weight influence function (wif) formula was
corrected to match R's implementation, achieving numerical equivalence across
all dataset scales.

Performance Comparison
----------------------
Expand Down Expand Up @@ -270,37 +271,37 @@ Three-Way Performance Summary
- R (s)
- Python Pure (s)
- Python Rust (s)
- Rust/R
- Pure/R
- Rust/Pure
* - small
- 0.071
- 0.005
- 0.005
- **14.1x**
- 0.070
- 0.007
- 0.007
- **10x**
- 1.0x
* - 1k
- 0.114
- 0.012
- 0.012
- **9.4x**
- 0.013
- 0.013
- **9x**
- 1.0x
* - 5k
- 0.341
- 0.055
- 0.056
- **6.1x**
- 0.345
- 0.053
- 0.051
- **7x**
- 1.0x
* - 10k
- 0.726
- 0.156
- 0.155
- **4.7x**
- 0.727
- 0.134
- 0.138
- **5x**
- 1.0x
* - 20k
- 1.464
- 0.404
- 0.411
- **3.6x**
- 1.490
- 0.352
- 0.358
- **4x**
- 1.0x

**SyntheticDiD Results:**
Expand Down Expand Up @@ -391,10 +392,10 @@ Dataset Sizes
Key Observations
~~~~~~~~~~~~~~~~

1. **diff-diff is dramatically faster than R**:
1. **Performance varies by estimator and scale**:

- **BasicDiD/TWFE**: 2-18x faster than R
- **CallawaySantAnna**: 4-14x faster than R
- **BasicDiD/TWFE**: 2-18x faster than R at all scales
- **CallawaySantAnna**: 4-10x faster than R at all scales (vectorized WIF computation)
- **SyntheticDiD**: 565-2234x faster than R (R takes 24 minutes at 10k scale!)

2. **Rust backend benefit depends on the estimator**:
Expand All @@ -410,15 +411,20 @@ Key Observations
- **Bootstrap inference**: May help with parallelized iterations
- **BasicDiD/CallawaySantAnna**: Optional - pure Python is equally fast

4. **Scaling behavior**: Both Python implementations show excellent scaling.
At 10K scale (500K observations for SyntheticDiD), Rust completes in
~2.6 seconds vs ~20 seconds for pure Python vs ~24 minutes for R.
4. **Scaling behavior**: Python implementations show excellent scaling behavior
across all estimators. SyntheticDiD is 565x faster than R at 10k scale.
CallawaySantAnna achieves **exact SE accuracy** (0.0% difference) while
being 4-10x faster than R through vectorized NumPy operations.

5. **No Rust required for most use cases**: Users without Rust/maturin can
install diff-diff and get full functionality with excellent performance.
For BasicDiD and CallawaySantAnna, pure Python achieves the same speed as Rust.
Only SyntheticDiD benefits significantly from the Rust backend.

6. **CallawaySantAnna accuracy and speed**: As of v2.0.3, CallawaySantAnna
achieves both exact numerical accuracy (0.0% SE difference from R) AND
superior performance (4-10x faster than R) through vectorized weight
influence function (WIF) computation using NumPy matrix operations.

Performance Optimization Details
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -436,7 +442,12 @@ The performance improvements come from:
4. **Vectorized bootstrap** (CallawaySantAnna): Matrix operations instead of
nested loops, batch weight generation

5. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive
5. **Vectorized WIF computation** (CallawaySantAnna, v2.0.3): Weight influence
function computation uses NumPy matrix operations instead of O(n_units × n_keepers)
nested loops. The indicator matrix, if1/if2 matrices, and wif contribution are
computed using broadcasting and matrix multiplication: ``wif_contrib = wif_matrix @ effects``

6. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive
operations (OLS, robust variance, bootstrap weights, simplex projection)

Why is diff-diff Fast?
Expand Down Expand Up @@ -496,12 +507,13 @@ Results Comparison
1. **Point estimates match exactly**: The overall ATT of -0.039951 is identical
between diff-diff and R's ``did`` package, validating the core estimation logic.

2. **Standard errors match**: As of v1.5.0, analytical SEs use influence function
aggregation (matching R's approach), resulting in < 1% difference. Both point
estimates and standard errors now match R's ``did`` package.
2. **Standard errors match exactly**: As of v2.0.2, analytical SEs use the corrected
weight influence function formula, achieving 0.0% difference from R's ``did``
package. Both point estimates and standard errors are numerically equivalent.

3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset,
consistent with the synthetic data benchmarks at small scale.
3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset
at small scale. Performance scales differently at larger sizes (see performance
tables above).

This validation on real-world data with known published results confirms that
diff-diff produces correct estimates that match the reference R implementation.
Expand Down Expand Up @@ -576,9 +588,9 @@ When to Trust Results
match R closely. Use ``variance_method="placebo"`` (default) to match R's
inference. Results are fully validated.

- **CallawaySantAnna**: Group-time effects (ATT(g,t)) are reliable. Overall
ATT aggregation may differ from R due to weighting choices. When comparing
to R ``did`` package, verify aggregation settings match.
- **CallawaySantAnna**: Both group-time effects (ATT(g,t)) and overall ATT
aggregation match R exactly. Standard errors are numerically equivalent
(0.0% difference) as of v2.0.2.

Known Differences
~~~~~~~~~~~~~~~~~
Expand Down