diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0061816 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +config/database.ini +__pycache__/ +.venv/ \ No newline at end of file diff --git a/README.md b/README.md index a1fdc78..3961e0c 100644 --- a/README.md +++ b/README.md @@ -1,119 +1,93 @@ -# Differential Privacy over SQL - -## Table of Contents -* [About the Project](#about-the-project) -* [Prerequisites](#prerequisites) - * [Tools](#tools) - * [Python Dependency](#python-dependency) - * [Database Permission](#database-permission) -* [system structure](#system-structure) -* [Demo System](#demo-system) -* [Instruction for Collecting Result](#collect-result) -* [Future Plan](#future-plan) +# Differential Privacy over SQL (DPSQL) + +DPSQL is a system designed for answering SQL queries while satisfying differential privacy guarantees. ## About The Project -Differential Privacy over SQL (DPSQL) is a system for answering queries over differential privacy. -The file structure is as below -``` +The file and directory structure of the project is organized as follows: + +```text project -│ -└───config -└───docs -└───Profile -└───src -│ └───algorithm -└───Test -│ └───TPCH -│ └───Graph -└───Sample +├── config/ # Configuration files required for the system +├── docs/ # Reference information and documentation +├── Profile/ # Profile information/licenses (e.g., mosek.lic) +├── src/ # Main source code files +│ └── algorithm/ # Core algorithms integrated into the system (e.g., FastSJA, OptSJA) +├── Test/ # Queries used in system experiments (TPCH, Graph) +└── Sample/ # Scripts for database setup and collecting experiment results ``` -`./config` stores the configuration files users need for the system. -`./docs` stores the reference information users need to work with DPSQL: +## Prerequisites -`./Profile` stores the Profile information for using `mosek` in the system. +### Tools +* **[PostgreSQL](https://www.postgresql.org/)**: Database engine. +* **[Python3](https://www.python.org/download/releases/3.0/)**: Ensure version 3.0 or higher. +* **[Mosek](https://www.mosek.com/downloads/)**: License file must be placed in `./Profile`. +* **CPLEX (Full Edition)**: Required for large datasets. Note: Do not rely on `pip install cplex` alone, as it has a 1,000-variable limit. + * [Detailed CPLEX Installation & Python Linking Guide](docs/cplex_setup.md) -`./src` stores main source files. -* `./src/algorithm` stores 3 algorithm we integrated into this system. +### Python Dependencies -`./Test` stores the queries used in the experiments of the system. +Install the required Python packages using the provided `requirements.txt` file: -`./Sample` stores the script for setting up database and collecting experiment results. +```bash +pip install -r requirements.txt +``` +### Database Permissions +The user running the system must have read permissions for the target database schema. -## Prerequisites -### Tools -Before running this project, please install below tools -* [PostgreSQL](https://www.postgresql.org/) -* [Python3](https://www.python.org/download/releases/3.0/) -* [Cplex](https://www.ibm.com/analytics/cplex-optimizer) -* [Mosek](https://www.mosek.com/downloads/) and the licence is under `./Profile`. - -Please do not install `Cplex` dependency, which can only handle a small dataset, but download the `Cplex API` and import that to python with this [instruction](https://www.ibm.com/docs/zh/icos/12.9.0?topic=cplex-setting-up-python-api). -(We are aware that this link is expired and are working on a substitute solution.) - -### Python Dependency -Here are dependencies used in python programs: -* `matplotlib` -* `numpy` -* `sys` -* `os` -* `collections` -* `configparser` -* `math` -* `psycopg2` -* `pglast`v4.4 -* `argparser` - -### Database permission -The user should have the permission to read the schema of the database to use this system. - -## System structure -TODO - -## Demo System - -To run the system, run `main.py`. There are seven parameters - - `--d`: path to database initialization file; - - `--q`: path to query file; - - `--r`: path to private relation file; - - `--c`: path to the configuration file; - - `--o`: path to the output file; - - `--debug`: debug mode for more information; - - `--optimal`: choose to use optimal algorithm for SJA queries; - -One can use `--h` to get help for parameter instruction. - -For more information about input file, users can consult [here](./docs/system-input.md) - -For the SQL syntax used in this system, users can consult [here](./docs/query-syntax.md) - -Example: -``` +## Usage / Demo System + +The main entry point for the system is `main.py`. + +### Command-Line Arguments +| Parameter | Description | +| :--- | :--- | +| `--d` | Path to the database initialization file | +| `--q` | Path to the query file | +| `--r` | Path to the private relation file | +| `--c` | Path to the configuration file | +| `--o` | Path to the output file | +| `--debug` | Enable debug mode for more detailed logging | +| `--optimal` | Use the optimal algorithm for SJA queries | + +*Use `python main.py --h` to view complete help instructions.* + +**Documentation Links:** +* [Input File Configuration](./docs/system-input.md) +* [Supported SQL Syntax](./docs/query-syntax.md) + +**Example Run:** +```bash python main.py --d ./config/database.ini --q ./test.txt --r ./test_relation.txt --c ./config/parameter.config --o out.txt ``` -## collect result +## Collecting Results -1. install the dependency +Follow these steps to set up the data and collect experiment results: -2. create an empty database in `PosgreSQL` -3. generate `tbl` data files by using dbgen from [TPCH website](https://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) -and store them in `/Sample/data/TPCH` -4. run script we provide in `/Sample/setupDBTPCH.py` -``` -python setupDBTPCH.py --db databasename -``` -5. run script we provide in `/Sample/collectResult.py` -```commandline -python collectResult.py -``` -6. find the result in `/Sample/result/TPCH` +1. **Install Dependencies**: Ensure tools and Python requirements are installed as per the [Prerequisites](#prerequisites). +2. **Database Setup**: Create an empty database in PostgreSQL. +3. **Data Generation**: Generate `.tbl` data files using `dbgen` from the [TPC-H website](https://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp), and store them in `./Sample/data/TPCH`. +4. **Database Initialization**: Run the setup script provided in `./Sample/setupDBTPCH.py`: + ```bash + python Sample/setupDBTPCH.py --db + ``` +5. **Run Collection Script**: + ```bash + cd Sample + python collectResult.py + ``` +6. **View Results**: The output will be available in `./Sample/result/TPCH`. + +## Query Rewriting & Subquery Unnesting + +DPSQL automatically rewrites and unnests subqueries to standard relational joins to ensure differential privacy mechanisms can be seamlessly applied. Through a custom Abstract Syntax Tree (AST) visitor (`UnnestSubqueries` in `src/parser.py`) built using `pglast`, the system traverses the AST and flattens nested `IN`, `ANY`, and `EXISTS` subqueries found in the `WHERE` clause into standard multi-table joins, while automatically preserving and linking the original filtering conditions. -## Future Plan +## Future Plans -- Distinct count queries type (projection); -- User Interface -- Better user experience; -- Optimization; +* Support for distinct count queries (projection). +* Develop a User Interface (UI). +* Improve overall user experience. +* General performance optimization. diff --git a/app.py b/app.py new file mode 100644 index 0000000..72c42d6 --- /dev/null +++ b/app.py @@ -0,0 +1,256 @@ +import streamlit as st +import subprocess +import configparser +import os +import ast +import pandas as pd +import matplotlib.pyplot as plt +import sys +import re + +st.set_page_config(layout="wide", page_title="DOP-SQL Interface") + +# --- Header --- +st.title("DOP-SQL: Differentially Private SQL System") +st.markdown("A General-purpose, High-utility, and Extensible Private SQL System") + +# --- Layout: Two columns for Inputs --- +col1, col2 = st.columns([1.2, 1]) + +with col1: + st.subheader("1. Input Query") + default_query = """select count(*) from supplier, lineitem, orders, customer, nation where supplier.S_SUPPKEY=lineitem.L_SUPPKEY and lineitem.L_ORDERKEY=orders.O_ORDERKEY and orders.O_CUSTKEY=customer.C_CUSTKEY and customer.C_NATIONKEY=nation.N_NATIONKEY and nation.N_NATIONKEY=supplier.S_NATIONKEY; +""" + + query = st.text_area("SQL Query", default_query, height=250) + relations_input = st.text_input( + "Primary Private Relations (comma-separated)", "customer" + ) + +with col2: + st.subheader("2. Parameter Configuration") + + st.markdown("**Global Parameters**") + col2a, col2b = st.columns(2) + with col2a: + epsilon = st.number_input( + "Privacy Budget (ε)", min_value=0.01, max_value=10.0, value=1.0, step=0.1 + ) + with col2b: + beta = st.number_input( + "Error Probability (β)", + min_value=0.001, + max_value=1.0, + value=0.1, + step=0.01, + ) + + p_num = st.number_input("Processor Count (Parallelism)", min_value=1, value=5) + recursion_bound = st.number_input("Recursion Bound", min_value=1, value=3) + with st.expander("Advanced Algorithm Parameters", expanded=False): + c1, c2 = st.columns(2) + with c1: + fast_global_sensitivity = st.number_input( + "FastSJA global_sensitivity", min_value=0, value=1000000 + ) + fast_approximate_factor = st.number_input( + "FastSJA approximate_factor", min_value=0.0, value=0.0, format="%f" + ) + delta = st.number_input( + "Relaxation (δ) [MultiQ]", min_value=0.0, value=0.000001, format="%f" + ) + with c2: + max_upper_bound = st.number_input( + "MaxSJA upper_bound", min_value=1, value=200 + ) + error_level = st.number_input( + "Error Level [MaxSJA]", min_value=0.01, value=0.1 + ) + +# --- Execution --- +st.divider() + +if st.button("Execute DP-SQL Query", type="primary"): + with st.spinner("Rewriting query and applying DP mechanisms..."): + + # 1. Prepare temporary files for main.py execution + with open("ui_test_query.txt", "w") as f: + f.write(query) + + with open("ui_test_relations.txt", "w") as f: + f.write("\n".join([r.strip() for r in relations_input.split(",")])) + + # 2. Generate configuration file dynamically + config = configparser.ConfigParser() + config.read("config/parameter.config") # Load base config to keep DB defaults + + if not config.has_section("global"): + config.add_section("global") + config.set("global", "epsilon", str(epsilon)) + config.set("global", "beta", str(beta)) + config.set("global", "processor_num", str(p_num)) + config.set("global", "recursion_bound", str(recursion_bound)) + + if not config.has_section("FastSJA"): + config.add_section("FastSJA") + # Use values from UI (defaults provided above) + config.set("FastSJA", "global_sensitivity", str(fast_global_sensitivity)) + config.set("FastSJA", "approximate_factor", str(fast_approximate_factor)) + + if not config.has_section("MultiQ"): + config.add_section("MultiQ") + config.set("MultiQ", "delta", str(delta)) + + if not config.has_section("MaxSJA"): + config.add_section("MaxSJA") + config.set("MaxSJA", "error_level", str(error_level)) + config.set("MaxSJA", "upper_bound", str(max_upper_bound)) + + with open("ui_parameter.config", "w") as f: + config.write(f) + + # 3. Call the existing backend (main.py) + cmd = [ + sys.executable, + "main.py", + "--d", + "config/database.ini", + "--q", + "ui_test_query.txt", + "--r", + "ui_test_relations.txt", + "--c", + "ui_parameter.config", + "--o", + "ui_out.txt", + "--debug", + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + st.error("Execution Failed. Check database connection and syntax.") + st.code(result.stderr) + else: + # 4. Parse the output text file generated by main.py + with open("ui_out.txt", "r") as f: + output_content = f.read() + + st.success("Query Evaluated Successfully!") + + # Extract metrics directly from the text file output + out_lines = output_content.split("\n") + q_type = ( + out_lines[0].replace("Query type:", "").strip() + if len(out_lines) > 0 + else "Unknown" + ) + + true_res_str = "" + noise_res_str = "" + rewrite_time_str = "N/A" + process_time_str = "N/A" + + for line in out_lines: + if line.startswith("true result:"): + true_res_str = line.split("true result:")[1].strip() + if line.startswith("noise result:"): + noise_res_str = line.split("noise result:")[1].strip() + if line.startswith("rewrite time:"): + rewrite_time_str = line.split("rewrite time:")[1].strip() + if line.startswith("process time:"): + process_time_str = line.split("process time:")[1].strip() + + # --- Visualization (Mirroring Fig 2c of the paper) --- + st.subheader(f"3. Results Overview (Mechanism: {q_type})") + + try: + # Clean up numpy specific formats before parsing + if true_res_str: + true_res_str = re.sub(r"np\.float64\((.*?)\)", r"\1", true_res_str) + if noise_res_str: + noise_res_str = re.sub( + r"np\.float64\((.*?)\)", r"\1", noise_res_str + ) + + # Try to parse python lists if output is group-by (e.g. [(val1, grp1), (val2, grp2)]) + true_vals = ast.literal_eval(true_res_str) + noise_vals = ast.literal_eval(noise_res_str) + + if ( + isinstance(true_vals, list) + and len(true_vals) > 0 + and isinstance(true_vals[0], tuple) + ): + # Group by query + df = pd.DataFrame( + { + "Group": [str(x[1]) for x in true_vals], + "True Result": [float(x[0]) for x in true_vals], + "Privatized Result": [float(x[0]) for x in noise_vals], + } + ).set_index("Group") + st.dataframe(df) + st.bar_chart(df) + else: + # Single Aggregate + true_val_f = float(true_vals) + noise_val_f = float(noise_vals) + rel_error = ( + abs(true_val_f - noise_val_f) / abs(true_val_f) + if true_val_f != 0 + else 0 + ) + + col_metric1, col_metric2 = st.columns(2) + with col_metric1: + st.metric( + "Privatized Output", + f"{noise_val_f:,.4f}", + delta=f"True: {true_val_f:,.4f}", + delta_color="off", + ) + with col_metric2: + st.metric("Relative Error", f"{rel_error:.4%}") + + except Exception as e: + # Fallback if output parsing fails (shows raw logs) + st.warning( + "Could not parse result for Bar Chart visualization. Displaying raw output." + ) + st.text(output_content) + + st.divider() + t_col1, t_col2 = st.columns(2) + try: + t_col1.metric( + "Rewrite Time", + ( + f"{float(rewrite_time_str):.4f} s" + if rewrite_time_str != "N/A" + else "N/A" + ), + ) + t_col2.metric( + "Processing Time", + ( + f"{float(process_time_str):.4f} s" + if process_time_str != "N/A" + else "N/A" + ), + ) + except Exception: + pass + + with st.expander("View Debug Logs / Rewritten Query"): + st.code(output_content) + + # Cleanup temp files + for temp_file in [ + "ui_test_query.txt", + "ui_test_relations.txt", + "ui_parameter.config", + "ui_out.txt", + ]: + if os.path.exists(temp_file): + os.remove(temp_file) diff --git a/config/parameter.config b/config/parameter.config index 14cc97f..8ae66bd 100644 --- a/config/parameter.config +++ b/config/parameter.config @@ -2,6 +2,7 @@ epsilon=1 beta=0.1 processor_num=5 +recursion_bound=3 [FastSJA] global_sensitivity=1000000 approximate_factor=0 diff --git a/docs/cplex_setup.md b/docs/cplex_setup.md new file mode 100644 index 0000000..b2c9688 --- /dev/null +++ b/docs/cplex_setup.md @@ -0,0 +1,61 @@ +# CPLEX Full Version Setup Guide + +To handle large-scale database workloads (like TPC-H and SSB), you must use the Full Edition of CPLEX. The standard `pip install cplex` is a Community Edition limited to 1,000 variables and 1,000 constraints. + +## 1. Download the Installer + +* **Students/Academics:** Register at the [IBM Academic Initiative](https://ibm.biz/academic) using your university email to download CPLEX for free. +* **Commercial:** Download via the IBM Passport Advantage portal. +* **Version:** This project is tested with CPLEX Studio 22.1.1 or newer. + +## 2. Install the Studio + +### Linux (Fedora/Ubuntu) + +1. Make the installer executable: `chmod +x cplex_studioXXXX.linux_x86_64.bin` +2. Run with sudo: `sudo ./cplex_studioXXXX.linux_x86_64.bin` +3. Default path: `/opt/ibm/ILOG/CPLEX_StudioXXXX` + +### Windows + +1. Run the `.exe` installer as Administrator. +2. Default path: `C:\Program Files\IBM\ILOG\CPLEX_StudioXXXX` + +## 3. Python API Integration + +IBM no longer includes a python folder in the installation. Follow these steps to link the Full Version to your Python environment: + +1. **Activate your virtual environment:** + ```bash + source .venv/bin/activate # Linux + .venv\Scripts\activate # Windows + ``` + +2. **Install the base packages:** + ```bash + pip install cplex docplex + ``` + +3. **Link to Local Binaries:** + Run the docplex utility to upgrade your pip installation to the Full Version: + + **Linux:** + ```bash + docplex config --upgrade /opt/ibm/ILOG/CPLEX_StudioXXXX + ``` + + **Windows:** + ```powershell + docplex config --upgrade "C:\Program Files\IBM\ILOG\CPLEX_StudioXXXX" + ``` + +## 4. Verification + +Run the following to ensure the unlimited version is active: + +```python +import cplex +c = cplex.Cplex() +print(f"CPLEX Version: {c.get_version()}") +# If this succeeds without a "Promotional Version" warning, you are ready. +``` \ No newline at end of file diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..1a54139 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,205 @@ +import subprocess +import configparser +import os +import ast +import re +import sys +import pandas as pd +import argparse + + +def parse_output(output_content): + """ + Parses the out.txt content generated by main.py + and extracts metrics along with calculating relative errors. + """ + out_lines = output_content.split("\n") + + q_type = "Unknown" + for line in out_lines: + if line.startswith("Query type:"): + q_type = line.replace("Query type:", "").strip() + break + + true_res_str, noise_res_str = "", "" + rewrite_time_str, process_time_str = "N/A", "N/A" + + for line in out_lines: + if line.startswith("true result:"): + true_res_str = line.split("true result:")[1].strip() + if line.startswith("noise result:"): + noise_res_str = line.split("noise result:")[1].strip() + if line.startswith("rewrite time:"): + rewrite_time_str = line.split("rewrite time:")[1].strip() + if line.startswith("process time:"): + process_time_str = line.split("process time:")[1].strip() + + # Clean numpy types formatting + if true_res_str: + true_res_str = re.sub(r"np\.float64\((.*?)\)", r"\1", true_res_str) + if noise_res_str: + noise_res_str = re.sub(r"np\.float64\((.*?)\)", r"\1", noise_res_str) + + rel_error = None + try: + true_vals = ast.literal_eval(true_res_str) + noise_vals = ast.literal_eval(noise_res_str) + + # If result is a list of tuples (e.g., from Group By queries) + if ( + isinstance(true_vals, list) + and len(true_vals) > 0 + and isinstance(true_vals[0], tuple) + ): + errors = [] + for t, n in zip(true_vals, noise_vals): + t_val = float(t[0]) + n_val = float(n[0]) + if t_val != 0: + errors.append(abs(t_val - n_val) / abs(t_val)) + else: + errors.append(0) + # Calculate Average Relative Error for Group By + rel_error = sum(errors) / len(errors) if errors else 0 + else: + # Single aggregate result + true_val_f = float(true_vals) + noise_val_f = float(noise_vals) + if true_val_f != 0: + rel_error = abs(true_val_f - noise_val_f) / abs(true_val_f) + else: + rel_error = 0 + except Exception: + pass # Leaving relative error as None if unparsable + + return { + "query_type": q_type, + "true_result": true_res_str, + "noise_result": noise_res_str, + "relative_error": rel_error, + "rewrite_time": rewrite_time_str, + "process_time": process_time_str, + } + + +def run_experiment( + query_file, relation_file, base_config_file, section, param, values, output_csv +): + """ + Runs the main.py pipeline over a list of parameter values, logs them, + and exports them to a CSV. + """ + print(f"Starting evaluation: Varying {section}.{param} for {query_file}") + results = [] + + config = configparser.ConfigParser() + config.read(base_config_file) + + for val in values: + print(f" -> Testing {param} = {val}...", end="", flush=True) + + # Add section if not present & set specific config variable + if not config.has_section(section): + config.add_section(section) + config.set(section, param, str(val)) + + # Temporary config/output files specifically for evaluating + temp_config = f"temp_eval_{param}.config" + temp_out = f"temp_eval_out_{param}.txt" + + with open(temp_config, "w") as f: + config.write(f) + + cmd = [ + sys.executable, + "main.py", + "--d", + "config/database.ini", + "--q", + query_file, + "--r", + relation_file, + "--c", + temp_config, + "--o", + temp_out, + "--debug", + ] + + # Run main.py with modified configs + run_result = subprocess.run(cmd, capture_output=True, text=True) + + if os.path.exists(temp_out): + with open(temp_out, "r") as f: + out_content = f.read() + + parsed = parse_output(out_content) + parsed["changed_parameter"] = param + parsed["parameter_value"] = val + parsed["success"] = run_result.returncode == 0 + if not parsed["success"]: + parsed["error_log"] = run_result.stderr.strip() + + results.append(parsed) + os.remove(temp_out) + print(" Done") + else: + print(" Failed") + results.append( + { + "changed_parameter": param, + "parameter_value": val, + "success": False, + "error_log": run_result.stderr.strip(), + } + ) + + if os.path.exists(temp_config): + os.remove(temp_config) + + # Dump to CSV + df = pd.DataFrame(results) + + # Reorder columns slightly for better readability + cols = [ + "success", + "changed_parameter", + "parameter_value", + "query_type", + "relative_error", + "true_result", + "noise_result", + "rewrite_time", + "process_time", + ] + # add remaining columns (like error_log if any failed) + cols.extend([c for c in df.columns if c not in cols]) + df = df[cols] + + df.to_csv(output_csv, index=False) + print(f"Results successfully saved to {output_csv}\n") + + +if __name__ == "__main__": + # --- Example 1: Vary Epsilon for a Nested Query --- + run_experiment( + query_file="test_rec.txt", + relation_file="test_relation.txt", + base_config_file="config/parameter.config", + section="global", + param="epsilon", + values=[0.1, 0.2, 0.5, 0.75, 1.0, 2.0, 5.0, 10.0], + output_csv="eval_rec6_epsilon.csv", + ) + + # # --- Example 2: Vary Recursion Bound for a Recursive Query --- + # # Using 'test.txt' as a placeholder since it might contain a recursive statement + # run_experiment( + # query_file="test_rec.txt", + # relation_file="test_relation.txt", + # base_config_file="config/parameter.config", + # section="global", + # param="recursion_bound", + # values=[2, 3, 4, 5], + # output_csv="eval_recursion_bound.csv", + # diff --git a/main.py b/main.py index cdb0538..825fbcd 100644 --- a/main.py +++ b/main.py @@ -8,104 +8,145 @@ from pglast import parser, prettify from pglast import ast import src.process +from src.recursive import is_recursive_query def get_project_root() -> Path: return Path(__file__).parent -def main(): - argparser = argparse.ArgumentParser(description='sql over DP') - argparser.add_argument('--db', '--d', type=str, default='./config/database.ini', - help='path to database initialization file') - argparser.add_argument('--query', '--q', type=str, default='./test.txt', help='path to query file') - argparser.add_argument('--relation', '--r', type=str, help='path to private relation file', default="./test_relation.txt") - argparser.add_argument('--config', '--c', type=str, help='path to the configuration file', default="./config/parameter.config") - argparser.add_argument('--output', '--o', type=str, help='path to output file', - default="./out.txt") - argparser.add_argument('--debug', action='store_true', help='debug mode, print more information') - argparser.add_argument('--optimal', action='store_true', help='optimal mode for SJA') - +def main(): + argparser = argparse.ArgumentParser(description="sql over DP") + argparser.add_argument( + "--db", + "--d", + type=str, + default="./config/database.ini", + help="path to database initialization file", + ) + argparser.add_argument( + "--query", "--q", type=str, default="./test.txt", help="path to query file" + ) + argparser.add_argument( + "--relation", + "--r", + type=str, + help="path to private relation file", + default="./test_relation.txt", + ) + argparser.add_argument( + "--config", + "--c", + type=str, + help="path to the configuration file", + default="./config/parameter.config", + ) + argparser.add_argument( + "--output", "--o", type=str, help="path to output file", default="./out.txt" + ) + argparser.add_argument( + "--debug", action="store_true", help="debug mode, print more information" + ) + argparser.add_argument( + "--optimal", action="store_true", help="optimal mode for SJA" + ) opt = argparser.parse_args() # load the config file dbsetting = config(opt.db) - global_para = config(opt.config, 'global') - fast_para = config(opt.config, 'FastSJA') - multi_para = config(opt.config, 'MultiQ') - max_para = config(opt.config, 'MaxSJA') + global_para = config(opt.config, "global") + fast_para = config(opt.config, "FastSJA") + multi_para = config(opt.config, "MultiQ") + max_para = config(opt.config, "MaxSJA") # load the input query query = "" - query_file = open(opt.query, 'r') + query_file = open(opt.query, "r") for line in query_file.readlines(): query = query + line if ";" in query: break # load the private relation - relation_file = open(opt.relation, 'r') + relation_file = open(opt.relation, "r") private_relations = "" for line in relation_file.readlines(): private_relations = private_relations + line + "," # first parsing for type check - root = parser.parse_sql(query) - selectstmt = root[0].stmt - if not isinstance(selectstmt, ast.SelectStmt): - raise Exception - check = check_type(private_relations) - check(selectstmt) + # + # The original check_type visitor assumes that every SelectStmt it visits + # has a fromClause. A WITH RECURSIVE query contains inner SELECT nodes + # such as constants/depth expressions that can violate that assumption. + # For recursive queries we choose the FastSJA path here and let + # process.rewrite(...) convert the recursive CTE into bounded row-level + # input before execution. + if is_recursive_query(query): + check = check_type(private_relations) + else: + root = parser.parse_sql(query) + selectstmt = root[0].stmt + if not isinstance(selectstmt, ast.SelectStmt): + raise Exception + check = check_type(private_relations) + check(selectstmt) filepath = get_project_root() - output_file = open(opt.output, 'w') - if opt.debug: - pg_test(dbsetting) + output_file = open(opt.output, "w") + if not pg_test(dbsetting): + raise SystemExit( + "Database connection failed. Check config/database.ini and make sure PostgreSQL is running/reachable." + ) # set up misc - multiprocessing.set_start_method("fork") + try: + multiprocessing.set_start_method("fork") + except (RuntimeError, ValueError): + # Windows does not support fork; keep the platform default. + pass para = dict(global_para) + para["recursion_bound"] = global_para.get("recursion_bound", "3") # load misc pks = pg_single(dbsetting, str(filepath) + "/config/primary_keys.txt") - fks = pg_single(dbsetting, str(filepath) + "/config/foreign_keys.txt") - table_file = open(str(filepath) + "/config/table.txt", 'r') + fks = pg_single(dbsetting, str(filepath) + "/config/foreign_keys.txt") + table_file = open(str(filepath) + "/config/table.txt", "r") q = table_file.read() schema = get_schema(dbsetting, q) if check.max is not None and check.groupby: para.update(multi_para) para.update(max_para) - output_file.write('Query type: MultiMax' + "\n") + output_file.write("Query type: MultiMax" + "\n") process = src.process.MultiMax(check.l, pks, fks, schema, para, dbsetting) elif check.max is not None: para.update(max_para) # shiftedinverse1 if check.l == 1: - output_file.write('Query type: MaxSJA1' + "\n") + output_file.write("Query type: MaxSJA1" + "\n") process = src.process.MaxSJA1(pks, fks, schema, para, dbsetting) # shiftedinverse2 if check.l > 1: - output_file.write('Query type: MaxSJA2' + "\n") + output_file.write("Query type: MaxSJA2" + "\n") process = src.process.MaxSJA2(pks, fks, schema, para, dbsetting) elif check.groupby: - para.update(multi_para) - if check.selfjoin: - # multiSJA - output_file.write('Query type: multiSJA' + "\n") - process = src.process.MultiSJA(pks, fks, schema, para, dbsetting) - else: - # multiSJF - output_file.write('Query type: multiSJF' + "\n") - process = src.process.MultiSJF(pks, fks, schema, para, dbsetting) + para.update(multi_para) + if check.selfjoin: + # multiSJA + output_file.write("Query type: multiSJA" + "\n") + process = src.process.MultiSJA(pks, fks, schema, para, dbsetting) + else: + # multiSJF + output_file.write("Query type: multiSJF" + "\n") + process = src.process.MultiSJF(pks, fks, schema, para, dbsetting) else: - # R2T + # R2T - para.update(fast_para) - if opt.optimal: - output_file.write('Query type: OptSJA' + "\n") - process = src.process.OptSJA(pks, fks, schema, para, dbsetting) - else: - output_file.write('Query type: FastSJA' + "\n") - process = src.process.FastSJA(pks, fks, schema, para, dbsetting) + para.update(fast_para) + if opt.optimal: + output_file.write("Query type: OptSJA" + "\n") + process = src.process.OptSJA(pks, fks, schema, para, dbsetting) + else: + output_file.write("Query type: FastSJA" + "\n") + process = src.process.FastSJA(pks, fks, schema, para, dbsetting) start = time.time() @@ -120,14 +161,17 @@ def main(): if opt.debug: output_file.write("original Query:" + "\n") output_file.write(prettify(query)) - output_file.write("\n" + "rewritten Query:" + "\n") + output_file.write("\n\n" + "rewritten Query:" + "\n") output_file.write(prettify(process.rewrite_query)) - output_file.write("\n" + "true result:") + output_file.write("\n\n" + "true result:") output_file.write(str(process.true_result)) if process.error is not None: output_file.write("\n" + "error:") output_file.write(str(process.error)) + output_file.write("\n" + "actual result:") + output_file.write(str(process.true_result)) + output_file.write("\n" + "noise result:") output_file.write(str(process.noise_result)) output_file.write("\n" + "rewrite time:") @@ -135,7 +179,9 @@ def main(): output_file.write("\n" + "process time:") output_file.write(str(end2 - end1)) + print("Processing completed. Check output file for results.") + # Press the green button in the gutter to run the script. -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2c7a3e1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +contourpy==1.3.3 +cplex==22.1.2.1 +cycler==0.12.1 +docplex==2.32.259 +fonttools==4.62.1 +hashable-list==0.2.0 +kiwisolver==1.5.0 +matplotlib==3.10.8 +Mosek==11.1.10 +numpy==2.4.3 +ordered-set==4.1.0 +packaging==26.0 +pglast==7.5 +pillow==12.1.1 +psycopg2-binary==2.9.11 +pyparsing==3.3.2 +python-dateutil==2.9.0.post0 +setuptools==82.0.1 +six==1.17.0 +streamlit==1.57.0 \ No newline at end of file diff --git a/src/__pycache__/parser.cpython-310.pyc b/src/__pycache__/parser.cpython-310.pyc deleted file mode 100644 index 041cc6e..0000000 Binary files a/src/__pycache__/parser.cpython-310.pyc and /dev/null differ diff --git a/src/__pycache__/process.cpython-310.pyc b/src/__pycache__/process.cpython-310.pyc deleted file mode 100644 index cebb75b..0000000 Binary files a/src/__pycache__/process.cpython-310.pyc and /dev/null differ diff --git a/src/__pycache__/util.cpython-310.pyc b/src/__pycache__/util.cpython-310.pyc deleted file mode 100644 index a735e48..0000000 Binary files a/src/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/FastSJA.py b/src/algorithm/FastSJA.py index 37626d0..4c29492 100644 --- a/src/algorithm/FastSJA.py +++ b/src/algorithm/FastSJA.py @@ -70,6 +70,8 @@ def ReadInput(): for line in input_result: ids = line[1:] for id in ids: + if id is None: + continue temp_id = id if temp_id not in reorder_ids: reorder_ids[temp_id] = num_id @@ -84,6 +86,8 @@ def ReadInput(): #print(aggregation_value) # For each entity contribution to that join result for element in elements[1:]: + if element is None: + continue element = reorder_ids[element] if element in id_dic.keys(): element = id_dic[element] diff --git a/src/algorithm/MaxSJA1.py b/src/algorithm/MaxSJA1.py index a220cf1..25377e8 100644 --- a/src/algorithm/MaxSJA1.py +++ b/src/algorithm/MaxSJA1.py @@ -21,6 +21,8 @@ def ReadInput(): tuple_value = float(elements[0]) user_id = elements[1] + if user_id is None: + continue if user_id in id_dict.keys(): user_id = id_dict[user_id] diff --git a/src/algorithm/MaxSJA2.py b/src/algorithm/MaxSJA2.py index d38e38e..c7223a0 100644 --- a/src/algorithm/MaxSJA2.py +++ b/src/algorithm/MaxSJA2.py @@ -28,6 +28,8 @@ def ReadInput(): value = float(elements[0]) for element in elements[1:]: + if element is None: + continue user_id = element if user_id in id_dict.keys(): diff --git a/src/algorithm/OptSJA.py b/src/algorithm/OptSJA.py index 515c05b..0de5811 100644 --- a/src/algorithm/OptSJA.py +++ b/src/algorithm/OptSJA.py @@ -47,6 +47,8 @@ def ReadInput(): aggregation_value = float(elements[0]) # For each entity contribution to that join result for element in elements[1:]: + if element is None: + continue # element = int(element) # Re-order the IDs if element in id_dic.keys(): diff --git a/src/algorithm/__pycache__/FastSJA.cpython-310.pyc b/src/algorithm/__pycache__/FastSJA.cpython-310.pyc deleted file mode 100644 index a4a2016..0000000 Binary files a/src/algorithm/__pycache__/FastSJA.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/MaxSJA1.cpython-310.pyc b/src/algorithm/__pycache__/MaxSJA1.cpython-310.pyc deleted file mode 100644 index 191339e..0000000 Binary files a/src/algorithm/__pycache__/MaxSJA1.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/MaxSJA2.cpython-310.pyc b/src/algorithm/__pycache__/MaxSJA2.cpython-310.pyc deleted file mode 100644 index bb8e8cf..0000000 Binary files a/src/algorithm/__pycache__/MaxSJA2.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/MultiSJA.cpython-310.pyc b/src/algorithm/__pycache__/MultiSJA.cpython-310.pyc deleted file mode 100644 index 5f7c35c..0000000 Binary files a/src/algorithm/__pycache__/MultiSJA.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/MultiSJF.cpython-310.pyc b/src/algorithm/__pycache__/MultiSJF.cpython-310.pyc deleted file mode 100644 index 2c9522b..0000000 Binary files a/src/algorithm/__pycache__/MultiSJF.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/OptSJA.cpython-310.pyc b/src/algorithm/__pycache__/OptSJA.cpython-310.pyc deleted file mode 100644 index d040193..0000000 Binary files a/src/algorithm/__pycache__/OptSJA.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/R2T.cpython-310.pyc b/src/algorithm/__pycache__/R2T.cpython-310.pyc deleted file mode 100644 index 9da80b5..0000000 Binary files a/src/algorithm/__pycache__/R2T.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/ShiftedInverse1.cpython-310.pyc b/src/algorithm/__pycache__/ShiftedInverse1.cpython-310.pyc deleted file mode 100644 index 12366c9..0000000 Binary files a/src/algorithm/__pycache__/ShiftedInverse1.cpython-310.pyc and /dev/null differ diff --git a/src/algorithm/__pycache__/ShiftedInverse2.cpython-310.pyc b/src/algorithm/__pycache__/ShiftedInverse2.cpython-310.pyc deleted file mode 100644 index ecbf89c..0000000 Binary files a/src/algorithm/__pycache__/ShiftedInverse2.cpython-310.pyc and /dev/null differ diff --git a/src/parser.py b/src/parser.py index 069171c..c41e737 100644 --- a/src/parser.py +++ b/src/parser.py @@ -28,7 +28,7 @@ def iterate(self, node): if isinstance(node, (tuple, ast.Node)): todo.append((Ancestor(), node)) else: - raise ValueError('Bad argument, expected a ast.Node instance or a tuple') + raise ValueError("Bad argument, expected a ast.Node instance or a tuple") while todo: ancestors, node = todo.popleft() @@ -93,8 +93,8 @@ def iterate(self, node): def visit_JoinExpr(self, ancestors, node): """ - we keep all the table name including renaming, and all the predicate - in the join condition + we keep all the table name including renaming, and all the predicate + in the join condition """ idx = 0 # left is table name @@ -137,12 +137,16 @@ def visit_SelectStmt(self, ancestors, node): if len(self.qual) == 1: node.whereClause = self.qual[0] elif len(self.qual) > 1: - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=()) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=() + ) for item in self.qual: node.whereClause.args += (item,) else: temp = node.whereClause - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=(temp,)) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=(temp,) + ) for item in self.qual: node.whereClause.args += (item,) @@ -162,7 +166,11 @@ def visit_ResTarget(self, ancestors, node): # all the written code does not have aggregation # may be revisited for max later if isinstance(node.val, ast.FuncCall): - if node.val.funcname[0].val == 'sum' or node.val.funcname[0].val == 'count' or node.val.funcname[0].val == 'max': + if ( + node.val.funcname[0].sval == "sum" + or node.val.funcname[0].sval == "count" + or node.val.funcname[0].sval == "max" + ): return visitors.Delete def visit_SelectStmt(self, ancestors, node): @@ -174,14 +182,16 @@ def visit_SelectStmt(self, ancestors, node): for item in node.targetList: if isinstance(item.val, ast.FuncCall): if len(item.val.funcname) == 1: - if item.val.funcname[0].val == 'count': - node.targetList += (ast.ResTarget(val=ast.A_Const(val=ast.Integer(1))),) - elif item.val.funcname[0].val == 'sum': + if item.val.funcname[0].sval == "count": + node.targetList += ( + ast.ResTarget(val=ast.A_Const(val=ast.Integer(1))), + ) + elif item.val.funcname[0].sval == "sum": temp = item.val.args[0] node.targetList += (ast.ResTarget(val=temp),) - elif item.val.funcname[0].val == 'max': + elif item.val.funcname[0].sval == "max": temp = item.val.args[0] - self.index = item.val.args[1].val.val + self.index = item.val.args[1].val.ival node.targetList += (ast.ResTarget(val=temp),) else: pass @@ -199,11 +209,34 @@ def get_primary_keys(pks, relations): if pk[0] in relation: left = pk[2].find("(") right = pk[2].find(")") - key = pk[2][left + 1:right] + key = pk[2][left + 1 : right] res.append(pk[0] + "." + key) return res +def _split_pk_columns(pk_text): + """Return individual PK column names from strings like 'id' or 'id1, id2'.""" + return [c.strip().strip('\"') for c in pk_text.split(",") if c.strip()] + + +def _build_tuple_id_expr(alias, columns, id_label): + """ + Build concat('id0', alias.col1, ':', alias.col2, ...) for tuple identity. + This fixes composite primary keys that were previously treated as one + quoted column name, e.g. lineitem."l_orderkey, l_linenumber". + """ + args = [ast.String(sval=id_label)] + for i, col in enumerate(columns): + if i > 0: + args.append(ast.String(sval=":")) + args.append( + ast.ColumnRef( + fields=(ast.String(sval=alias), ast.String(sval=col)) + ) + ) + return ast.FuncCall(funcname=(ast.String(sval="concat"),), args=tuple(args)) + + class userAdder(visitors.Visitor): def __init__(self, keys): @@ -223,15 +256,21 @@ def visit_SelectStmt(self, ancestors, node): renaming = renaming.rename_dict # add all the private keys into the select statement for r in self.keys: - table_attribute = r.split(".") - for name in renaming[table_attribute[0]]: - rename = 'id' + str(idx) - table = ast.String(value=rename) - attri = ast.ColumnRef(fields=(name, table_attribute[1])) - node.targetList += (ast.ResTarget(val=ast.FuncCall(funcname=(ast.String(value='concat'),), - args=(table, attri)), name=rename),) - # node.targetList += (ast.ResTarget(val=ast.ColumnRef(fields=(name, - # table_attribute[1])), name=rename),) + table_attribute = r.split(".", 1) + if len(table_attribute) != 2: + continue + table_name, pk_text = table_attribute + pk_columns = _split_pk_columns(pk_text) + if table_name not in renaming: + continue + for name in renaming[table_name]: + rename = "id" + str(idx) + node.targetList += ( + ast.ResTarget( + val=_build_tuple_id_expr(name, pk_columns, rename), + name=rename, + ), + ) idx += 1 @@ -266,10 +305,10 @@ def __init__(self, fks): split = fk[2].split("REFERENCES ") left1 = split[0].find("(") right1 = split[0].find(")") - src_key = split[0][left1 + 1:right1] + src_key = split[0][left1 + 1 : right1] left2 = split[1].find("(") right2 = split[1].find(")") - dest_key = split[1][left2 + 1:right2] + dest_key = split[1][left2 + 1 : right2] dest_table = split[1][0:left2] self.fk_dic[(src_table, src_key)] = (dest_table, dest_key) @@ -278,32 +317,53 @@ def check_condition(self, node, src_rename, dest_rename, src_key, dest_key): # for each dest table, there are four cases for foreign key condition for dest in dest_rename: # r1.k1 = r2.k2 - c1 = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef( - fields=(ast.String(value=src_rename), ast.String(value=src_key[0].strip()))), - rexpr=ast.ColumnRef( - fields=(ast.String(value=dest), ast.String(value=dest_key[0].strip())))) + c1 = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef( + fields=( + ast.String(sval=src_rename), + ast.String(sval=src_key[0].strip()), + ) + ), + rexpr=ast.ColumnRef( + fields=(ast.String(sval=dest), ast.String(sval=dest_key[0].strip())) + ), + ) # r2.k2 = r1.k1 - c2 = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef( - fields=(ast.String(value=dest), ast.String(value=dest_key[0].strip()),)), - rexpr=ast.ColumnRef( - fields=(ast.String(value=src_rename), ast.String(value=src_key[0].strip()),))) + c2 = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef( + fields=( + ast.String(sval=dest), + ast.String(sval=dest_key[0].strip()), + ) + ), + rexpr=ast.ColumnRef( + fields=( + ast.String(sval=src_rename), + ast.String(sval=src_key[0].strip()), + ) + ), + ) # k1 = k2 - c3 = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef( - fields=(ast.String(value=src_key[0].strip()),)), - rexpr=ast.ColumnRef( - fields=(ast.String(value=dest_key[0].strip()),))) + c3 = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef(fields=(ast.String(sval=src_key[0].strip()),)), + rexpr=ast.ColumnRef(fields=(ast.String(sval=dest_key[0].strip()),)), + ) # k2 = k1 - c4 = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef( - fields=(ast.String(value=dest_key[0].strip()),)), - rexpr=ast.ColumnRef( - fields=(ast.String(value=src_key[0].strip()),))) + c4 = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef(fields=(ast.String(sval=dest_key[0].strip()),)), + rexpr=ast.ColumnRef(fields=(ast.String(sval=src_key[0].strip()),)), + ) r1 = check_expr(c1) r1(node.whereClause) @@ -334,11 +394,18 @@ def visit_SelectStmt(self, ancestors, node): for fk in self.fk_dic.keys(): # if src table is in the current query, and destination table is not in # we add the destination table into the query, and the join condition - if fk[0] in relation_dict.keys() and self.fk_dic[fk][0] not in relation_dict.keys(): + if ( + fk[0] in relation_dict.keys() + and self.fk_dic[fk][0] not in relation_dict.keys() + ): # renaming the upcoming table rename = self.fk_dic[fk][0] + str(0) # syntax node for this table - item = ast.RangeVar(relname=self.fk_dic[fk][0], inh=True, alias=ast.Alias(aliasname=rename)) + item = ast.RangeVar( + relname=self.fk_dic[fk][0], + inh=True, + alias=ast.Alias(aliasname=rename), + ) # add to the select statement node.fromClause += (item,) # update renaming @@ -348,11 +415,22 @@ def visit_SelectStmt(self, ancestors, node): src_key = fk[1].split(",") dest_key = self.fk_dic[fk][1].split(",") for i in range(len(src_key)): - c = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef(fields=(ast.String(value=relation_dict[fk[0]][0]), - ast.String(value=src_key[i].strip()))), - rexpr=ast.ColumnRef( - fields=(ast.String(value=rename), ast.String(value=dest_key[i].strip())))) + c = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef( + fields=( + ast.String(sval=relation_dict[fk[0]][0]), + ast.String(sval=src_key[i].strip()), + ) + ), + rexpr=ast.ColumnRef( + fields=( + ast.String(sval=rename), + ast.String(sval=dest_key[i].strip()), + ) + ), + ) # print(stream.RawStream()(c)) conditions += (c,) # add the join condition to the select statement @@ -360,18 +438,25 @@ def visit_SelectStmt(self, ancestors, node): if len(conditions) == 1: node.whereClause = conditions[0] elif len(conditions) > 1: - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=()) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=() + ) for condition in conditions: node.whereClause.args += (condition,) else: temp = node.whereClause - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=(temp,)) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=(temp,) + ) for condition in conditions: node.whereClause.args += (condition,) visited = True # if both src and dest tables are in the query # we check if the foreign key condition is in the whereClause - if fk[0] in relation_dict.keys() and self.fk_dic[fk][0] in relation_dict.keys(): + if ( + fk[0] in relation_dict.keys() + and self.fk_dic[fk][0] in relation_dict.keys() + ): src_rename = relation_dict[fk[0]] dest_rename = relation_dict[self.fk_dic[fk][0]] src_key = fk[1].split(",") @@ -379,34 +464,57 @@ def visit_SelectStmt(self, ancestors, node): # for each renaming of the src table # we check there is at least one dest table renaming connecting to this src table for src in src_rename: - if not (self.check_condition(node, src, dest_rename, src_key, dest_key)): + if not ( + self.check_condition( + node, src, dest_rename, src_key, dest_key + ) + ): # add another destination table renaming - rename = self.fk_dic[fk][0] + str(len(relation_dict[self.fk_dic[fk][0]])) + rename = self.fk_dic[fk][0] + str( + len(relation_dict[self.fk_dic[fk][0]]) + ) relation_dict[self.fk_dic[fk][0]].append(rename) - item = ast.RangeVar(relname=self.fk_dic[fk][0], inh=True, alias=ast.Alias(aliasname=rename)) + item = ast.RangeVar( + relname=self.fk_dic[fk][0], + inh=True, + alias=ast.Alias(aliasname=rename), + ) # add to the select statement node.fromClause += (item,) conditions = () for i in range(len(src_key)): - c = ast.A_Expr(kind=enums.A_Expr_Kind.AEXPR_OP, name=(ast.String(value="="),), - lexpr=ast.ColumnRef( - fields=(ast.String(value=src), - ast.String(value=src_key[i].strip()))), - rexpr=ast.ColumnRef( - fields=(ast.String(value=rename), - ast.String(value=dest_key[i].strip())))) + c = ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String(sval="="),), + lexpr=ast.ColumnRef( + fields=( + ast.String(sval=src), + ast.String(sval=src_key[i].strip()), + ) + ), + rexpr=ast.ColumnRef( + fields=( + ast.String(sval=rename), + ast.String(sval=dest_key[i].strip()), + ) + ), + ) conditions += (c,) # add the join condition to the select statement if node.whereClause is None: if len(conditions) == 1: node.whereClause = conditions elif len(conditions) > 1: - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=()) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=() + ) for condition in conditions: node.whereClause.args += (condition,) else: temp = node.whereClause - node.whereClause = ast.BoolExpr(boolop=enums.BoolExprType.AND_EXPR, args=(temp,)) + node.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, args=(temp,) + ) for condition in conditions: node.whereClause.args += (condition,) @@ -414,9 +522,10 @@ def visit_SelectStmt(self, ancestors, node): class add_table_name(visitors.Visitor): - ''' + """ add the table either renaming or not to the column - ''' + """ + def __init__(self, select_node, schema): # get the renaming of the current node renaming = get_rename() @@ -428,25 +537,29 @@ def visit_ColumnRef(self, ancestors, node): # get fields fields = node.fields if len(fields) == 1: - attribute = fields[0].val + attribute = fields[0].sval for table in self.rename_dict.keys(): table_attr = self.schema[table] if attribute in table_attr: # self.rename_dict[table] length should be 1, otherwise this # is an ambiguous column - node.fields = (ast.String(value=self.rename_dict[table][0]),) + node.fields + node.fields = ( + ast.String(sval=self.rename_dict[table][0]), + ) + node.fields class get_rename(visitors.Visitor): - ''' + """ get all the current renaming in tis select statement this is mainly a helper visitor for other functional visitor - ''' + """ def __init__(self): self.rename_dict = {} def visit_SelectStmt(self, ancestors, node): + if not node.fromClause: + return for item in node.fromClause: if isinstance(item, ast.RangeVar): if item.relname not in self.rename_dict.keys(): @@ -478,7 +591,9 @@ def visit_SelectStmt(self, ancestors, node): # remove the group column to the selection (targetList) self.root = node self.group = node.groupClause - new_selection = ast.FuncCall(funcname=(ast.String(value='concat'),), args=self.group) + new_selection = ast.FuncCall( + funcname=(ast.String(sval="concat"),), args=self.group + ) node.targetList = (ast.ResTarget(val=new_selection),) + node.targetList node.groupClause = None @@ -489,7 +604,7 @@ def __init__(self): self.node = None def visit_SelectStmt(self, ancestors, node): - if isinstance(node.fromClause[0], ast.RangeSubselect): + if node.fromClause and isinstance(node.fromClause[0], ast.RangeSubselect): self.node = node.fromClause[0].subquery @@ -498,6 +613,7 @@ class check_type(visitors.Visitor): this visitor will check the type of the input query and then decide which algorithm to process the input query """ + def __init__(self, relations): self.subquery = False self.groupby = False @@ -507,12 +623,14 @@ def __init__(self, relations): self._relation = relations def visit_SelectStmt(self, ancestors, node): + if not node.fromClause: + return if isinstance(node.fromClause[0], ast.RangeSubselect): self.subquery = True for item in node.targetList: if isinstance(item.val, ast.FuncCall): if len(item.val.funcname) == 1: - if item.val.funcname[0].val == 'max': + if item.val.funcname[0].sval == "max": self.max = True else: # if there is a subquery, we have to get implicit join first @@ -522,7 +640,7 @@ def visit_SelectStmt(self, ancestors, node): for item in node.targetList: if isinstance(item.val, ast.FuncCall): if len(item.val.funcname) == 1: - if item.val.funcname[0].val == 'max': + if item.val.funcname[0].sval == "max": self.max = True # check l renaming = get_rename() @@ -541,5 +659,49 @@ def visit_SelectStmt(self, ancestors, node): self.groupby = True -if __name__ == '__main__': - pass \ No newline at end of file +if __name__ == "__main__": + pass + +class UnnestSubqueries(visitors.Visitor): + def __init__(self): + self.sub_from_clauses = () + self.sub_where_clauses = [] + + def visit_SubLink(self, ancestors, node): + if node.subLinkType in (enums.SubLinkType.ANY_SUBLINK, enums.SubLinkType.EXISTS_SUBLINK): + self.sub_from_clauses += node.subselect.fromClause + + if node.subLinkType == enums.SubLinkType.ANY_SUBLINK: + if node.subselect.whereClause: + self.sub_where_clauses.append(node.subselect.whereClause) + return ast.A_Expr( + kind=enums.A_Expr_Kind.AEXPR_OP, + name=(ast.String('='),), + lexpr=node.testexpr, + rexpr=node.subselect.targetList[0].val + ) + elif node.subLinkType == enums.SubLinkType.EXISTS_SUBLINK: + if node.subselect.whereClause: + return node.subselect.whereClause + else: + return ast.A_Const(val=ast.Integer(1)) + +def apply_unnest_subqueries(selectstmt): + while True: + unnester = UnnestSubqueries() + unnester(selectstmt) + if not unnester.sub_from_clauses: + break + + if selectstmt.fromClause is None: + selectstmt.fromClause = () + selectstmt.fromClause += unnester.sub_from_clauses + + for sub_where in unnester.sub_where_clauses: + if selectstmt.whereClause: + selectstmt.whereClause = ast.BoolExpr( + boolop=enums.BoolExprType.AND_EXPR, + args=(selectstmt.whereClause, sub_where) + ) + else: + selectstmt.whereClause = sub_where diff --git a/src/process.py b/src/process.py index 0b378ed..a472a3e 100644 --- a/src/process.py +++ b/src/process.py @@ -11,9 +11,18 @@ import src.algorithm.MaxSJA1 import src.algorithm.MaxSJA2 import src.algorithm.OptSJA -from src.parser import userAdder, ImplicitJoin, complete_query, aggregationVisit, get_primary_keys, add_table_name, \ - group_by +from src.parser import ( + userAdder, + ImplicitJoin, + complete_query, + aggregationVisit, + get_primary_keys, + add_table_name, + group_by, + apply_unnest_subqueries, +) from src.util import pg_exec +from src.recursive import rewrite_bounded_recursive_query class algorithm(ABC): @@ -32,6 +41,16 @@ def __init__(self, pks, fks, schema, parameters, dbsetting): def get_input_result(self): self.input_result = pg_exec(self.dbsetting, self.rewrite_query) + def rewrite_recursive_if_needed(self, query, private_relations): + recursion_bound = int(self.parameters.get("recursion_bound", 10)) + rewritten = rewrite_bounded_recursive_query( + query, private_relations, self.pks, recursion_bound + ) + if rewritten is not None: + self.rewrite_query = rewritten + return True + return False + @abstractmethod def rewrite(self, query, private_relations): pass @@ -44,33 +63,46 @@ def process(self): class FastSJA(algorithm): def rewrite(self, query, private_relations): + if self.rewrite_recursive_if_needed(query, private_relations): + return private_pk = get_primary_keys(self.pks, private_relations) root = parser.parse_sql(query) selectstmt = root[0].stmt if not isinstance(selectstmt, ast.SelectStmt): raise Exception + print("Original Query:") + print(stream.RawStream()(selectstmt)) + apply_unnest_subqueries(selectstmt) + print("After unnesting subqueries:") + print(stream.RawStream()(selectstmt)) ImplicitJoin()(selectstmt) add_table_name(selectstmt, self.schema)(selectstmt) aggregationVisit()(selectstmt) complete_query(self.fks)(selectstmt) userAdder(private_pk)(selectstmt) - self.rewrite_query = (stream.RawStream()(selectstmt)) + self.rewrite_query = stream.RawStream()(selectstmt) def process(self): - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - processor_num = int(self.parameters['processor_num']) - global_sensitivity = float(self.parameters['global_sensitivity']) - approximate_factor = float(self.parameters['approximate_factor']) - src.algorithm.FastSJA.processFastSJA(self.input_result, e=epsilon, b=beta, gs=global_sensitivity, - p_num=processor_num, afactor=approximate_factor) + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + processor_num = int(self.parameters["processor_num"]) + global_sensitivity = float(self.parameters["global_sensitivity"]) + approximate_factor = float(self.parameters["approximate_factor"]) + src.algorithm.FastSJA.processFastSJA( + self.input_result, + e=epsilon, + b=beta, + gs=global_sensitivity, + p_num=processor_num, + afactor=approximate_factor, + ) self.true_result, self.noise_result = src.algorithm.FastSJA.get_result() class OptSJA(FastSJA): def process(self): - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) src.algorithm.OptSJA.processOpt(self.input_result, e=epsilon, b=beta) self.true_result, self.noise_result = src.algorithm.OptSJA.get_result() @@ -78,35 +110,46 @@ def process(self): class MultiSJF(algorithm): def rewrite(self, query, private_relations): + if self.rewrite_recursive_if_needed(query, private_relations): + return private_pk = get_primary_keys(self.pks, private_relations) root = parser.parse_sql(query) selectstmt = root[0].stmt if not isinstance(selectstmt, ast.SelectStmt): raise Exception + apply_unnest_subqueries(selectstmt) ImplicitJoin()(selectstmt) add_table_name(selectstmt, self.schema)(selectstmt) group_by()(selectstmt) aggregationVisit()(selectstmt) complete_query(self.fks)(selectstmt) userAdder(private_pk)(selectstmt) - self.rewrite_query = (stream.RawStream()(selectstmt)) + self.rewrite_query = stream.RawStream()(selectstmt) def process(self): - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - delta = float(self.parameters['delta']) - src.algorithm.MultiSJF.ProcessMultiQSJF(self.input_result, e=epsilon, b=beta, d=delta) - self.true_result, self.noise_result, self.error = src.algorithm.MultiSJF.get_result() + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + delta = float(self.parameters["delta"]) + src.algorithm.MultiSJF.ProcessMultiQSJF( + self.input_result, e=epsilon, b=beta, d=delta + ) + self.true_result, self.noise_result, self.error = ( + src.algorithm.MultiSJF.get_result() + ) class MultiSJA(MultiSJF): def process(self): - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - delta = float(self.parameters['delta']) - src.algorithm.MultiSJA.ProcessMultiQSJA(self.input_result, e=epsilon, b=beta, Del=delta) - self.true_result, self.noise_result, self.error = src.algorithm.MultiSJA.get_result() + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + delta = float(self.parameters["delta"]) + src.algorithm.MultiSJA.ProcessMultiQSJA( + self.input_result, e=epsilon, b=beta, Del=delta + ) + self.true_result, self.noise_result, self.error = ( + src.algorithm.MultiSJA.get_result() + ) class MaxSJA1(algorithm): @@ -115,11 +158,14 @@ def __init__(self, pks, fks, schema, parameters, dbsetting): self.k = None def rewrite(self, query, private_relations): + if self.rewrite_recursive_if_needed(query, private_relations): + return private_pk = get_primary_keys(self.pks, private_relations) root = parser.parse_sql(query) selectstmt = root[0].stmt if not isinstance(selectstmt, ast.SelectStmt): raise Exception + apply_unnest_subqueries(selectstmt) ImplicitJoin()(selectstmt) add_table_name(selectstmt, self.schema)(selectstmt) agg = aggregationVisit() @@ -127,18 +173,26 @@ def rewrite(self, query, private_relations): complete_query(self.fks)(selectstmt) userAdder(private_pk)(selectstmt) self.k = agg.index - self.rewrite_query = (stream.RawStream()(selectstmt)) + self.rewrite_query = stream.RawStream()(selectstmt) def process(self): if self.k == 0: self.k = len(self.input_result) - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - error_level = float(self.parameters['error_level']) - upper_bound = float(self.parameters['upper_bound']) - src.algorithm.MaxSJA1.processMaxSJA1(self.input_result, self.k, e=epsilon, b=beta, - error=error_level, ub=upper_bound) - self.true_result, self.noise_result, self.error = src.algorithm.MaxSJA1.get_result() + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + error_level = float(self.parameters["error_level"]) + upper_bound = float(self.parameters["upper_bound"]) + src.algorithm.MaxSJA1.processMaxSJA1( + self.input_result, + self.k, + e=epsilon, + b=beta, + error=error_level, + ub=upper_bound, + ) + self.true_result, self.noise_result, self.error = ( + src.algorithm.MaxSJA1.get_result() + ) class MaxSJA2(MaxSJA1): @@ -146,14 +200,23 @@ class MaxSJA2(MaxSJA1): def process(self): if self.k == 0: self.k = len(self.input_result) - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - processor_num = int(self.parameters['processor_num']) - error_level = float(self.parameters['error_level']) - upper_bound = float(self.parameters['upper_bound']) - src.algorithm.MaxSJA2.processMaxSJA2(self.input_result, self.k, e=epsilon, b=beta, - error=error_level, ub=upper_bound, p_num=processor_num) - self.true_result, self.noise_result, self.error = src.algorithm.MaxSJA2.get_result() + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + processor_num = int(self.parameters["processor_num"]) + error_level = float(self.parameters["error_level"]) + upper_bound = float(self.parameters["upper_bound"]) + src.algorithm.MaxSJA2.processMaxSJA2( + self.input_result, + self.k, + e=epsilon, + b=beta, + error=error_level, + ub=upper_bound, + p_num=processor_num, + ) + self.true_result, self.noise_result, self.error = ( + src.algorithm.MaxSJA2.get_result() + ) class MultiMax(algorithm): @@ -167,11 +230,14 @@ def __init__(self, input_l, pks, fks, schema, parameters, dbsetting): self.num_query = None def rewrite(self, query, private_relations): + if self.rewrite_recursive_if_needed(query, private_relations): + return private_pk = get_primary_keys(self.pks, private_relations) root = parser.parse_sql(query) selectstmt = root[0].stmt if not isinstance(selectstmt, ast.SelectStmt): raise Exception + apply_unnest_subqueries(selectstmt) ImplicitJoin()(selectstmt) add_table_name(selectstmt, self.schema)(selectstmt) group_by()(selectstmt) @@ -180,7 +246,7 @@ def rewrite(self, query, private_relations): complete_query(self.fks)(selectstmt) userAdder(private_pk)(selectstmt) self.k = agg.index - self.rewrite_query = (stream.RawStream()(selectstmt)) + self.rewrite_query = stream.RawStream()(selectstmt) def get_input_result(self): self.input_result = pg_exec(self.dbsetting, self.rewrite_query) @@ -200,26 +266,38 @@ def get_input_result(self): self.error = 0 def process(self): - epsilon = float(self.parameters['epsilon']) - beta = float(self.parameters['beta']) - error_level = float(self.parameters['error_level']) - upper_bound = float(self.parameters['upper_bound']) - processor_num = int(self.parameters['processor_num']) - delta = float(self.parameters['delta']) + epsilon = float(self.parameters["epsilon"]) + beta = float(self.parameters["beta"]) + error_level = float(self.parameters["error_level"]) + upper_bound = float(self.parameters["upper_bound"]) + processor_num = int(self.parameters["processor_num"]) + delta = float(self.parameters["delta"]) # advanced composition beta = beta / self.num_query - epsilon = (math.sqrt(2 * self.num_query * math.log(1 / delta) + 4 * epsilon * self.num_query) - math.sqrt( - 2 * self.num_query * math.log(1 / delta))) / (2 * self.num_query) + epsilon = ( + math.sqrt( + 2 * self.num_query * math.log(1 / delta) + 4 * epsilon * self.num_query + ) + - math.sqrt(2 * self.num_query * math.log(1 / delta)) + ) / (2 * self.num_query) if self.input_l == 1: for g_id in self.input_final_result.keys(): group = self.group_ids[g_id] next_input = self.input_final_result[g_id] if self.k == 0: - self.k = len(next_input)-1 - src.algorithm.MaxSJA1.processMaxSJA1(next_input, self.k, e=epsilon, b=beta, - error=error_level, ub=upper_bound) - true_result_k, noise_result_k, error_k = src.algorithm.MaxSJA1.get_result() + self.k = len(next_input) - 1 + src.algorithm.MaxSJA1.processMaxSJA1( + next_input, + self.k, + e=epsilon, + b=beta, + error=error_level, + ub=upper_bound, + ) + true_result_k, noise_result_k, error_k = ( + src.algorithm.MaxSJA1.get_result() + ) self.true_result.append((true_result_k, group)) self.noise_result.append((noise_result_k, group)) self.error += error_k @@ -228,11 +306,19 @@ def process(self): group = self.group_ids[g_id] next_input = self.input_final_result[g_id] if self.k == 0: - self.k = len(next_input)-1 - src.algorithm.MaxSJA2.processMaxSJA2(next_input, self.k, e=epsilon, b=beta, - error=error_level, ub=upper_bound, - p_num=processor_num) - true_result_k, noise_result_k, error_k = src.algorithm.MaxSJA2.get_result() + self.k = len(next_input) - 1 + src.algorithm.MaxSJA2.processMaxSJA2( + next_input, + self.k, + e=epsilon, + b=beta, + error=error_level, + ub=upper_bound, + p_num=processor_num, + ) + true_result_k, noise_result_k, error_k = ( + src.algorithm.MaxSJA2.get_result() + ) self.true_result.append((true_result_k, group)) self.noise_result.append((noise_result_k, group)) self.error += error_k diff --git a/src/recursive.py b/src/recursive.py new file mode 100644 index 0000000..3971c7e --- /dev/null +++ b/src/recursive.py @@ -0,0 +1,346 @@ +"""Bounded recursive CTE support for DPSQL. + +The original DPSQL rewrite pipeline expects a flat SelectStmt. Recursive CTEs +(`WITH RECURSIVE ...`) break that assumption because table aliases inside the +recursive term are not visible from the outer query. This module handles the +common linear-recursive pattern separately by producing the row-level input that +DPSQL algorithms need directly: + + WITH RECURSIVE r AS (...) + SELECT count(*) FROM r + +becomes roughly: + + WITH RECURSIVE r(..., id0, ...) AS (... id columns ...) + SELECT 1, id0, ... FROM r + +The recursive term is also bounded with `--recursion-bound` when the query does +not already contain a tighter depth predicate. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Iterable, List, Optional, Sequence, Tuple + + +@dataclass(frozen=True) +class IdColumn: + table: str + alias: str + pk: str + name: str + + +def is_recursive_query(query: str) -> bool: + return bool(re.search(r"\bWITH\s+RECURSIVE\b", query, flags=re.IGNORECASE)) + + +def _strip_semicolon(query: str) -> str: + return query.strip().rstrip(";").strip() + + +def _split_cte(query: str) -> Optional[Tuple[str, List[str], str, str]]: + """Return (cte_name, column_names, cte_body, outer_select) for one CTE. + + This intentionally supports the project use case: a single recursive CTE + followed by a final SELECT. It is conservative and raises a helpful error + through the caller for unsupported shapes instead of silently producing + invalid SQL. + """ + q = _strip_semicolon(query) + m = re.match( + r"\s*WITH\s+RECURSIVE\s+(?P[A-Za-z_][\w]*)\s*(?P\([^)]*\))?\s+AS\s*\(", + q, + flags=re.IGNORECASE | re.DOTALL, + ) + if not m: + return None + + open_idx = m.end() - 1 + depth = 0 + close_idx = None + for i in range(open_idx, len(q)): + ch = q[i] + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + close_idx = i + break + if close_idx is None: + return None + + cols_raw = m.group("cols") + cols = [] + if cols_raw: + cols = [c.strip() for c in cols_raw[1:-1].split(",") if c.strip()] + + return m.group("name"), cols, q[open_idx + 1 : close_idx].strip(), q[close_idx + 1 :].strip() + + +def _split_union_all(cte_body: str) -> Tuple[str, str]: + # Split at top-level UNION ALL only. + depth = 0 + upper = cte_body.upper() + i = 0 + while i < len(cte_body): + ch = cte_body[i] + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + elif depth == 0 and upper.startswith("UNION ALL", i): + return cte_body[:i].strip(), cte_body[i + len("UNION ALL") :].strip() + i += 1 + raise ValueError("Recursive CTE must contain a top-level UNION ALL.") + + +def _pk_lookup(pks: Sequence[Sequence[str]], private_relations: str) -> dict: + relation_set = {r.strip() for r in private_relations.split(",") if r.strip()} + out = {} + for pk in pks: + table = str(pk[0]) + if table not in relation_set: + continue + definition = str(pk[2]) + left = definition.find("(") + right = definition.find(")") + if left != -1 and right != -1 and right > left: + out[table] = definition[left + 1 : right].strip() + return out + + +def _infer_pk_from_recursive_sql(table: str, anchor_sql: str, recursive_sql: str) -> Optional[str]: + """Best-effort fallback when PostgreSQL primary key metadata is absent. + + DPSQL needs a stable tuple identifier for each private tuple contribution. + Some demo graph tables do not declare a DB primary key, so primary_keys.txt + returns nothing. For bounded graph recursion, the edge tuple is usually + identified by the columns used as alias.column in the anchor/recursive terms, + e.g. graph_edges(src, dst). + """ + cols: List[str] = [] + combined = anchor_sql + "\n" + recursive_sql + pat = re.compile( + rf"\b(?:FROM|JOIN)\s+{re.escape(table)}(?:\s+(?:AS\s+)?([A-Za-z_]\w*))?", + flags=re.IGNORECASE, + ) + aliases = [] + for m in pat.finditer(combined): + alias = m.group(1) or table + if alias.upper() in {"WHERE", "JOIN", "INNER", "LEFT", "RIGHT", "FULL", "CROSS", "ON", "GROUP", "ORDER", "UNION"}: + alias = table + aliases.append(alias) + + for alias in aliases: + for cm in re.finditer(rf"\b{re.escape(alias)}\.([A-Za-z_]\w*)\b", combined): + col = cm.group(1) + if col not in cols: + cols.append(col) + + # Prefer the natural edge key when present. + lowered = {c.lower(): c for c in cols} + if "src" in lowered and "dst" in lowered: + return f"{lowered['src']}, {lowered['dst']}" + if cols: + return ", ".join(cols[:2]) + return None + + +def _fill_missing_pk_fallbacks( + pk_by_table: dict, + private_relations: str, + anchor_sql: str, + recursive_sql: str, +) -> dict: + """Fill missing private relation keys using recursive SQL as fallback.""" + out = dict(pk_by_table) + for table in [r.strip() for r in private_relations.split(",") if r.strip()]: + if table in out: + continue + inferred = _infer_pk_from_recursive_sql(table, anchor_sql, recursive_sql) + if inferred: + out[table] = inferred + return out + + +def _qualified_pk_expr(alias: str, pk: str) -> str: + """Return a SQL expression identifying a private tuple. + + Supports both single-column and composite primary keys: + - Single: _qualified_pk_expr("e", "id") → "e.id" + - Composite: _qualified_pk_expr("e", "src,dst") → "concat(e.src, ':', e.dst)" + + Composite keys are concatenated with ':' as separator to produce + a single stable tuple identifier for differential privacy. + """ + cols = [c.strip() for c in pk.split(",") if c.strip()] + if not cols: + raise ValueError("Primary key definition did not contain any columns.") + + # Single column: return qualified name directly + if len(cols) == 1: + return f"{alias}.{cols[0]}" + + # Composite key: concat with ':' separator + qualified = [f"{alias}.{col}" for col in cols] + parts = [qualified[0]] + for col in qualified[1:]: + parts.append("':'") # String literal ':' to be concatenated + parts.append(col) + return "concat(" + ", ".join(parts) + ")" + + +def _find_private_aliases(sql: str, pk_by_table: dict, cte_name: str, start_index: int = 0) -> List[IdColumn]: + found: List[IdColumn] = [] + idx = start_index + for table, pk in pk_by_table.items(): + pat = re.compile( + rf"\b(?:FROM|JOIN)\s+{re.escape(table)}(?:\s+(?:AS\s+)?([A-Za-z_]\w*))?", + flags=re.IGNORECASE, + ) + for m in pat.finditer(sql): + alias = m.group(1) or table + if alias.upper() in {"WHERE", "JOIN", "INNER", "LEFT", "RIGHT", "FULL", "CROSS", "ON", "GROUP", "ORDER", "UNION"}: + alias = table + if alias.lower() == cte_name.lower(): + continue + found.append(IdColumn(table=table, alias=alias, pk=pk, name=f"id{idx}")) + idx += 1 + return found + + +def _insert_targets(select_sql: str, targets: Iterable[str]) -> str: + targets = list(targets) + if not targets: + return select_sql + m = re.search(r"\bFROM\b", select_sql, flags=re.IGNORECASE) + if not m: + raise ValueError("Could not add DPSQL id columns: SELECT term has no FROM clause.") + return select_sql[: m.start()].rstrip() + ", " + ", ".join(targets) + " " + select_sql[m.start() :].lstrip() + + +def _depth_alias_and_column(recursive_sql: str, cte_cols: Sequence[str], cte_name: str) -> Tuple[Optional[str], Optional[str]]: + # Prefer an explicit alias in patterns such as r.depth + 1. + m = re.search(r"\b([A-Za-z_]\w*)\.([A-Za-z_]\w*)\s*\+\s*1\b", recursive_sql, flags=re.IGNORECASE) + if m and m.group(1).lower() != cte_name.lower(): + return m.group(1), m.group(2) + # Fall back to a common CTE column name. + for c in cte_cols: + if c.lower() in {"depth", "level", "hop", "hops"}: + return None, c + return None, None + + +def _apply_recursion_bound(recursive_sql: str, cte_cols: Sequence[str], cte_name: str, bound: int) -> str: + if bound < 1: + raise ValueError("recursion_bound must be >= 1.") + alias, depth_col = _depth_alias_and_column(recursive_sql, cte_cols, cte_name) + if not depth_col: + # The query may already be naturally bounded by another predicate. Keep it unchanged. + return recursive_sql + depth_ref = f"{alias}.{depth_col}" if alias else depth_col + predicate = f"{depth_ref} < {int(bound)}" + + # Always append the CLI bound. If the query already has a stricter bound, + # this is redundant; if it has a looser bound, this safely tightens it. + if re.search(r"\bWHERE\b", recursive_sql, flags=re.IGNORECASE): + return recursive_sql.rstrip() + f" AND {predicate}" + return recursive_sql.rstrip() + f" WHERE {predicate}" + + +def _outer_select_to_row_input(outer_select: str, cte_name: str, id_columns: Sequence[IdColumn]) -> str: + ids = ", ".join(c.name for c in id_columns) + suffix = f", {ids}" if ids else "" + # COUNT/SUM/MAX become row-level rows consumed by DPSQL algorithms. + m_count = re.match(rf"\s*SELECT\s+count\s*\(\s*\*\s*\)\s+FROM\s+{re.escape(cte_name)}\b.*", outer_select, flags=re.IGNORECASE | re.DOTALL) + if m_count: + return f"SELECT 1{suffix} FROM {cte_name}" + + m_sum = re.match(rf"\s*SELECT\s+sum\s*\((?P.*?)\)\s+FROM\s+{re.escape(cte_name)}\b.*", outer_select, flags=re.IGNORECASE | re.DOTALL) + if m_sum: + return f"SELECT {m_sum.group('expr')}{suffix} FROM {cte_name}" + + # Non-aggregate final select: append id columns so the DP layer still has user ids. + return _insert_targets(outer_select, [c.name for c in id_columns]) + + +def rewrite_bounded_recursive_query( + query: str, + private_relations: str, + pks: Sequence[Sequence[str]], + recursion_bound: int, +) -> Optional[str]: + """Rewrite a bounded linear recursive CTE into DPSQL row-level input. + + Returns None when the query is not recursive. Raises ValueError for a + recursive query shape that is not supported by this project extension. + """ + if not is_recursive_query(query): + return None + + split = _split_cte(query) + if split is None: + raise ValueError("Only single WITH RECURSIVE cte AS (...) SELECT ... queries are supported.") + + cte_name, cte_cols, cte_body, outer_select = split + anchor_sql, recursive_sql = _split_union_all(cte_body) + pk_by_table = _pk_lookup(pks, private_relations) + pk_by_table = _fill_missing_pk_fallbacks(pk_by_table, private_relations, anchor_sql, recursive_sql) + if not pk_by_table: + raise ValueError( + "No primary key was found or inferred for the private recursive relation. " + "Either declare a DB primary key or use edge columns such as src/dst in the recursive SQL." + ) + + anchor_aliases = _find_private_aliases(anchor_sql, pk_by_table, cte_name, 0) + recursive_aliases = _find_private_aliases(recursive_sql, pk_by_table, cte_name, 0) + if not anchor_aliases or not recursive_aliases: + raise ValueError("Could not find private table references in both anchor and recursive CTE terms.") + if len(anchor_aliases) != 1 or len(recursive_aliases) != 1: + raise ValueError("Bounded recursion currently supports one private table reference per anchor/recursive term.") + + anchor_edge = anchor_aliases[0] + recursive_edge = recursive_aliases[0] + recursive_cte_alias, depth_col = _depth_alias_and_column(recursive_sql, cte_cols, cte_name) + if not recursive_cte_alias or not depth_col: + raise ValueError("Could not infer recursive alias/depth column. Use a pattern such as r.depth + 1.") + + # One id column per possible hop. This preserves the full path contribution: + # depth 1 -> id0, depth 2 -> id0,id1, ..., depth B -> id0..id(B-1). + id_names = [f"id{i}" for i in range(int(recursion_bound))] + anchor_targets = [ + f"concat('id0', {_qualified_pk_expr(anchor_edge.alias, anchor_edge.pk)}) AS id0", + *[f"NULL AS {name}" for name in id_names[1:]], + ] + recursive_targets = [f"{recursive_cte_alias}.id0 AS id0"] + current_id_expr = f"concat('id', {recursive_cte_alias}.{depth_col}, {_qualified_pk_expr(recursive_edge.alias, recursive_edge.pk)})" + for i, name in enumerate(id_names[1:], start=1): + recursive_targets.append( + f"CASE WHEN {recursive_cte_alias}.{depth_col} = {i} " + f"THEN {current_id_expr} ELSE {recursive_cte_alias}.{name} END AS {name}" + ) + + bounded_recursive_sql = _apply_recursion_bound(recursive_sql, cte_cols, cte_name, recursion_bound) + anchor_sql = _insert_targets(anchor_sql, anchor_targets) + bounded_recursive_sql = _insert_targets(bounded_recursive_sql, recursive_targets) + + cte_column_suffix = "" + if cte_cols: + cte_column_suffix = "(" + ", ".join(list(cte_cols) + id_names) + ")" + + row_input_select = _outer_select_to_row_input( + outer_select, + cte_name, + [IdColumn(table=anchor_edge.table, alias=cte_name, pk="", name=name) for name in id_names], + ) + return ( + f"WITH RECURSIVE {cte_name}{cte_column_suffix} AS (\n" + f" {anchor_sql}\n" + f" UNION ALL\n" + f" {bounded_recursive_sql}\n" + f")\n{row_input_select}" + ) diff --git a/src/util.py b/src/util.py index a4a7e4b..ceca915 100644 --- a/src/util.py +++ b/src/util.py @@ -7,37 +7,26 @@ # test the connection to pgsql def pg_test(dbsetting): - """ Connect to the PostgreSQL database server """ + """Connect to the PostgreSQL database server and return True/False.""" + conn = None try: - # read connection parameters params = dbsetting - print("testing database connection ") - # connect to the PostgreSQL server conn = psql.connect(**params) - - # create a cursor cur = conn.cursor() - - # execute a statement cur.execute("select 1") - # get the result - res = cur.fetchall() - - - # close the communication with the PostgreSQL + cur.fetchall() cur.close() print("connection ok ") + return True except (Exception, psql.DatabaseError) as error: print(error) - return False finally: if conn is not None: conn.close() - return True # read the configuration file diff --git a/test_exists.txt b/test_exists.txt new file mode 100644 index 0000000..96fff99 --- /dev/null +++ b/test_exists.txt @@ -0,0 +1,8 @@ +SELECT sum(l_extendedprice) +FROM lineitem +WHERE EXISTS ( + SELECT 1 + FROM orders + WHERE o_orderkey = l_orderkey + AND o_custkey = 123 +); diff --git a/test_nested.txt b/test_nested.txt new file mode 100644 index 0000000..b08dd86 --- /dev/null +++ b/test_nested.txt @@ -0,0 +1,7 @@ +SELECT sum(l_extendedprice) +FROM lineitem +WHERE l_orderkey IN ( + SELECT o_orderkey + FROM orders + WHERE o_custkey > 1 AND o_custkey < 200 +); diff --git a/test_rec.txt b/test_rec.txt new file mode 100644 index 0000000..05ef0e0 --- /dev/null +++ b/test_rec.txt @@ -0,0 +1,12 @@ +WITH RECURSIVE order_path AS ( + SELECT o_orderkey, o_custkey, 1 AS depth + FROM orders + WHERE o_orderkey > 1 AND o_orderkey < 500 + + UNION ALL + + SELECT orders.o_orderkey, orders.o_custkey, r.depth + 1 + FROM orders, order_path r + WHERE orders.o_custkey = r.o_custkey +) +SELECT COUNT(*) FROM order_path; \ No newline at end of file