Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 157 additions & 6 deletions diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,29 +1479,36 @@ def _aggregate_simple(

Standard errors are computed using influence function aggregation,
which properly accounts for covariances across (g,t) pairs due to
shared control units. This matches R's `did` package approach.
shared control units. This includes the wif (weight influence function)
adjustment from R's `did` package that accounts for uncertainty in
estimating the group-size weights.
"""
effects = []
weights_list = []
gt_pairs = []
groups_for_gt = []

for (g, t), data in group_time_effects.items():
effects.append(data['effect'])
weights_list.append(data['n_treated'])
gt_pairs.append((g, t))
groups_for_gt.append(g)

effects = np.array(effects)
weights = np.array(weights_list, dtype=float)
groups_for_gt = np.array(groups_for_gt)

# Normalize weights
weights = weights / np.sum(weights)
total_weight = np.sum(weights)
weights_norm = weights / total_weight

# Weighted average
overall_att = np.sum(weights * effects)
overall_att = np.sum(weights_norm * effects)

# Compute SE using influence function aggregation
overall_se = self._compute_aggregated_se(
gt_pairs, weights, influence_func_info
# Compute SE using influence function aggregation with wif adjustment
overall_se = self._compute_aggregated_se_with_wif(
gt_pairs, weights_norm, effects, groups_for_gt,
influence_func_info, df, unit
)

return overall_att, overall_se
Expand Down Expand Up @@ -1566,6 +1573,150 @@ def _compute_aggregated_se(
variance = np.sum(psi_overall ** 2)
return np.sqrt(variance)

def _compute_aggregated_se_with_wif(
self,
gt_pairs: List[Tuple[Any, Any]],
weights: np.ndarray,
effects: np.ndarray,
groups_for_gt: np.ndarray,
influence_func_info: Dict,
df: pd.DataFrame,
unit: str,
) -> float:
"""
Compute SE with weight influence function (wif) adjustment.

This matches R's `did` package approach for "simple" aggregation,
which accounts for uncertainty in estimating group-size weights.

The wif adjustment adds variance due to the fact that aggregation
weights w_g = n_g / N depend on estimated group sizes.

Formula (matching R's did::aggte):
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
se = sqrt(mean(agg_inf^2) / n)

where:
- k indexes "keepers" (post-treatment (g,t) pairs)
- w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
- wif captures how unit i influences the weight estimation
"""
if not influence_func_info:
return 0.0

# Build unit index mapping
all_units = set()
for (g, t) in gt_pairs:
if (g, t) in influence_func_info:
info = influence_func_info[(g, t)]
all_units.update(info['treated_units'])
all_units.update(info['control_units'])

if not all_units:
return 0.0

all_units = sorted(all_units)
n_units = len(all_units)
unit_to_idx = {u: i for i, u in enumerate(all_units)}

# Get unique groups and their information
unique_groups = sorted(set(groups_for_gt))
group_to_idx = {g: i for i, g in enumerate(unique_groups)}

# Compute group-level probabilities matching R's formula:
# pg[g] = n_g / n_all (fraction of ALL units in group g)
# This differs from our old formula which used n_g / total_treated
group_sizes = {}
for g in unique_groups:
treated_in_g = df[df['first_treat'] == g][unit].nunique()
group_sizes[g] = treated_in_g

# pg indexed by group
pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])

# pg indexed by keeper (each (g,t) pair gets its group's pg)
# This matches R's: pg <- pgg[match(group, originalglist)]
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
sum_pg_keepers = np.sum(pg_keepers)

# Standard aggregated influence (without wif)
psi_standard = np.zeros(n_units)

for j, (g, t) in enumerate(gt_pairs):
if (g, t) not in influence_func_info:
continue

info = influence_func_info[(g, t)]
w = weights[j]

for i, uid in enumerate(info['treated_units']):
idx = unit_to_idx[uid]
psi_standard[idx] += w * info['treated_inf'][i]

for i, uid in enumerate(info['control_units']):
idx = unit_to_idx[uid]
psi_standard[idx] += w * info['control_inf'][i]

# Build unit-group membership indicator
unit_groups = {}
for uid in all_units:
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
if unit_first_treat in unique_groups:
unit_groups[uid] = unit_first_treat
else:
unit_groups[uid] = None # Never-treated or other

# Compute wif using R's exact formula (iterate over keepers, not groups)
# R's wif function:
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
# wif[i,k] = if1[i,k] - if2[i,k]
#
# Then: wif_contrib[i] = sum_k(wif[i,k] * att[k])

n_keepers = len(gt_pairs)
wif_contrib = np.zeros(n_units)

# Pre-compute indicator_sum for each unit
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
indicator_sum = np.zeros(n_units)
for j, g in enumerate(groups_for_gt):
pg_k = pg_keepers[j]
for uid in all_units:
i = unit_to_idx[uid]
unit_g = unit_groups[uid]
indicator = 1.0 if unit_g == g else 0.0
indicator_sum[i] += (indicator - pg_k)

# Compute wif contribution for each keeper
for j, (g, t) in enumerate(gt_pairs):
pg_k = pg_keepers[j]
att_k = effects[j]

for uid in all_units:
i = unit_to_idx[uid]
unit_g = unit_groups[uid]
indicator = 1.0 if unit_g == g else 0.0

# R's formula for wif
if1_ik = (indicator - pg_k) / sum_pg_keepers
if2_ik = indicator_sum[i] * pg_k / (sum_pg_keepers ** 2)
wif_ik = if1_ik - if2_ik

# Add contribution: wif[i,k] * att[k]
wif_contrib[i] += wif_ik * att_k

# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
psi_wif = wif_contrib / n_units

# Combine standard and wif terms
psi_total = psi_standard + psi_wif

# Compute variance and SE
# R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
variance = np.sum(psi_total ** 2)
return np.sqrt(variance)

def _aggregate_event_study(
self,
group_time_effects: Dict,
Expand Down
Loading