From 5b45897453a3c2af79740cbd0ac9e2eed2a14d49 Mon Sep 17 00:00:00 2001 From: Aryan Amit Barsainyan Date: Mon, 11 May 2026 19:37:14 +0530 Subject: [PATCH] update decision tree to check group var hypothesis to fix RDD identification --- cais/components/decision_tree.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cais/components/decision_tree.py b/cais/components/decision_tree.py index 81a2ef1..3e987c9 100644 --- a/cais/components/decision_tree.py +++ b/cais/components/decision_tree.py @@ -160,6 +160,7 @@ def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional running_var = dataset_properties.get("running_variable") cutoff_val = dataset_properties.get("cutoff_value") time_var = dataset_properties.get("time_variable") + group_var = dataset_properties.get("group_variable") is_rct = dataset_properties.get("is_rct", False) has_temporal = dataset_properties.get("has_temporal_structure", False) frontdoor = dataset_properties.get("frontdoor_criterion", False) @@ -220,9 +221,9 @@ def add(method: str, justification: str, prio_order: List[str]): ] # Common early structural signals first (still only add as candidates) - if has_temporal and time_var: + if has_temporal and time_var and group_var: add(DIFF_IN_DIFF, - f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).", + f"Temporal structure via '{time_var}' for state structure via '{group_var}' —consider Difference-in-Differences (assumes parallel trends).", [DIFF_IN_DIFF]) # highest among itself if running_var and cutoff_val is not None: @@ -361,7 +362,9 @@ def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_d properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"), "covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"), - "time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"), + "time_variable": variables.get("time_variable"), + "group_variable": variables.get("group_variable"), + "running_variable": variables.get("running_variable"), "treatment_variable_type": variables.get("treatment_variable_type", "binary"), "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), "frontdoor_criterion": variables.get("frontdoor_criterion", False),