From 4ef5142831ab0beca9944cb1d027f12307c986bc Mon Sep 17 00:00:00 2001 From: Sepideh Abedini Date: Wed, 7 Jan 2026 18:31:37 -0500 Subject: [PATCH 1/4] Delete deprecated codes --- DESIGN.md | 16 +- EA_failure_analysis.py | 77 ----- pyproject.toml | 5 +- resdsql/evaluate_text2sql_ckpts.py | 6 +- resdsql/get_schema_ranks.py | 2 +- resdsql/schema_item_classifier.py | 2 +- resdsql/text2sql.py | 2 +- src/config.py | 2 +- src/data_cache/json_cache.py | 2 +- src/{models => data_models}/__init__.py | 2 +- src/{models => data_models}/base_object.py | 0 src/{models => data_models}/data_point.py | 0 src/{models => data_models}/masksql_input.py | 2 +- src/{models => data_models}/masksql_output.py | 4 +- src/{models => data_models}/masksql_result.py | 2 +- src/masksql.py | 55 ++-- src/pipe/abl_prompts/__init__.py | 1 - src/pipe/abl_prompts/schema_value_link.py | 120 -------- src/pipe/add_masked_terms.py | 38 --- src/pipe/attack_prompts/add_masked_terms.py | 32 --- src/pipe/attack_prompts/attack_raw_v1.py | 42 --- src/pipe/deprecated/__init__.py | 5 - src/pipe/deprecated/add_fks.py | 39 --- src/pipe/deprecated/add_full_schema.py | 132 --------- src/pipe/deprecated/add_masked_terms_det.py | 226 --------------- .../add_value_links_from_schema_links.py | 30 -- src/pipe/deprecated/add_value_symbol_table.py | 39 --- src/pipe/deprecated/filtered_symb_schema.py | 265 ------------------ src/pipe/deprecated/utility.py | 164 ----------- src/pipe/deprecated/value_link_eval.py | 46 --- src/pipe/estimate_sql.py | 85 ------ src/pipe/filer_schema_items.py | 26 -- src/pipe/filer_schema_links.py | 32 --- src/pipe/filer_value_links.py | 27 -- src/pipe/gen_gold_mask.py | 18 -- src/pipe/gen_gold_schema.py | 22 -- src/pipe/gen_masked_sql_raw.py | 22 -- src/pipe/gen_sql.py | 70 ----- src/pipe/gen_sql_prompts.py | 30 -- src/pipe/gold_mask/__init__.py | 1 - src/pipe/gold_mask/gold_mask_v1.py | 37 --- src/pipe/gold_schema_link/__init__.py | 1 - src/pipe/gold_schema_link/repair.py | 74 ----- src/pipe/gold_schema_link/v1.py | 72 ----- src/pipe/link_schema_and_value.py | 51 ---- src/pipe/monitor/__init__.py | 1 - src/pipe/monitor/lib.py | 134 --------- src/pipe/processor/print_results.py | 116 -------- src/pipe/repair_link_schema.py | 62 ---- .../filter_annotated_links.py | 41 --- .../schema_items_filter_prompts/__init__.py | 1 - src/pipe/schema_items_filter_prompts/v1.py | 34 --- src/pipe/schema_link_prompts/repair.py | 54 ---- src/pipe/slm_mask.py | 39 --- src/pipe/slm_mask_for_det_unmask.py | 27 -- src/pipe/slm_sql.py | 49 ---- src/pipe/slm_sql_prompt/v1.py | 70 ----- src/pipe/slm_unmask_repair.py | 30 -- src/pipe/utils.py | 151 ---------- src/{pipe => pipeline}/__init__.py | 0 src/{pipe => pipeline}/add_schema.py | 6 +- src/{pipe => pipeline}/add_symb_schema.py | 8 +- .../add_symbolic_question.py} | 6 +- .../attack}/__init__.py | 0 .../attack/add_inference_attack.py} | 24 +- src/pipeline/attack/prompts/__init__.py | 4 + .../attack/prompts}/attack_v1.py | 2 +- .../attack/prompts}/attack_v2.py | 2 +- .../base_processor}/__init__.py | 0 .../base_processor}/limit_list.py | 6 +- .../base_processor}/list_processor.py | 8 +- .../base_processor}/list_transformer.py | 4 +- src/pipeline/base_processor/print_results.py | 22 ++ .../base_processor}/printer.py | 4 +- .../base_processor}/prompt_processor.py | 18 +- .../base_processor}/prop_printer.py | 2 +- .../detect_values}/__init__.py | 0 .../detect_values/detect_values.py} | 10 +- .../detect_values/prompts/__init__.py | 5 + .../detect_values/prompts}/v1.py | 2 +- .../detect_values/prompts}/v2.py | 2 +- .../detect_values/prompts}/v3.py | 2 +- src/{pipe => pipeline}/exec_acc.py | 6 +- src/{pipe => pipeline}/exec_conc_sql.py | 6 +- .../filter_schema_links}/__init__.py | 0 .../filter_schema_links.py | 36 +++ .../filter_schema_links/prompts/__init__.py | 5 + .../filter_schema_links/prompts}/v1.py | 2 +- .../filter_schema_links/prompts}/v2.py | 2 +- .../filter_value_links}/__init__.py | 0 .../filter_value_links/filter_value_links.py | 34 +++ .../filter_value_links/prompts/__init__.py | 5 + .../filter_value_links/prompts}/v1.py | 2 +- .../filter_value_links/prompts}/v2.py | 2 +- .../gen_sql}/__init__.py | 0 .../gen_sql}/gen_masked_sql.py | 8 +- src/pipeline/gen_sql/prompts/__init__.py | 5 + .../gen_sql/prompts}/masked_v1.py | 2 +- .../gen_sql/prompts}/masked_v2.py | 2 +- .../gen_sql/prompts}/masked_v3.py | 2 +- .../gen_sql/prompts}/masked_v3_raw.py | 2 +- .../gen_sql/prompts}/masked_v4.py | 2 +- .../gen_sql/prompts}/unmasked_v1.py | 2 +- .../init_data.py} | 4 +- .../link_schema}/__init__.py | 0 .../link_schema}/link_schema.py | 8 +- src/pipeline/link_schema/prompts/__init__.py | 5 + .../link_schema/prompts}/v1.py | 2 +- .../link_schema/prompts}/v2.py | 2 +- .../link_schema/prompts}/v3.py | 2 +- .../link_schema/prompts}/v4.py | 2 +- .../link_schema/prompts}/v5.py | 2 +- .../link_values}/__init__.py | 0 .../link_values/link_values.py} | 14 +- src/pipeline/link_values/prompts/__init__.py | 5 + .../link_values/prompts}/v1.py | 2 +- src/{pipe => pipeline}/pipeline.py | 8 +- src/{pipe => pipeline}/rank_schema.py | 6 +- src/{pipe => pipeline}/rank_schema_llm.py | 12 +- .../rank_schema_prompts/__init__.py | 0 .../rank_schema_prompts/v1.py | 2 +- .../repair_sql}/__init__.py | 0 src/pipeline/repair_sql/prompts/__init__.py | 5 + .../repair_sql/prompts}/v1.py | 2 +- .../repair_sql/prompts}/v2.py | 2 +- .../repair_sql/prompts}/v3.py | 2 +- .../repair_sql/prompts}/v4.py | 2 +- .../repair_sql/prompts}/v5.py | 2 +- .../repair_sql}/repair_sql.py | 8 +- .../repair_symb_sql}/__init__.py | 0 .../repair_symb_sql/prompts/__init__.py | 5 + .../repair_symb_sql/prompts}/v1.py | 2 +- .../repair_symb_sql/prompts}/v2.py | 2 +- .../repair_symb_sql}/raw_v2.py | 2 +- .../repair_symb_sql}/repair_symb_sql.py | 24 +- src/pipeline/resd/__init__.py | 5 + .../resdsql.py => pipeline/resd/add_resd.py} | 8 +- src/{pipe => pipeline/resd}/run_resdsql.py | 6 +- src/{pipe => pipeline}/results.py | 4 +- .../slm_mask_prompts/__init__.py | 0 .../mask_and_schema_link_v1.py | 2 +- .../mask_and_schema_link_v2.py | 2 +- .../slm_mask_prompts/mask_v1.py | 2 +- .../slm_mask_prompts/unmask_and_repair_v1.py | 2 +- .../slm_mask_prompts/unmask_v1.py | 2 +- src/{pipe => pipeline}/symb_table.py | 6 +- src/{pipe => pipeline}/unmask.py | 6 +- src/pipeline/util_processors/__init__.py | 5 + .../util_processors}/copy_transformer.py | 2 +- src/{pipe => utils}/async_utils.py | 0 src/{pipe => utils}/llm_util.py | 14 +- src/{pipe/monitor => utils}/mem.py | 0 src/{pipe => utils}/schema_repo.py | 0 src/{pipe => utils}/sqlite_facade.py | 0 src/utils/strings.py | 153 ++++++++++ src/utils/timer.py | 45 +++ test.json | 6 + tests/e2e/test_pipeline.py | 12 +- tests/e2e/test_processor.py | 14 +- uv.lock | 8 +- 160 files changed, 564 insertions(+), 2871 deletions(-) delete mode 100644 EA_failure_analysis.py rename src/{models => data_models}/__init__.py (62%) rename src/{models => data_models}/base_object.py (100%) rename src/{models => data_models}/data_point.py (100%) rename src/{models => data_models}/masksql_input.py (93%) rename src/{models => data_models}/masksql_output.py (86%) rename src/{models => data_models}/masksql_result.py (92%) delete mode 100644 src/pipe/abl_prompts/__init__.py delete mode 100644 src/pipe/abl_prompts/schema_value_link.py delete mode 100644 src/pipe/add_masked_terms.py delete mode 100644 src/pipe/attack_prompts/add_masked_terms.py delete mode 100644 src/pipe/attack_prompts/attack_raw_v1.py delete mode 100644 src/pipe/deprecated/__init__.py delete mode 100644 src/pipe/deprecated/add_fks.py delete mode 100644 src/pipe/deprecated/add_full_schema.py delete mode 100644 src/pipe/deprecated/add_masked_terms_det.py delete mode 100644 src/pipe/deprecated/add_value_links_from_schema_links.py delete mode 100644 src/pipe/deprecated/add_value_symbol_table.py delete mode 100644 src/pipe/deprecated/filtered_symb_schema.py delete mode 100644 src/pipe/deprecated/utility.py delete mode 100644 src/pipe/deprecated/value_link_eval.py delete mode 100644 src/pipe/estimate_sql.py delete mode 100644 src/pipe/filer_schema_items.py delete mode 100644 src/pipe/filer_schema_links.py delete mode 100644 src/pipe/filer_value_links.py delete mode 100644 src/pipe/gen_gold_mask.py delete mode 100644 src/pipe/gen_gold_schema.py delete mode 100644 src/pipe/gen_masked_sql_raw.py delete mode 100644 src/pipe/gen_sql.py delete mode 100644 src/pipe/gen_sql_prompts.py delete mode 100644 src/pipe/gold_mask/__init__.py delete mode 100644 src/pipe/gold_mask/gold_mask_v1.py delete mode 100644 src/pipe/gold_schema_link/__init__.py delete mode 100644 src/pipe/gold_schema_link/repair.py delete mode 100644 src/pipe/gold_schema_link/v1.py delete mode 100644 src/pipe/link_schema_and_value.py delete mode 100644 src/pipe/monitor/__init__.py delete mode 100644 src/pipe/monitor/lib.py delete mode 100644 src/pipe/processor/print_results.py delete mode 100644 src/pipe/repair_link_schema.py delete mode 100644 src/pipe/schema_filter_prompts/filter_annotated_links.py delete mode 100644 src/pipe/schema_items_filter_prompts/__init__.py delete mode 100644 src/pipe/schema_items_filter_prompts/v1.py delete mode 100644 src/pipe/schema_link_prompts/repair.py delete mode 100644 src/pipe/slm_mask.py delete mode 100644 src/pipe/slm_mask_for_det_unmask.py delete mode 100644 src/pipe/slm_sql.py delete mode 100644 src/pipe/slm_sql_prompt/v1.py delete mode 100644 src/pipe/slm_unmask_repair.py delete mode 100644 src/pipe/utils.py rename src/{pipe => pipeline}/__init__.py (100%) rename src/{pipe => pipeline}/add_schema.py (93%) rename src/{pipe => pipeline}/add_symb_schema.py (96%) rename src/{pipe/det_mask.py => pipeline/add_symbolic_question.py} (97%) rename src/{pipe/attack_prompts => pipeline/attack}/__init__.py (100%) rename src/{pipe/attack.py => pipeline/attack/add_inference_attack.py} (65%) create mode 100644 src/pipeline/attack/prompts/__init__.py rename src/{pipe/attack_prompts => pipeline/attack/prompts}/attack_v1.py (97%) rename src/{pipe/attack_prompts => pipeline/attack/prompts}/attack_v2.py (97%) rename src/{pipe/processor => pipeline/base_processor}/__init__.py (100%) rename src/{pipe/processor => pipeline/base_processor}/limit_list.py (85%) rename src/{pipe/processor => pipeline/base_processor}/list_processor.py (94%) rename src/{pipe/processor => pipeline/base_processor}/list_transformer.py (79%) create mode 100644 src/pipeline/base_processor/print_results.py rename src/{pipe/processor => pipeline/base_processor}/printer.py (81%) rename src/{pipe/detect_values_prompts => pipeline/base_processor}/prompt_processor.py (82%) rename src/{pipe/processor => pipeline/base_processor}/prop_printer.py (93%) rename src/{pipe/detect_values_prompts => pipeline/detect_values}/__init__.py (100%) rename src/{pipe/detect_entities.py => pipeline/detect_values/detect_values.py} (78%) create mode 100644 src/pipeline/detect_values/prompts/__init__.py rename src/{pipe/detect_values_prompts => pipeline/detect_values/prompts}/v1.py (95%) rename src/{pipe/detect_values_prompts => pipeline/detect_values/prompts}/v2.py (97%) rename src/{pipe/detect_values_prompts => pipeline/detect_values/prompts}/v3.py (96%) rename src/{pipe => pipeline}/exec_acc.py (92%) rename src/{pipe => pipeline}/exec_conc_sql.py (95%) rename src/{pipe/schema_filter_prompts => pipeline/filter_schema_links}/__init__.py (100%) create mode 100644 src/pipeline/filter_schema_links/filter_schema_links.py create mode 100644 src/pipeline/filter_schema_links/prompts/__init__.py rename src/{pipe/schema_filter_prompts => pipeline/filter_schema_links/prompts}/v1.py (94%) rename src/{pipe/schema_filter_prompts => pipeline/filter_schema_links/prompts}/v2.py (96%) rename src/{pipe/value_filter_prompts => pipeline/filter_value_links}/__init__.py (100%) create mode 100644 src/pipeline/filter_value_links/filter_value_links.py create mode 100644 src/pipeline/filter_value_links/prompts/__init__.py rename src/{pipe/value_filter_prompts => pipeline/filter_value_links/prompts}/v1.py (97%) rename src/{pipe/value_filter_prompts => pipeline/filter_value_links/prompts}/v2.py (96%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql}/__init__.py (100%) rename src/{pipe => pipeline/gen_sql}/gen_masked_sql.py (84%) create mode 100644 src/pipeline/gen_sql/prompts/__init__.py rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/masked_v1.py (98%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/masked_v2.py (97%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/masked_v3.py (97%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/masked_v3_raw.py (97%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/masked_v4.py (98%) rename src/{pipe/sql_gen_prompts => pipeline/gen_sql/prompts}/unmasked_v1.py (96%) rename src/{pipe/util_processors.py => pipeline/init_data.py} (85%) rename src/{pipe/schema_link_prompts => pipeline/link_schema}/__init__.py (100%) rename src/{pipe => pipeline/link_schema}/link_schema.py (90%) create mode 100644 src/pipeline/link_schema/prompts/__init__.py rename src/{pipe/schema_link_prompts => pipeline/link_schema/prompts}/v1.py (94%) rename src/{pipe/schema_link_prompts => pipeline/link_schema/prompts}/v2.py (92%) rename src/{pipe/schema_link_prompts => pipeline/link_schema/prompts}/v3.py (95%) rename src/{pipe/schema_link_prompts => pipeline/link_schema/prompts}/v4.py (98%) rename src/{pipe/schema_link_prompts => pipeline/link_schema/prompts}/v5.py (97%) rename src/{pipe/value_linking_prompts => pipeline/link_values}/__init__.py (100%) rename src/{pipe/value_links.py => pipeline/link_values/link_values.py} (76%) create mode 100644 src/pipeline/link_values/prompts/__init__.py rename src/{pipe/value_linking_prompts => pipeline/link_values/prompts}/v1.py (98%) rename src/{pipe => pipeline}/pipeline.py (92%) rename src/{pipe => pipeline}/rank_schema.py (90%) rename src/{pipe => pipeline}/rank_schema_llm.py (85%) rename src/{pipe => pipeline}/rank_schema_prompts/__init__.py (100%) rename src/{pipe => pipeline}/rank_schema_prompts/v1.py (95%) rename src/{pipe/sql_repair_prompts => pipeline/repair_sql}/__init__.py (100%) create mode 100644 src/pipeline/repair_sql/prompts/__init__.py rename src/{pipe/sql_repair_prompts => pipeline/repair_sql/prompts}/v1.py (90%) rename src/{pipe/sql_repair_prompts => pipeline/repair_sql/prompts}/v2.py (95%) rename src/{pipe/sql_repair_prompts => pipeline/repair_sql/prompts}/v3.py (99%) rename src/{pipe/sql_repair_prompts => pipeline/repair_sql/prompts}/v4.py (99%) rename src/{pipe/sql_repair_prompts => pipeline/repair_sql/prompts}/v5.py (98%) rename src/{pipe => pipeline/repair_sql}/repair_sql.py (87%) rename src/{pipe/symb_sql_repair_prompts => pipeline/repair_symb_sql}/__init__.py (100%) create mode 100644 src/pipeline/repair_symb_sql/prompts/__init__.py rename src/{pipe/symb_sql_repair_prompts => pipeline/repair_symb_sql/prompts}/v1.py (98%) rename src/{pipe/symb_sql_repair_prompts => pipeline/repair_symb_sql/prompts}/v2.py (98%) rename src/{pipe/symb_sql_repair_prompts => pipeline/repair_symb_sql}/raw_v2.py (98%) rename src/{pipe => pipeline/repair_symb_sql}/repair_symb_sql.py (63%) create mode 100644 src/pipeline/resd/__init__.py rename src/{pipe/resdsql.py => pipeline/resd/add_resd.py} (83%) rename src/{pipe => pipeline/resd}/run_resdsql.py (97%) rename src/{pipe => pipeline}/results.py (95%) rename src/{pipe => pipeline}/slm_mask_prompts/__init__.py (100%) rename src/{pipe => pipeline}/slm_mask_prompts/mask_and_schema_link_v1.py (96%) rename src/{pipe => pipeline}/slm_mask_prompts/mask_and_schema_link_v2.py (97%) rename src/{pipe => pipeline}/slm_mask_prompts/mask_v1.py (96%) rename src/{pipe => pipeline}/slm_mask_prompts/unmask_and_repair_v1.py (98%) rename src/{pipe => pipeline}/slm_mask_prompts/unmask_v1.py (97%) rename src/{pipe => pipeline}/symb_table.py (93%) rename src/{pipe => pipeline}/unmask.py (92%) create mode 100644 src/pipeline/util_processors/__init__.py rename src/{pipe => pipeline/util_processors}/copy_transformer.py (89%) rename src/{pipe => utils}/async_utils.py (100%) rename src/{pipe => utils}/llm_util.py (91%) rename src/{pipe/monitor => utils}/mem.py (100%) rename src/{pipe => utils}/schema_repo.py (100%) rename src/{pipe => utils}/sqlite_facade.py (100%) create mode 100644 src/utils/timer.py create mode 100644 test.json diff --git a/DESIGN.md b/DESIGN.md index b6c0093..e1aebd3 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -240,10 +240,10 @@ class TextToSQLModelConfig: """Text-to-SQL model configuration. Supports multiple model types: - - "openai": OpenAI models (GPT-4, GPT-3.5, etc.) - - "anthropic": Anthropic models (Claude, etc.) - - "custom": Custom models via REST API (proprietary models, custom endpoints) - - "local": Local models or self-hosted endpoints + - "openai": OpenAI data_models (GPT-4, GPT-3.5, etc.) + - "anthropic": Anthropic data_models (Claude, etc.) + - "custom": Custom data_models via REST API (proprietary data_models, custom endpoints) + - "local": Local data_models or self-hosted endpoints - "none": Disable built-in SQL generation (use custom steps instead) When model_type="none", you must provide custom SQL generation steps. @@ -263,7 +263,7 @@ class TextToSQLModelConfig: custom_endpoint: Optional[str] = None custom_headers: Optional[dict] = None - # Request/response format for custom models + # Request/response format for custom data_models custom_request_format: Optional[str] = None # "openai", "anthropic", "custom" custom_response_parser: Optional[callable] = None # Custom parser function @@ -306,7 +306,7 @@ class MaskSQLConfig: """Main configuration for MaskSQL. Configuration Philosophy: - - Use config_file for standard settings (models, databases, privacy policy) + - Use config_file for standard settings (data_models, databases, privacy policy) - Use the `steps` parameter in MaskSQL() for pipeline customization - Custom steps provide more flexibility than PipelineConfig flags @@ -481,7 +481,7 @@ class MaskSQL: Use cases: - Domain-specific preprocessing (e.g., medical terminology normalization) - - Custom Text-to-SQL models + - Custom Text-to-SQL data_models - Compliance validation - Custom logging or monitoring hooks - Additional SQL repair or optimization steps @@ -506,7 +506,7 @@ class MaskSQL: ... }) ... ] >>> config = MaskSQLConfig.from_yaml("config.yaml") - >>> config.models.text_to_sql.model_type = "none" # Disable built-in generation + >>> config.data_models.text_to_sql.model_type = "none" # Disable built-in generation >>> masksql = MaskSQL(config=config, steps=custom_steps) """ diff --git a/EA_failure_analysis.py b/EA_failure_analysis.py deleted file mode 100644 index b57b1bb..0000000 --- a/EA_failure_analysis.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Execution Accuracy failure analysis utilities. - -This module provides tools for analyzing failed test cases by comparing -JSON result files and extracting SQL query details for further investigation. -""" - -from typing import Any - -from src.utils.json_io import read_json_raw, write_json_raw - - -def finder(path1: str, path2: str) -> list[Any]: - """Find differences between two JSON files and write results. - - Parameters - ---------- - path1 : str - Path to the first JSON file (full dataset). - path2 : str - Path to the second JSON file (category dataset). - - Returns - ------- - list - Items present in path1 but not in path2. - """ - full: list[dict] = read_json_raw(path1) - category: list[dict] = read_json_raw(path2) - diff = [] - for items in full: - if items not in category: - diff.append(items) - write_json_raw("data/EA_diff", diff) - - for items in category: - if items not in full: - print(items) - return diff - - -def analyser(arr: list[Any]) -> None: - """Analyze failure cases and extract SQL details. - - Parameters - ---------- - arr : list - List of question IDs to analyze. - """ - path = "data/full/19_RepairSQL.json" - file = read_json_raw(path) - res = [] - for items in arr: - for records in file: - if records["question_id"] == items: - res.append( - { - "id": records["question_id"], - "question": records["question"], - "gold": records["SQL"], - "pred": records["pred_sql"], - } - ) - - write_json_raw("data/EA_sql_diff", res) - - -def main() -> None: - """Run the EA failure analysis workflow.""" - path1 = "data/full/EA_failures.json" - path2 = "data/category/EA_failures.json" - - res = finder(path1, path2) - analyser(res) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index e0c2400..480d940 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,8 @@ license = "Apache-2.0" repository = "https://github.com/VectorInstitute/masksql" requires-python = ">=3.11" dependencies = [ - "filelock>=3.20.1", # Pinning version to address vulnerability CVE-2025-68146 - "fonttools>=4.60.2", # Pinning version to address vulnerability GHSA-768j-98cg-p3fv + "filelock>=3.20.1", # Pinning version to address vulnerability CVE-2025-68146 + "fonttools>=4.60.2", # Pinning version to address vulnerability GHSA-768j-98cg-p3fv "werkzeug>=3.1.4", # Pinning version to address vulnerability GHSA-hgf8-39gv-g3f2 "absl-py>=2.3.1", "python-dotenv>=1.0.0", @@ -45,6 +45,7 @@ dependencies = [ "transformers>=4.30.0", "vcrpy>=7.0.0", "ply>=3.11", + "urllib3==2.6.3", ] [build-system] diff --git a/resdsql/evaluate_text2sql_ckpts.py b/resdsql/evaluate_text2sql_ckpts.py index d3855a7..e512642 100644 --- a/resdsql/evaluate_text2sql_ckpts.py +++ b/resdsql/evaluate_text2sql_ckpts.py @@ -18,14 +18,14 @@ def parse_option(): parser.add_argument( "--save_path", type=str, - default="./models/text2sql", - help="save path of fine-tuned text2sql models.", + default="./data_models/text2sql", + help="save path of fine-tuned text2sql data_models.", ) parser.add_argument( "--eval_results_path", type=str, default="./eval_results/text2sql", - help="the evaluation results of fine-tuned text2sql models.", + help="the evaluation results of fine-tuned text2sql data_models.", ) parser.add_argument("--mode", type=str, default="eval", help="eval.") parser.add_argument( diff --git a/resdsql/get_schema_ranks.py b/resdsql/get_schema_ranks.py index d375f42..edf2baf 100644 --- a/resdsql/get_schema_ranks.py +++ b/resdsql/get_schema_ranks.py @@ -20,7 +20,7 @@ "epochs": 128, "patience": 32, "seed": 42, - "save_path": "./models/text2sql_schema_item_classifier", + "save_path": "./data_models/text2sql_schema_item_classifier", "tensorboard_save_path": None, "train_filepath": "data/pre-processing/preprocessed_train_spider.json", "dev_filepath": "./data/preprocessed_data/preprocessed_test.json", diff --git a/resdsql/schema_item_classifier.py b/resdsql/schema_item_classifier.py index eda936f..4f4ddcd 100644 --- a/resdsql/schema_item_classifier.py +++ b/resdsql/schema_item_classifier.py @@ -62,7 +62,7 @@ def parse_option(): parser.add_argument( "--save_path", type=str, - default="models/schema_item_classifier", + default="data_models/schema_item_classifier", help="save path of best fine-tuned model on validation set.", ) parser.add_argument( diff --git a/resdsql/text2sql.py b/resdsql/text2sql.py index b6ee07d..30908e4 100644 --- a/resdsql/text2sql.py +++ b/resdsql/text2sql.py @@ -48,7 +48,7 @@ def parse_option(): parser.add_argument( "--save_path", type=str, - default="models/text2sql", + default="data_models/text2sql", help="save path of best fine-tuned text2sql model.", ) parser.add_argument( diff --git a/src/config.py b/src/config.py index e8674a1..d318c67 100644 --- a/src/config.py +++ b/src/config.py @@ -1,7 +1,7 @@ """Configuration management for MaskSQL. This module defines the configuration dataclass used throughout the MaskSQL -project for managing paths, models, and runtime settings. +project for managing paths, data_models, and runtime settings. """ import os diff --git a/src/data_cache/json_cache.py b/src/data_cache/json_cache.py index ef62fbd..26ec0e1 100644 --- a/src/data_cache/json_cache.py +++ b/src/data_cache/json_cache.py @@ -7,7 +7,7 @@ import os from typing import Generic, Type, TypeVar -from src.models.base_object import BaseObject +from src.data_models.base_object import BaseObject from src.utils.json_io import read_json, write_json diff --git a/src/models/__init__.py b/src/data_models/__init__.py similarity index 62% rename from src/models/__init__.py rename to src/data_models/__init__.py index c923b47..d440f2b 100644 --- a/src/models/__init__.py +++ b/src/data_models/__init__.py @@ -1,5 +1,5 @@ """Models package for data structures used in the MaskSQL project. -This package contains data models and structures that represent various +This package contains data data_models and structures that represent various entities and concepts used throughout the MaskSQL system. """ diff --git a/src/models/base_object.py b/src/data_models/base_object.py similarity index 100% rename from src/models/base_object.py rename to src/data_models/base_object.py diff --git a/src/models/data_point.py b/src/data_models/data_point.py similarity index 100% rename from src/models/data_point.py rename to src/data_models/data_point.py diff --git a/src/models/masksql_input.py b/src/data_models/masksql_input.py similarity index 93% rename from src/models/masksql_input.py rename to src/data_models/masksql_input.py index b110ef2..e59b49a 100644 --- a/src/models/masksql_input.py +++ b/src/data_models/masksql_input.py @@ -6,7 +6,7 @@ from typing import Any -from src.models.base_object import BaseObject +from src.data_models.base_object import BaseObject class MaskSqlInput(BaseObject): diff --git a/src/models/masksql_output.py b/src/data_models/masksql_output.py similarity index 86% rename from src/models/masksql_output.py rename to src/data_models/masksql_output.py index b7636cc..6d5b68a 100644 --- a/src/models/masksql_output.py +++ b/src/data_models/masksql_output.py @@ -4,8 +4,8 @@ representing the result of processing a natural language question. """ -from src.models.base_object import BaseObject -from src.pipe.exec_acc import EvaluationData +from src.data_models.base_object import BaseObject +from src.pipeline.exec_acc import EvaluationData class MaskSqlOutput(BaseObject): diff --git a/src/models/masksql_result.py b/src/data_models/masksql_result.py similarity index 92% rename from src/models/masksql_result.py rename to src/data_models/masksql_result.py index a7a0d21..295a3f6 100644 --- a/src/models/masksql_result.py +++ b/src/data_models/masksql_result.py @@ -4,7 +4,7 @@ representing the outcome of processing a database question. """ -from src.models.data_point import DataPoint +from src.data_models.data_point import DataPoint class MaskSqlResult(DataPoint): diff --git a/src/masksql.py b/src/masksql.py index dad9678..6821b72 100644 --- a/src/masksql.py +++ b/src/masksql.py @@ -8,32 +8,35 @@ import uuid from src.config import MaskSqlConfig, OpenAIConfig -from src.models.masksql_input import MaskSqlInput -from src.models.masksql_output import MaskSqlOutput -from src.pipe.add_schema import AddFilteredSchema -from src.pipe.add_symb_schema import AddSymbolicSchema -from src.pipe.attack import AddInferenceAttack -from src.pipe.copy_transformer import CopyTransformer -from src.pipe.det_mask import AddSymbolicQuestion -from src.pipe.detect_entities import DetectValues -from src.pipe.exec_acc import CalcExecAcc -from src.pipe.exec_conc_sql import ExecuteConcreteSql -from src.pipe.gen_masked_sql import GenerateSymbolicSql -from src.pipe.link_schema import FilterSchemaLinksModel, LinkSchema -from src.pipe.pipeline import Pipeline -from src.pipe.processor.limit_list import LimitJson -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.rank_schema import RankSchemaResd -from src.pipe.rank_schema_llm import RankSchemaItems -from src.pipe.repair_sql import RepairSQL -from src.pipe.repair_symb_sql import RepairSymbolicSQL -from src.pipe.resdsql import AddResd -from src.pipe.results import Results -from src.pipe.run_resdsql import RunResdsql -from src.pipe.symb_table import AddSymbolTable -from src.pipe.unmask import AddConcreteSql -from src.pipe.util_processors import InitData -from src.pipe.value_links import FilterValueLinksModel, LinkValues +from src.data_models.masksql_input import MaskSqlInput +from src.data_models.masksql_output import MaskSqlOutput +from src.pipeline.add_schema import AddFilteredSchema +from src.pipeline.add_symb_schema import AddSymbolicSchema +from src.pipeline.add_symbolic_question import AddSymbolicQuestion +from src.pipeline.attack.add_inference_attack import AddInferenceAttack +from src.pipeline.base_processor.limit_list import LimitJson +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.detect_values.detect_values import DetectValues +from src.pipeline.exec_acc import CalcExecAcc +from src.pipeline.exec_conc_sql import ExecuteConcreteSql +from src.pipeline.gen_sql.gen_masked_sql import GenerateSymbolicSql +from src.pipeline.init_data import InitData +from src.pipeline.link_schema.link_schema import ( + FilterSchemaLinksModel, + LinkSchema, +) +from src.pipeline.link_values.link_values import FilterValueLinksModel, LinkValues +from src.pipeline.pipeline import Pipeline +from src.pipeline.rank_schema import RankSchemaResd +from src.pipeline.rank_schema_llm import RankSchemaItems +from src.pipeline.repair_sql.repair_sql import RepairSQL +from src.pipeline.repair_symb_sql.repair_symb_sql import RepairSymbolicSQL +from src.pipeline.resd.add_resd import AddResd +from src.pipeline.resd.run_resdsql import RunResdsql +from src.pipeline.results import Results +from src.pipeline.symb_table import AddSymbolTable +from src.pipeline.unmask import AddConcreteSql +from src.pipeline.util_processors.copy_transformer import CopyTransformer from src.utils.json_io import read_json, write_json diff --git a/src/pipe/abl_prompts/__init__.py b/src/pipe/abl_prompts/__init__.py deleted file mode 100644 index 6a60ece..0000000 --- a/src/pipe/abl_prompts/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Ablation study prompts.""" diff --git a/src/pipe/abl_prompts/schema_value_link.py b/src/pipe/abl_prompts/schema_value_link.py deleted file mode 100644 index ccbbb45..0000000 --- a/src/pipe/abl_prompts/schema_value_link.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Prompt templates for schema and value linking ablation studies.""" - -SCHEMA_VALUE_LINK_PROMPT_V1 = """ -You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) -of a natural-language question to database schema items (tables or fully qualified columns). - -You are given: -- Question: a natural language question. -- Schema Items: a list of schema items (table names or fully qualified column names). -- Value List: a list of n-grams in the question that represent literal values, entities, constants, etc in the question. - -Input Format: -- Each schema item is either "TABLE:[table]", "COLUMN:[table].[column]" or "VALUE:[table].[column]" -- Not all listed schema items are relevant. Your goal is to identify the relevant ones. - -Goal -Return a JSON object mapping relevant n-grams (contiguous word sequences of length 1–3 taken from the question text) -to the single most relevant schema item. - -Mapping Rules: -- Based on the given list of values, you should find the most relevant column that is related to the value and add a mapping -of the form {{"literal value in question":"VALUE:[table]:[column]}} -- Consider all 1-, 2-, and 3-word spans. -- Include a mapping only if the n-gram refers to a schema item. -- Prefer the most specific applicable item: column beats table when the question refers to an attribute. -- If nothing maps, return an empty JSON object. -- Chose the shortest n-gram that maps to the schema item. -- If removing a word from an n-gram still points to the same schema item, use the shorter version. - -Output Rules: -- Output only a JSON object representing the mapping. -- The value of each entry should only be selected from the given list of Schema Items. -- Generating new values or using different values is not allowed. -- You can only select values form the given list of schema items. -- Each entry should be a key-value pair where the key is an n-gram and the value is a schema item. -- Value of each entry can only be a single string of the form "COLUMN:[table].[column]" or "TABLE:[table]". -- Do not include any additional text, explanations, or formatting. -- All json key and values should be in double quotes. -- Output should be a top-level JSON object. No nested keys. - -Here are some examples: - ---------------------------------------------- -Example 1: -Question: -“What is the name of the instructor who has the lowest salary and located in London?” -Schema items: -["TABLE:[instructor]", "COLUMN:[instructor].[name]", "COLUMN:[instructor].[city], "COLUMN:[instructor].[salary]", "TABLE:[department]", "COLUMN[department].[name]"] -Value List: -[ "London" ] - -Output: -{{ - "name": "COLUMN:[instructor].[name]", - "salary": "COLUMN:[instructor].[salary]", - "instructor": "TABLE:[instructor]", - "London": "VALUE:[instructor].[city]" -}} - ---------------------------------------------- -Example 2: -Question: -"Please calculate the total payment amount of customers who come from the USA. USA is a country; total amount payment refers to SUM(amount);", -Schema items: [ - "TABLE:[customers]", - "COLUMN:[customers].[customernumber]", - "COLUMN:[customers].[country]", - "TABLE:[payments]", - "COLUMN:[payments].[customernumber]", - "COLUMN:[payments].[amount]" -] -Value List: -[ "USA" ] - -Output: -{{ - "customers": "TABLE:[customers]", - "country": "COLUMN:[customers].[country]", - "amount": "COLUMN:[payments].[amount]", - "USA": "VALUE:[customers].[country]" -}} - ---------------------------------------------- -Example 3: -Question: -"What are the total payments of customers with no credit limit in 2003? total payment refers to SUM(amount)", -Schema Items: [ - "TABLE:[customers]", - "COLUMN:[customers].[customernumber]", - "COLUMN:[customers].[creditlimit]", - "TABLE:[payments]", - "COLUMN:[payments].[customernumber]", - "COLUMN:[payments].[paymentdate]", - "COLUMN:[payments].[amount]", - "COLUMN:[payments].[year]" -] -Value List: -[ "2003" ] - -Output: -{{ - "payments": "TABLE:[payments]", - "customers": "TABLE:[customers]", - "credit limit": "COLUMN:[customers].[creditlimit]", - "amount": "COLUMN:[payments].[amount]", - "2003": "VALUE:[payments].[year]" -}} - - - -Now generate the JSON object of mapping for the following question, schema items, and value list: -Question: {question} -Schema items: {schema_items} -Value List: {value_List} - -Iterate through each key,value pair of the answer and make sure that: -- MAKE SURE THAT EACH KEY OF THE MAPPING SHOULD BE A TERM OF THE QUESTION -- MAKE SURE THAT EACH VALUE OF THE MAPPING SHOULD BE A VALID SCHEMA ITEM INCLUDED IN THE GIVEN LIST OF SCHEMA ITEMS -- MAKE SURE THAT EACH KEY BE MINIMAL, IF ANY WORD CAN BE DELETED WHILE THE RELATION STILL HOLDS, IT SHOULD BE REMOVED -""" diff --git a/src/pipe/add_masked_terms.py b/src/pipe/add_masked_terms.py deleted file mode 100644 index c58a7ba..0000000 --- a/src/pipe/add_masked_terms.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Module for identifying and adding masked terms to questions.""" - -from typing import Any - -from src.pipe.attack_prompts.add_masked_terms import ADD_MASKED_TERMS_PROMPT_V1 -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.utils.logging import logger - - -class AddMaskedTerms(PromptProcessor): - """ - Processor for identifying and extracting masked terms from questions. - - This class uses LLM prompts to identify terms in natural language questions - that should be masked, comparing the original and symbolic representations. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> list[str]: - output_obj = extract_object(output) - if output_obj is None: - return [] - masked_terms = list(output_obj.keys()) - q = row["question"] - filtered_terms = [] - for m in masked_terms: - if m.lower() in q.lower(): - filtered_terms.append(m) - else: - logger.error(f"{m} not in question: {q}") - return masked_terms - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - symbolic_question = row["symbolic"]["question"] - return ADD_MASKED_TERMS_PROMPT_V1.format( - question=question, masked_question=symbolic_question - ) diff --git a/src/pipe/attack_prompts/add_masked_terms.py b/src/pipe/attack_prompts/add_masked_terms.py deleted file mode 100644 index 886beba..0000000 --- a/src/pipe/attack_prompts/add_masked_terms.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Prompt templates for adding masked terms during attacks.""" - -ADD_MASKED_TERMS_PROMPT_V1 = """ -You are given a text in its masked version, your goal is to return a mapping from -each unmasked n-grams in the original text to their masked symbol in the masked text. -Masked symbols are wrapped in [] for instance [C1] or [T1].[C1]. -You should consider [T1].[C1] as a single symbol. -You should return a json object with keys being terms in the original text and values being -symbols used in the masked text. -Do not include any comments or extra text in the output. -Output should be a valid json object. - -Here are some examples: - ------------------------------------ -Example 1: -Text: What is the name of people who have a car with horse power higher than 1000? -Masked Text: What is the [T1].[C1] of [T1] who have a [T2] with [T2].[C4] higher than [V1]? [V1] is a value of [T2].[C4] -OUTPUT: -{{ - "name": "[T1].[C1]", - "people": "[T1]", - "car": "[T2]", - "horse power": "[T2].[C4]", - "1000": "[V1]" -}} - -Now, return the mapping for the following text and its masked version: -Inputs: -Text: {question} -Masked Text: {masked_question} -""" diff --git a/src/pipe/attack_prompts/attack_raw_v1.py b/src/pipe/attack_prompts/attack_raw_v1.py deleted file mode 100644 index d715c84..0000000 --- a/src/pipe/attack_prompts/attack_raw_v1.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Raw attack prompt template version 1.""" - -ATTACK_PROMPT_RAW_V1 = """ -Your goal is to guess the words in a masked question. Given -(1) NL Question: a natural-language question about a dataset and -(2) DB Schema: the database’s schema expressed in YAML -guess the original each masked symbol exists in the question. -Masked symbols are wrapped in brackets. - -Input Format: -- DB Schema: given in YAML format where top-level keys are table names; each table lists its columns and their data types. -- Each column might be primary key or a foreign key. -- For foreign key columns, fully qualified name of the referenced column is given - -Output Rules -- C1,C2,... are symbols used to mask column names. -- T1,T2,... are symbols used to mask table names. -- V1,V2,... are symbols used to mask literal values. -- Output ONLY the unmasked question. -- Do not include any comments. -- For columns names with spaces, wrap them in backticks, e.g. "WHERE `car model` = 'bar'" instead of "WHERE car model = 'bar'". - -Here are some examples: - ------------------------------------ -Example 1: -NL Question: What is the [T1].[C1] of [T1] who have a [T2] with [T2].[C4] higher than [V1]? [V1] is a value of [T2].[C4] -Database Schema: - [T1]: - [C1]: text - [C2]: text - [T2]: - [C3]: text - [C4]: number - -OUTPUT: -What is the name of people who have a car with horse power higher than 1000? - -Now, unmask the following question considering the following DB schema -Inputs: -{symbolic_raw} -""" diff --git a/src/pipe/deprecated/__init__.py b/src/pipe/deprecated/__init__.py deleted file mode 100644 index e346b79..0000000 --- a/src/pipe/deprecated/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Deprecated pipeline components package. - -This package contains pipeline components that are no longer actively used -but are kept for reference or backward compatibility. -""" diff --git a/src/pipe/deprecated/add_fks.py b/src/pipe/deprecated/add_fks.py deleted file mode 100644 index dedbb2d..0000000 --- a/src/pipe/deprecated/add_fks.py +++ /dev/null @@ -1,39 +0,0 @@ -# mypy: ignore-errors - -"""Module for adding foreign key relationships to database schemas.""" - -from typing import Any - -from src.pipe.processor.list_transformer import JsonListTransformer -from src.pipe.schema_repo import DatabaseSchemaRepo - - -class AddForeignKeys(JsonListTransformer): - """ - Processor for adding foreign key relationships to database schemas. - - This class extracts and formats foreign key relationships from database schemas, - making them available for downstream processing. - - Parameters - ---------- - prop_name : str - The property name where foreign keys will be stored in the row. - tables_path : str - Path to the database tables/schemas repository. - """ - - def __init__(self, prop_name: str, tables_path: str) -> None: - super().__init__(force=True) - self.prop_name = prop_name - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - fks = [] - schema = self.schema_repo.dbs[row["db_id"]] - for table_name, table_columns in schema.tables.items(): - for col_name, col_data in table_columns.items(): - if isinstance(col_data, dict) and "foreign_key" in col_data: - fks.append(f"{table_name}.{col_name}={col_data['foreign_key']}") - row[self.prop_name] = fks - return row diff --git a/src/pipe/deprecated/add_full_schema.py b/src/pipe/deprecated/add_full_schema.py deleted file mode 100644 index d47a30e..0000000 --- a/src/pipe/deprecated/add_full_schema.py +++ /dev/null @@ -1,132 +0,0 @@ -# mypy: ignore-errors - -"""Module for adding database schema information to data rows.""" - -from typing import Any - -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.processor.list_transformer import JsonListTransformer -from src.pipe.rank_schema import RankSchemaResd -from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo - - -def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSchema: - """ - Filter database schema to include only specified schema items. - - Parameters - ---------- - schema : DatabaseSchema - The full database schema to filter. - schema_items : List[str] - List of schema item references (e.g., 'COLUMN:table.column'). - - Returns - ------- - DatabaseSchema - Filtered schema containing only the specified items and their foreign keys. - """ - columns = set() - for item in schema_items: - item_ref = item.split(":")[1] - if "[*]" in item_ref: - continue - if item.split(":")[0] == "COLUMN": - columns.add(item_ref) - - for col_ref in list(columns): - table_name = col_ref.split(".")[0] - col_name = col_ref.split(".")[1] - col_data = schema.tables[table_name][col_name] - if isinstance(col_data, dict) and "foreign_key" in col_data: - fk_ref = col_data["foreign_key"] - if isinstance(fk_ref, str): - columns.add(fk_ref) - - filtered_schema = DatabaseSchema() - for table_name, table_columns in schema.tables.items(): - filtered_table_columns = {} - for col_name, col_data in table_columns.items(): - if f"{table_name}.{col_name}" in columns: - filtered_table_columns[col_name] = col_data - if len(filtered_table_columns) > 0: - filtered_schema.tables[table_name] = filtered_table_columns - return filtered_schema - - -class AddFullSchema(JsonListTransformer): - """ - Processor for adding full database schema to data rows. - - Parameters - ---------- - tables_path : str - Path to the database tables/schemas repository. - """ - - def __init__(self, tables_path: str) -> None: - super().__init__() - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - schema = self.schema_repo.dbs[row["db_id"]] - row["schema"] = schema.to_yaml() - return row - - -class AddFilteredSchema( - JsonListProcessor[RankSchemaResd.Model, "AddFilteredSchema.Model"] -): - """ - Processor for adding filtered database schema to data rows. - - Only includes schema items that are referenced in the row's schema_items list. - - Parameters - ---------- - tables_path : str - Path to the database tables/schemas repository. - """ - - class Model(RankSchemaResd.Model): - """Data model for filtered schema processing with schema field.""" - - schema: str - - def __init__(self, tables_path: str) -> None: - super().__init__(self.Model) - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def _process_row(self, row: RankSchemaResd.Model) -> Model: - schema = self.schema_repo.dbs[row.db_id] - schema_items = row.schema_items - filtered_schema = filter_schema(schema, schema_items) - return self.Model(schema=filtered_schema.to_yaml(), **row.dict()) - - -class AddSchemaItems(JsonListTransformer): - """ - Processor for extracting all schema items from database schema. - - Creates a list of all tables and columns in the database schema. - - Parameters - ---------- - tables_path : str - Path to the database tables/schemas repository. - """ - - def __init__(self, tables_path: str) -> None: - super().__init__(force=True) - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - schema = self.schema_repo.dbs[row["db_id"]] - schema_items = [] - for table, columns in schema.tables.items(): - schema_items.append(f"TABLE:{table}") - for col, _col_data in columns.items(): - schema_items.append(f"COLUMN:{table}.{col}") - schema_items.append(f"COLUMN:{table}.[*]") - row["schema_items"] = schema_items - return row diff --git a/src/pipe/deprecated/add_masked_terms_det.py b/src/pipe/deprecated/add_masked_terms_det.py deleted file mode 100644 index 2086343..0000000 --- a/src/pipe/deprecated/add_masked_terms_det.py +++ /dev/null @@ -1,226 +0,0 @@ -# mypy: ignore-errors - -"""Module for deterministic masking of terms in natural language questions.""" - -from typing import Any - -from src.pipe.processor.list_transformer import JsonListTransformer -from src.pipe.utils import replace_str -from src.utils.logging import logger - - -class AddMaskedTermsDeterministic(JsonListTransformer): - """ - Deterministic processor for masking terms in natural language questions. - - This class performs rule-based masking of schema and value references in questions, - replacing them with symbolic representations based on schema and value links. - """ - - def __init__(self) -> None: - super().__init__(force=True) - - def get_symbol( - self, schema_items: list[str] | str, symbol_table: dict[str, str] - ) -> str: - """ - Get symbolic representation for schema items. - - Parameters - ---------- - schema_items : list[str] | str - Schema item(s) to get symbols for. - symbol_table : dict[str, str] - Mapping from schema items to their symbolic representations. - - Returns - ------- - str - Comma-separated symbolic representations. - """ - if not isinstance(schema_items, list): - schema_items = [schema_items] - symbols: list[str | None] = [] - for schema_item in schema_items: - schema_item_parts = schema_item.split(":") - schema_item_name = schema_item_parts[1] - symbol = symbol_table.get(schema_item_name) - symbols.append(symbol) - return ",".join(str(s) for s in symbols if s is not None) - - def symbolize_term( - self, - question: str, - question_term: str, - schema_items: str, - symbol_table: dict[str, str], - ) -> str: - """ - Replace a question term with its symbolic representation. - - Parameters - ---------- - question : str - The question text to modify. - question_term : str - The term in the question to replace. - schema_items : str - Schema item(s) corresponding to the term. - symbol_table : dict[str, str] - Mapping from schema items to symbols. - - Returns - ------- - str - Question with term replaced by symbol. - """ - symbol = self.get_symbol(schema_items, symbol_table) - return replace_str(question, question_term, symbol) - - def symbolize_value( - self, - question: str, - question_term: str, - column_ref: str, - updated_schema_links: dict[str, str], - filtered_value_links: dict[str, str], - symbol_table: dict[str, str], - ) -> str: - """ - Replace a value term with its symbolic representation. - - Parameters - ---------- - question : str - The question text to modify. - question_term : str - The value term in the question to replace. - column_ref : str - Reference to the column containing this value. - updated_schema_links : dict[str, str] - Updated schema links mapping. - filtered_value_links : dict[str, str] - Filtered value links mapping. - symbol_table : dict[str, str] - Mapping from schema items to symbols. - - Returns - ------- - str - Question with value replaced by symbol and evidence added. - """ - value_symbol = f"[V{self.vid}]" - if ( - column_ref in filtered_value_links.values() - or f"COLUMN:{column_ref}" in updated_schema_links.values() - ): - column_symbol = symbol_table[column_ref] - else: - column_symbol = column_ref - self.vid += 1 - evidence = f"{value_symbol} is a value of the column {column_symbol}" - self.value_dict[value_symbol] = question_term - symbolic_question = replace_str(question, question_term, value_symbol) - return f"{symbolic_question}; {evidence}" - - def add_tables_of_columns( - self, schema_links: dict[str, str], filtered_schema_links: dict[str, str] - ) -> dict[str, str]: - """ - Add table references for columns in filtered schema links. - - Parameters - ---------- - schema_links : dict[str, str] - All schema links from question terms to schema items. - filtered_schema_links : dict[str, str] - Subset of schema links to use. - - Returns - ------- - dict[str, str] - Updated schema links with tables included. - """ - updated_schema_links = filtered_schema_links.copy() - tables = set() - for schema_items in filtered_schema_links.values(): - if schema_items is None: - logger.error(f"Invalid schema item: {schema_items}") - continue - items = ( - [schema_items] if not isinstance(schema_items, list) else schema_items - ) - for schema_item in items: - if schema_item.startswith("COLUMN"): - col_ref = schema_item.split(":")[1] - table_name = col_ref.split(".")[0] - tables.add(table_name) - - for question_term, schema_items in schema_links.items(): - items = ( - [schema_items] if not isinstance(schema_items, list) else schema_items - ) - for schema_item in items: - if schema_item.startswith("TABLE"): - assert len(schema_items) == 1 - table_name = schema_item.split(":")[1] - if table_name in tables: - updated_schema_links[question_term] = schema_item - return updated_schema_links - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - self.vid: int = 1 - self.value_dict: dict[str, str] = {} - filtered_schema_links = row["filtered_schema_links"] - schema_links = row["schema_links"] - question = row["question"] - symbol_table = row["symbolic"]["to_symbol"] - updated_schema_links = self.add_tables_of_columns( - schema_links, filtered_schema_links - ) - masked_terms = [] - - symbolic_question = question - masked = 0 - - value_links = row["value_links"] - filtered_value_links = row["filtered_value_links"] - - if isinstance(value_links, (list, str)): - logger.error(f"Invalid value links: {value_links}") - value_links = {} - - if isinstance(filtered_value_links, (list, str)): - logger.error(f"Invalid value links: {filtered_value_links}") - filtered_value_links = {} - - for question_term, schema_item in value_links.items(): - try: - symbolic_question = self.symbolize_value( - symbolic_question, - question_term, - schema_item, - updated_schema_links, - filtered_value_links, - symbol_table, - ) - masked_terms.append(question_term) - masked += 1 - except Exception as e: - logger.error( - f"Failed to mask {question_term}:{schema_item}, error={e} " - ) - - for question_term, schema_items in updated_schema_links.items(): - try: - symbolic_question = self.symbolize_term( - symbolic_question, question_term, schema_items, symbol_table - ) - masked_terms.append(question_term) - masked += 1 - except Exception as e: - logger.error( - f"Failed to mask {question_term}:{schema_items}, error={e} " - ) - row["symbolic"].update({"masked_terms": masked_terms}) - return row diff --git a/src/pipe/deprecated/add_value_links_from_schema_links.py b/src/pipe/deprecated/add_value_links_from_schema_links.py deleted file mode 100644 index b99ef40..0000000 --- a/src/pipe/deprecated/add_value_links_from_schema_links.py +++ /dev/null @@ -1,30 +0,0 @@ -# mypy: ignore-errors - -"""Module for extracting value links from schema links.""" - -from typing import Any - -from src.pipe.processor.list_transformer import JsonListTransformer - - -class AddValueLinksFromSchemaLinks(JsonListTransformer): - """ - Processor for extracting value links from schema links. - - Separates value links (prefixed with 'VALUE:') from schema links, - creating separate mappings for each. - """ - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - schema_links = row["schema_links"] - value_links = {} - updated_schema_links = {} - for q_term, item in schema_links.items(): - if "VALUE:" in item: - value = item.replace("VALUE:", "") - value_links[q_term] = value - else: - updated_schema_links[q_term] = item - row["schema_links"] = updated_schema_links - row["value_links"] = value_links - return row diff --git a/src/pipe/deprecated/add_value_symbol_table.py b/src/pipe/deprecated/add_value_symbol_table.py deleted file mode 100644 index 06be074..0000000 --- a/src/pipe/deprecated/add_value_symbol_table.py +++ /dev/null @@ -1,39 +0,0 @@ -# mypy: ignore-errors - -"""Module for adding symbolic representations of values to the symbol table.""" - -from typing import Any - -from src.pipe.processor.list_transformer import JsonListTransformer -from src.pipe.schema_repo import DatabaseSchemaRepo - - -class AddValueSymbolTable(JsonListTransformer): - """ - Add symbolic representations for values to the symbol table. - - Extends the symbol table with value symbols for values detected in questions. - - Parameters - ---------- - tables_path : str - Path to the database schema definitions file - """ - - def __init__(self, tables_path: str) -> None: - super().__init__(True) - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - vid = 1 - value_links = row["value_links"] - symbol_table = row["symbolic"]["to_symbol"] - to_value = {} - for value in value_links: - symbol = f"[V{vid}]" - symbol_table[value] = symbol - to_value[symbol] = value - vid += 1 - row["symbolic"]["to_symbol"] = symbol_table - row["symbolic"]["to_value"] = to_value - return row diff --git a/src/pipe/deprecated/filtered_symb_schema.py b/src/pipe/deprecated/filtered_symb_schema.py deleted file mode 100644 index 2e20cde..0000000 --- a/src/pipe/deprecated/filtered_symb_schema.py +++ /dev/null @@ -1,265 +0,0 @@ -# mypy: ignore-errors - -"""Filtered symbolic schema generation.""" - -from typing import Any - -from src.models.base_object import BaseObject -from src.pipe.processor.list_transformer import JsonListTransformer -from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo -from src.utils.logging import logger - - -class AddFilteredSymbolicSchema(JsonListTransformer): - """ - Add symbolic schema with filtered columns and tables. - - Parameters - ---------- - tables_path : str - Path to tables JSON file - """ - - def __init__(self, tables_path: str) -> None: - super().__init__(BaseObject) - self.schema_repo = DatabaseSchemaRepo(tables_path) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - tables, col_refs = self.get_items_to_symbolize(row) - - schema = DatabaseSchema.from_yaml(row["schema"]) - symbol_table = row["symbolic"]["to_symbol"] - - symbolic_schema = self.get_symb_schema(schema, symbol_table, tables, col_refs) - - reverse_dict = self.get_reverse_dict(tables, col_refs, symbol_table) - - row["symbolic"]["schema"] = symbolic_schema.to_yaml() - row["symbolic"]["reverse_dict"] = reverse_dict - - return row - - def get_col_symbol( - self, - table_name: str, - col_name: str, - col_refs: set[str], - symbol_table: dict[str, str], - ) -> str: - """ - Get symbolic representation for a column. - - Parameters - ---------- - table_name : str - Table name - col_name : str - Column name - col_refs : Set[str] - Set of column references to symbolize - symbol_table : dict[str, str] - Mapping from original names to symbols - - Returns - ------- - str - Column symbol or original name - """ - col_ref = f"{table_name}.{col_name}" - if col_ref in col_refs: - return symbol_table[col_ref] - return col_name - - def get_table_symbol( - self, table_name: str, tables: set[str], symbol_table: dict[str, str] - ) -> str: - """ - Get symbolic representation for a table. - - Parameters - ---------- - table_name : str - Table name - tables : Set[str] - Set of table names to symbolize - symbol_table : dict[str, str] - Mapping from original names to symbols - - Returns - ------- - str - Table symbol or original name - """ - if table_name in tables: - return symbol_table[table_name] - return table_name - - def get_symbolic_col_data( - self, - col_data: str | dict[str, Any], - tables: set[str], - col_refs: set[str], - symbol_table: dict[str, str], - ) -> str | dict[str, Any]: - """ - Convert column data to symbolic form. - - Parameters - ---------- - col_data : Union[str, dict[str, str]] - Column data including foreign key information - tables : Set[str] - Set of table names to symbolize - col_refs : Set[str] - Set of column references to symbolize - symbol_table : dict[str, str] - Mapping from original names to symbols - - Returns - ------- - str - Symbolic column data - """ - symbolic_col_data: str | dict[str, Any] - if isinstance(col_data, dict) and "foreign_key" in col_data: - symbolic_col_data = col_data.copy() - foreign_col_ref = symbolic_col_data["foreign_key"] - table_name = foreign_col_ref.split(".")[0] - table_symbol = self.get_table_symbol(table_name, tables, symbol_table) - column_name = foreign_col_ref.split(".")[1] - column_symbol = self.get_col_symbol( - table_name, column_name, col_refs, symbol_table - ) - symbolic_col_data["foreign_key"] = f"{table_symbol}.{column_symbol}" - else: - symbolic_col_data = col_data - return symbolic_col_data - - def get_items_to_symbolize(self, row: dict[str, Any]) -> tuple[set[str], set[str]]: - """ - Extract tables and columns to symbolize from row. - - Parameters - ---------- - row : dict - Data row with schema and value links - - Returns - ------- - Tuple[Set[str], Set[str]] - Tables and columns to symbolize - """ - schema_items = row["filtered_schema_links"] - value_links = row["filtered_value_links"] - tables = set() - columns = set() - - for item in schema_items.values(): - if not item or item.strip() == "{}": - continue - if ":" not in item: - logger.error(f"Invalid schema item: {item}") - continue - item_type = item.split(":")[0] - item_ref = item.split(":")[1] - if item_type.startswith("TABLE"): - tables.add(item_ref) - if item_type.startswith("COLUMN"): - table_name = item_ref.split(".")[0] - tables.add(table_name) - columns.add(item_ref) - - if isinstance(value_links, dict): - for item in value_links.values(): - columns.add(item) - else: - logger.error(f"Invalid value links: {value_links}") - return tables, columns - - def get_symb_schema( - self, - schema: DatabaseSchema, - symbol_table: dict[str, str], - tables: set[str], - col_refs: set[str], - ) -> DatabaseSchema: - """ - Create symbolic database schema. - - Parameters - ---------- - schema : DatabaseSchema - Original database schema - symbol_table : dict[str, str] - Mapping from original names to symbols - tables : Set[str] - Set of table names to symbolize - col_refs : Set[str] - Set of column references to symbolize - - Returns - ------- - DatabaseSchema - Symbolic version of schema - """ - symbolic_schema = DatabaseSchema() - - for table_name, columns in list(schema.tables.items()): - symbolic_columns = {} - for col_name, col_data in columns.items(): - col_symbol = self.get_col_symbol( - table_name, col_name, col_refs, symbol_table - ) - symbolic_col_data = self.get_symbolic_col_data( - col_data, tables, col_refs, symbol_table - ) - symbolic_columns[col_symbol] = symbolic_col_data - table_symbol = self.get_table_symbol(table_name, tables, symbol_table) - symbolic_schema.tables[table_symbol] = symbolic_columns - return symbolic_schema - - def get_reverse_dict( - self, tables: set[str], col_refs: set[str], symbol_table: dict[str, str] - ) -> dict[str, str]: - """ - Create reverse mapping from symbols to original names. - - Parameters - ---------- - tables : Set[str] - Set of table names - col_refs : Set[str] - Set of column references - symbol_table : dict[str, str] - Mapping from original names to symbols - - Returns - ------- - dict[str, str] - Mapping from symbols back to original names - """ - reverse_dict = {} - for table in tables: - if table not in symbol_table: - logger.error( - f"Table {table} not found in symbol table: {symbol_table.keys()}" - ) - continue - table_symbol = symbol_table[table] - reverse_dict[table_symbol] = table - - for col_ref in col_refs: - if "." not in col_ref: - logger.error(f"Invalid col ref: {col_ref}") - continue - if col_ref not in symbol_table: - logger.error(f"Invalid col ref: {col_ref}") - continue - table = col_ref.split(".")[0] - - table_symbol = symbol_table[table] - col_symbol = symbol_table[col_ref] - - reverse_dict[col_symbol] = col_ref - reverse_dict[f"{table_symbol}.{col_symbol}"] = col_ref - return reverse_dict diff --git a/src/pipe/deprecated/utility.py b/src/pipe/deprecated/utility.py deleted file mode 100644 index ad0197f..0000000 --- a/src/pipe/deprecated/utility.py +++ /dev/null @@ -1,164 +0,0 @@ -# mypy: ignore-errors - -"""Utility transformers for data processing. - -This module contains various utility transformers for processing data -in the pipeline, such as filtering, copying, and property manipulation. -""" - -import json -import os -from typing import Any, Callable - -from src.pipe.processor.list_transformer import JsonListTransformer - - -class DeleteProp(JsonListTransformer): - """ - Transformer for deleting a property from data rows. - - Parameters - ---------- - prop : str - Property name to delete. - """ - - def __init__(self, prop: str) -> None: - super().__init__() - self.prop = prop - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - del row[self.prop] - return row - - -class CopyFromPrevStage(JsonListTransformer): - """ - Transformer for copying values from a previous pipeline stage. - - Parameters - ---------- - stage : str - Name of the previous stage to copy from. - src : str - Source field to copy. - """ - - def __init__(self, stage: str, src: str) -> None: - super().__init__(force=True) - self.stage = stage - self.src = src - - def get_prev_stage(self, input_file: str) -> list[dict[str, Any]]: - """ - Load data from previous pipeline stage. - - Parameters - ---------- - input_file : str - Path to current input file - - Returns - ------- - list - Data from previous stage - """ - dir_path = os.path.dirname(input_file) - prev_stage_file_path = os.path.join(dir_path, f"{self.stage}.json") - return super()._get_input_data(prev_stage_file_path) - - async def run(self, input_file: str) -> str: - """ - Run the transformer and copy values from previous stage. - - Parameters - ---------- - input_file : str - Path to input file - - Returns - ------- - str - Path to output file - """ - output_file = await super().run(input_file) - - with open(output_file) as f: - data = json.load(f) - - prev_stage = self.get_prev_stage(input_file) - - updated_rows = [] - for i, row in enumerate(data): - row[self.src] = prev_stage[i][self.src] - updated_rows.append(row) - - with open(output_file, "w") as f: - f.write(json.dumps(updated_rows, indent=4)) - return output_file - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - return row - - -class AddGoldValues(JsonListTransformer): - """ - Transformer for extracting gold value links as a list. - - Extracts keys from gold_value_links and stores them in a 'values' field. - """ - - def __init__(self) -> None: - super().__init__(force=True) - - async def __process_row_internal(self, row: dict[str, Any]) -> dict[str, Any]: - value_links = row["gold_value_links"] - keys = list(value_links.keys()) - row["values"] = keys - return row - - -class FilterList(JsonListTransformer): - """ - Filter JSON list based on predicate function. - - Parameters - ---------- - predicate : callable, optional - Function to test each row, default returns all rows - """ - - def __init__(self, predicate: Callable[[Any], Any] = lambda r: r) -> None: - super().__init__(True) - self.predicate = predicate - - async def run(self, input_file: str) -> str: - """ - Filter input file based on predicate. - - Parameters - ---------- - input_file : str - Path to input JSON file - - Returns - ------- - str - Path to filtered output file - """ - output_file = self.get_output_file(input_file) - - with open(input_file) as f: - in_data = json.load(f) - - out_data = [] - for row in in_data: - if self.predicate(row): - out_data.append(row) - - with open(output_file, "w") as f: - f.write(json.dumps(out_data, indent=4)) - return output_file - - async def _process_row(self, row: Any) -> Any: - return row diff --git a/src/pipe/deprecated/value_link_eval.py b/src/pipe/deprecated/value_link_eval.py deleted file mode 100644 index ee6d8be..0000000 --- a/src/pipe/deprecated/value_link_eval.py +++ /dev/null @@ -1,46 +0,0 @@ -# mypy: ignore-errors - -"""Value linking evaluation processor.""" - -from typing import Any - -import pandas as pd - -from src.pipe.processor.list_processor import JsonListProcessor - - -class ValueLinkEval(JsonListProcessor[dict[str, Any], dict[str, Any]]): - """Evaluate value linking accuracy.""" - - def __init__(self) -> None: - super().__init__(dict) - self.scores: list[dict[str, int]] = [] - - def _post_run(self) -> None: - df = pd.DataFrame(self.scores) - df["bin"] = (df["score"] == df["total"]).astype(int) - avg = df["score"].sum() / df["total"].sum() - overall_avg = (df["score"] / df["total"]).mean() - bin_avg = df["bin"].mean() - print(f"Score: {df['score'].sum()}/{df['total'].sum()}") - print(f"AVG Score: {avg}") - print(f"Overall AVG Score: {overall_avg}") - print(f"Binary Score: {bin_avg}") - - async def _process_row(self, row: dict[str, Any]) -> dict[str, Any]: - gold = row["gold_value_links"] - pred = row["filtered_value_links"] - print("##############################") - print(f"GOLD: {gold}") - print("------------------------------") - print(f"PRED: {pred}") - print("##############################") - score = 0 - total = 0 - for gk, gv in gold.items(): - total += 1 - if gk in pred and pred[gk] == gv: - score += 1 - - self.scores.append({"score": score, "total": total}) - return row diff --git a/src/pipe/estimate_sql.py b/src/pipe/estimate_sql.py deleted file mode 100644 index b2734fc..0000000 --- a/src/pipe/estimate_sql.py +++ /dev/null @@ -1,85 +0,0 @@ -"""SQL query estimation and scoring.""" - -import re -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor - - -PROMPT = """ -I give you a natural language question and a database schema. -Give me the SQL that can answer the given question. - -Example: -NL Question: "What is the name of the instructor who has the lowest salary?" -DB Schema: -tables: - instructor: - - id: text - - name: text - - dept_name: text - - salary: number - department: - - dept_name: text - - building: text - - budget: number - -SQL: "SELECT name FROM instructor ORDER BY salary LIMIT 1" - -Now generate the SQL for the following data: -NL Question: {question} -DB Schema: {schema} -""" - -N = 3 - - -class EstimateSQL(PromptProcessor): - """ - Estimate SQL queries from natural language questions. - - Uses an LLM to generate SQL queries based on natural language questions - and database schemas. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> str: - """ - Extract SQL query from LLM output. - - Parameters - ---------- - row : dict - Data row (unused in this implementation) - output : str - Raw LLM output containing SQL in markdown code blocks - - Returns - ------- - str - Cleaned SQL query - """ - masked = re.findall(r"```([\s\S]*?)```", output) - final_answer = masked[0] - final_answer = final_answer.strip() - final_answer = final_answer.replace("\n", " ") - if final_answer.startswith("sql"): - final_answer = final_answer[3:] - return final_answer - - def _get_prompt(self, row: dict[str, Any]) -> str: - """ - Generate prompt for SQL estimation. - - Parameters - ---------- - row : dict - Data row containing question and schema - - Returns - ------- - str - Formatted prompt for LLM - """ - schema = row["schema"] - question = row["question"] - return PROMPT.format(question=question, schema=schema) diff --git a/src/pipe/filer_schema_items.py b/src/pipe/filer_schema_items.py deleted file mode 100644 index 119f338..0000000 --- a/src/pipe/filer_schema_items.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Filter schema items based on relevance.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.filer_schema_links import CONCEPTS -from src.pipe.llm_util import extract_object -from src.pipe.schema_items_filter_prompts.v1 import FILTER_SCHEMA_ITEMS_PROMPT_V1 - - -class FilterSchemaItems(PromptProcessor): - """ - Filter schema items based on relevance to predefined concepts. - - Uses LLM prompts to determine which schema items are relevant to - specific concepts like person names, locations, and occupations. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> Any: - return extract_object(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - schema_items = row["schema_items"] - return FILTER_SCHEMA_ITEMS_PROMPT_V1.format( - concepts=CONCEPTS, schema_items=schema_items - ) diff --git a/src/pipe/filer_schema_links.py b/src/pipe/filer_schema_links.py deleted file mode 100644 index b9d24cf..0000000 --- a/src/pipe/filer_schema_links.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Filter schema links based on relevance.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.pipe.schema_filter_prompts.v2 import FILTER_SCHEMA_LINKS_PROMPT_V2 - - -CONCEPTS = ["Person's name", "Location", "Occupation"] - - -class FilterSchemaLinks(PromptProcessor): - """ - Filter schema links based on relevance to predefined concepts. - - Uses LLM prompts to determine which schema links (mappings from question - terms to schema items) are relevant to specific concepts. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, Any]: - obj = extract_object(output) - if obj is None: - return {} - return obj - - def _get_prompt(self, row: dict[str, Any]) -> str: - schema_links = row["schema_links"] - question = row["question"] - return FILTER_SCHEMA_LINKS_PROMPT_V2.format( - concepts=CONCEPTS, question=question, schema_links=schema_links - ) diff --git a/src/pipe/filer_value_links.py b/src/pipe/filer_value_links.py deleted file mode 100644 index e16ae8e..0000000 --- a/src/pipe/filer_value_links.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Filter value links based on relevance.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.filer_schema_links import CONCEPTS -from src.pipe.llm_util import extract_object -from src.pipe.value_filter_prompts.v1 import VALUE_LINKS_FILTER_PROMPT_V1 - - -class FilterValueLinks(PromptProcessor): - """ - Filter value links based on relevance to predefined concepts. - - Uses LLM prompts to determine which value links (mappings from question - values to database columns) are relevant to specific concepts. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> Any: - return extract_object(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - value_links = row["value_links"] - return VALUE_LINKS_FILTER_PROMPT_V1.format( - concepts=CONCEPTS, question=question, value_links=value_links - ) diff --git a/src/pipe/gen_gold_mask.py b/src/pipe/gen_gold_mask.py deleted file mode 100644 index d286f18..0000000 --- a/src/pipe/gen_gold_mask.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Generate gold standard masked questions.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gold_mask.gold_mask_v1 import GOLD_MASK_V1 - - -class GenGoldMask(PromptProcessor): - """Generate gold standard masked questions for evaluation.""" - - def _process_output(self, row: dict[str, Any], output: str) -> str: - return output - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - return GOLD_MASK_V1.format(question=question, schema=schema) diff --git a/src/pipe/gen_gold_schema.py b/src/pipe/gen_gold_schema.py deleted file mode 100644 index bf888b7..0000000 --- a/src/pipe/gen_gold_schema.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Generate gold standard schema links.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gold_schema_link.v1 import GOLD_SCHEMA_LINKING_PROMPT_V1 -from src.pipe.llm_util import extract_object - - -class GenGoldLinks(PromptProcessor): - """Generate gold standard schema links from SQL queries.""" - - def _process_output(self, row: dict[str, Any], output: str) -> Any: - return extract_object(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - sql = row["query"] - return GOLD_SCHEMA_LINKING_PROMPT_V1.format( - question=question, schema=schema, sql=sql - ) diff --git a/src/pipe/gen_masked_sql_raw.py b/src/pipe/gen_masked_sql_raw.py deleted file mode 100644 index cf3a9ac..0000000 --- a/src/pipe/gen_masked_sql_raw.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Generate masked SQL without post-processing.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_sql import extract_sql -from src.pipe.sql_gen_prompts.masked_v3_raw import MASKED_GEN_SQL_RAW_PROMPT_V3 - - -DATA_DIR = "data" - - -class GenerateSymbolicSqlRaw(PromptProcessor): - """Generate SQL from raw symbolic inputs without post-processing.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, str]: - masked_sql = extract_sql(output) - return {"sql": masked_sql} - - def _get_prompt(self, row: dict[str, Any]) -> str: - inputs = row["symbolic"]["raw"] - return MASKED_GEN_SQL_RAW_PROMPT_V3.format(inputs=inputs) diff --git a/src/pipe/gen_sql.py b/src/pipe/gen_sql.py deleted file mode 100644 index 4965bf3..0000000 --- a/src/pipe/gen_sql.py +++ /dev/null @@ -1,70 +0,0 @@ -"""SQL query generation from natural language.""" - -import re -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.sql_gen_prompts.masked_v3 import MASKED_GEN_SQL_PROMPT_V3 -from src.utils.logging import logger - - -DATA_DIR = "data" - - -def extract_sql(output: str) -> str: - """ - Extract SQL query from LLM output. - - Parameters - ---------- - output : str - Raw LLM output containing SQL - - Returns - ------- - str - Extracted SQL query - """ - output = output.strip() - output = output.strip('"') - sql = "SELECT" - if output.startswith("SELECT"): - sql = output - elif "```sql" in output: - res = re.findall(r"```sql([\s\S]*?)```", output) - if res: - sql = res[0] - else: - logger.error( - f"Failed to extract sql from output with ```sql marker: {output}" - ) - elif "```" in output: - res = re.findall(r"```([\s\S]*?)```", output) - if res: - sql = res[0] - else: - logger.error(f"Failed to extract sql from output with ``` marker: {output}") - elif "`" in output: - res = re.findall(r"`([\s\S]*?)`", output) - if res: - sql = res[0] - else: - logger.error(f"Failed to extract sql from output with ` marker: {output}") - else: - logger.error(f"Failed to extract sql from output: {output}") - sql = sql.strip() - return sql.replace("\n", " ") - - -class GenSql(PromptProcessor): - """Generate SQL queries from natural language questions.""" - - def _process_output(self, row: dict[str, Any], output: str) -> str: - return extract_sql(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - # schema_items = row['schema_items'] - # return GEN_SQL_PROMPT_V1.format(question=question, schema_items=schema_items) - schema = row["schema"] - return MASKED_GEN_SQL_PROMPT_V3.format(question=question, schema=schema) diff --git a/src/pipe/gen_sql_prompts.py b/src/pipe/gen_sql_prompts.py deleted file mode 100644 index 6e0f5fc..0000000 --- a/src/pipe/gen_sql_prompts.py +++ /dev/null @@ -1,30 +0,0 @@ -"""SQL generation prompt utilities.""" - -GEN_SQL_PROMPT_V1 = """ -I give you a natural language question where I replaced some n-grams that reference a column name of the database -with symbolic variables like [T1].[C1] for columns and [T1] for tables. -Each of these variables represents a database schema item. -Schema items are also symbolic variables. -I will give the database schema based on these symbolic variables. -You should generate a symbolic SQL query that can be used to answer the question. -You should use the symbolic variables to generate the SQL query. - -Example: -Symbolic Question: "What is the T1.C1 of the T1 who has the lowest T1.C2?" -Symbolic Schema: - [T1]: - [T1].[C1]: text - [T2].[C2]: number - [T2].[C3]: - type: text - foreign_key: [T2].[C4] - [T2]: - [T2].[C4]: text -Symbolic SQL: "SELECT T1.C1 FROM T1 ORDER BY T1.C2 LIMIT 1" - -Now give me the symbolic SQL query for the following data: -Symbolic Question: {symbolic_question} -Symbolic Schema: {symbolic_schema} - -Go step by step but give the final answer wrapped in ``` ``` -""" diff --git a/src/pipe/gold_mask/__init__.py b/src/pipe/gold_mask/__init__.py deleted file mode 100644 index b990fd2..0000000 --- a/src/pipe/gold_mask/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Gold standard masking data and utilities.""" diff --git a/src/pipe/gold_mask/gold_mask_v1.py b/src/pipe/gold_mask/gold_mask_v1.py deleted file mode 100644 index ec690b1..0000000 --- a/src/pipe/gold_mask/gold_mask_v1.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Gold standard mask template version 1.""" - -GOLD_MASK_V1 = """ -I'll give you a natural language question and the schema of the underlying database. -Your task is to mask every references to the database items in the given question. -Database items means columns or tables. -References to database have three types: -1- Tables: using the name of table or some term referencing a column -2- Columns: using the name of a column or some term referencing a column -3- Literal Values: literal values that are related to database columns - -Your goal is to find all such references and mask them using place holders. -You should use [M1],[M2],... symbols to replace all such references. - - -Here are some examples: ------------------------------------ -Example 1: -Question: -Among the German customers, how many of the them has credit limit of zero? - -Database Schema: -'[customers]': - '[country]': text - '[creditlimit]': real - '[customernumber]': - primary_key: true - type: integer - -OUTPUT: -Among the [M1] [M2], how many of them has [M3] of [M4]? ------------------------------------ - -Now, based on the given question and database schema returned the masked question: -Question: {question} -DB Schema: {schema} -""" diff --git a/src/pipe/gold_schema_link/__init__.py b/src/pipe/gold_schema_link/__init__.py deleted file mode 100644 index 53caebb..0000000 --- a/src/pipe/gold_schema_link/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Gold standard schema linking data and utilities.""" diff --git a/src/pipe/gold_schema_link/repair.py b/src/pipe/gold_schema_link/repair.py deleted file mode 100644 index 5382eac..0000000 --- a/src/pipe/gold_schema_link/repair.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Gold standard schema link repair utilities.""" - -GOLD_SCHEMA_LINKING_REPAIR_PROMPT_V1 = """ -You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) -of a natural-language question to database schema items (tables or fully qualified columns). -Schema links is a mapping from question terms to the database schema items. - -You are given: -- A natural language question. -- A database schema -- A SQL query -- Schema Links - -Goal -Inspect the given schema links -Return a JSON object mapping relevant n-grams (contiguous word sequences of length 1–3 taken from the question text) -to the single most relevant schema item or a list of relevant schema items. -You should look at the SQL query and extract the mapping between the question terms and database -schema items. -The keys of the mapping are n-grams of the question and values should be a schema item. -Each schema item has one of the following forms: -- "TABLE:[table]": if n-gram references a table -- "COLUMN:[table].[column]": if n-gram references a column -- "VALUE:[table].[column]": if n-gram is a literal value related to a column - -Schema items should be valid with respect to the given database schema. - -Mapping Rules: -- Consider all 1-, 2-, and 3-word spans. -- Include a mapping only if the n-gram refers to a schema item. -- Prefer the most specific applicable item: column beats table when the question refers to an attribute. -- Chose the shortest n-gram that maps to the schema item. -- If removing a word from an n-gram still points to the same schema item, use the shorter version. -- Exclude stop words from the n-grams - -Output Rules: -- Output only a JSON object representing the mapping. -- Each entry should be a key-value pair where the key is an n-gram and the value is a schema item. -- Value of each entry can only be a single string of the form "COLUMN:[table].[column]" or "TABLE:[table]". -- All json key and values should be in double quotes. -- Output should be a top-level JSON object. No nested keys. - -Example: -NL Question: What is the release title of the music that was released by Ron Hunt in 1979 that was downloaded 239 times? -Database Schema: - songs: - rt: text - artist: text - releasetype: text - year: number - totalsnatched: number - tags: - tag: text - index: number - id: number - -SQL: -SELECT [rt] FROM [songs] WHERE [artist] LIKE 'ron hunt' AND [groupYear] = 1979 AND [totalSnatched] = 239 - -OUTPUT: -{{ - "release title": "COLUMN:[torrents].[rt]", - "music": "TABLE:[songs]", - "Ron Hunt": "COLUMN:[songs].[artist]", - "1979": "VALUE:[songs].[year]", - "downloaded": "COLUMN:[songs].[totalsnatched]" - "239": "VALUE:[songs].[totalsnatched]" -}} - -Now generate the JSON object of mapping for the following question and schema items: -Question: {question} -DB Schema: {schema} -SQL: {sql} -""" diff --git a/src/pipe/gold_schema_link/v1.py b/src/pipe/gold_schema_link/v1.py deleted file mode 100644 index 32e8d15..0000000 --- a/src/pipe/gold_schema_link/v1.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Gold standard schema link template version 1.""" - -GOLD_SCHEMA_LINKING_PROMPT_V1 = """ -You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) -of a natural-language question to database schema items (tables or fully qualified columns). - -You are given: -- A natural language question. -- A database schema -- A SQL query - -Goal -Return a JSON object mapping relevant n-grams (contiguous word sequences of length 1–3 taken from the question text) -to the single most relevant schema item or a list of relevant schema items. -You should look at the SQL query and extract the mapping between the question terms and database -schema items. -The keys of the mapping are n-grams of the question and values should be a schema item. -Each schema item has one of the following forms: -- "TABLE:[table]": if n-gram references a table -- "COLUMN:[table].[column]": if n-gram references a column -- "VALUE:[table].[column]": if n-gram is a literal value related to a column - -Schema items should be valid with respect to the given database schema. - -Mapping Rules: -- Consider all 1-, 2-, and 3-word spans. -- Include a mapping only if the n-gram refers to a schema item. -- Prefer the most specific applicable item: column beats table when the question refers to an attribute. -- Chose the shortest n-gram that maps to the schema item. -- If removing a word from an n-gram still points to the same schema item, use the shorter version. -- Exclude stop words from the n-grams -- Look for typos and other mistakes, user might meant to reference a table or column but having a typo in the question - -Output Rules: -- Output only a JSON object representing the mapping. -- Each entry should be a key-value pair where the key is an n-gram and the value is a schema item. -- Value of each entry can only be a single string of the form "COLUMN:[table].[column]" or "TABLE:[table]". -- All json key and values should be in double quotes. -- Output should be a top-level JSON object. No nested keys. - -Example: -NL Question: What is the release title of the music that was released by Ron Hunt in 1979 that was downloaded 239 times? -Database Schema: - songs: - rt: text - artist: text - releasetype: text - year: number - totalsnatched: number - tags: - tag: text - index: number - id: number - -SQL: -SELECT [rt] FROM [songs] WHERE [artist] LIKE 'ron hunt' AND [groupYear] = 1979 AND [totalSnatched] = 239 - -OUTPUT: -{{ - "release title": "COLUMN:[torrents].[rt]", - "music": "TABLE:[songs]", - "Ron Hunt": "COLUMN:[songs].[artist]", - "1979": "VALUE:[songs].[year]", - "downloaded": "COLUMN:[songs].[totalsnatched]" - "239": "VALUE:[songs].[totalsnatched]" -}} - -Now generate the JSON object of mapping for the following question and schema items: -Question: {question} -DB Schema: {schema} -SQL: {sql} -""" diff --git a/src/pipe/link_schema_and_value.py b/src/pipe/link_schema_and_value.py deleted file mode 100644 index e4959a9..0000000 --- a/src/pipe/link_schema_and_value.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Combined schema and value linking.""" - -from typing import Any - -from src.pipe.able_prompts.schema_value_link import SCHEMA_VALUE_LINK_PROMPT_V1 -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.utils.logging import logger - - -class LinkSchemaAndValue(PromptProcessor): - """Link question terms to both schema items and values.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, Any]: - schema_links = extract_object(output) - if schema_links is None: - schema_links = {} - question = row["question"] - schema_items = row["schema_items"] - refined_links: dict[str, Any] = {} - if isinstance(schema_links, (list, str)): - logger.error(f"Invalid schema links: {schema_links}") - refined_links = {} - - for question_term, schema_item in schema_links.items(): - if question_term.lower() not in question.lower(): - logger.error( - f"Invalid schema link {question_term} -> {schema_item}, term not found in question" - ) - continue - orig_schema_item = schema_item - normalized_item = ( - schema_item.replace("VALUE:", "COLUMN:") - if "VALUE:" in schema_item - else schema_item - ) - if normalized_item.lower() not in [i.lower() for i in schema_items]: - logger.error( - f"Invalid schema link {question_term} -> {orig_schema_item}, schema item not exists" - ) - continue - refined_links[question_term] = orig_schema_item - return refined_links - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema_items = row["schema_items"] - value_list = row["values"] - return SCHEMA_VALUE_LINK_PROMPT_V1.format( - schema_items=schema_items, question=question, value_List=value_list - ) diff --git a/src/pipe/monitor/__init__.py b/src/pipe/monitor/__init__.py deleted file mode 100644 index 655d1fe..0000000 --- a/src/pipe/monitor/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Monitoring utilities for pipeline execution.""" diff --git a/src/pipe/monitor/lib.py b/src/pipe/monitor/lib.py deleted file mode 100644 index bce43f0..0000000 --- a/src/pipe/monitor/lib.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Monitoring and logging utilities.""" - -import math -import subprocess -from datetime import datetime - -import pandas as pd - - -class TimeLogger: - """ - Logger for tracking operation timing. - - Parameters - ---------- - idx : str - Identifier for the logged operation - """ - - idx: str - - def __init__(self, idx: str): - self.idx = idx - - @staticmethod - def start(idx: str) -> "TimeLogger": - """ - Start timing an operation. - - Parameters - ---------- - idx : str - Operation identifier - - Returns - ------- - TimeLogger - Timer instance - """ - # logger.info(f"started", idx=f"{idx}", start=True) - return TimeLogger(idx) - - def lap(self) -> None: - """Record lap time for operation.""" - pass - # logger.info(f"finished", idx=f"{self.idx}", finish=True) - - -class Timer: - """Simple timer for measuring elapsed time.""" - - start_time: datetime - - def __init__(self) -> None: - self.start_time = datetime.now() - - @staticmethod - def start() -> "Timer": - """ - Start a new timer. - - Returns - ------- - Timer - New timer instance - """ - return Timer() - - def lap(self) -> float: - """ - Get elapsed time since timer start. - - Returns - ------- - float - Elapsed time in seconds - """ - return (datetime.now() - self.start_time).total_seconds() - - -def confidence_interval(column: pd.Series) -> str: - """ - Calculate confidence interval for numeric column. - - Parameters - ---------- - column : pd.Series - Numeric data series - - Returns - ------- - str - Formatted confidence interval string - """ - if not pd.api.types.is_numeric_dtype(column): - return "NA" - z = 1.65 - se = column.std() / math.sqrt(column.size) - err_margin = z * se - mean = column.mean() - interval_start = mean - err_margin - interval_end = mean + err_margin - if ( - interval_start >= 0 - and interval_end <= 1 - and interval_start >= 0 - and interval_end <= 1 - ): - return "({:.2f}%, {:.2f}%)".format(interval_start * 100, interval_end * 100) - return "({:.2f}, {:.2f})".format(interval_start * 100, interval_end * 100) - # return f"({interval_start}, {interval_end})" - - -def execute_command(command: str) -> None: - """ - Execute shell command and capture output. - - Parameters - ---------- - command : str - Shell command to execute - - Raises - ------ - subprocess.CalledProcessError - If command execution fails - """ - with subprocess.Popen( - command, shell=True, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True - ) as p: - output, errors = p.communicate() - print(output, errors) - if p.returncode != 0: - raise subprocess.CalledProcessError(p.returncode, p.args) diff --git a/src/pipe/processor/print_results.py b/src/pipe/processor/print_results.py deleted file mode 100644 index 6c3590f..0000000 --- a/src/pipe/processor/print_results.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Results printing processor.""" - -from src.pipe.attack import AddInferenceAttack -from src.pipe.processor.list_processor import JsonListProcessor - - -def print_color(text: str, color: str = "green") -> None: - """ - Print text with ANSI color codes. - - Parameters - ---------- - text : str - Text to print - color : str, optional - Color name (red, green, blue), default green - """ - colors = { - "red": "\033[91m", - "green": "\033[92m", - "blue": "\033[94m", - } - reset = "\033[0m" - color_code = colors.get(color.lower(), "") - print(f"{color_code}{text}{reset}") - - -class PrintResults( - JsonListProcessor[AddInferenceAttack.Model, AddInferenceAttack.Model] -): - """Print execution accuracy and privacy metrics.""" - - def __init__(self) -> None: - super().__init__(AddInferenceAttack.Model) - self.score = 0.0 - self.pre_score = 0 - self.total = 0 - self.total_toks = 0 - self.total_gold_masks = 0 - self.masks = 0 - self.leakage = 0 - self.total_masks = 0 - - def _post_run(self) -> None: - print(f"PreScore: {self.pre_score}/{self.total}") - print(f"Accuracy: {self.score}/{self.total}") - print(f"Masked: {self.masks}/{self.total_gold_masks}") - print(f"Leak: {self.leakage}/{self.total_masks}") - # print(f"Toks: {self.total_toks}/{self.total}") - - async def _process_row( - self, row: AddInferenceAttack.Model - ) -> AddInferenceAttack.Model: - self.total += 1 - # self.total_toks += row['total_toks'] - exec_acc = row.eval.acc - if exec_acc == 0: - print(f"#{row.idx}") - print(f"Q: {row.question}") - - self.score += exec_acc - # pre_score = row['pre_eval']['acc'] - # self.pre_score += pre_score - - # Uncomment and modify as needed for additional functionality - # masked_terms = row["symbolic"]["masked_terms"] - # gold_links = row["gold_links"] - # masks = 0 - # for q_term, _schema_item in gold_links.items(): - # for p_term in masked_terms: - # if similar(p_term, q_term): - # masks += 1 - # - # self.total_masks += len(masked_terms) - # self.masks += masks - # self.total_gold_masks += len(gold_links.keys()) - # - # if "attack" in row: - # guess = row["attack"] - # leakage = 0 - # leak_terms = [] - # for term in masked_terms: - # if term.lower() in guess.lower(): - # leakage += 1 - # leak_terms.append(term) - # self.leakage += leakage - - # Additional debugging output can be added here if needed - # print(f"MASKED: {row['symbolic']['masked']}") - # if "symbolic" in row: - # print(f"Masked Question: {row['symbolic']['question']}") - # print_color(f"Question: {row['question']}", "green") - # print_color(f"Gold: {row['query']}", "green") - # print(f"Pred: {row['pred_sql']}") - # print(f"Conc: {row['concrete_sql']}") - # print(f"Masked SQL: {row['symbolic']['sql']}") - # print(f"Schema Items: {row['schema_items']}") - # print(f"Schema Links: {row['schema_links']}") - # print(f"Filtered Schema Links: {row['filtered_schema_links']}") - # print(f"Value Links: {row['value_links']}") - # print(f"Filtered Value Links: {row['filtered_value_links']}") - # print("\n") - # print("RESULTS: ") - # if row['eval']['acc'] == 0: - # print_color(f"GOLD RES: {row['eval']['gold']}", "green") - # print_color(f"PRED RES: {row['eval']['pred']}", "red") - # print_color(f"PRED ERR: {row['eval']['pred_err']}", "red") - # print("\n") - # print("#" * 10) - # print(f"Schema:\n {row['schema']}") - # print("#" * 10) - # print("#" * 10) - # print(f"Symbolic Schema:\n {row['symbolic']['schema']}") - # print("#" * 10) - - return row diff --git a/src/pipe/repair_link_schema.py b/src/pipe/repair_link_schema.py deleted file mode 100644 index 4f4f2a2..0000000 --- a/src/pipe/repair_link_schema.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Schema link repair utilities.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.pipe.schema_link_prompts.repair import REPAIR_SCHEMA_LINK_PROMPT_V1 -from src.utils.logging import logger - - -class RepairSchemaLinks(PromptProcessor): - """Repair and refine schema links based on validation.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, Any]: - schema_links = extract_object(output) - question = row["question"] - schema_items = row["schema_items"] - refined_links: dict[str, Any] = {} - if isinstance(schema_links, (list, str)): - logger.error(f"Invalid schema links: {schema_links}") - refined_links = {} - - if schema_links is not None: - for question_term, schema_item in schema_links.items(): - if question_term not in question or schema_item not in schema_items: - logger.error( - f"Invalid schema link {question_term} -> {schema_item}" - ) - continue - refined_links[question_term] = schema_item - return refined_links - - def get_n_grams(self, text: str, n: int) -> list[list[str]]: - """ - Extract n-grams from text. - - Parameters - ---------- - text : str - Input text - n : int - Size of n-grams - - Returns - ------- - list - List of n-grams as word lists - """ - words = text.split(" ") - return [words[i : i + n] for i in range(len(words) - n + 1)] - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema_items = row["schema_items"] - value_list = row["values"] - schema_links = row["schema_links"] - return REPAIR_SCHEMA_LINK_PROMPT_V1.format( - schema_items=schema_items, - question=question, - value_List=value_list, - schema_links=schema_links, - ) diff --git a/src/pipe/schema_filter_prompts/filter_annotated_links.py b/src/pipe/schema_filter_prompts/filter_annotated_links.py deleted file mode 100644 index 2945e54..0000000 --- a/src/pipe/schema_filter_prompts/filter_annotated_links.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Filter annotated schema links prompt.""" - -FILTER_ANNOTATED_LINKS = """ -You are an assistant that filters a given dictionary to retain items that are related to a set of concepts. - -You are given: - 1. A natural language question. - 2. A mapping (SchemaLinks) from n-grams in the question to relevant table or column names - in a database schema. - -Goal: -Return a filtered JSON object that contains only those key-value pairs from SchemaLinks that are -related to at least one the following concepts: -{concepts} - -Output Rules -- Do not add, alter, or rename keys or values. Only delete non-matching entries. -- Output valid JSON only: double quotes around all keys and string values; no trailing commas. -- If no entries match, return an empty JSON object. -- Do not include any additional text, explanations, or formatting. - -Example: -Question: “What is the name of the instructor who has the lowest salary?” -SchemaLinks: -{{ - "name": "COLUMN:[instructor].[name]", - "salary": "COLUMN:[instructor].[salary]", - "instructor": "TABLE:[instructor]" -}} - -Output: -{{ - "name": "COLUMN:[instructor].[name]", -}} - -Now filter the following SchemaLinks mapping based on the given question and concepts. -You should generate a valid JSON object. -Question: {question} -SchemaLinks: {schema_links} - -""" diff --git a/src/pipe/schema_items_filter_prompts/__init__.py b/src/pipe/schema_items_filter_prompts/__init__.py deleted file mode 100644 index 3e433ff..0000000 --- a/src/pipe/schema_items_filter_prompts/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Prompt templates for schema items filtering.""" diff --git a/src/pipe/schema_items_filter_prompts/v1.py b/src/pipe/schema_items_filter_prompts/v1.py deleted file mode 100644 index c86cbce..0000000 --- a/src/pipe/schema_items_filter_prompts/v1.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Schema items filter prompt template version 1.""" - -FILTER_SCHEMA_ITEMS_PROMPT_V1 = """ -You are an assistant that filters a list of a database schema items based on their relevance to a set of concepts. - -Input: - Schema Items: a list of table names and fully qualified column names. Each Schema Item is of the form - "TABLE:[table_name]" or "COLUMN:[table_name].[column_name]". - -Goal: -Return a filtered list of schema items that contains only those that are related to at least one the following concepts: -{concepts} - -Output Rules -- Do not add, alter, or rename any item. -- Output should be a valid JSON list. -- Do not include any additional text, explanations, or formatting. - -Example: -Schema Items: -[ - "COLUMN:[instructor].[name]", - "COLUMN:[instructor].[salary]", - "TABLE:[instructor]" -] - -Output: -[ - "name": "COLUMN:[instructor].[name]", -] - -Now filter the following Schema Items: -Schema Items: {schema_items} -""" diff --git a/src/pipe/schema_link_prompts/repair.py b/src/pipe/schema_link_prompts/repair.py deleted file mode 100644 index 853f140..0000000 --- a/src/pipe/schema_link_prompts/repair.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Schema link repair prompt utilities.""" - -REPAIR_SCHEMA_LINK_PROMPT_V1 = """ -You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) -of a natural-language question to database schema items (tables or fully qualified columns). -Your goal is to repair a given schema links. -Keys of the dictionary are n-grams in the question and values or schema items. -Your goal is to make sure that the given schema links is correct according to the following rules. - -You are given: -- Question: a natural language question. -- Schema Items: a list of schema items (table names or fully qualified column names). -- Value List: a list of n-grams in the question that represent literal values, entities, constants, etc in the question. -- Schema Links: a dictionary that maps schema n-grams of the question to the schema items - -Goal: -Iterate through each key,value of the mapping and verify: -- Each key s a n-gram of the question -- Each value is a schema item included in the provided schema items list -- Each key doesn't exist in the given value list -- Do not change the key-value if its is correct - -If you found any errors, fix the issues with the minimum change by replacing some words. -Return the repaired schema link mapping. - -Here are some examples: - ---------------------------------------------- -Example 1: -Question: -“What is the name of the instructor who has the lowest salary and located in London?” -Schema items: -["TABLE:[instructor]", "COLUMN:[instructor].[name]", "COLUMN:[instructor].[salary]", "TABLE:[department]", "COLUMN[department].[name]"] -Value List: -[ "London" ] -Schema Links: -{{ - "lowest salary": "COLUMN:[instructor].[salary]", - "who": "TABLE:[instructor]" -}} - -Output: -{{ - "name": "COLUMN:[instructor].[name]", - "salary": "COLUMN:[instructor].[salary]", - "instructor": "TABLE:[instructor]" -}} - -Now generate the repaired JSON object of mapping for the following question, schema items, value list, and schema links: -Question: {question} -Schema items: {schema_items} -Value List: {value_List} -Schema Links: {schema_links} -""" diff --git a/src/pipe/slm_mask.py b/src/pipe/slm_mask.py deleted file mode 100644 index d2f340a..0000000 --- a/src/pipe/slm_mask.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Small language model masking.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_sql import extract_sql -from src.pipe.slm_mask_prompts.mask_v1 import SLM_MASK_PROMPT_V1 -from src.pipe.slm_mask_prompts.unmask_v1 import SLM_UNMASK_PROMPT_V1 - - -class SlmMask(PromptProcessor): - """Generate masked questions using small language model.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, str]: - return {"raw": output} - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - return SLM_MASK_PROMPT_V1.format(question=question, schema=schema) - - -class SlmUnmask(PromptProcessor): - """Unmask questions and generate SQL using small language model.""" - - def _process_output(self, row: dict[str, Any], output: str) -> str: - return extract_sql(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - masked_raw = row["symbolic"]["raw"] - masked_sql = row["symbolic"]["sql"] - return SLM_UNMASK_PROMPT_V1.format( - question=question, - schema=schema, - masked_raw=masked_raw, - masked_sql=masked_sql, - ) diff --git a/src/pipe/slm_mask_for_det_unmask.py b/src/pipe/slm_mask_for_det_unmask.py deleted file mode 100644 index 67305dd..0000000 --- a/src/pipe/slm_mask_for_det_unmask.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Deterministic unmasking for small language models.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.slm_mask_prompts.mask_and_schema_link_v2 import ( - SLM_MASK_AND_LINK_PROMPT_V2, -) - - -class SlmMaskWithSymbolTable(PromptProcessor): - """Generate masked questions with symbol table for deterministic unmasking.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, str]: - return {"question": output} - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - symbol_table = row["symbolic"]["to_symbol"] - value_links = row["value_links"] - return SLM_MASK_AND_LINK_PROMPT_V2.format( - question=question, - schema=schema, - symbol_table=symbol_table, - value_links=value_links, - ) diff --git a/src/pipe/slm_sql.py b/src/pipe/slm_sql.py deleted file mode 100644 index 93fffc1..0000000 --- a/src/pipe/slm_sql.py +++ /dev/null @@ -1,49 +0,0 @@ -"""SLM (Small Language Model) SQL generation module.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_sql import extract_sql -from src.pipe.slm_sql_prompt.v1 import GENERATE_SQL_PROMPT_V1 - - -class SlmSQL(PromptProcessor): - """Generate SQL using a small language model. - - This class processes natural language questions and database schemas - to generate SQL queries using a small language model. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> str: - """Process the LLM output to extract SQL. - - Parameters - ---------- - row : dict[str, Any] - The input row data. - output : str - The raw output from the language model. - - Returns - ------- - str - The extracted SQL query. - """ - return extract_sql(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - """Generate the prompt for SQL generation. - - Parameters - ---------- - row : dict[str, Any] - The input row containing 'question' and 'schema'. - - Returns - ------- - str - The formatted prompt for the language model. - """ - question = row["question"] - schema = row["schema"] - return GENERATE_SQL_PROMPT_V1.format(question=question, schema=schema) diff --git a/src/pipe/slm_sql_prompt/v1.py b/src/pipe/slm_sql_prompt/v1.py deleted file mode 100644 index f97b504..0000000 --- a/src/pipe/slm_sql_prompt/v1.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Small language model SQL generation prompt version 1.""" - -GENERATE_SQL_PROMPT_V1 = """ -You are a SQL generation assistant. Given -(1) NL Question: a natural-language question about a dataset and -(2) DB Schema: the database’s schema expressed in YAML -produce a single SQL SELECT statement that answers the question. -We also provided the immediate generated SQL query as your reference to find the related database items needed for the given question. -The provided SQL query is not always complete and accurate and you have to use it as a hint only. - -Input Format: -- DB Schema: given in YAML format where top-level keys are table names; each table lists its columns and their data types. -- Column names are case-sensitive exactly as shown in the schema. -- Each column might be primary key or a foreign key. -- For foreign key columns, fully qualified name of the referenced column is given - -Output Rules -- Table and column names specified in the database schema already wrapped in brackets. You should use them with the brackets. -You should not remove the brackets when using them in the SQL. -- Each reference to a table or column name should be of the form [table_name] or [table_name].[column_name]. -- Output ONLY the SQL query (no extra explanation or text). -- Use fully qualified column names: table.column. -- Only reference tables/columns that exist in the provided schema. -- Do not include any comments. -- For columns names with spaces, wrap them in backticks, e.g. "WHERE `car model` = 'bar'" instead of "WHERE car model = 'bar'". - -Here are some examples: - ------------------------------------ -Example 1: -NL Question: What is the release title of the single that was released by Ron Hunt in 1979 that was downloaded 239 times? release title refers to groupName; Ron Hunt is an artist; groupYear = 1979; releaseType = 'single'; downloaded 239 times refer to totalSnatched = 239; -Database Schema: - torrents: - groupname: text - artist: text - releasetype: text - groupyear: number - totalsnatched: number - tags: - tag: text - index: number - id: number - -OUTPUT: -SELECT [groupName] FROM [torrents] WHERE [artist] LIKE 'ron hunt & ronnie g & the sm crew' AND [groupYear] = 1979 AND [releaseType] LIKE 'single' AND [totalSnatched] = 239 - - ------------------------------------ -Example 2: -NL Question: How many times was the album released by blowfly in 1980 downloaded? blowfly is an artist; groupYear = 1980; album refers to releaseType; downloaded refers to totalSnatched; -Database Schema: - torrents: - groupname: text - artist: text - releasetype: text - groupyear: number - totalsnatched: number - tags: - tag: text - index: number - id: number - -OUTPUT: -SELECT [totalSnatched] FROM [torrents] WHERE [artist] LIKE 'blowfly' AND [groupYear] = 1980 - -Now, generate a SQL query for the following question and database schema: -Inputs: -question: {question} -schema: {schema} -""" diff --git a/src/pipe/slm_unmask_repair.py b/src/pipe/slm_unmask_repair.py deleted file mode 100644 index 7a95263..0000000 --- a/src/pipe/slm_unmask_repair.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Small language model unmasking and repair.""" - -from typing import Any - -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_sql import extract_sql -from src.pipe.slm_mask_prompts.unmask_and_repair_v1 import ( - SLM_UNMASK_AND_REPAIR_PROMPT_V1, -) - - -class SlmUnmaskAndRepair(PromptProcessor): - """Unmask questions and repair SQL using small language model.""" - - def _process_output(self, row: dict[str, Any], output: str) -> str: - return extract_sql(output) - - def _get_prompt(self, row: dict[str, Any]) -> str: - question = row["question"] - schema = row["schema"] - masked_question = row["symbolic"]["question"] - masked_schema = row["symbolic"]["schema"] - masked_sql = row["symbolic"]["sql"] - return SLM_UNMASK_AND_REPAIR_PROMPT_V1.format( - question=question, - schema=schema, - masked_question=masked_question, - masked_schema=masked_schema, - masked_sql=masked_sql, - ) diff --git a/src/pipe/utils.py b/src/pipe/utils.py deleted file mode 100644 index 85b5a5e..0000000 --- a/src/pipe/utils.py +++ /dev/null @@ -1,151 +0,0 @@ -"""General utility functions for pipeline processing.""" - -import re -from datetime import datetime - -from src.utils.logging import logger - - -def replace_str(text: str, src: str, dst: str) -> str: - """ - Replace a substring in text with word boundaries. - - Parameters - ---------- - text : str - The text to search in - src : str - The substring to replace - dst : str - The replacement substring - - Returns - ------- - str - Text with replacements made - """ - try: - result = re.sub( - r"\b{}\b".format(re.escape(src)), dst, text, flags=re.IGNORECASE - ) - except Exception: - logger.error(f"Failed to replace {src} -> {dst} in {text}") - result = text - return result - - -def check_str(text: str, src: str) -> bool: - """ - Check if a substring exists in text with word boundaries. - - Parameters - ---------- - text : str - The text to search in - src : str - The substring to search for - - Returns - ------- - bool - True if substring found with word boundaries, False otherwise - """ - try: - pattern = r"\b{}\b".format(re.escape(src)) - if re.search(pattern, text, flags=re.IGNORECASE): - return True - except Exception: - logger.error(f"Failed to search {src} in {text}") - return False - - -def replace_str_punc(text: str, src: str, dst: str) -> str: - """ - Replace a substring in text with punctuation-aware boundaries. - - Parameters - ---------- - text : str - The text to search in - src : str - The substring to replace - dst : str - The replacement substring - - Returns - ------- - str - Text with replacements made - """ - try: - result = re.sub( - r"(? {dst} in {text}") - result = text - return result - - -def check_str_punc(text: str, src: str) -> bool: - """ - Check if a substring exists in text with punctuation-aware boundaries. - - Parameters - ---------- - text : str - The text to search in - src : str - The substring to search for - - Returns - ------- - bool - True if substring found with punctuation boundaries, False otherwise - """ - try: - pattern = r"(? None: - self.start_time = datetime.now() - - @staticmethod - def start() -> "Timer": - """ - Create and start a new timer. - - Returns - ------- - Timer - A new timer instance - """ - return Timer() - - def lap(self) -> float: - """ - Get elapsed time since timer started. - - Returns - ------- - float - Elapsed time in seconds - """ - return (datetime.now() - self.start_time).total_seconds() diff --git a/src/pipe/__init__.py b/src/pipeline/__init__.py similarity index 100% rename from src/pipe/__init__.py rename to src/pipeline/__init__.py diff --git a/src/pipe/add_schema.py b/src/pipeline/add_schema.py similarity index 93% rename from src/pipe/add_schema.py rename to src/pipeline/add_schema.py index ec0a034..6f1c7b4 100644 --- a/src/pipe/add_schema.py +++ b/src/pipeline/add_schema.py @@ -1,8 +1,8 @@ """Module for adding database schema information to data rows.""" -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.rank_schema import RankSchemaResd -from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.rank_schema import RankSchemaResd +from src.utils.schema_repo import DatabaseSchema, DatabaseSchemaRepo def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSchema: diff --git a/src/pipe/add_symb_schema.py b/src/pipeline/add_symb_schema.py similarity index 96% rename from src/pipe/add_symb_schema.py rename to src/pipeline/add_symb_schema.py index b9605c5..246677f 100644 --- a/src/pipe/add_symb_schema.py +++ b/src/pipeline/add_symb_schema.py @@ -2,10 +2,10 @@ from typing import Any -from src.pipe.link_schema import FilterSchemaLinksModel -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo -from src.pipe.symb_table import SymbolTableDicts +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.link_schema.link_schema import FilterSchemaLinksModel +from src.pipeline.symb_table import SymbolTableDicts +from src.utils.schema_repo import DatabaseSchema, DatabaseSchemaRepo class SymbolicSchema(SymbolTableDicts): diff --git a/src/pipe/det_mask.py b/src/pipeline/add_symbolic_question.py similarity index 97% rename from src/pipe/det_mask.py rename to src/pipeline/add_symbolic_question.py index 84dae25..d83759b 100644 --- a/src/pipe/det_mask.py +++ b/src/pipeline/add_symbolic_question.py @@ -2,9 +2,9 @@ import logging -from src.pipe.add_symb_schema import AddSymbolicSchema, SymbolicSchema -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.utils import replace_str +from src.pipeline.add_symb_schema import AddSymbolicSchema, SymbolicSchema +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.utils.strings import replace_str logger = logging.getLogger(__name__) diff --git a/src/pipe/attack_prompts/__init__.py b/src/pipeline/attack/__init__.py similarity index 100% rename from src/pipe/attack_prompts/__init__.py rename to src/pipeline/attack/__init__.py diff --git a/src/pipe/attack.py b/src/pipeline/attack/add_inference_attack.py similarity index 65% rename from src/pipe/attack.py rename to src/pipeline/attack/add_inference_attack.py index bb7bc94..8134084 100644 --- a/src/pipe/attack.py +++ b/src/pipeline/attack/add_inference_attack.py @@ -3,11 +3,10 @@ from typing import Any from src.config import OpenAIConfig -from src.pipe.attack_prompts.attack_raw_v1 import ATTACK_PROMPT_RAW_V1 -from src.pipe.attack_prompts.attack_v2 import ATTACK_PROMPT_V2 -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.exec_acc import CalcExecAcc -from src.pipe.repair_sql import RepairSQL +from src.pipeline.attack.prompts.attack_v2 import ATTACK_PROMPT_V2 +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.exec_acc import CalcExecAcc +from src.pipeline.repair_sql.repair_sql import RepairSQL # TODO[X]: ask to infer all tokens not only in question evidence is being ignored @@ -49,18 +48,3 @@ def _get_prompt(self, row: RepairSQL.Model) -> str: return ATTACK_PROMPT_V2.format( question=symbolic_question, schema=symbolic_schema ) - - -class AttackRaw(PromptProcessor): - """ - Processor for testing attacks on raw symbolic data. - - Tests inference attacks on raw symbolic representations without schema context. - """ - - def _process_output(self, row: dict[str, Any], output: str) -> str: - return output - - def _get_prompt(self, row: dict[str, Any]) -> str: - symbolic_raw = row["symbolic"]["raw"] - return ATTACK_PROMPT_RAW_V1.format(symbolic_raw=symbolic_raw) diff --git a/src/pipeline/attack/prompts/__init__.py b/src/pipeline/attack/prompts/__init__.py new file mode 100644 index 0000000..e0022fb --- /dev/null +++ b/src/pipeline/attack/prompts/__init__.py @@ -0,0 +1,4 @@ +"""Package containing prompt templates for the attack module. + +This package provides prompt templates used by the inference attack components. +""" diff --git a/src/pipe/attack_prompts/attack_v1.py b/src/pipeline/attack/prompts/attack_v1.py similarity index 97% rename from src/pipe/attack_prompts/attack_v1.py rename to src/pipeline/attack/prompts/attack_v1.py index 3ddfff5..f179a36 100644 --- a/src/pipe/attack_prompts/attack_v1.py +++ b/src/pipeline/attack/prompts/attack_v1.py @@ -1,4 +1,4 @@ -"""Attack prompt template version 1.""" +"""Attack prompts template version 1.""" ATTACK_PROMPT_V1 = """ Your goal is to guess the words in a masked question. Given diff --git a/src/pipe/attack_prompts/attack_v2.py b/src/pipeline/attack/prompts/attack_v2.py similarity index 97% rename from src/pipe/attack_prompts/attack_v2.py rename to src/pipeline/attack/prompts/attack_v2.py index 0abece6..32eebf8 100644 --- a/src/pipe/attack_prompts/attack_v2.py +++ b/src/pipeline/attack/prompts/attack_v2.py @@ -1,4 +1,4 @@ -"""Attack prompt template version 2.""" +"""Attack prompts template version 2.""" ATTACK_PROMPT_V2 = """ Your goal is to guess the words in a masked text. Given diff --git a/src/pipe/processor/__init__.py b/src/pipeline/base_processor/__init__.py similarity index 100% rename from src/pipe/processor/__init__.py rename to src/pipeline/base_processor/__init__.py diff --git a/src/pipe/processor/limit_list.py b/src/pipeline/base_processor/limit_list.py similarity index 85% rename from src/pipe/processor/limit_list.py rename to src/pipeline/base_processor/limit_list.py index eabead8..0166a97 100644 --- a/src/pipe/processor/limit_list.py +++ b/src/pipeline/base_processor/limit_list.py @@ -1,9 +1,9 @@ -"""List length limiting processor.""" +"""List length limiting base_processor.""" import os -from src.models.masksql_input import MaskSqlInput -from src.pipe.processor.list_processor import JsonListProcessor +from src.data_models.masksql_input import MaskSqlInput +from src.pipeline.base_processor.list_processor import JsonListProcessor START = int(os.environ.get("START", "0")) diff --git a/src/pipe/processor/list_processor.py b/src/pipeline/base_processor/list_processor.py similarity index 94% rename from src/pipe/processor/list_processor.py rename to src/pipeline/base_processor/list_processor.py index 1117dd6..bb06477 100644 --- a/src/pipe/processor/list_processor.py +++ b/src/pipeline/base_processor/list_processor.py @@ -5,8 +5,8 @@ from typing import Generic, Type, TypeVar from src.data_cache.json_cache import JsonCache -from src.models.base_object import BaseObject -from src.pipe.async_utils import apply_async +from src.data_models.base_object import BaseObject +from src.utils.async_utils import apply_async T = TypeVar("T", bound=BaseObject) @@ -74,12 +74,12 @@ async def _process_row(self, row: T) -> U: @property def name(self) -> str: """ - Get processor name. + Get base_processor name. Returns ------- str - Class name of processor + Class name of base_processor """ return self.__class__.__name__ diff --git a/src/pipe/processor/list_transformer.py b/src/pipeline/base_processor/list_transformer.py similarity index 79% rename from src/pipe/processor/list_transformer.py rename to src/pipeline/base_processor/list_transformer.py index 214fe93..a899a76 100644 --- a/src/pipe/processor/list_transformer.py +++ b/src/pipeline/base_processor/list_transformer.py @@ -4,8 +4,8 @@ import os from abc import ABC -from src.models.base_object import BaseObject -from src.pipe.processor.list_processor import JsonListProcessor +from src.data_models.base_object import BaseObject +from src.pipeline.base_processor.list_processor import JsonListProcessor logger = logging.getLogger(__name__) diff --git a/src/pipeline/base_processor/print_results.py b/src/pipeline/base_processor/print_results.py new file mode 100644 index 0000000..c479b3c --- /dev/null +++ b/src/pipeline/base_processor/print_results.py @@ -0,0 +1,22 @@ +"""Results printing base_processor.""" + + +def print_color(text: str, color: str = "green") -> None: + """ + Print text with ANSI color codes. + + Parameters + ---------- + text : str + Text to print + color : str, optional + Color name (red, green, blue), default green + """ + colors = { + "red": "\033[91m", + "green": "\033[92m", + "blue": "\033[94m", + } + reset = "\033[0m" + color_code = colors.get(color.lower(), "") + print(f"{color_code}{text}{reset}") diff --git a/src/pipe/processor/printer.py b/src/pipeline/base_processor/printer.py similarity index 81% rename from src/pipe/processor/printer.py rename to src/pipeline/base_processor/printer.py index 0e5cfa1..b20a678 100644 --- a/src/pipe/processor/printer.py +++ b/src/pipeline/base_processor/printer.py @@ -3,8 +3,8 @@ from abc import ABC from collections.abc import Callable -from src.models.masksql_input import MaskSqlInput -from src.pipe.processor.list_processor import JsonListProcessor +from src.data_models.masksql_input import MaskSqlInput +from src.pipeline.base_processor.list_processor import JsonListProcessor class LambdaPrinter(JsonListProcessor[MaskSqlInput, MaskSqlInput], ABC): diff --git a/src/pipe/detect_values_prompts/prompt_processor.py b/src/pipeline/base_processor/prompt_processor.py similarity index 82% rename from src/pipe/detect_values_prompts/prompt_processor.py rename to src/pipeline/base_processor/prompt_processor.py index 52df0f1..ee5c9e6 100644 --- a/src/pipe/detect_values_prompts/prompt_processor.py +++ b/src/pipeline/base_processor/prompt_processor.py @@ -1,15 +1,15 @@ -"""Base processor for LLM-based value detection.""" +"""Base base_processor for LLM-based value detection.""" from abc import ABC, abstractmethod from json import JSONDecodeError from typing import Any, Generic, Type, TypeVar from src.config import OpenAIConfig -from src.pipe.llm_util import send_prompt -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.util_processors import InitData -from src.pipe.utils import Timer +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.init_data import InitData +from src.utils.llm_util import send_prompt from src.utils.logging import logger +from src.utils.timer import Timer T = TypeVar("T", bound=InitData.Model) @@ -18,10 +18,10 @@ class PromptProcessor(JsonListProcessor[T, U], ABC, Generic[T, U]): """ - Base processor for LLM-based prompt processing. + Base base_processor for LLM-based prompts processing. This abstract class provides common functionality for processors that use - LLM prompts to transform data rows, including prompt logging and statistics + LLM prompts to transform data rows, including prompts logging and statistics tracking. Parameters @@ -53,10 +53,10 @@ async def _prompt_llm(self, row: T, prompt: str) -> tuple[Any, str]: try: res, toks = await send_prompt(prompt, self.openai_config, model=self.model) except JSONDecodeError as e: - logger.error(f"Sending prompt failed: {e}") + logger.error(f"Sending prompts failed: {e}") return "", "0" except Exception as e: - logger.error(f"Sending prompt failed: {e}") + logger.error(f"Sending prompts failed: {e}") raise e processed_res = self._process_output(row, res) return processed_res, toks diff --git a/src/pipe/processor/prop_printer.py b/src/pipeline/base_processor/prop_printer.py similarity index 93% rename from src/pipe/processor/prop_printer.py rename to src/pipeline/base_processor/prop_printer.py index ababaec..5310cbc 100644 --- a/src/pipe/processor/prop_printer.py +++ b/src/pipeline/base_processor/prop_printer.py @@ -2,7 +2,7 @@ from typing import Any -from src.pipe.processor.list_processor import JsonListProcessor +from src.pipeline.base_processor.list_processor import JsonListProcessor class PrintProps(JsonListProcessor[Any, Any]): diff --git a/src/pipe/detect_values_prompts/__init__.py b/src/pipeline/detect_values/__init__.py similarity index 100% rename from src/pipe/detect_values_prompts/__init__.py rename to src/pipeline/detect_values/__init__.py diff --git a/src/pipe/detect_entities.py b/src/pipeline/detect_values/detect_values.py similarity index 78% rename from src/pipe/detect_entities.py rename to src/pipeline/detect_values/detect_values.py index 77ae5fe..c19363d 100644 --- a/src/pipe/detect_entities.py +++ b/src/pipeline/detect_values/detect_values.py @@ -1,10 +1,10 @@ """Entity detection in natural language questions.""" from src.config import OpenAIConfig -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor, T -from src.pipe.detect_values_prompts.v3 import DETECT_VALUES_PROMPT_V3 -from src.pipe.llm_util import extract_object -from src.pipe.symb_table import AddSymbolTable +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.detect_values.prompts.v3 import DETECT_VALUES_PROMPT_V3 +from src.pipeline.symb_table import AddSymbolTable +from src.utils.llm_util import extract_object class DetectValues(PromptProcessor[AddSymbolTable.Model, "DetectValues.Model"]): @@ -28,7 +28,7 @@ def _get_result_data( def __init__(self, openai_config: OpenAIConfig, model: str) -> None: super().__init__(self.Model, openai_config, model) - def _process_output(self, row: T, output: str) -> list[str]: + def _process_output(self, row: AddSymbolTable.Model, output: str) -> list[str]: obj = extract_object(output) if obj is None: return [] diff --git a/src/pipeline/detect_values/prompts/__init__.py b/src/pipeline/detect_values/prompts/__init__.py new file mode 100644 index 0000000..d199b61 --- /dev/null +++ b/src/pipeline/detect_values/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the detect_values module. + +This package provides prompt templates used for detecting values in +natural language queries. +""" diff --git a/src/pipe/detect_values_prompts/v1.py b/src/pipeline/detect_values/prompts/v1.py similarity index 95% rename from src/pipe/detect_values_prompts/v1.py rename to src/pipeline/detect_values/prompts/v1.py index daf809f..82ae470 100644 --- a/src/pipe/detect_values_prompts/v1.py +++ b/src/pipeline/detect_values/prompts/v1.py @@ -1,4 +1,4 @@ -"""Value detection prompt template version 1.""" +"""Value detection prompts template version 1.""" PROMPT = """ You are given: diff --git a/src/pipe/detect_values_prompts/v2.py b/src/pipeline/detect_values/prompts/v2.py similarity index 97% rename from src/pipe/detect_values_prompts/v2.py rename to src/pipeline/detect_values/prompts/v2.py index a402ce6..c4ebb80 100644 --- a/src/pipe/detect_values_prompts/v2.py +++ b/src/pipeline/detect_values/prompts/v2.py @@ -1,4 +1,4 @@ -"""Value detection prompt template version 2.""" +"""Value detection prompts template version 2.""" DETECT_VALUES_PROMPT_V2 = """ You are given: diff --git a/src/pipe/detect_values_prompts/v3.py b/src/pipeline/detect_values/prompts/v3.py similarity index 96% rename from src/pipe/detect_values_prompts/v3.py rename to src/pipeline/detect_values/prompts/v3.py index 271f7f4..85885ee 100644 --- a/src/pipe/detect_values_prompts/v3.py +++ b/src/pipeline/detect_values/prompts/v3.py @@ -1,4 +1,4 @@ -"""Value detection prompt template version 3.""" +"""Value detection prompts template version 3.""" DETECT_VALUES_PROMPT_V3 = """ You are given a natural language question and a list of schema items diff --git a/src/pipe/exec_acc.py b/src/pipeline/exec_acc.py similarity index 92% rename from src/pipe/exec_acc.py rename to src/pipeline/exec_acc.py index 9261db6..adfac60 100644 --- a/src/pipe/exec_acc.py +++ b/src/pipeline/exec_acc.py @@ -4,9 +4,9 @@ from pydantic import BaseModel -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.repair_sql import RepairSQL -from src.pipe.sqlite_facade import SqliteFacade +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.repair_sql.repair_sql import RepairSQL +from src.utils.sqlite_facade import SqliteFacade class EvaluationData(BaseModel): diff --git a/src/pipe/exec_conc_sql.py b/src/pipeline/exec_conc_sql.py similarity index 95% rename from src/pipe/exec_conc_sql.py rename to src/pipeline/exec_conc_sql.py index c6cbc04..634a877 100644 --- a/src/pipe/exec_conc_sql.py +++ b/src/pipeline/exec_conc_sql.py @@ -4,10 +4,10 @@ from pydantic import BaseModel -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.sqlite_facade import SqliteFacade -from src.pipe.unmask import AddConcreteSql +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.unmask import AddConcreteSql from src.utils.logging import logger +from src.utils.sqlite_facade import SqliteFacade class PreEvaluation(BaseModel): diff --git a/src/pipe/schema_filter_prompts/__init__.py b/src/pipeline/filter_schema_links/__init__.py similarity index 100% rename from src/pipe/schema_filter_prompts/__init__.py rename to src/pipeline/filter_schema_links/__init__.py diff --git a/src/pipeline/filter_schema_links/filter_schema_links.py b/src/pipeline/filter_schema_links/filter_schema_links.py new file mode 100644 index 0000000..03b211a --- /dev/null +++ b/src/pipeline/filter_schema_links/filter_schema_links.py @@ -0,0 +1,36 @@ +"""Filter schema links based on relevance.""" + +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.filter_schema_links.prompts.v2 import FILTER_SCHEMA_LINKS_PROMPT_V2 +from src.pipeline.link_schema.link_schema import FilterSchemaLinksModel, LinkSchema +from src.utils.llm_util import extract_object + + +CONCEPTS = ["Person's name", "Location", "Occupation"] + + +class FilterSchemaLinks(PromptProcessor[LinkSchema.Model, FilterSchemaLinksModel]): + """ + Filter schema links based on relevance to predefined concepts. + + Uses LLM prompts to determine which schema links (mappings from question + terms to schema items) are relevant to specific concepts. + """ + + def _get_result_data( + self, row: LinkSchema.Model, llm_output: dict[str, str] + ) -> FilterSchemaLinksModel: + return FilterSchemaLinksModel(filtered_schema_links=llm_output, **row.dict()) + + def _process_output(self, row: LinkSchema.Model, output: str) -> dict[str, str]: + obj = extract_object(output) + if obj is None: + return {} + return obj + + def _get_prompt(self, row: LinkSchema.Model) -> str: + schema_links = row.schema_links + question = row.question + return FILTER_SCHEMA_LINKS_PROMPT_V2.format( + concepts=CONCEPTS, question=question, schema_links=schema_links + ) diff --git a/src/pipeline/filter_schema_links/prompts/__init__.py b/src/pipeline/filter_schema_links/prompts/__init__.py new file mode 100644 index 0000000..8dfd53e --- /dev/null +++ b/src/pipeline/filter_schema_links/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the filter_schema_links module. + +This package provides prompt templates used for filtering schema links +in database queries. +""" diff --git a/src/pipe/schema_filter_prompts/v1.py b/src/pipeline/filter_schema_links/prompts/v1.py similarity index 94% rename from src/pipe/schema_filter_prompts/v1.py rename to src/pipeline/filter_schema_links/prompts/v1.py index 8a6b397..230b574 100644 --- a/src/pipe/schema_filter_prompts/v1.py +++ b/src/pipeline/filter_schema_links/prompts/v1.py @@ -1,4 +1,4 @@ -"""Schema filter prompt template version 1.""" +"""Schema filter prompts template version 1.""" FILTER_SCHEMA_LINKS_PROMPT_V1 = """ You are given: diff --git a/src/pipe/schema_filter_prompts/v2.py b/src/pipeline/filter_schema_links/prompts/v2.py similarity index 96% rename from src/pipe/schema_filter_prompts/v2.py rename to src/pipeline/filter_schema_links/prompts/v2.py index 244a92a..b5b6747 100644 --- a/src/pipe/schema_filter_prompts/v2.py +++ b/src/pipeline/filter_schema_links/prompts/v2.py @@ -1,4 +1,4 @@ -"""Schema filter prompt template version 2.""" +"""Schema filter prompts template version 2.""" FILTER_SCHEMA_LINKS_PROMPT_V2 = """ You are an assistant that filters a given dictionary to retain items that are related to a set of concepts. diff --git a/src/pipe/value_filter_prompts/__init__.py b/src/pipeline/filter_value_links/__init__.py similarity index 100% rename from src/pipe/value_filter_prompts/__init__.py rename to src/pipeline/filter_value_links/__init__.py diff --git a/src/pipeline/filter_value_links/filter_value_links.py b/src/pipeline/filter_value_links/filter_value_links.py new file mode 100644 index 0000000..6d5a04f --- /dev/null +++ b/src/pipeline/filter_value_links/filter_value_links.py @@ -0,0 +1,34 @@ +"""Filter value links based on relevance.""" + +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.filter_schema_links.filter_schema_links import CONCEPTS +from src.pipeline.filter_value_links.prompts.v1 import VALUE_LINKS_FILTER_PROMPT_V1 +from src.pipeline.link_values.link_values import FilterValueLinksModel, LinkValues +from src.utils.llm_util import extract_object + + +class FilterValueLinks(PromptProcessor[LinkValues.Model, FilterValueLinksModel]): + """ + Filter value links based on relevance to predefined concepts. + + Uses LLM prompts to determine which value links (mappings from question + values to database columns) are relevant to specific concepts. + """ + + def _get_result_data( + self, row: LinkValues.Model, llm_output: dict[str, str] + ) -> FilterValueLinksModel: + return FilterValueLinksModel(filtered_value_links=llm_output, **row.dict()) + + def _process_output(self, row: LinkValues.Model, output: str) -> dict[str, str]: + result = extract_object(output) + if result is None: + return {} + return result + + def _get_prompt(self, row: LinkValues.Model) -> str: + question = row.question + value_links = row.values + return VALUE_LINKS_FILTER_PROMPT_V1.format( + concepts=CONCEPTS, question=question, value_links=value_links + ) diff --git a/src/pipeline/filter_value_links/prompts/__init__.py b/src/pipeline/filter_value_links/prompts/__init__.py new file mode 100644 index 0000000..4d4314a --- /dev/null +++ b/src/pipeline/filter_value_links/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the filter_value_links module. + +This package provides prompt templates used for filtering value links +in database queries. +""" diff --git a/src/pipe/value_filter_prompts/v1.py b/src/pipeline/filter_value_links/prompts/v1.py similarity index 97% rename from src/pipe/value_filter_prompts/v1.py rename to src/pipeline/filter_value_links/prompts/v1.py index e87d099..80d284d 100644 --- a/src/pipe/value_filter_prompts/v1.py +++ b/src/pipeline/filter_value_links/prompts/v1.py @@ -1,4 +1,4 @@ -"""Value filtering prompt template version 1.""" +"""Value filtering prompts template version 1.""" VALUE_LINKS_FILTER_PROMPT_V1 = """ You are given: diff --git a/src/pipe/value_filter_prompts/v2.py b/src/pipeline/filter_value_links/prompts/v2.py similarity index 96% rename from src/pipe/value_filter_prompts/v2.py rename to src/pipeline/filter_value_links/prompts/v2.py index 0095847..7bf1fcb 100644 --- a/src/pipe/value_filter_prompts/v2.py +++ b/src/pipeline/filter_value_links/prompts/v2.py @@ -1,4 +1,4 @@ -"""Value filtering prompt template version 2.""" +"""Value filtering prompts template version 2.""" FILTER_SCHEMA_LINKS_PROMPT_V2 = """ You are an assistant that filters a given dictionary to retain items that are related to a set of concepts. diff --git a/src/pipe/sql_gen_prompts/__init__.py b/src/pipeline/gen_sql/__init__.py similarity index 100% rename from src/pipe/sql_gen_prompts/__init__.py rename to src/pipeline/gen_sql/__init__.py diff --git a/src/pipe/gen_masked_sql.py b/src/pipeline/gen_sql/gen_masked_sql.py similarity index 84% rename from src/pipe/gen_masked_sql.py rename to src/pipeline/gen_sql/gen_masked_sql.py index 9923ee6..cc05b88 100644 --- a/src/pipe/gen_masked_sql.py +++ b/src/pipeline/gen_sql/gen_masked_sql.py @@ -3,10 +3,10 @@ from typing import Any from src.config import OpenAIConfig -from src.pipe.det_mask import AddSymbolicQuestion, SymbolicQuestion -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_sql import extract_sql -from src.pipe.sql_gen_prompts.masked_v4 import MASKED_GEN_SQL_PROMPT_V4 +from src.pipeline.add_symbolic_question import AddSymbolicQuestion, SymbolicQuestion +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.gen_sql.prompts.masked_v4 import MASKED_GEN_SQL_PROMPT_V4 +from src.utils.strings import extract_sql DATA_DIR = "data" diff --git a/src/pipeline/gen_sql/prompts/__init__.py b/src/pipeline/gen_sql/prompts/__init__.py new file mode 100644 index 0000000..1f64eff --- /dev/null +++ b/src/pipeline/gen_sql/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the gen_sql module. + +This package provides prompt templates used for generating +SQL queries from natural language. +""" diff --git a/src/pipe/sql_gen_prompts/masked_v1.py b/src/pipeline/gen_sql/prompts/masked_v1.py similarity index 98% rename from src/pipe/sql_gen_prompts/masked_v1.py rename to src/pipeline/gen_sql/prompts/masked_v1.py index 5cd7183..7de3f6c 100644 --- a/src/pipe/sql_gen_prompts/masked_v1.py +++ b/src/pipeline/gen_sql/prompts/masked_v1.py @@ -1,4 +1,4 @@ -"""Masked SQL generation prompt template version 1.""" +"""Masked SQL generation prompts template version 1.""" GEN_SQL_PROMPT_V1 = """ I'll give you a natural language question and the schema of the underlying database. diff --git a/src/pipe/sql_gen_prompts/masked_v2.py b/src/pipeline/gen_sql/prompts/masked_v2.py similarity index 97% rename from src/pipe/sql_gen_prompts/masked_v2.py rename to src/pipeline/gen_sql/prompts/masked_v2.py index fa35381..a538c49 100644 --- a/src/pipe/sql_gen_prompts/masked_v2.py +++ b/src/pipeline/gen_sql/prompts/masked_v2.py @@ -1,4 +1,4 @@ -"""Masked SQL generation prompt template version 2.""" +"""Masked SQL generation prompts template version 2.""" MASKED_GEN_SQL_PROMPT_V2 = """ I'll give you a natural language question and the schema of the underlying database diff --git a/src/pipe/sql_gen_prompts/masked_v3.py b/src/pipeline/gen_sql/prompts/masked_v3.py similarity index 97% rename from src/pipe/sql_gen_prompts/masked_v3.py rename to src/pipeline/gen_sql/prompts/masked_v3.py index fcda265..13cd49b 100644 --- a/src/pipe/sql_gen_prompts/masked_v3.py +++ b/src/pipeline/gen_sql/prompts/masked_v3.py @@ -1,4 +1,4 @@ -"""Masked SQL generation prompt template version 3.""" +"""Masked SQL generation prompts template version 3.""" MASKED_GEN_SQL_PROMPT_V3 = """ You are a SQL generation assistant. Given diff --git a/src/pipe/sql_gen_prompts/masked_v3_raw.py b/src/pipeline/gen_sql/prompts/masked_v3_raw.py similarity index 97% rename from src/pipe/sql_gen_prompts/masked_v3_raw.py rename to src/pipeline/gen_sql/prompts/masked_v3_raw.py index 3b730fa..f35398b 100644 --- a/src/pipe/sql_gen_prompts/masked_v3_raw.py +++ b/src/pipeline/gen_sql/prompts/masked_v3_raw.py @@ -1,4 +1,4 @@ -"""Raw masked SQL generation prompt template version 3.""" +"""Raw masked SQL generation prompts template version 3.""" MASKED_GEN_SQL_RAW_PROMPT_V3 = """ You are a SQL generation assistant. Given diff --git a/src/pipe/sql_gen_prompts/masked_v4.py b/src/pipeline/gen_sql/prompts/masked_v4.py similarity index 98% rename from src/pipe/sql_gen_prompts/masked_v4.py rename to src/pipeline/gen_sql/prompts/masked_v4.py index 1e1990b..1ee9c52 100644 --- a/src/pipe/sql_gen_prompts/masked_v4.py +++ b/src/pipeline/gen_sql/prompts/masked_v4.py @@ -1,4 +1,4 @@ -"""Masked SQL generation prompt template version 4.""" +"""Masked SQL generation prompts template version 4.""" MASKED_GEN_SQL_PROMPT_V4 = """ You are a SQL generation assistant. Given diff --git a/src/pipe/sql_gen_prompts/unmasked_v1.py b/src/pipeline/gen_sql/prompts/unmasked_v1.py similarity index 96% rename from src/pipe/sql_gen_prompts/unmasked_v1.py rename to src/pipeline/gen_sql/prompts/unmasked_v1.py index 7019def..3159f96 100644 --- a/src/pipe/sql_gen_prompts/unmasked_v1.py +++ b/src/pipeline/gen_sql/prompts/unmasked_v1.py @@ -1,4 +1,4 @@ -"""Unmasked SQL generation prompt template version 1.""" +"""Unmasked SQL generation prompts template version 1.""" GEN_UNMASKED_SQL_PROMPT_V1 = """ I'll give you a natural language question and the schema of the underlying database. diff --git a/src/pipe/util_processors.py b/src/pipeline/init_data.py similarity index 85% rename from src/pipe/util_processors.py rename to src/pipeline/init_data.py index 8ca683c..7776c26 100644 --- a/src/pipe/util_processors.py +++ b/src/pipeline/init_data.py @@ -1,7 +1,7 @@ """Utility processors for MaskSQL pipeline.""" -from src.models.masksql_input import MaskSqlInput -from src.pipe.processor.list_processor import JsonListProcessor +from src.data_models.masksql_input import MaskSqlInput +from src.pipeline.base_processor.list_processor import JsonListProcessor class InitData(JsonListProcessor[MaskSqlInput, "InitData.Model"]): diff --git a/src/pipe/schema_link_prompts/__init__.py b/src/pipeline/link_schema/__init__.py similarity index 100% rename from src/pipe/schema_link_prompts/__init__.py rename to src/pipeline/link_schema/__init__.py diff --git a/src/pipe/link_schema.py b/src/pipeline/link_schema/link_schema.py similarity index 90% rename from src/pipe/link_schema.py rename to src/pipeline/link_schema/link_schema.py index 56af553..f2b9803 100644 --- a/src/pipe/link_schema.py +++ b/src/pipeline/link_schema/link_schema.py @@ -1,10 +1,10 @@ """Schema linking from questions to database schemas.""" from src.config import OpenAIConfig -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.pipe.schema_link_prompts.v4 import SCHEMA_LINK_PROMPT_V4 -from src.pipe.value_links import FilterValueLinksModel +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.link_schema.prompts.v4 import SCHEMA_LINK_PROMPT_V4 +from src.pipeline.link_values.link_values import FilterValueLinksModel +from src.utils.llm_util import extract_object from src.utils.logging import logger diff --git a/src/pipeline/link_schema/prompts/__init__.py b/src/pipeline/link_schema/prompts/__init__.py new file mode 100644 index 0000000..3d2d814 --- /dev/null +++ b/src/pipeline/link_schema/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the link_schema module. + +This package provides prompt templates used for linking schema +elements in database queries. +""" diff --git a/src/pipe/schema_link_prompts/v1.py b/src/pipeline/link_schema/prompts/v1.py similarity index 94% rename from src/pipe/schema_link_prompts/v1.py rename to src/pipeline/link_schema/prompts/v1.py index ab4a613..aa49d85 100644 --- a/src/pipe/schema_link_prompts/v1.py +++ b/src/pipeline/link_schema/prompts/v1.py @@ -1,4 +1,4 @@ -"""Schema linking prompt template version 1.""" +"""Schema linking prompts template version 1.""" SCHEMA_LINK_PROMPT_V1 = """ Consider the following question: diff --git a/src/pipe/schema_link_prompts/v2.py b/src/pipeline/link_schema/prompts/v2.py similarity index 92% rename from src/pipe/schema_link_prompts/v2.py rename to src/pipeline/link_schema/prompts/v2.py index 4adcf6d..a4c2490 100644 --- a/src/pipe/schema_link_prompts/v2.py +++ b/src/pipeline/link_schema/prompts/v2.py @@ -1,4 +1,4 @@ -"""Schema linking prompt template version 2.""" +"""Schema linking prompts template version 2.""" SCHEMA_LINK_PROMPT_V2 = """ Consider the following question: diff --git a/src/pipe/schema_link_prompts/v3.py b/src/pipeline/link_schema/prompts/v3.py similarity index 95% rename from src/pipe/schema_link_prompts/v3.py rename to src/pipeline/link_schema/prompts/v3.py index 8ea6c0b..7e41f94 100644 --- a/src/pipe/schema_link_prompts/v3.py +++ b/src/pipeline/link_schema/prompts/v3.py @@ -1,4 +1,4 @@ -"""Schema linking prompt template version 3.""" +"""Schema linking prompts template version 3.""" SCHEMA_LINK_PROMPT_V3 = """ You are given a natural language question and a list of schema items diff --git a/src/pipe/schema_link_prompts/v4.py b/src/pipeline/link_schema/prompts/v4.py similarity index 98% rename from src/pipe/schema_link_prompts/v4.py rename to src/pipeline/link_schema/prompts/v4.py index a1bca3b..d3b2db3 100644 --- a/src/pipe/schema_link_prompts/v4.py +++ b/src/pipeline/link_schema/prompts/v4.py @@ -1,4 +1,4 @@ -"""Schema linking prompt template version 4.""" +"""Schema linking prompts template version 4.""" SCHEMA_LINK_PROMPT_V4 = """ You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) diff --git a/src/pipe/schema_link_prompts/v5.py b/src/pipeline/link_schema/prompts/v5.py similarity index 97% rename from src/pipe/schema_link_prompts/v5.py rename to src/pipeline/link_schema/prompts/v5.py index 81d4c30..2e684b8 100644 --- a/src/pipe/schema_link_prompts/v5.py +++ b/src/pipeline/link_schema/prompts/v5.py @@ -1,4 +1,4 @@ -"""Schema linking prompt template version 5.""" +"""Schema linking prompts template version 5.""" SCHEMA_LINK_PROMPT_V4 = """ You are an assistant that links n-grams (sub-sequences of up to 3 consecutive words) diff --git a/src/pipe/value_linking_prompts/__init__.py b/src/pipeline/link_values/__init__.py similarity index 100% rename from src/pipe/value_linking_prompts/__init__.py rename to src/pipeline/link_values/__init__.py diff --git a/src/pipe/value_links.py b/src/pipeline/link_values/link_values.py similarity index 76% rename from src/pipe/value_links.py rename to src/pipeline/link_values/link_values.py index b1b0a41..2e40e77 100644 --- a/src/pipe/value_links.py +++ b/src/pipeline/link_values/link_values.py @@ -1,10 +1,10 @@ """Value link data structures and utilities.""" from src.config import OpenAIConfig -from src.pipe.detect_entities import DetectValues -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor, T -from src.pipe.llm_util import extract_object -from src.pipe.value_linking_prompts.v1 import VALUE_LINKING_PROMPT_V1 +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.detect_values.detect_values import DetectValues +from src.pipeline.link_values.prompts.v1 import VALUE_LINKING_PROMPT_V1 +from src.utils.llm_util import extract_object class LinkValues(PromptProcessor[DetectValues.Model, "LinkValues.Model"]): @@ -27,10 +27,12 @@ class Model(DetectValues.Model): def __init__(self, openai_config: OpenAIConfig, model: str) -> None: super().__init__(self.Model, openai_config, model) - def _get_result_data(self, row: T, llm_processed_output: dict[str, str]) -> Model: + def _get_result_data( + self, row: DetectValues.Model, llm_processed_output: dict[str, str] + ) -> Model: return self.Model(value_links=llm_processed_output, **row.dict()) - def _process_output(self, row: T, output: str) -> dict[str, str]: + def _process_output(self, row: DetectValues.Model, output: str) -> dict[str, str]: obj = extract_object(output) if obj is None: return {} diff --git a/src/pipeline/link_values/prompts/__init__.py b/src/pipeline/link_values/prompts/__init__.py new file mode 100644 index 0000000..3fa83e2 --- /dev/null +++ b/src/pipeline/link_values/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the link_values module. + +This package provides prompt templates used for linking values +in database queries. +""" diff --git a/src/pipe/value_linking_prompts/v1.py b/src/pipeline/link_values/prompts/v1.py similarity index 98% rename from src/pipe/value_linking_prompts/v1.py rename to src/pipeline/link_values/prompts/v1.py index 6f58913..a609d96 100644 --- a/src/pipe/value_linking_prompts/v1.py +++ b/src/pipeline/link_values/prompts/v1.py @@ -1,4 +1,4 @@ -"""Value linking prompt template version 1.""" +"""Value linking prompts template version 1.""" VALUE_LINKING_PROMPT_V1 = """ You are given: diff --git a/src/pipe/pipeline.py b/src/pipeline/pipeline.py similarity index 92% rename from src/pipe/pipeline.py rename to src/pipeline/pipeline.py index 6b6af65..0bca29c 100644 --- a/src/pipe/pipeline.py +++ b/src/pipeline/pipeline.py @@ -3,16 +3,16 @@ import os from typing import Any, Generic, TypeVar -from src.models.base_object import BaseObject -from src.pipe.monitor.lib import Timer -from src.pipe.monitor.mem import track_memory_async -from src.pipe.processor.list_processor import JsonListProcessor +from src.data_models.base_object import BaseObject +from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.logging import ( log_pipeline_summary, log_stage_complete, log_stage_start, reset_stage_timings, ) +from src.utils.mem import track_memory_async +from src.utils.timer import Timer T = TypeVar("T", bound=BaseObject) diff --git a/src/pipe/rank_schema.py b/src/pipeline/rank_schema.py similarity index 90% rename from src/pipe/rank_schema.py rename to src/pipeline/rank_schema.py index ced9e7c..47ad593 100644 --- a/src/pipe/rank_schema.py +++ b/src/pipeline/rank_schema.py @@ -1,8 +1,8 @@ """Schema ranking utilities.""" -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.resdsql import AddResd -from src.pipe.schema_repo import DatabaseSchemaRepo +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.resd.add_resd import AddResd +from src.utils.schema_repo import DatabaseSchemaRepo class RankSchemaResd(JsonListProcessor[AddResd.Model, "RankSchemaResd.Model"]): diff --git a/src/pipe/rank_schema_llm.py b/src/pipeline/rank_schema_llm.py similarity index 85% rename from src/pipe/rank_schema_llm.py rename to src/pipeline/rank_schema_llm.py index e74d223..dc326d5 100644 --- a/src/pipe/rank_schema_llm.py +++ b/src/pipeline/rank_schema_llm.py @@ -3,12 +3,12 @@ from typing import Any from src.config import OpenAIConfig -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.llm_util import extract_object -from src.pipe.rank_schema import RankSchemaResd -from src.pipe.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1 -from src.pipe.schema_repo import DatabaseSchemaRepo -from src.pipe.util_processors import InitData +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.init_data import InitData +from src.pipeline.rank_schema import RankSchemaResd +from src.pipeline.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1 +from src.utils.llm_util import extract_object +from src.utils.schema_repo import DatabaseSchemaRepo class RankSchemaItems(PromptProcessor[InitData.Model, RankSchemaResd.Model]): diff --git a/src/pipe/rank_schema_prompts/__init__.py b/src/pipeline/rank_schema_prompts/__init__.py similarity index 100% rename from src/pipe/rank_schema_prompts/__init__.py rename to src/pipeline/rank_schema_prompts/__init__.py diff --git a/src/pipe/rank_schema_prompts/v1.py b/src/pipeline/rank_schema_prompts/v1.py similarity index 95% rename from src/pipe/rank_schema_prompts/v1.py rename to src/pipeline/rank_schema_prompts/v1.py index 9843298..288f4c3 100644 --- a/src/pipe/rank_schema_prompts/v1.py +++ b/src/pipeline/rank_schema_prompts/v1.py @@ -1,4 +1,4 @@ -"""Schema ranking prompt template version 1.""" +"""Schema ranking prompts template version 1.""" RANK_SCHEMA_ITEMS_V1 = """ You are given: diff --git a/src/pipe/sql_repair_prompts/__init__.py b/src/pipeline/repair_sql/__init__.py similarity index 100% rename from src/pipe/sql_repair_prompts/__init__.py rename to src/pipeline/repair_sql/__init__.py diff --git a/src/pipeline/repair_sql/prompts/__init__.py b/src/pipeline/repair_sql/prompts/__init__.py new file mode 100644 index 0000000..83fc7e6 --- /dev/null +++ b/src/pipeline/repair_sql/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the repair_sql module. + +This package provides prompt templates used for repairing and +correcting SQL queries. +""" diff --git a/src/pipe/sql_repair_prompts/v1.py b/src/pipeline/repair_sql/prompts/v1.py similarity index 90% rename from src/pipe/sql_repair_prompts/v1.py rename to src/pipeline/repair_sql/prompts/v1.py index e33ba3a..d670f50 100644 --- a/src/pipe/sql_repair_prompts/v1.py +++ b/src/pipeline/repair_sql/prompts/v1.py @@ -1,4 +1,4 @@ -"""SQL repair prompt template version 1.""" +"""SQL repair prompts template version 1.""" REPAIR_SQL_PROMPT_V1 = """ I'll give you a natural language question, the schema of the underlying database, and a candidate SQL query. diff --git a/src/pipe/sql_repair_prompts/v2.py b/src/pipeline/repair_sql/prompts/v2.py similarity index 95% rename from src/pipe/sql_repair_prompts/v2.py rename to src/pipeline/repair_sql/prompts/v2.py index 36dbfd9..a6fc1ef 100644 --- a/src/pipe/sql_repair_prompts/v2.py +++ b/src/pipeline/repair_sql/prompts/v2.py @@ -1,4 +1,4 @@ -"""SQL repair prompt template version 2.""" +"""SQL repair prompts template version 2.""" REPAIR_SQL_PROMPT_V2 = """ You are a SQL repair assistant. Given a natural language question, a database schema, and a candidate SQL query, diff --git a/src/pipe/sql_repair_prompts/v3.py b/src/pipeline/repair_sql/prompts/v3.py similarity index 99% rename from src/pipe/sql_repair_prompts/v3.py rename to src/pipeline/repair_sql/prompts/v3.py index 72043e1..f1490ff 100644 --- a/src/pipe/sql_repair_prompts/v3.py +++ b/src/pipeline/repair_sql/prompts/v3.py @@ -1,4 +1,4 @@ -"""SQL repair prompt template version 3.""" +"""SQL repair prompts template version 3.""" REPAIR_SQL_PROMPT_V3 = """ You are an SQL database expert tasked with correcting a SQL query. diff --git a/src/pipe/sql_repair_prompts/v4.py b/src/pipeline/repair_sql/prompts/v4.py similarity index 99% rename from src/pipe/sql_repair_prompts/v4.py rename to src/pipeline/repair_sql/prompts/v4.py index e586d0b..f836f27 100644 --- a/src/pipe/sql_repair_prompts/v4.py +++ b/src/pipeline/repair_sql/prompts/v4.py @@ -1,4 +1,4 @@ -"""SQL repair prompt template version 4.""" +"""SQL repair prompts template version 4.""" REPAIR_SQL_PROMPT_V4 = """ You are an SQL database expert tasked with correcting a SQL query that corresponds to a natural language question. diff --git a/src/pipe/sql_repair_prompts/v5.py b/src/pipeline/repair_sql/prompts/v5.py similarity index 98% rename from src/pipe/sql_repair_prompts/v5.py rename to src/pipeline/repair_sql/prompts/v5.py index fc96403..e7e8375 100644 --- a/src/pipe/sql_repair_prompts/v5.py +++ b/src/pipeline/repair_sql/prompts/v5.py @@ -1,4 +1,4 @@ -"""SQL repair prompt template version 5.""" +"""SQL repair prompts template version 5.""" REPAIR_SQL_PROMPT_V5 = """ You are an expert SQL database assistant. diff --git a/src/pipe/repair_sql.py b/src/pipeline/repair_sql/repair_sql.py similarity index 87% rename from src/pipe/repair_sql.py rename to src/pipeline/repair_sql/repair_sql.py index 8ec263a..f0a152f 100644 --- a/src/pipe/repair_sql.py +++ b/src/pipeline/repair_sql/repair_sql.py @@ -3,11 +3,11 @@ from typing import Any from src.config import OpenAIConfig -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.exec_conc_sql import ExecuteConcreteSql -from src.pipe.gen_sql import extract_sql -from src.pipe.sql_repair_prompts.v3 import REPAIR_SQL_PROMPT_V3 +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.exec_conc_sql import ExecuteConcreteSql +from src.pipeline.repair_sql.prompts.v3 import REPAIR_SQL_PROMPT_V3 from src.utils.logging import logger +from src.utils.strings import extract_sql class RepairSQL(PromptProcessor[ExecuteConcreteSql.Model, "RepairSQL.Model"]): diff --git a/src/pipe/symb_sql_repair_prompts/__init__.py b/src/pipeline/repair_symb_sql/__init__.py similarity index 100% rename from src/pipe/symb_sql_repair_prompts/__init__.py rename to src/pipeline/repair_symb_sql/__init__.py diff --git a/src/pipeline/repair_symb_sql/prompts/__init__.py b/src/pipeline/repair_symb_sql/prompts/__init__.py new file mode 100644 index 0000000..4f561cb --- /dev/null +++ b/src/pipeline/repair_symb_sql/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Package containing prompt templates for the repair_symb_sql module. + +This package provides prompt templates used for repairing and +correcting symbolic SQL queries. +""" diff --git a/src/pipe/symb_sql_repair_prompts/v1.py b/src/pipeline/repair_symb_sql/prompts/v1.py similarity index 98% rename from src/pipe/symb_sql_repair_prompts/v1.py rename to src/pipeline/repair_symb_sql/prompts/v1.py index ed4460f..677f569 100644 --- a/src/pipe/symb_sql_repair_prompts/v1.py +++ b/src/pipeline/repair_symb_sql/prompts/v1.py @@ -1,4 +1,4 @@ -"""Symbolic SQL repair prompt template version 1.""" +"""Symbolic SQL repair prompts template version 1.""" REPAIR_SYMBOLIC_SQL_PROMPT_V1 = """ You are an SQL database expert tasked with correcting a SQL query. diff --git a/src/pipe/symb_sql_repair_prompts/v2.py b/src/pipeline/repair_symb_sql/prompts/v2.py similarity index 98% rename from src/pipe/symb_sql_repair_prompts/v2.py rename to src/pipeline/repair_symb_sql/prompts/v2.py index 8f44649..10a864c 100644 --- a/src/pipe/symb_sql_repair_prompts/v2.py +++ b/src/pipeline/repair_symb_sql/prompts/v2.py @@ -1,4 +1,4 @@ -"""Symbolic SQL repair prompt template version 2.""" +"""Symbolic SQL repair prompts template version 2.""" REPAIR_SYMBOLIC_SQL_PROMPT_V2 = """ You are an SQL database expert tasked with debugging a SQL query. diff --git a/src/pipe/symb_sql_repair_prompts/raw_v2.py b/src/pipeline/repair_symb_sql/raw_v2.py similarity index 98% rename from src/pipe/symb_sql_repair_prompts/raw_v2.py rename to src/pipeline/repair_symb_sql/raw_v2.py index 9c51986..b916173 100644 --- a/src/pipe/symb_sql_repair_prompts/raw_v2.py +++ b/src/pipeline/repair_symb_sql/raw_v2.py @@ -1,4 +1,4 @@ -"""Raw symbolic SQL repair prompt template version 2.""" +"""Raw symbolic SQL repair prompts template version 2.""" REPAIR_SYMBOLIC_SQL_RAW_PROMPT_V2 = """ You are an SQL database expert tasked with debugging a SQL query. diff --git a/src/pipe/repair_symb_sql.py b/src/pipeline/repair_symb_sql/repair_symb_sql.py similarity index 63% rename from src/pipe/repair_symb_sql.py rename to src/pipeline/repair_symb_sql/repair_symb_sql.py index ec2dcc8..84c809e 100644 --- a/src/pipe/repair_symb_sql.py +++ b/src/pipeline/repair_symb_sql/repair_symb_sql.py @@ -3,11 +3,10 @@ from typing import Any from src.config import OpenAIConfig -from src.pipe.detect_values_prompts.prompt_processor import PromptProcessor -from src.pipe.gen_masked_sql import GenerateSymbolicSql, SymbolicSql -from src.pipe.gen_sql import extract_sql -from src.pipe.symb_sql_repair_prompts.raw_v2 import REPAIR_SYMBOLIC_SQL_RAW_PROMPT_V2 -from src.pipe.symb_sql_repair_prompts.v2 import REPAIR_SYMBOLIC_SQL_PROMPT_V2 +from src.pipeline.base_processor.prompt_processor import PromptProcessor +from src.pipeline.gen_sql.gen_masked_sql import GenerateSymbolicSql, SymbolicSql +from src.pipeline.repair_symb_sql.prompts.v2 import REPAIR_SYMBOLIC_SQL_PROMPT_V2 +from src.utils.strings import extract_sql class RepairedSymbolicSql(SymbolicSql): @@ -50,18 +49,3 @@ def _get_prompt(self, row: "GenerateSymbolicSql.Model") -> str: return REPAIR_SYMBOLIC_SQL_PROMPT_V2.format( question=symbolic_question, schema=symbolic_schema, sql=symbolic_sql ) - - -class RepairSymbolicSQLRaw(PromptProcessor): - """Repair symbolic SQL from raw inputs without schema.""" - - def _process_output(self, row: dict[str, Any], output: str) -> dict[str, str]: - sql = extract_sql(output) - return {"repaired_sql": sql} - - def _get_prompt(self, row: dict[str, Any]) -> str: - symbolic_raw = row["symbolic"]["raw"] - symbolic_sql = row["symbolic"]["sql"] - return REPAIR_SYMBOLIC_SQL_RAW_PROMPT_V2.format( - symbolic_raw=symbolic_raw, sql=symbolic_sql - ) diff --git a/src/pipeline/resd/__init__.py b/src/pipeline/resd/__init__.py new file mode 100644 index 0000000..f0b5a28 --- /dev/null +++ b/src/pipeline/resd/__init__.py @@ -0,0 +1,5 @@ +"""Package for RESidual Disambiguation (RESD) functionality. + +This package provides components for handling residual +disambiguation in SQL generation. +""" diff --git a/src/pipe/resdsql.py b/src/pipeline/resd/add_resd.py similarity index 83% rename from src/pipe/resdsql.py rename to src/pipeline/resd/add_resd.py index f5876a6..4afa8f1 100644 --- a/src/pipe/resdsql.py +++ b/src/pipeline/resd/add_resd.py @@ -1,8 +1,8 @@ """RESDSQL model integration for SQL generation.""" -from src.models.masksql_input import MaskSqlInput -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.util_processors import InitData +from src.data_models.masksql_input import MaskSqlInput +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.init_data import InitData from src.utils.json_io import read_json_raw @@ -24,7 +24,7 @@ class Model(InitData.Model): def __init__(self, resd_path: str) -> None: super().__init__(self.Model, force=True) self.resd_path = resd_path - self.resd : list[dict] = [] + self.resd: list[dict] = [] def _pre_run(self) -> None: """Load RESDSQL predictions before processing rows.""" diff --git a/src/pipe/run_resdsql.py b/src/pipeline/resd/run_resdsql.py similarity index 97% rename from src/pipe/run_resdsql.py rename to src/pipeline/resd/run_resdsql.py index 4123148..53d81ad 100644 --- a/src/pipe/run_resdsql.py +++ b/src/pipeline/resd/run_resdsql.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import Any -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.util_processors import InitData +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.init_data import InitData from src.utils.json_io import write_json from src.utils.logging import console, log_error, log_success, logger @@ -209,7 +209,7 @@ def _classify_schema_items(self) -> None: console.print( f"[bold cyan] → Schema classification[/bold cyan] [dim]({self.device})[/dim]" ) - model_path = Path("resdsql/models/text2sql_schema_item_classifier").absolute() + model_path = Path("resdsql/data_models/text2sql_schema_item_classifier").absolute() self._run_step( "schema_item_classifier", diff --git a/src/pipe/results.py b/src/pipeline/results.py similarity index 95% rename from src/pipe/results.py rename to src/pipeline/results.py index 8fe61c9..f5a8da5 100644 --- a/src/pipe/results.py +++ b/src/pipeline/results.py @@ -5,8 +5,8 @@ import pandas as pd from rich.table import Table -from src.pipe.attack import AddInferenceAttack -from src.pipe.processor.list_processor import JsonListProcessor +from src.pipeline.attack.add_inference_attack import AddInferenceAttack +from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.logging import console diff --git a/src/pipe/slm_mask_prompts/__init__.py b/src/pipeline/slm_mask_prompts/__init__.py similarity index 100% rename from src/pipe/slm_mask_prompts/__init__.py rename to src/pipeline/slm_mask_prompts/__init__.py diff --git a/src/pipe/slm_mask_prompts/mask_and_schema_link_v1.py b/src/pipeline/slm_mask_prompts/mask_and_schema_link_v1.py similarity index 96% rename from src/pipe/slm_mask_prompts/mask_and_schema_link_v1.py rename to src/pipeline/slm_mask_prompts/mask_and_schema_link_v1.py index cbc3dc8..19ff98f 100644 --- a/src/pipe/slm_mask_prompts/mask_and_schema_link_v1.py +++ b/src/pipeline/slm_mask_prompts/mask_and_schema_link_v1.py @@ -1,4 +1,4 @@ -"""Combined masking and schema linking prompt version 1.""" +"""Combined masking and schema linking prompts version 1.""" SLM_MASK_AND_LINK_PROMPT_V1 = """ You are a database expert. diff --git a/src/pipe/slm_mask_prompts/mask_and_schema_link_v2.py b/src/pipeline/slm_mask_prompts/mask_and_schema_link_v2.py similarity index 97% rename from src/pipe/slm_mask_prompts/mask_and_schema_link_v2.py rename to src/pipeline/slm_mask_prompts/mask_and_schema_link_v2.py index d6ab970..081eb7d 100644 --- a/src/pipe/slm_mask_prompts/mask_and_schema_link_v2.py +++ b/src/pipeline/slm_mask_prompts/mask_and_schema_link_v2.py @@ -1,4 +1,4 @@ -"""Combined masking and schema linking prompt version 2.""" +"""Combined masking and schema linking prompts version 2.""" SLM_MASK_AND_LINK_PROMPT_V2 = """ You are a database expert. diff --git a/src/pipe/slm_mask_prompts/mask_v1.py b/src/pipeline/slm_mask_prompts/mask_v1.py similarity index 96% rename from src/pipe/slm_mask_prompts/mask_v1.py rename to src/pipeline/slm_mask_prompts/mask_v1.py index b822b5a..7044087 100644 --- a/src/pipe/slm_mask_prompts/mask_v1.py +++ b/src/pipeline/slm_mask_prompts/mask_v1.py @@ -1,4 +1,4 @@ -"""Small language model mask prompt version 1.""" +"""Small language model mask prompts version 1.""" SLM_MASK_PROMPT_V1 = """ You are a database expert. You are given: diff --git a/src/pipe/slm_mask_prompts/unmask_and_repair_v1.py b/src/pipeline/slm_mask_prompts/unmask_and_repair_v1.py similarity index 98% rename from src/pipe/slm_mask_prompts/unmask_and_repair_v1.py rename to src/pipeline/slm_mask_prompts/unmask_and_repair_v1.py index 2362ddb..c9abc1f 100644 --- a/src/pipe/slm_mask_prompts/unmask_and_repair_v1.py +++ b/src/pipeline/slm_mask_prompts/unmask_and_repair_v1.py @@ -1,4 +1,4 @@ -"""Combined unmasking and repair prompt version 1.""" +"""Combined unmasking and repair prompts version 1.""" SLM_UNMASK_AND_REPAIR_PROMPT_V1 = """ Your task is to restore a complete and correct SQL query from a masked version, diff --git a/src/pipe/slm_mask_prompts/unmask_v1.py b/src/pipeline/slm_mask_prompts/unmask_v1.py similarity index 97% rename from src/pipe/slm_mask_prompts/unmask_v1.py rename to src/pipeline/slm_mask_prompts/unmask_v1.py index 3d61853..a716544 100644 --- a/src/pipe/slm_mask_prompts/unmask_v1.py +++ b/src/pipeline/slm_mask_prompts/unmask_v1.py @@ -1,4 +1,4 @@ -"""Small language model unmask prompt version 1.""" +"""Small language model unmask prompts version 1.""" SLM_UNMASK_PROMPT_V1 = """ You are a database expert. You are given diff --git a/src/pipe/symb_table.py b/src/pipeline/symb_table.py similarity index 93% rename from src/pipe/symb_table.py rename to src/pipeline/symb_table.py index 01e28e2..c1a24eb 100644 --- a/src/pipe/symb_table.py +++ b/src/pipeline/symb_table.py @@ -2,9 +2,9 @@ from pydantic import BaseModel -from src.pipe.add_schema import AddFilteredSchema -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo +from src.pipeline.add_schema import AddFilteredSchema +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.utils.schema_repo import DatabaseSchema, DatabaseSchemaRepo class SymbolTableDicts(BaseModel): diff --git a/src/pipe/unmask.py b/src/pipeline/unmask.py similarity index 92% rename from src/pipe/unmask.py rename to src/pipeline/unmask.py index 8e30da2..46c9419 100644 --- a/src/pipe/unmask.py +++ b/src/pipeline/unmask.py @@ -2,9 +2,9 @@ import re -from src.pipe.processor.list_processor import JsonListProcessor -from src.pipe.repair_symb_sql import RepairSymbolicSQL -from src.pipe.utils import replace_str_punc +from src.pipeline.base_processor.list_processor import JsonListProcessor +from src.pipeline.repair_symb_sql.repair_symb_sql import RepairSymbolicSQL +from src.utils.strings import replace_str_punc class AddConcreteSql( diff --git a/src/pipeline/util_processors/__init__.py b/src/pipeline/util_processors/__init__.py new file mode 100644 index 0000000..0f527b1 --- /dev/null +++ b/src/pipeline/util_processors/__init__.py @@ -0,0 +1,5 @@ +"""Package containing utility processors for the pipeline. + +This package provides various utility processors and +transformers used across the pipeline. +""" diff --git a/src/pipe/copy_transformer.py b/src/pipeline/util_processors/copy_transformer.py similarity index 89% rename from src/pipe/copy_transformer.py rename to src/pipeline/util_processors/copy_transformer.py index 763cf51..d6c40a6 100644 --- a/src/pipe/copy_transformer.py +++ b/src/pipeline/util_processors/copy_transformer.py @@ -2,7 +2,7 @@ from typing import Type -from src.pipe.processor.list_processor import JsonListProcessor, T, U +from src.pipeline.base_processor.list_processor import JsonListProcessor, T, U class CopyTransformer(JsonListProcessor[T, U]): diff --git a/src/pipe/async_utils.py b/src/utils/async_utils.py similarity index 100% rename from src/pipe/async_utils.py rename to src/utils/async_utils.py diff --git a/src/pipe/llm_util.py b/src/utils/llm_util.py similarity index 91% rename from src/pipe/llm_util.py rename to src/utils/llm_util.py index 07d1262..b6aa868 100644 --- a/src/pipe/llm_util.py +++ b/src/utils/llm_util.py @@ -1,4 +1,4 @@ -"""Utilities for working with language models.""" +"""Utilities for working with language data_models.""" import ast import json @@ -26,20 +26,20 @@ def wrap_prompt(prompt: str) -> str: """ - Wrap prompt with architecture-specific formatting. + Wrap prompts with architecture-specific formatting. Parameters ---------- prompt : str - Raw prompt text + Raw prompts text Returns ------- str - Formatted prompt for specific model architecture + Formatted prompts for specific model architecture """ if VLM_ARCH in wrappers: - print("Wrapping prompt for", VLM_ARCH) + print("Wrapping prompts for", VLM_ARCH) return wrappers[VLM_ARCH](prompt) return prompt @@ -48,7 +48,7 @@ async def send_prompt( prompt: str, openai_config: OpenAIConfig, model: str ) -> tuple[str, str]: """ - Send prompt to language model and get response. + Send prompts to language model and get response. Parameters ---------- @@ -90,7 +90,7 @@ async def send_prompt( ) if response.choices is None: print(prompt) - raise Exception(f"LM prompt failed: {response.model_extra}") + raise Exception(f"LM prompts failed: {response.model_extra}") usage = "0" if response.usage: usage = str(response.usage.total_tokens) diff --git a/src/pipe/monitor/mem.py b/src/utils/mem.py similarity index 100% rename from src/pipe/monitor/mem.py rename to src/utils/mem.py diff --git a/src/pipe/schema_repo.py b/src/utils/schema_repo.py similarity index 100% rename from src/pipe/schema_repo.py rename to src/utils/schema_repo.py diff --git a/src/pipe/sqlite_facade.py b/src/utils/sqlite_facade.py similarity index 100% rename from src/pipe/sqlite_facade.py rename to src/utils/sqlite_facade.py diff --git a/src/utils/strings.py b/src/utils/strings.py index b15f434..4d1637b 100644 --- a/src/utils/strings.py +++ b/src/utils/strings.py @@ -4,6 +4,8 @@ from difflib import SequenceMatcher from enum import Enum +from src.utils.logging import logger + def delete_whitespace(content: str) -> str: """ @@ -210,3 +212,154 @@ def get_colored_diff(a: str, b: str) -> str: if tag == "replace": result += colored(b[j1:j2], Color.BLUE) return result + + +def replace_str(text: str, src: str, dst: str) -> str: + """ + Replace a substring in text with word boundaries. + + Parameters + ---------- + text : str + The text to search in + src : str + The substring to replace + dst : str + The replacement substring + + Returns + ------- + str + Text with replacements made + """ + try: + result = re.sub( + r"\b{}\b".format(re.escape(src)), dst, text, flags=re.IGNORECASE + ) + except Exception: + logger.error(f"Failed to replace {src} -> {dst} in {text}") + result = text + return result + + +def check_str(text: str, src: str) -> bool: + """ + Check if a substring exists in text with word boundaries. + + Parameters + ---------- + text : str + The text to search in + src : str + The substring to search for + + Returns + ------- + bool + True if substring found with word boundaries, False otherwise + """ + try: + pattern = r"\b{}\b".format(re.escape(src)) + if re.search(pattern, text, flags=re.IGNORECASE): + return True + except Exception: + logger.error(f"Failed to search {src} in {text}") + return False + + +def replace_str_punc(text: str, src: str, dst: str) -> str: + """ + Replace a substring in text with punctuation-aware boundaries. + + Parameters + ---------- + text : str + The text to search in + src : str + The substring to replace + dst : str + The replacement substring + + Returns + ------- + str + Text with replacements made + """ + try: + result = re.sub( + r"(? {dst} in {text}") + result = text + return result + + +def check_str_punc(text: str, src: str) -> bool: + """ + Check if a substring exists in text with punctuation-aware boundaries. + + Parameters + ---------- + text : str + The text to search in + src : str + The substring to search for + + Returns + ------- + bool + True if substring found with punctuation boundaries, False otherwise + """ + try: + pattern = r"(? str: + """ + Extract SQL query from LLM output. + + Parameters + ---------- + output : str + Raw LLM output containing SQL + + Returns + ------- + str + Extracted SQL query + """ + output = output.strip() + output = output.strip('"') + sql = "SELECT" + if output.startswith("SELECT"): + sql = output + elif "```sql" in output: + res = re.findall(r"```sql([\s\S]*?)```", output) + if res: + sql = res[0] + else: + logger.error( + f"Failed to extract sql from output with ```sql marker: {output}" + ) + elif "```" in output: + res = re.findall(r"```([\s\S]*?)```", output) + if res: + sql = res[0] + else: + logger.error(f"Failed to extract sql from output with ``` marker: {output}") + elif "`" in output: + res = re.findall(r"`([\s\S]*?)`", output) + if res: + sql = res[0] + else: + logger.error(f"Failed to extract sql from output with ` marker: {output}") + else: + logger.error(f"Failed to extract sql from output: {output}") + sql = sql.strip() + return sql.replace("\n", " ") diff --git a/src/utils/timer.py b/src/utils/timer.py new file mode 100644 index 0000000..b682a78 --- /dev/null +++ b/src/utils/timer.py @@ -0,0 +1,45 @@ +"""Utility module for time measurement and tracking. + +This module provides a simple Timer class for measuring elapsed time in seconds. +""" + +from datetime import datetime + + +class Timer: + """ + Simple timer for measuring elapsed time. + + Attributes + ---------- + start_time : datetime + The time when the timer was created + """ + + start_time: datetime + + def __init__(self) -> None: + self.start_time = datetime.now() + + @staticmethod + def start() -> "Timer": + """ + Create and start a new timer. + + Returns + ------- + Timer + A new timer instance + """ + return Timer() + + def lap(self) -> float: + """ + Get elapsed time since timer started. + + Returns + ------- + float + Elapsed time in seconds + """ + return (datetime.now() - self.start_time).total_seconds() diff --git a/test.json b/test.json new file mode 100644 index 0000000..ff4c037 --- /dev/null +++ b/test.json @@ -0,0 +1,6 @@ +[ + { + "idx": "i1", + "a": 1 + } +] \ No newline at end of file diff --git a/tests/e2e/test_pipeline.py b/tests/e2e/test_pipeline.py index dc7144d..b143e39 100644 --- a/tests/e2e/test_pipeline.py +++ b/tests/e2e/test_pipeline.py @@ -10,9 +10,9 @@ import pytest -from src.models.base_object import BaseObject -from src.pipe.pipeline import Pipeline -from src.pipe.processor.list_processor import JsonListProcessor, T, U +from src.data_models.base_object import BaseObject +from src.pipeline.base_processor.list_processor import JsonListProcessor, T, U +from src.pipeline.pipeline import Pipeline from src.utils.json_io import read_json, write_json_raw @@ -38,7 +38,7 @@ class DataModel(BaseObject): class Plus2(JsonListProcessor[DataModel, DataModel]): """Processor that adds 2 to the 'a' attribute of each DataModel. - This processor is used in the pipeline test to demonstrate + This base_processor is used in the pipeline test to demonstrate sequential processing of data. """ @@ -53,8 +53,8 @@ async def _process_row(self, row: T) -> U: class Times5(JsonListProcessor[DataModel, DataModel]): """Processor that multiplies the 'a' attribute of each DataModel by 5. - This processor is used in the pipeline test to demonstrate - sequential processing of data after the Plus2 processor. + This base_processor is used in the pipeline test to demonstrate + sequential processing of data after the Plus2 base_processor. """ def __init__(self): diff --git a/tests/e2e/test_processor.py b/tests/e2e/test_processor.py index cbc85bb..10c6bb0 100644 --- a/tests/e2e/test_processor.py +++ b/tests/e2e/test_processor.py @@ -1,6 +1,6 @@ """Tests for the JsonListProcessor class functionality. -This module tests the list processor functionality, including caching behavior +This module tests the list base_processor functionality, including caching behavior and proper processing of data items. """ @@ -11,18 +11,18 @@ import pytest -from src.models.base_object import BaseObject -from src.pipe.processor.list_processor import JsonListProcessor, T, U +from src.data_models.base_object import BaseObject +from src.pipeline.base_processor.list_processor import JsonListProcessor, T, U from src.utils.json_io import read_json, write_json_raw class DataModel(BaseObject): - """Simple data model for testing processor functionality. + """Simple data model for testing base_processor functionality. Attributes ---------- a : int - An integer value to be processed by the test processor. + An integer value to be processed by the test base_processor. """ a: int @@ -32,9 +32,9 @@ class DataModel(BaseObject): class PlusPlusProcessor(JsonListProcessor[DataModel, DataModel]): - """Test processor that increments the 'a' attribute of each DataModel. + """Test base_processor that increments the 'a' attribute of each DataModel. - This processor is used to test the JsonListProcessor functionality, + This base_processor is used to test the JsonListProcessor functionality, particularly its caching behavior when processing the same data multiple times. """ diff --git a/uv.lock b/uv.lock index 19291b6..fb04e7e 100644 --- a/uv.lock +++ b/uv.lock @@ -1481,6 +1481,7 @@ dependencies = [ { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, + { name = "urllib3" }, { name = "vcrpy" }, { name = "werkzeug" }, ] @@ -1546,6 +1547,7 @@ requires-dist = [ { name = "torch", specifier = ">=2.0.0" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "transformers", specifier = ">=4.30.0" }, + { name = "urllib3", specifier = "==2.6.3" }, { name = "vcrpy", specifier = ">=7.0.0" }, { name = "werkzeug", specifier = ">=3.1.4" }, ] @@ -3954,11 +3956,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.1" +version = "2.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5e/1d/0f3a93cca1ac5e8287842ed4eebbd0f7a991315089b1a0b01c7788aa7b63/urllib3-2.6.1.tar.gz", hash = "sha256:5379eb6e1aba4088bae84f8242960017ec8d8e3decf30480b3a1abdaa9671a3f", size = 432678, upload-time = "2025-12-08T15:25:26.773Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/56/190ceb8cb10511b730b564fb1e0293fa468363dbad26145c34928a60cb0c/urllib3-2.6.1-py3-none-any.whl", hash = "sha256:e67d06fe947c36a7ca39f4994b08d73922d40e6cca949907be05efa6fd75110b", size = 131138, upload-time = "2025-12-08T15:25:25.51Z" }, + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] [[package]] From a8c580593b7414e6524e9b6a64c8815542a3f131 Mon Sep 17 00:00:00 2001 From: Sepideh Abedini Date: Wed, 7 Jan 2026 18:42:33 -0500 Subject: [PATCH 2/4] deleting test file --- test.json | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 test.json diff --git a/test.json b/test.json deleted file mode 100644 index ff4c037..0000000 --- a/test.json +++ /dev/null @@ -1,6 +0,0 @@ -[ - { - "idx": "i1", - "a": 1 - } -] \ No newline at end of file From ee17ad3968f69cacb29283b8234ec7460b468f7a Mon Sep 17 00:00:00 2001 From: Sepideh Abedini Date: Wed, 7 Jan 2026 18:50:10 -0500 Subject: [PATCH 3/4] Add proper error handling --- main.py | 33 +- pyproject.toml | 1 + src/data_models/masksql_input.py | 4 +- src/masksql.py | 10 +- src/pipeline/add_symbolic_question.py | 5 +- src/pipeline/base_processor/list_processor.py | 9 + .../base_processor/list_transformer.py | 4 - .../base_processor/prompt_processor.py | 8 +- src/pipeline/exec_conc_sql.py | 2 +- src/pipeline/link_schema/link_schema.py | 3 +- src/pipeline/pipeline.py | 19 - src/pipeline/repair_sql/repair_sql.py | 3 +- src/pipeline/resd/run_resdsql.py | 15 +- src/pipeline/results.py | 2 +- src/utils/json_io.py | 6 +- src/utils/llm_util.py | 16 +- src/utils/logging.py | 235 +++--------- src/utils/mem.py | 5 +- src/utils/sqlite_facade.py | 2 +- src/utils/strings.py | 2 +- tests/e2e/test_data/1_input.json | 5 - tests/utils/test_logging.py | 356 ------------------ uv.lock | 24 ++ 23 files changed, 117 insertions(+), 652 deletions(-) delete mode 100644 tests/utils/test_logging.py diff --git a/main.py b/main.py index 8b797d7..7a15dc3 100644 --- a/main.py +++ b/main.py @@ -2,43 +2,16 @@ import argparse import asyncio -import logging import shutil -from pathlib import Path from dotenv import load_dotenv +from src.utils.logging import configure_logging + load_dotenv() from src.masksql import MaskSQL # noqa: E402 -from src.utils.logging import configure_logging # noqa: E402 - - -logger = logging.getLogger(__name__) - - -def clean_cache_directory(cache_dir: str) -> None: - """Clean intermediate files from the data directory. - - Removes files matching the pattern [0-9]*_* but excludes files starting with 1_*. - This is used to clean up intermediate pipeline output files while preserving - the initial input files. - - Parameters - ---------- - cache_dir : str - Path to the cache directory to clean. - """ - cache_path = Path(cache_dir) - - if not cache_path.exists(): - logger.error(f"Data directory does not exist: {cache_dir}") - return - - shutil.rmtree(cache_path) - - logger.info("Cleanup complete") async def main() -> None: @@ -58,7 +31,7 @@ async def main() -> None: mask_sql = MaskSQL.from_config(args.config) if args.clean: - clean_cache_directory(mask_sql.conf.cache_dir) + shutil.rmtree(mask_sql.conf.cache_dir, ignore_errors=True) else: await mask_sql.evaluate() diff --git a/pyproject.toml b/pyproject.toml index 480d940..652f512 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ packages = ["src"] [dependency-groups] dev = [ "codecov>=2.1.13", + "loguru>=0.7.3", "mypy>=1.14.1", "nbqa>=1.9.1", "pip>=25.3", # Pinning version to address vulnerability GHSA-4xh5-x5gv-qwph diff --git a/src/data_models/masksql_input.py b/src/data_models/masksql_input.py index e59b49a..5351913 100644 --- a/src/data_models/masksql_input.py +++ b/src/data_models/masksql_input.py @@ -20,10 +20,10 @@ class MaskSqlInput(BaseObject): db_id: Identifier of the database the question is about question: Natural language question text query: Optional SQL query (may be empty for new inputs) - annotated_links: Dictionary of annotations for the question + gold_schema_links: Dictionary of annotations for the question """ db_id: str question: str query: str - annotated_links: dict[str, Any] + gold_schema_links: dict[str, Any] diff --git a/src/masksql.py b/src/masksql.py index 6821b72..05ecea3 100644 --- a/src/masksql.py +++ b/src/masksql.py @@ -21,10 +21,7 @@ from src.pipeline.exec_conc_sql import ExecuteConcreteSql from src.pipeline.gen_sql.gen_masked_sql import GenerateSymbolicSql from src.pipeline.init_data import InitData -from src.pipeline.link_schema.link_schema import ( - FilterSchemaLinksModel, - LinkSchema, -) +from src.pipeline.link_schema.link_schema import FilterSchemaLinksModel, LinkSchema from src.pipeline.link_values.link_values import FilterValueLinksModel, LinkValues from src.pipeline.pipeline import Pipeline from src.pipeline.rank_schema import RankSchemaResd @@ -74,10 +71,8 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: LimitJson(), InitData(), *rank_schema, - # ResdItemCount(), AddFilteredSchema(conf.tables_path), AddSymbolTable(conf.tables_path), - # SlmSQL("slm_sql", conf.openai, model=conf.slm), DetectValues(conf.openai, model=conf.slm), LinkValues(conf.openai, model=conf.slm), CopyTransformer("value_links", "filtered_value_links", FilterValueLinksModel), @@ -94,7 +89,6 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: RepairSQL(conf.openai, model=conf.slm), CalcExecAcc(conf.db_path, conf.policy), AddInferenceAttack(conf.openai, model=conf.llm), - # # PrintProps(['question', 'symbolic.question', 'attack']) Results(), ] @@ -175,7 +169,7 @@ async def query(self, db_id: str, question: str) -> MaskSqlOutput: db_id=db_id, question=question, query="", - annotated_links={}, + gold_schema_links={}, ) results = await self.pipeline.run([data]) return results[0] diff --git a/src/pipeline/add_symbolic_question.py b/src/pipeline/add_symbolic_question.py index d83759b..3833f53 100644 --- a/src/pipeline/add_symbolic_question.py +++ b/src/pipeline/add_symbolic_question.py @@ -1,15 +1,12 @@ """Deterministic masking of terms in questions.""" -import logging +from loguru import logger from src.pipeline.add_symb_schema import AddSymbolicSchema, SymbolicSchema from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.strings import replace_str -logger = logging.getLogger(__name__) - - class SymbolicQuestion(SymbolicSchema): """ Data model for questions with symbolic representations. diff --git a/src/pipeline/base_processor/list_processor.py b/src/pipeline/base_processor/list_processor.py index bb06477..6136003 100644 --- a/src/pipeline/base_processor/list_processor.py +++ b/src/pipeline/base_processor/list_processor.py @@ -4,9 +4,12 @@ from abc import ABC, abstractmethod from typing import Generic, Type, TypeVar +from loguru import logger + from src.data_cache.json_cache import JsonCache from src.data_models.base_object import BaseObject from src.utils.async_utils import apply_async +from src.utils.logging import along T = TypeVar("T", bound=BaseObject) @@ -57,6 +60,7 @@ def get_cache_file_path(self, cache_dir: str, sequence: int) -> str: """ return os.path.join(cache_dir, f"{sequence}_{self.name}.json") + @logger.catch(message="Failed to process row", reraise=True) async def __process_row_internal(self, row: T) -> U: if self.cache and not self.force and row.idx in self.cache: return self.cache[row.idx] @@ -89,6 +93,7 @@ def _pre_run(self) -> None: # noqa: B027 def _post_run(self) -> None: # noqa: B027 """Override to add post-processing logic after run.""" + @along("Processor completed: {0}") async def run(self, input_data: list[T]) -> list[U]: """ Process input file and return output_data. @@ -111,3 +116,7 @@ async def run(self, input_data: list[T]) -> list[U]: self._post_run() return output_data + + def __repr__(self) -> str: + """Name of the processor.""" + return self.name diff --git a/src/pipeline/base_processor/list_transformer.py b/src/pipeline/base_processor/list_transformer.py index a899a76..6275ccb 100644 --- a/src/pipeline/base_processor/list_transformer.py +++ b/src/pipeline/base_processor/list_transformer.py @@ -1,6 +1,5 @@ """List transformation base classes.""" -import logging import os from abc import ABC @@ -8,9 +7,6 @@ from src.pipeline.base_processor.list_processor import JsonListProcessor -logger = logging.getLogger(__name__) - - FORCE = int(os.environ.get("FORCE", "0")) > 0 diff --git a/src/pipeline/base_processor/prompt_processor.py b/src/pipeline/base_processor/prompt_processor.py index ee5c9e6..f28bdf8 100644 --- a/src/pipeline/base_processor/prompt_processor.py +++ b/src/pipeline/base_processor/prompt_processor.py @@ -1,14 +1,16 @@ """Base base_processor for LLM-based value detection.""" +import uuid from abc import ABC, abstractmethod from json import JSONDecodeError from typing import Any, Generic, Type, TypeVar +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.init_data import InitData from src.utils.llm_util import send_prompt -from src.utils.logging import logger from src.utils.timer import Timer @@ -48,10 +50,14 @@ def __init__( self.openai_config = openai_config self.model = model self.include_stats = include_stats + self.prompt_logger = logger.bind(type="prompt", name=self.name) async def _prompt_llm(self, row: T, prompt: str) -> tuple[Any, str]: + prompt_logger = self.prompt_logger.bind(prompt_id=uuid.uuid4()) try: + prompt_logger.bind(is_req=True).debug(prompt) res, toks = await send_prompt(prompt, self.openai_config, model=self.model) + prompt_logger.bind(is_req=False).debug(res) except JSONDecodeError as e: logger.error(f"Sending prompts failed: {e}") return "", "0" diff --git a/src/pipeline/exec_conc_sql.py b/src/pipeline/exec_conc_sql.py index 634a877..daf2538 100644 --- a/src/pipeline/exec_conc_sql.py +++ b/src/pipeline/exec_conc_sql.py @@ -2,11 +2,11 @@ from typing import Any, Optional +from loguru import logger from pydantic import BaseModel from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.unmask import AddConcreteSql -from src.utils.logging import logger from src.utils.sqlite_facade import SqliteFacade diff --git a/src/pipeline/link_schema/link_schema.py b/src/pipeline/link_schema/link_schema.py index f2b9803..1919937 100644 --- a/src/pipeline/link_schema/link_schema.py +++ b/src/pipeline/link_schema/link_schema.py @@ -1,11 +1,12 @@ """Schema linking from questions to database schemas.""" +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.prompt_processor import PromptProcessor from src.pipeline.link_schema.prompts.v4 import SCHEMA_LINK_PROMPT_V4 from src.pipeline.link_values.link_values import FilterValueLinksModel from src.utils.llm_util import extract_object -from src.utils.logging import logger class LinkSchema(PromptProcessor[FilterValueLinksModel, "LinkSchema.Model"]): diff --git a/src/pipeline/pipeline.py b/src/pipeline/pipeline.py index 0bca29c..2b5afd4 100644 --- a/src/pipeline/pipeline.py +++ b/src/pipeline/pipeline.py @@ -7,9 +7,6 @@ from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.logging import ( log_pipeline_summary, - log_stage_complete, - log_stage_start, - reset_stage_timings, ) from src.utils.mem import track_memory_async from src.utils.timer import Timer @@ -40,22 +37,9 @@ def __init__( stage.set_cache_file(pipeline_cache_dir, i + 1) async def __run_internal(self, input_data: list[T]) -> list[Any]: - # tmp_file: Any = input_file tmp_data = input_data - timer: Timer = Timer() - timer.start() - last_lap_time = 0.0 - for stage in self.stages: - log_stage_start(stage.name) tmp_data = await stage.run(tmp_data) - - # Get cumulative time and calculate stage time - cumulative_time = timer.lap() - stage_time = cumulative_time - last_lap_time - last_lap_time = cumulative_time - - log_stage_complete(stage.name, stage_time) return tmp_data async def run(self, input_data: list[T]) -> list[Any]: @@ -72,9 +56,6 @@ async def run(self, input_data: list[T]) -> list[Any]: tuple[Any, float, float] Tuple of (result, average_memory_mb, peak_memory_mb) """ - # Reset timing tracker for new pipeline run - reset_stage_timings() - timer: Timer = Timer.start() result, avg_mem, peak_mem = await track_memory_async( self.__run_internal, input_data diff --git a/src/pipeline/repair_sql/repair_sql.py b/src/pipeline/repair_sql/repair_sql.py index f0a152f..19fa9e4 100644 --- a/src/pipeline/repair_sql/repair_sql.py +++ b/src/pipeline/repair_sql/repair_sql.py @@ -2,11 +2,12 @@ from typing import Any +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.prompt_processor import PromptProcessor from src.pipeline.exec_conc_sql import ExecuteConcreteSql from src.pipeline.repair_sql.prompts.v3 import REPAIR_SQL_PROMPT_V3 -from src.utils.logging import logger from src.utils.strings import extract_sql diff --git a/src/pipeline/resd/run_resdsql.py b/src/pipeline/resd/run_resdsql.py index 53d81ad..b923043 100644 --- a/src/pipeline/resd/run_resdsql.py +++ b/src/pipeline/resd/run_resdsql.py @@ -6,10 +6,12 @@ from pathlib import Path from typing import Any +from loguru import logger + from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.init_data import InitData from src.utils.json_io import write_json -from src.utils.logging import console, log_error, log_success, logger +from src.utils.logging import console class RunResdsql(JsonListProcessor[InitData.Model, InitData.Model]): @@ -85,7 +87,7 @@ async def run(self, input_data: list[InitData.Model]) -> list[InitData.Model]: # Skip if RESDSQL output already exists and force is not enabled if not self.force and self.resd_output_path.exists(): logger.info( - f"[dim]⏭ Skipping RESDSQL pipeline - output exists:[/dim] " + f"Skipping RESDSQL pipeline - output exists:" f"{self.resd_output_path.name}" ) return input_data @@ -107,13 +109,10 @@ async def run(self, input_data: list[InitData.Model]) -> list[InitData.Model]: self._generate_text2sql_data() self._add_question_ids() - log_success( - "RESDSQL pipeline completed", - output_file=str(self.resd_output_path), - ) + logger.info(f"RESDSQL pipeline completed: {self.resd_output_path}") except Exception as e: - log_error(f"RESDSQL pipeline failed: {e}") + logger.error(f"RESDSQL pipeline failed: {e}") raise return input_data @@ -147,7 +146,7 @@ def _run_step(self, step_name: str, script: str, args: list[str]) -> None: ) if result.returncode != 0: - log_error( + logger.error( f"RESDSQL step failed: {step_name}", script=script, exit_code=result.returncode, diff --git a/src/pipeline/results.py b/src/pipeline/results.py index f5a8da5..afde9bb 100644 --- a/src/pipeline/results.py +++ b/src/pipeline/results.py @@ -45,7 +45,7 @@ async def _process_row( # if "attack" in row and "annotated_links" in row: masked_terms = row.symbolic.masked_terms attack = row.attack - a_links = row.annotated_links + a_links = row.gold_schema_links ri_terms = 0 num_masks = len(masked_terms) diff --git a/src/utils/json_io.py b/src/utils/json_io.py index cc9c7d3..76e6abe 100644 --- a/src/utils/json_io.py +++ b/src/utils/json_io.py @@ -3,7 +3,8 @@ import json from typing import Any, Type, TypeVar -from pydantic import BaseModel +from loguru import logger +from pydantic import BaseModel, ValidationError T = TypeVar("T", bound=BaseModel) @@ -26,6 +27,9 @@ def read_json_raw(path: str) -> Any: return json.load(f) +@logger.catch( + message="Failed to validate data", reraise=True, exception=ValidationError +) def read_json(path: str, cls: Type[T]) -> list[T]: """ Read and parse a JSON file. diff --git a/src/utils/llm_util.py b/src/utils/llm_util.py index b6aa868..92ca004 100644 --- a/src/utils/llm_util.py +++ b/src/utils/llm_util.py @@ -2,18 +2,16 @@ import ast import json -import logging import os import re from typing import Any +from loguru import logger from openai import AsyncClient from src.config import OpenAIConfig -logger = logging.getLogger(__name__) - VLM_ARCH = os.environ.get("VLM_ARCH") MAX_COMPLETION_TOKENS = os.environ.get("MAX_COMPLETION_TOKENS") @@ -72,12 +70,6 @@ async def send_prompt( timeout=openai_config.timeout, ) - # Concise logging with rich markup - logger.debug( - f"[cyan]LLM Request[/cyan] → [bold]{model}[/bold] " - f"([dim]{len(prompt)} chars[/dim])" - ) - response = await client.chat.completions.create( model=model, messages=[ @@ -89,18 +81,12 @@ async def send_prompt( max_completion_tokens=openai_config.max_completion_tokens, ) if response.choices is None: - print(prompt) raise Exception(f"LM prompts failed: {response.model_extra}") usage = "0" if response.usage: usage = str(response.usage.total_tokens) content = response.choices[0].message.content or "" - logger.debug( - f"[green]LLM Response[/green] ← [bold]{usage}[/bold] tokens " - f"([dim]{len(content)} chars[/dim])" - ) - return content, usage diff --git a/src/utils/logging.py b/src/utils/logging.py index 3fc7dbf..17e8d06 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,11 +1,12 @@ """Logging configuration utilities using rich library.""" -import logging import os -from typing import Any +import sys +from functools import wraps +from typing import Any, Awaitable, Callable, ParamSpec, TypeVar +from loguru import logger from rich.console import Console -from rich.logging import RichHandler from rich.panel import Panel from rich.table import Table from rich.text import Text @@ -25,52 +26,7 @@ install_rich_traceback(console=console, show_locals=False, width=100, word_wrap=True) -def configure_logging() -> None: - """Configure Python logging with rich formatting and custom handlers. - - This function sets up logging with: - - Rich colored console output - - Concise timestamp format - - Different colors for different log levels - - Enhanced traceback formatting for exceptions - """ - # Remove existing handlers to avoid duplicates - root_logger = logging.getLogger() - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - - # Create rich handler with custom formatting - rich_handler = RichHandler( - console=console, - show_time=True, - show_path=False, - show_level=True, - rich_tracebacks=True, - tracebacks_show_locals=False, - markup=True, - log_time_format="[%H:%M:%S]", - omit_repeated_times=False, - ) - - # Minimal formatter for cleaner output - rich_handler.setFormatter( - logging.Formatter( - fmt="%(message)s", - datefmt="[%X]", - ) - ) - - # Configure root logger - root_logger.addHandler(rich_handler) - root_logger.setLevel(LOG_LEVEL) - - # Silence verbose HTTP request logs from third-party libraries - logging.getLogger("httpx").setLevel(logging.WARNING) - logging.getLogger("openai").setLevel(logging.WARNING) - logging.getLogger("urllib3").setLevel(logging.WARNING) - - -def log_error(message: str, **kwargs: Any) -> None: +def log_panel(message: str, **kwargs: Any) -> None: """Log an error message in a styled box. Parameters @@ -101,143 +57,29 @@ def log_error(message: str, **kwargs: Any) -> None: console.print(panel) -def log_warning(message: str, **kwargs: Any) -> None: - """Log a warning message in a styled box. - - Parameters - ---------- - message : str - The warning message to display - **kwargs : Any - Additional context information to include in the warning box - """ - # Build warning content - warning_text = Text() - warning_text.append(message, style="bold yellow") - - if kwargs: - warning_text.append("\n\n", style="") - warning_text.append("Context:\n", style="bold cyan") - for key, value in kwargs.items(): - warning_text.append(f" {key}: ", style="cyan") - warning_text.append(f"{value}\n", style="white") - - # Display in a yellow panel - panel = Panel( - warning_text, - title="[bold yellow]WARNING", - border_style="yellow", - padding=(1, 2), - ) - console.print(panel) - - -def log_success(message: str, **kwargs: Any) -> None: - """Log a success message in a styled box. - - Parameters - ---------- - message : str - The success message to display - **kwargs : Any - Additional context information to include in the success box - """ - # Build success content - success_text = Text() - success_text.append(message, style="bold green") - - if kwargs: - success_text.append("\n\n", style="") - success_text.append("Details:\n", style="bold cyan") - for key, value in kwargs.items(): - success_text.append(f" {key}: ", style="cyan") - success_text.append(f"{value}\n", style="white") - - # Display in a green panel - panel = Panel( - success_text, - title="[bold green]SUCCESS", - border_style="green", - padding=(1, 2), - ) - console.print(panel) - - -def log_info(message: str, **kwargs: Any) -> None: - """Log an info message in a styled box. - - Parameters - ---------- - message : str - The info message to display - **kwargs : Any - Additional context information to include in the info box - """ - # Build info content - info_text = Text() - info_text.append(message, style="bold blue") - - if kwargs: - info_text.append("\n\n", style="") - info_text.append("Details:\n", style="bold cyan") - for key, value in kwargs.items(): - info_text.append(f" {key}: ", style="cyan") - info_text.append(f"{value}\n", style="white") - - # Display in a blue panel - panel = Panel( - info_text, - title="[bold blue]INFO", - border_style="blue", - padding=(1, 2), +def configure_logging() -> None: + """Configure Python logging with rich formatting and custom handlers.""" + logger.remove() + logger.add( + sys.stdout, + level=LOG_LEVEL, + colorize=True, + backtrace=False, + catch=False, + format="[{level:>7}]: {message}", ) - console.print(panel) - - -# Create a logger instance that can be imported -logger = logging.getLogger("masksql") - - -def log_stage_start(stage_name: str) -> None: - """Log the start of a pipeline stage. - - Parameters - ---------- - stage_name : str - Name of the pipeline stage starting - """ - console.print( - f"\n[bold cyan]▶ Starting Stage:[/bold cyan] [bold white]{stage_name}[/bold white]" + logger.add( + "logs/debug-{time:MMMD-HH-mm}.log", + level="DEBUG", + format="[{time:HH:mm:ss}]-[{level:<7}]-[{name:>20} | {function:<25}:{line:<3}]: {message}", ) - -def log_stage_complete(stage_name: str, elapsed_time: float) -> None: - """Log the completion of a pipeline stage with timing. - - Parameters - ---------- - stage_name : str - Name of the pipeline stage completed - elapsed_time : float - Time taken to complete the stage in seconds - """ - # Store timing for summary - _stage_timings.append((stage_name, elapsed_time)) - - # Format time with appropriate precision and color - if elapsed_time < 1.0: - time_str = f"{elapsed_time:.3f}s" - time_color = "green" - elif elapsed_time < 10.0: - time_str = f"{elapsed_time:.2f}s" - time_color = "yellow" - else: - time_str = f"{elapsed_time:.2f}s" - time_color = "red" - - console.print( - f"[bold green]✓ Done Stage:[/bold green] [bold white]{stage_name}[/bold white] " - f"[dim]│[/dim] [{time_color}]{time_str}[/{time_color}]" + logger.add( + "logs/prompts-{time:MMMD-HH-mm}.jsonl", + level="DEBUG", + filter=lambda record: record["extra"].get("type") == "prompt", + format="{message}", + serialize=True, ) @@ -339,10 +181,25 @@ def log_pipeline_summary( console.print("\n") -def reset_stage_timings() -> None: - """Reset the stage timings tracker. +P = ParamSpec("P") +R = TypeVar("R") +AR = Awaitable[R] - This should be called at the start of a pipeline run. - """ - global _stage_timings # noqa: PLW0603 - _stage_timings = [] + +def along( + message: str = "", before: str | None = None +) -> Callable[[Callable[P, AR]], Callable[P, AR]]: + """Log messages before and after an async function execution.""" + + def decorator(func: Callable[P, AR]) -> Callable[P, AR]: + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> AR: + if before is not None: + logger.info(before.format(*args, **kwargs)) + result = await func(*args, **kwargs) + logger.info(message.format(*args, **kwargs)) + return result + + return wrapper + + return decorator diff --git a/src/utils/mem.py b/src/utils/mem.py index 39c0241..e10585b 100644 --- a/src/utils/mem.py +++ b/src/utils/mem.py @@ -1,6 +1,5 @@ """Memory usage monitoring utilities.""" -import logging import os import threading import time @@ -8,9 +7,7 @@ from typing import Any import psutil - - -logger = logging.getLogger(__name__) +from loguru import logger def _monitor_memory( diff --git a/src/utils/sqlite_facade.py b/src/utils/sqlite_facade.py index ce947e6..36896d1 100644 --- a/src/utils/sqlite_facade.py +++ b/src/utils/sqlite_facade.py @@ -8,7 +8,7 @@ from sqlite3 import Connection from typing import Any -from src.utils.logging import logger +from loguru import logger DB_TIMEOUT = 10000 diff --git a/src/utils/strings.py b/src/utils/strings.py index 4d1637b..7adb4f7 100644 --- a/src/utils/strings.py +++ b/src/utils/strings.py @@ -4,7 +4,7 @@ from difflib import SequenceMatcher from enum import Enum -from src.utils.logging import logger +from loguru import logger def delete_whitespace(content: str) -> str: diff --git a/tests/e2e/test_data/1_input.json b/tests/e2e/test_data/1_input.json index 9d01a16..0cd5a28 100644 --- a/tests/e2e/test_data/1_input.json +++ b/tests/e2e/test_data/1_input.json @@ -14,11 +14,6 @@ "pets": "TABLE:pets", "weight": "COLUMN:pets.weight" }, - "annotated_links": { - "10": "VALUE:pets.weight", - "pets": "TABLE:pets", - "weight": "COLUMN:pets.weight" - }, "idx": "spider_45" } ] diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py deleted file mode 100644 index c5114fe..0000000 --- a/tests/utils/test_logging.py +++ /dev/null @@ -1,356 +0,0 @@ -"""Tests for logging configuration utilities.""" - -from unittest.mock import MagicMock, call, patch - -from rich.console import Console -from rich.panel import Panel -from rich.table import Table - -from src.utils import logging as logging_module -from src.utils.logging import ( - configure_logging, - log_error, - log_info, - log_pipeline_summary, - log_stage_complete, - log_stage_start, - log_success, - log_warning, - reset_stage_timings, -) - - -class TestConfigureLogging: - """Test suite for configure_logging function.""" - - def test_configure_logging_default_level(self, monkeypatch): - """Test logging configuration with default INFO level.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should add handler to root logger - mock_logger.addHandler.assert_called_once() - # Check that root logger setLevel was called with INFO - assert call("INFO") in mock_logger.setLevel.call_args_list - - def test_configure_logging_custom_level(self, monkeypatch): - """Test logging configuration with custom LOG_LEVEL.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "DEBUG") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should set logger to DEBUG level - assert call("DEBUG") in mock_logger.setLevel.call_args_list - - def test_configure_logging_error_level(self, monkeypatch): - """Test logging configuration with ERROR level.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "ERROR") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - assert call("ERROR") in mock_logger.setLevel.call_args_list - - def test_configure_logging_warning_level(self, monkeypatch): - """Test logging configuration with WARNING level.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "WARNING") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - assert call("WARNING") in mock_logger.setLevel.call_args_list - - def test_configure_logging_removes_existing_handlers(self, monkeypatch): - """Test that existing handlers are removed before adding new ones.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_handler1 = MagicMock() - mock_handler2 = MagicMock() - mock_logger.handlers = [mock_handler1, mock_handler2] - mock_get_logger.return_value = mock_logger - - configure_logging() - - # Should remove existing handlers - assert mock_logger.removeHandler.call_count == 2 - mock_logger.removeHandler.assert_any_call(mock_handler1) - mock_logger.removeHandler.assert_any_call(mock_handler2) - - def test_configure_logging_adds_rich_handler(self, monkeypatch): - """Test that RichHandler is added with correct configuration.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should add a handler - mock_logger.addHandler.assert_called_once() - added_handler = mock_logger.addHandler.call_args[0][0] - - # Verify it's a RichHandler by checking its type name - assert added_handler.__class__.__name__ == "RichHandler" - - def test_configure_logging_formatter(self, monkeypatch): - """Test that the formatter is configured correctly.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Get the handler that was added - added_handler = mock_logger.addHandler.call_args[0][0] - - # Check the formatter is minimal - formatter = added_handler.formatter - assert formatter._fmt == "%(message)s" - - -class TestLogError: - """Test suite for log_error function.""" - - def test_log_error_simple_message(self): - """Test logging a simple error message.""" - with patch.object(Console, "print") as mock_print: - log_error("Test error message") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "red" - - def test_log_error_with_context(self): - """Test logging an error message with context.""" - with patch.object(Console, "print") as mock_print: - log_error("Test error", file="test.py", line=42) - - # Should print a Panel with context - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogWarning: - """Test suite for log_warning function.""" - - def test_log_warning_simple_message(self): - """Test logging a simple warning message.""" - with patch.object(Console, "print") as mock_print: - log_warning("Test warning message") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "yellow" - - def test_log_warning_with_context(self): - """Test logging a warning message with context.""" - with patch.object(Console, "print") as mock_print: - log_warning("Deprecated function", function="old_func") - - # Should print a Panel with context - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogSuccess: - """Test suite for log_success function.""" - - def test_log_success_simple_message(self): - """Test logging a simple success message.""" - with patch.object(Console, "print") as mock_print: - log_success("Operation completed") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "green" - - def test_log_success_with_details(self): - """Test logging a success message with details.""" - with patch.object(Console, "print") as mock_print: - log_success("File saved", path="/tmp/file.txt", size=1024) - - # Should print a Panel with details - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogInfo: - """Test suite for log_info function.""" - - def test_log_info_simple_message(self): - """Test logging a simple info message.""" - with patch.object(Console, "print") as mock_print: - log_info("Processing data") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "blue" - - def test_log_info_with_details(self): - """Test logging an info message with details.""" - with patch.object(Console, "print") as mock_print: - log_info("Processing", items=100, status="active") - - # Should print a Panel with details - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogStageStart: - """Test suite for log_stage_start function.""" - - def test_log_stage_start(self): - """Test logging stage start.""" - with patch.object(Console, "print") as mock_print: - log_stage_start("TestStage") - - # Should print formatted output - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "Starting Stage" in call_args - - -class TestLogStageComplete: - """Test suite for log_stage_complete function.""" - - def test_log_stage_complete_fast(self): - """Test logging stage completion with fast time (< 1s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 0.5) - - # Should print formatted output with green timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "Done Stage" in call_args - assert "0.500s" in call_args - - def test_log_stage_complete_medium(self): - """Test logging stage completion with medium time (1-10s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 5.0) - - # Should print formatted output with yellow timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "5.00s" in call_args - - def test_log_stage_complete_slow(self): - """Test logging stage completion with slow time (> 10s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 15.5) - - # Should print formatted output with red timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "15.50s" in call_args - - -class TestResetStageTimings: - """Test suite for reset_stage_timings function.""" - - def test_reset_stage_timings(self): - """Test resetting stage timings.""" - # Add some timings - with patch.object(Console, "print"): - log_stage_complete("Stage1", 1.0) - log_stage_complete("Stage2", 2.0) - - # Reset timings - reset_stage_timings() - - # Verify timings are empty by checking summary output - with patch.object(Console, "print") as mock_print: - log_pipeline_summary(10.0, 100.0, 150.0) - - # Should have been called multiple times (for table, memory, etc.) - assert mock_print.call_count > 0 - - -class TestLogPipelineSummary: - """Test suite for log_pipeline_summary function.""" - - def test_log_pipeline_summary_basic(self): - """Test logging pipeline summary without results.""" - reset_stage_timings() - - with patch.object(Console, "print") as mock_print: - # Add some stage timings first - log_stage_complete("Stage1", 1.0) - log_stage_complete("Stage2", 2.0) - - mock_print.reset_mock() - - # Log summary - log_pipeline_summary(3.0, 100.0, 150.0) - - # Should print multiple times (table, memory, etc.) - assert mock_print.call_count >= 2 - - # Check that at least one call contains a Table - has_table = False - for call_args in mock_print.call_args_list: - if call_args[0]: # Check positional args - arg = call_args[0][0] - if isinstance(arg, Table): - has_table = True - break - assert has_table - - def test_log_pipeline_summary_with_results(self): - """Test logging pipeline summary with results.""" - reset_stage_timings() - - with patch.object(Console, "print") as mock_print: - # Add some stage timings - log_stage_complete("Stage1", 1.0) - - mock_print.reset_mock() - - # Log summary with results - results = {"accuracy": 0.95, "latency": 2.5, "count": 100} - log_pipeline_summary(3.0, 100.0, 150.0, results) - - # Should print multiple times - assert mock_print.call_count >= 2 diff --git a/uv.lock b/uv.lock index fb04e7e..47257b6 100644 --- a/uv.lock +++ b/uv.lock @@ -1346,6 +1346,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/40/791891d4c0c4dab4c5e187c17261cedc26285fd41541577f900470a45a4d/license_expression-30.4.4-py3-none-any.whl", hash = "sha256:421788fdcadb41f049d2dc934ce666626265aeccefddd25e162a26f23bcbf8a4", size = 120615, upload-time = "2025-07-22T11:13:31.217Z" }, ] +[[package]] +name = "loguru" +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" }, +] + [[package]] name = "markdown" version = "3.10" @@ -1489,6 +1502,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "codecov" }, + { name = "loguru" }, { name = "mypy" }, { name = "nbqa" }, { name = "pip" }, @@ -1555,6 +1569,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "codecov", specifier = ">=2.1.13" }, + { name = "loguru", specifier = ">=0.7.3" }, { name = "mypy", specifier = ">=1.14.1" }, { name = "nbqa", specifier = ">=1.9.1" }, { name = "pip", specifier = ">=25.3" }, @@ -4070,6 +4085,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, ] +[[package]] +name = "win32-setctime" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/8f/705086c9d734d3b663af0e9bb3d4de6578d08f46b1b101c2442fd9aecaa2/win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0", size = 4867, upload-time = "2024-12-07T15:28:28.314Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" }, +] + [[package]] name = "wrapt" version = "2.0.1" From b19d55700e866732d1e8ffc6fc4644c3259db25b Mon Sep 17 00:00:00 2001 From: Sepideh Abedini Date: Wed, 7 Jan 2026 19:09:56 -0500 Subject: [PATCH 4/4] renaming along to log --- src/pipeline/base_processor/list_processor.py | 4 ++-- src/utils/logging.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pipeline/base_processor/list_processor.py b/src/pipeline/base_processor/list_processor.py index 6136003..bc1f390 100644 --- a/src/pipeline/base_processor/list_processor.py +++ b/src/pipeline/base_processor/list_processor.py @@ -9,7 +9,7 @@ from src.data_cache.json_cache import JsonCache from src.data_models.base_object import BaseObject from src.utils.async_utils import apply_async -from src.utils.logging import along +from src.utils.logging import log T = TypeVar("T", bound=BaseObject) @@ -93,7 +93,7 @@ def _pre_run(self) -> None: # noqa: B027 def _post_run(self) -> None: # noqa: B027 """Override to add post-processing logic after run.""" - @along("Processor completed: {0}") + @log("Processor completed: {0}") async def run(self, input_data: list[T]) -> list[U]: """ Process input file and return output_data. diff --git a/src/utils/logging.py b/src/utils/logging.py index 17e8d06..7f19063 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -186,7 +186,7 @@ def log_pipeline_summary( AR = Awaitable[R] -def along( +def log( message: str = "", before: str | None = None ) -> Callable[[Callable[P, AR]], Callable[P, AR]]: """Log messages before and after an async function execution."""