diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 37b2eb5..01d1148 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import pandas as pd @@ -1117,7 +1117,7 @@ def fit( # Compute overall ATT (simple aggregation) overall_att, overall_se = self._aggregate_simple( - group_time_effects, influence_func_info, df, unit + group_time_effects, influence_func_info, df, unit, precomputed ) overall_t = overall_att / overall_se if overall_se > 0 else 0.0 overall_p = compute_p_value(overall_t) @@ -1471,6 +1471,7 @@ def _aggregate_simple( influence_func_info: Dict, df: pd.DataFrame, unit: str, + precomputed: Optional[PrecomputedData] = None, ) -> Tuple[float, float]: """ Compute simple weighted average of ATT(g,t). @@ -1508,7 +1509,7 @@ def _aggregate_simple( # Compute SE using influence function aggregation with wif adjustment overall_se = self._compute_aggregated_se_with_wif( gt_pairs, weights_norm, effects, groups_for_gt, - influence_func_info, df, unit + influence_func_info, df, unit, precomputed ) return overall_att, overall_se @@ -1582,6 +1583,7 @@ def _compute_aggregated_se_with_wif( influence_func_info: Dict, df: pd.DataFrame, unit: str, + precomputed: Optional[PrecomputedData] = None, ) -> float: """ Compute SE with weight influence function (wif) adjustment. @@ -1605,22 +1607,23 @@ def _compute_aggregated_se_with_wif( return 0.0 # Build unit index mapping - all_units = set() + all_units_set: Set[Any] = set() for (g, t) in gt_pairs: if (g, t) in influence_func_info: info = influence_func_info[(g, t)] - all_units.update(info['treated_units']) - all_units.update(info['control_units']) + all_units_set.update(info['treated_units']) + all_units_set.update(info['control_units']) - if not all_units: + if not all_units_set: return 0.0 - all_units = sorted(all_units) + all_units = sorted(all_units_set) n_units = len(all_units) unit_to_idx = {u: i for i, u in enumerate(all_units)} # Get unique groups and their information unique_groups = sorted(set(groups_for_gt)) + unique_groups_set = set(unique_groups) group_to_idx = {g: i for i, g in enumerate(unique_groups)} # Compute group-level probabilities matching R's formula: @@ -1639,6 +1642,10 @@ def _compute_aggregated_se_with_wif( pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt]) sum_pg_keepers = np.sum(pg_keepers) + # Guard against zero weights (no keepers = no variance) + if sum_pg_keepers == 0: + return 0.0 + # Standard aggregated influence (without wif) psi_standard = np.zeros(n_units) @@ -1649,62 +1656,66 @@ def _compute_aggregated_se_with_wif( info = influence_func_info[(g, t)] w = weights[j] - for i, uid in enumerate(info['treated_units']): - idx = unit_to_idx[uid] - psi_standard[idx] += w * info['treated_inf'][i] - - for i, uid in enumerate(info['control_units']): - idx = unit_to_idx[uid] - psi_standard[idx] += w * info['control_inf'][i] - - # Build unit-group membership indicator - unit_groups = {} - for uid in all_units: - unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0] - if unit_first_treat in unique_groups: - unit_groups[uid] = unit_first_treat - else: - unit_groups[uid] = None # Never-treated or other - - # Compute wif using R's exact formula (iterate over keepers, not groups) - # R's wif function: + # Vectorized influence function aggregation for treated units + treated_indices = np.array([unit_to_idx[uid] for uid in info['treated_units']]) + if len(treated_indices) > 0: + np.add.at(psi_standard, treated_indices, w * info['treated_inf']) + + # Vectorized influence function aggregation for control units + control_indices = np.array([unit_to_idx[uid] for uid in info['control_units']]) + if len(control_indices) > 0: + np.add.at(psi_standard, control_indices, w * info['control_inf']) + + # Build unit-group array using precomputed data if available + # This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups + if precomputed is not None: + # Use precomputed cohort mapping + precomputed_units = precomputed['all_units'] + precomputed_cohorts = precomputed['unit_cohorts'] + precomputed_unit_to_idx = precomputed['unit_to_idx'] + + # Build unit_groups_array for the units in this SE computation + # A value of -1 indicates never-treated or other (not in unique_groups) + unit_groups_array = np.full(n_units, -1, dtype=np.float64) + for i, uid in enumerate(all_units): + if uid in precomputed_unit_to_idx: + cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] + if cohort in unique_groups_set: + unit_groups_array[i] = cohort + else: + # Fallback: build from DataFrame (slow path for backward compatibility) + unit_groups_array = np.full(n_units, -1, dtype=np.float64) + for i, uid in enumerate(all_units): + unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0] + if unit_first_treat in unique_groups_set: + unit_groups_array[i] = unit_first_treat + + # Vectorized WIF computation + # R's wif formula: # if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers]) # if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2 # wif[i,k] = if1[i,k] - if2[i,k] - # - # Then: wif_contrib[i] = sum_k(wif[i,k] * att[k]) + # wif_contrib[i] = sum_k(wif[i,k] * att[k]) - n_keepers = len(gt_pairs) - wif_contrib = np.zeros(n_units) + # Build indicator matrix: (n_units, n_keepers) + # indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k + groups_for_gt_array = np.array(groups_for_gt) + indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64) - # Pre-compute indicator_sum for each unit + # Vectorized indicator_sum: sum over keepers # indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k]) - indicator_sum = np.zeros(n_units) - for j, g in enumerate(groups_for_gt): - pg_k = pg_keepers[j] - for uid in all_units: - i = unit_to_idx[uid] - unit_g = unit_groups[uid] - indicator = 1.0 if unit_g == g else 0.0 - indicator_sum[i] += (indicator - pg_k) - - # Compute wif contribution for each keeper - for j, (g, t) in enumerate(gt_pairs): - pg_k = pg_keepers[j] - att_k = effects[j] - - for uid in all_units: - i = unit_to_idx[uid] - unit_g = unit_groups[uid] - indicator = 1.0 if unit_g == g else 0.0 - - # R's formula for wif - if1_ik = (indicator - pg_k) / sum_pg_keepers - if2_ik = indicator_sum[i] * pg_k / (sum_pg_keepers ** 2) - wif_ik = if1_ik - if2_ik - - # Add contribution: wif[i,k] * att[k] - wif_contrib[i] += wif_ik * att_k + indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1) + + # Vectorized wif matrix computation + # if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg + if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers + # if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2 + if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2) + wif_matrix = if1_matrix - if2_matrix + + # Single matrix-vector multiply for all contributions + # wif_contrib[i] = sum_k(wif[i,k] * att[k]) + wif_contrib = wif_matrix @ effects # Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n) psi_wif = wif_contrib / n_units diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index 8880210..4f91704 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -76,7 +76,7 @@ Summary Table - **PASS** * - CallawaySantAnna - < 1e-10 - - < 1% + - 0.0% - Yes - **PASS** * - SyntheticDiD @@ -171,17 +171,17 @@ Callaway-Sant'Anna Results - 2.519 - < 1e-10 * - SE - - 0.062 - - 0.062 - 0.063 - - 2.3% + - 0.063 + - 0.063 + - 0.0% * - Time (s) - - 0.005 - - 0.005 - - 0.071 - - **14x faster** + - 0.007 ± 0.000 + - 0.007 ± 0.000 + - 0.070 ± 0.001 + - **10x faster** -**Validation**: PASS - Both point estimates and standard errors match R closely. +**Validation**: PASS - Both point estimates and standard errors match R exactly. **Key findings from investigation:** @@ -189,9 +189,10 @@ Callaway-Sant'Anna Results 2. **Never-treated coding**: R's ``did`` package requires ``first_treat=Inf`` for never-treated units. diff-diff accepts ``first_treat=0``. The benchmark converts 0 to Inf for R compatibility. -3. **Standard errors**: As of v1.5.0, analytical SEs use influence function - aggregation (matching R's approach), resulting in < 3% SE difference across - all scales. Both analytical and bootstrap inference now match R closely. +3. **Standard errors**: As of v2.0.2, analytical SEs match R's ``did`` package + exactly (0.0% difference). The weight influence function (wif) formula was + corrected to match R's implementation, achieving numerical equivalence across + all dataset scales. Performance Comparison ---------------------- @@ -270,37 +271,37 @@ Three-Way Performance Summary - R (s) - Python Pure (s) - Python Rust (s) - - Rust/R + - Pure/R - Rust/Pure * - small - - 0.071 - - 0.005 - - 0.005 - - **14.1x** + - 0.070 + - 0.007 + - 0.007 + - **10x** - 1.0x * - 1k - 0.114 - - 0.012 - - 0.012 - - **9.4x** + - 0.013 + - 0.013 + - **9x** - 1.0x * - 5k - - 0.341 - - 0.055 - - 0.056 - - **6.1x** + - 0.345 + - 0.053 + - 0.051 + - **7x** - 1.0x * - 10k - - 0.726 - - 0.156 - - 0.155 - - **4.7x** + - 0.727 + - 0.134 + - 0.138 + - **5x** - 1.0x * - 20k - - 1.464 - - 0.404 - - 0.411 - - **3.6x** + - 1.490 + - 0.352 + - 0.358 + - **4x** - 1.0x **SyntheticDiD Results:** @@ -391,10 +392,10 @@ Dataset Sizes Key Observations ~~~~~~~~~~~~~~~~ -1. **diff-diff is dramatically faster than R**: +1. **Performance varies by estimator and scale**: - - **BasicDiD/TWFE**: 2-18x faster than R - - **CallawaySantAnna**: 4-14x faster than R + - **BasicDiD/TWFE**: 2-18x faster than R at all scales + - **CallawaySantAnna**: 4-10x faster than R at all scales (vectorized WIF computation) - **SyntheticDiD**: 565-2234x faster than R (R takes 24 minutes at 10k scale!) 2. **Rust backend benefit depends on the estimator**: @@ -410,15 +411,20 @@ Key Observations - **Bootstrap inference**: May help with parallelized iterations - **BasicDiD/CallawaySantAnna**: Optional - pure Python is equally fast -4. **Scaling behavior**: Both Python implementations show excellent scaling. - At 10K scale (500K observations for SyntheticDiD), Rust completes in - ~2.6 seconds vs ~20 seconds for pure Python vs ~24 minutes for R. +4. **Scaling behavior**: Python implementations show excellent scaling behavior + across all estimators. SyntheticDiD is 565x faster than R at 10k scale. + CallawaySantAnna achieves **exact SE accuracy** (0.0% difference) while + being 4-10x faster than R through vectorized NumPy operations. 5. **No Rust required for most use cases**: Users without Rust/maturin can install diff-diff and get full functionality with excellent performance. - For BasicDiD and CallawaySantAnna, pure Python achieves the same speed as Rust. Only SyntheticDiD benefits significantly from the Rust backend. +6. **CallawaySantAnna accuracy and speed**: As of v2.0.3, CallawaySantAnna + achieves both exact numerical accuracy (0.0% SE difference from R) AND + superior performance (4-10x faster than R) through vectorized weight + influence function (WIF) computation using NumPy matrix operations. + Performance Optimization Details ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -436,7 +442,12 @@ The performance improvements come from: 4. **Vectorized bootstrap** (CallawaySantAnna): Matrix operations instead of nested loops, batch weight generation -5. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive +5. **Vectorized WIF computation** (CallawaySantAnna, v2.0.3): Weight influence + function computation uses NumPy matrix operations instead of O(n_units × n_keepers) + nested loops. The indicator matrix, if1/if2 matrices, and wif contribution are + computed using broadcasting and matrix multiplication: ``wif_contrib = wif_matrix @ effects`` + +6. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive operations (OLS, robust variance, bootstrap weights, simplex projection) Why is diff-diff Fast? @@ -496,12 +507,13 @@ Results Comparison 1. **Point estimates match exactly**: The overall ATT of -0.039951 is identical between diff-diff and R's ``did`` package, validating the core estimation logic. -2. **Standard errors match**: As of v1.5.0, analytical SEs use influence function - aggregation (matching R's approach), resulting in < 1% difference. Both point - estimates and standard errors now match R's ``did`` package. +2. **Standard errors match exactly**: As of v2.0.2, analytical SEs use the corrected + weight influence function formula, achieving 0.0% difference from R's ``did`` + package. Both point estimates and standard errors are numerically equivalent. -3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset, - consistent with the synthetic data benchmarks at small scale. +3. **Performance**: diff-diff is ~14x faster than R on this real-world dataset + at small scale. Performance scales differently at larger sizes (see performance + tables above). This validation on real-world data with known published results confirms that diff-diff produces correct estimates that match the reference R implementation. @@ -576,9 +588,9 @@ When to Trust Results match R closely. Use ``variance_method="placebo"`` (default) to match R's inference. Results are fully validated. -- **CallawaySantAnna**: Group-time effects (ATT(g,t)) are reliable. Overall - ATT aggregation may differ from R due to weighting choices. When comparing - to R ``did`` package, verify aggregation settings match. +- **CallawaySantAnna**: Both group-time effects (ATT(g,t)) and overall ATT + aggregation match R exactly. Standard errors are numerically equivalent + (0.0% difference) as of v2.0.2. Known Differences ~~~~~~~~~~~~~~~~~